Matplotlib同时有几个数字矩阵

时间:2018-06-05 16:09:58

标签: python matplotlib plot

我正在开发一个带matplotlib python的代码,以显示两个轴之间的关系,我想绘制另一个图中行的总和,紧挨着第一个。主要问题是我想在另一个矩阵中做,相对于第一个矩阵的距离最小。我附上了绘图解决方案的全部代码:

import matplotlib.pyplot as plt
from matplotlib import colors
import numpy as np

matr=[[1,0,0],[0,0,1],[1,1,0],[0,0,1]]

def plot_tools(matrix=matr,sigma=[1,2,3,4],m=[1,2,3],name='a'):

    #matrix=matrix with numbers
    #sigma= values for the y axis
    #m=values for the x axis
    #name=name for the image

    W=np.array(matrix)    
    id_matrix=W
    id_labels=m #nombre para el eje x
    fig, ax = plt.subplots()
    cmap = colors.ListedColormap(['lavender','purple'])
    mat = ax.imshow(id_matrix, interpolation='nearest',cmap=cmap)
    plt.suptitle('Plot:')
    plt.yticks(range(id_matrix.shape[0]), sigma) #label for y axis
    plt.xticks(range(id_matrix.shape[1]), id_labels) #label for x axis
    ax.xaxis.tick_top()
    plt.xticks(rotation=0)
    plt.ylabel('Y axis',fontsize=13)
    plt.xlabel('X axis',fontsize=13)

    major_ticks = np.arange(0, len(sigma), 1)


    ax.set_yticks(major_ticks)
    ax.set_yticks(major_ticks, minor=True)
    temp=0
    for x in xrange(id_matrix.shape[0]):
        for y in xrange(id_matrix.shape[1]):
            if id_matrix[x, y]==1:
                temp+=1
                ax.annotate(str(temp), xy=(y, x),horizontalalignment='center', verticalalignment='center')
    plt.savefig('Images/' + str(name) + '.png')
    plt.show()

我想达到以下结果: This question

1 个答案:

答案 0 :(得分:0)

最简单的方法之一是使用subplots,在两个图中共享Y轴,如this example所示。

或多或少是这样的:

import matplotlib.pyplot as plt
from matplotlib import colors
import numpy as np

matr=[[1,0,0],[0,0,1],[1,1,0],[0,0,1]]

def plot_tools(matrix=matr,sigma=[1,2,3,4],m=[1,2,3],name='a'):

    #matrix=matrix with numbers
    #sigma= values for the y axis
    #m=values for the x axis
    #name=name for the image

    W=np.array(matrix)    
    id_matrix=W
    id_labels=m #nombre para el eje x
    f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
    cmap = colors.ListedColormap(['lavender','purple'])
    mat = ax1.imshow(id_matrix, interpolation='nearest',cmap=cmap)
    plt.suptitle('Plot:')
    plt.yticks(range(id_matrix.shape[0]), sigma) #label for y axis
    plt.xticks(range(id_matrix.shape[1]), id_labels) #label for x axis
    ax1.xaxis.tick_top()
    plt.xticks(rotation=0)
    plt.ylabel('Y axis',fontsize=13)
    plt.xlabel('X axis',fontsize=13)

    major_ticks = np.arange(0, len(sigma), 1)

    ax2.set_yticks(major_ticks)
    ax2.set_yticks(major_ticks, minor=True)
    temp=0
    for x in xrange(id_matrix.shape[0]):
        for y in xrange(id_matrix.shape[1]):
            if id_matrix[x, y]==1:
                temp+=1
                ax2.annotate(str(temp), xy=(y, x),horizontalalignment='center', verticalalignment='center')
    #plt.savefig('Images/' + str(name) + '.png')
    plt.show()