Matplotlib - 基于数据的动态(条形图)高度?

时间:2017-02-12 18:42:39

标签: python matplotlib

在与matplotlib挣扎的时间超过我之后,我想通过尝试在我曾经使用过的任何其他绘图库中做一些轻而易举的事情,我决定向Stackiverse询问有一些见解。简而言之,我需要的是创建多个水平条形图,所有水平条形图共享x轴,y轴上的值不同,所有条形高度相同,而图表本身将调整为条目。我需要绘制的简化数据结构如下:

[
    {"name": "Category 1", "entries": [
        {"name": "Entry 1", "value": 5},
        {"name": "Entry 2", "value": 2},
    ]},
    {"name": "Category 2", "entries": [
        {"name": "Entry 1", "value": 1},
    ]},
    {"name": "Category 3", "entries": [
        {"name": "Entry 1", "value": 1},
        {"name": "Entry 2", "value": 10},
        {"name": "Entry 3", "value": 4},
    ]},    
]

最后我得到的结果是:

import matplotlib.pyplot as plt

def plot_data(data):
    total_categories = len(data)  # holds how many charts to create
    max_values = 1  # holds the maximum number of bars to create
    for category in data:
        max_values = max(max_values, len(category["entries"]))
    fig = plt.figure(1)
    ax = None
    for index, category in enumerate(data):
        entries = []
        values = []
        for entry in category["entries"]:
            entries.append(entry["name"])
            values.append(entry["value"])
        if not entries:
            continue  # do not create empty charts
        y_ticks = range(1, len(entries) + 1)
        ax = fig.add_subplot(total_categories, 1, index + 1, sharex=ax)
        ax.barh(y_ticks, values)
        ax.set_ylim(0, max_values + 1)  # limit the y axis for fixed height
        ax.set_yticks(y_ticks)
        ax.set_yticklabels(entries)
        ax.invert_yaxis()
        ax.set_title(category["name"], loc="left")
    fig.tight_layout()

由于y限制(set_ylim())设置为数据中的最大条数,因此无论有多少条目具有特定类别,这将使条形高度保持相同(至少在整个图中)。但是,它也会在具有少于最大条目数的类别中留下令人讨厌的空白。或者从视觉角度看所有内容,我试图从实际预期

IMG LINK

我已经尝试通过gridspec和不同的尺度去除间隙,这取决于条目的数量,但最终看起来更加“倾斜”了。并且不一致。我尝试生成多个图表并操纵图形大小,然后在后处理中将它们拼接在一起,但我无法找到一种方法来确保条形高度保持不变,无论数据如何。我确定有一种方法可以从matplotlib中的一些不起眼的对象中提取所需的精确缩放指标,但此时我担心如果我尝试的话,我会继续进行另一次野鹅追逐追踪整个绘图程序。

此外,如果可以围绕数据折叠各个子图,我如何根据数据使图形增长?例如,如果我要在上面的数据中添加第四个类别,而不是让数字增长'在另一个图表的高度,它实际上会缩小所有图表以适应默认数字大小的所有内容。现在,我想我理解matplotlib背后的逻辑与轴单位和所有这些,我知道我可以设置数字大小来增加整体高度,但我不知道如何在图表中保持一致,即如何无论数据如何,条形高度都完全相同?

我是否真的需要手动绘制所有内容以获得我想要的内容?如果是这样,我可能只是转储整个matplotlib包并从头开始创建自己的SVG。事后看来,考虑到我花在这上面的时间,我可能应该首先做到这一点,但现在我太顽固不能放弃(或者我是受害者)可怕的沉没成本谬误)。

有什么想法吗?

由于

1 个答案:

答案 0 :(得分:2)

我认为同时具有相同条宽(垂直方向宽度)和不同子图的唯一方法是手动将轴定位在图中。

为此,您可以指定条形的大小(以英寸为单位)以及在此条形宽度单位的子图之间要有多少间距。从这些数字以及要绘制的数据量,您可以计算以英寸为单位的总图形高度。 然后根据数据量和先前子图中的间距定位每个子图(通过fig.add_axes)。因此,你很好地填补了情节。 添加一组新数据将使数字变大。

data = [
    {"name": "Category 1", "entries": [
        {"name": "Entry 1", "value": 5},
        {"name": "Entry 2", "value": 2},
    ]},
    {"name": "Category 2", "entries": [
        {"name": "Entry 1", "value": 1},
    ]},
    {"name": "Category 3", "entries": [
        {"name": "Entry 1", "value": 1},
        {"name": "Entry 2", "value": 10},
        {"name": "Entry 3", "value": 4},
    ]}, 
    {"name": "Category 4", "entries": [
        {"name": "Entry 1", "value": 6},
    ]},
]

import matplotlib.pyplot as plt
import numpy as np

def plot_data(data,
              barwidth = 0.2, # inch per bar
              spacing = 3,    # spacing between subplots in units of barwidth
              figx = 5,       # figure width in inch
              left = 4,       # left margin in units of bar width
              right=2):       # right margin in units of bar width

    tc = len(data) # "total_categories", holds how many charts to create
    max_values = []  # holds the maximum number of bars to create
    for category in data:
        max_values.append( len(category["entries"]))
    max_values = np.array(max_values)

    # total figure height:
    figy = ((np.sum(max_values)+tc) + (tc+1)*spacing)*barwidth #inch

    fig = plt.figure(figsize=(figx,figy))
    ax = None
    for index, category in enumerate(data):
        entries = []
        values = []
        for entry in category["entries"]:
            entries.append(entry["name"])
            values.append(entry["value"])
        if not entries:
            continue  # do not create empty charts
        y_ticks = range(1, len(entries) + 1)
        # coordinates of new axes [left, bottom, width, height]
        coord = [left*barwidth/figx, 
                 1-barwidth*((index+1)*spacing+np.sum(max_values[:index+1])+index+1)/figy,  
                 1-(left+right)*barwidth/figx,  
                 (max_values[index]+1)*barwidth/figy ] 

        ax = fig.add_axes(coord, sharex=ax)
        ax.barh(y_ticks, values)
        ax.set_ylim(0, max_values[index] + 1)  # limit the y axis for fixed height
        ax.set_yticks(y_ticks)
        ax.set_yticklabels(entries)
        ax.invert_yaxis()
        ax.set_title(category["name"], loc="left")


plot_data(data)
plt.savefig(__file__+".png")
plt.show()

enter image description here

相关问题