使用matplotlib堆积条形图

时间:2017-06-01 13:55:14

标签: python matplotlib plot bar-chart

我正在使用matplotlib生成条形图,看起来堆积条形图有一个错误。每个垂直堆栈的总和应为100.但是,对于X-AXIS刻度65,70,75和80,我们得到完全任意的结果,这没有任何意义。我不明白是什么问题。请在下面找到MWE。

import numpy as np
import matplotlib.pyplot as plt
import matplotlib
header = ['a','b','c','d']
dataset= [('60.0', '65.0', '70.0', '75.0', '80.0', '85.0', '90.0', '95.0', '100.0', '105.0', '110.0', '115.0', '120.0', '125.0', '130.0', '135.0', '140.0', '145.0', '150.0', '155.0', '160.0', '165.0', '170.0', '175.0', '180.0', '185.0', '190.0', '195.0', '200.0'), (0.0, 25.0, 48.93617021276596, 83.01886792452831, 66.66666666666666, 66.66666666666666, 70.96774193548387, 84.61538461538461, 93.33333333333333, 85.0, 92.85714285714286, 93.75, 95.0, 100.0, 100.0, 100.0, 100.0, 80.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0), (0.0, 50.0, 36.17021276595745, 11.320754716981133, 26.666666666666668, 33.33333333333333, 29.03225806451613, 15.384615384615385, 6.666666666666667, 15.0, 7.142857142857142, 6.25, 5.0, 0.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), (0.0, 12.5, 10.638297872340425, 3.7735849056603774, 4.444444444444445, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), (100.0, 12.5, 4.25531914893617, 1.8867924528301887, 2.2222222222222223, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)]
X_AXIS = dataset[0]

matplotlib.rc('font', serif='Helvetica Neue')
matplotlib.rc('text', usetex='false')
matplotlib.rcParams.update({'font.size': 40})

fig = matplotlib.pyplot.gcf()
fig.set_size_inches(18.5, 10.5)

configs = dataset[0]
N = len(configs)
ind = np.arange(N)
width = 0.4

p1 = plt.bar(ind, dataset[1], width, color='r')
p2 = plt.bar(ind, dataset[2], width, bottom=dataset[1], color='b')
p3 = plt.bar(ind, dataset[3], width, bottom=dataset[2], color='g')
p4 = plt.bar(ind, dataset[4], width, bottom=dataset[3], color='c')

plt.ylim([0,120])
plt.yticks(fontsize=12)
plt.ylabel(output, fontsize=12)
plt.xticks(ind, X_AXIS, fontsize=12, rotation=90)
plt.xlabel('test', fontsize=12)
plt.legend((p1[0], p2[0], p3[0], p4[0]), (header[0], header[1], header[2], header[3]), fontsize=12, ncol=4, framealpha=0, fancybox=True)
plt.show()

enter image description here

5 个答案:

答案 0 :(得分:24)

您需要每个数据集的bottom为之前所有数据集的总和。您可能还需要将数据集转换为numpy数组以将它们添加到一起。

p1 = plt.bar(ind, dataset[1], width, color='r')
p2 = plt.bar(ind, dataset[2], width, bottom=dataset[1], color='b')
p3 = plt.bar(ind, dataset[3], width, 
             bottom=np.array(dataset[1])+np.array(dataset[2]), color='g')
p4 = plt.bar(ind, dataset[4], width,
             bottom=np.array(dataset[1])+np.array(dataset[2])+np.array(dataset[3]),
             color='c')

enter image description here

或者,您可以在开始绘图之前将它们转换为numpy数组。

dataset1 = np.array(dataset[1])
dataset2 = np.array(dataset[2])
dataset3 = np.array(dataset[3])
dataset4 = np.array(dataset[4])

p1 = plt.bar(ind, dataset1, width, color='r')
p2 = plt.bar(ind, dataset2, width, bottom=dataset1, color='b')
p3 = plt.bar(ind, dataset3, width, bottom=dataset1+dataset2, color='g')
p4 = plt.bar(ind, dataset4, width, bottom=dataset1+dataset2+dataset3,
             color='c')

或者最后如果你想避免转换为numpy数组,你可以使用列表理解:

p1 = plt.bar(ind, dataset[1], width, color='r')
p2 = plt.bar(ind, dataset[2], width, bottom=dataset[1], color='b')
p3 = plt.bar(ind, dataset[3], width,
             bottom=[sum(x) for x in zip(dataset[1],dataset[2])], color='g')
p4 = plt.bar(ind, dataset[4], width,
             bottom=[sum(x) for x in zip(dataset[1],dataset[2],dataset[3])],
             color='c')

答案 1 :(得分:17)

我发现这样的痛苦,我写了一个功能来做到这一点。我分享它是希望其他人觉得有用:

import numpy as np
import matplotlib.pyplot as plt

def stacked_bar(data, series_labels, category_labels=None, 
                show_values=False, value_format="{}", y_label=None, 
                grid=True, reverse=False):
    """Plots a stacked bar chart with the data and labels provided.

    Keyword arguments:
    data            -- 2-dimensional numpy array or nested list
                       containing data for each series in rows
    series_labels   -- list of series labels (these appear in
                       the legend)
    category_labels -- list of category labels (these appear
                       on the x-axis)
    show_values     -- If True then numeric value labels will 
                       be shown on each bar
    value_format    -- Format string for numeric value labels
                       (default is "{}")
    y_label         -- Label for y-axis (str)
    grid            -- If True display grid
    reverse         -- If True reverse the order that the
                       series are displayed (left-to-right
                       or right-to-left)
    """

    ny = len(data[0])
    ind = list(range(ny))

    axes = []
    cum_size = np.zeros(ny)

    data = np.array(data)

    if reverse:
        data = np.flip(data, axis=1)
        category_labels = reversed(category_labels)

    for i, row_data in enumerate(data):
        axes.append(plt.bar(ind, row_data, bottom=cum_size, 
                            label=series_labels[i]))
        cum_size += row_data

    if category_labels:
        plt.xticks(ind, category_labels)

    if y_label:
        plt.ylabel(y_label)

    plt.legend()

    if grid:
        plt.grid()

    if show_values:
        for axis in axes:
            for bar in axis:
                w, h = bar.get_width(), bar.get_height()
                plt.text(bar.get_x() + w/2, bar.get_y() + h/2, 
                         value_format.format(h), ha="center", 
                         va="center")

示例:

plt.figure(figsize=(6, 4))

series_labels = ['Series 1', 'Series 2']

data = [
    [0.2, 0.3, 0.35, 0.3],
    [0.8, 0.7, 0.6, 0.5]
]

category_labels = ['Cat A', 'Cat B', 'Cat C', 'Cat D']

stacked_bar(
    data, 
    series_labels, 
    category_labels=category_labels, 
    show_values=True, 
    value_format="{:.1f}",
    y_label="Quantity (units)"
)

plt.savefig('bar.png')
plt.show()

enter image description here

答案 2 :(得分:13)

如果您愿意使用Pandas,这可能是您最方便的解决方案:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

X_AXIS = (
    '60.0', '65.0', '70.0', '75.0', '80.0', '85.0', '90.0', '95.0', '100.0', '105.0', 
    '110.0', '115.0', '120.0', '125.0', '130.0', '135.0', '140.0', '145.0', '150.0', 
    '155.0', '160.0', '165.0', '170.0', '175.0', '180.0', '185.0', '190.0', '195.0', 
    '200.0'
)

index = pd.Index(X_AXIS, name='test')

data = {
    'a': (0.0, 25.0, 48.93617021276596, 83.01886792452831, 66.66666666666666, 66.66666666666666, 70.96774193548387, 84.61538461538461, 93.33333333333333, 85.0, 92.85714285714286, 93.75, 95.0, 100.0, 100.0, 100.0, 100.0, 80.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0),
    'b': (0.0, 50.0, 36.17021276595745, 11.320754716981133, 26.666666666666668, 33.33333333333333, 29.03225806451613, 15.384615384615385, 6.666666666666667, 15.0, 7.142857142857142, 6.25, 5.0, 0.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0),
    'c': (0.0, 12.5, 10.638297872340425, 3.7735849056603774, 4.444444444444445, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0),
    'd': (100.0, 12.5, 4.25531914893617, 1.8867924528301887, 2.2222222222222223, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
}

df = pd.DataFrame(data, index=index)
ax = df.plot(kind='bar', stacked=True, figsize=(18.5, 10.5))
ax.set_ylabel('foo')
plt.savefig('stacked.png')
plt.show()

输出:

enter image description here

答案 3 :(得分:0)

如果您对有序堆叠(底部最长的条形)感兴趣,请按照以下步骤操作:

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
a = pd.DataFrame({'a':[0.25, 0.5, 0.15, 0], 'b':[0.15, 0.25, 0.35, 0.15], 
                  'c':[0.50, 0.15, 0.5, 0.35], 'd':[0.35, 0.35, 0.25, 0.5],})

#       a     b     c     d
# 0  0.25  0.15  0.50  0.35
# 1  0.50  0.25  0.15  0.35
# 2  0.15  0.35  0.50  0.25
# 3  0.00  0.15  0.35  0.50

fig, ax = plt.subplots()
x = a.index
indexes = np.argsort(a.values).T
heights = np.sort(a.values).T
order = -1
bottoms = heights[::order].cumsum(axis=0)
bottoms = np.insert(bottoms, 0, np.zeros(len(bottoms[0])), axis=0)
mpp_colors = dict(zip(a.columns, plt.rcParams['axes.prop_cycle'].by_key()['color']))
for btms, (idxs, vals) in enumerate(list(zip(indexes, heights))[::order]):
    mps = np.take(np.array(a.columns), idxs)
    ax.bar(x, height=vals, bottom=bottoms[btms], color=[mpp_colors[m] for m in mps])
ax.set_ylim(bottom=0, top=2)
plt.legend((np.take(np.array(a.columns), np.argsort(a.values)[0]))[::order], loc='upper right')

enter image description here

答案 4 :(得分:0)

这是一个带有类似 seaborn 的 API 的解决方案。您可以找到示例用法 here

def stackedbarplot(data, stack_order=None, palette=None, **barplot_kws):
    """
    Create a stacked barplot
    Inputs:
    | data <pd.DataFrame>: A wideform dataframe where the index is the variable to stack, the columns are different samples (x-axis), and the cells the counts (y-axis)
    | stack_order <array-like>: The order for bars to be stacked (Default: given order)
    | palette <array-like>: The colors to use for each value of `stack_order` (Default: husl)
    | barplot_kws: Arguments to pass to sns.barplot()
    
    Author: Michael Silverstein
    Usage: https://github.com/michaelsilverstein/Pandas-and-Plotting/blob/master/lessons/stacked_bar_chart.ipynb
    """
    # Order df
    if stack_order is None:
        stack_order = data.index
    # Create palette if none
    if palette is None:
        palette = dict(zip(stack_order, sns.husl_palette(len(stack_order))))
    # Compute cumsum
    cumsum = data.loc[stack_order].cumsum()
    # Melt for passing to seaborn
    cumsum_stacked = cumsum.stack().reset_index(name='count')
    # Get name of variable to stack and sample
    stack_name, sample_name = cumsum_stacked.columns[:2]
    
    # Plot bar plot
    for s in stack_order[::-1]:
        # Subset to this stack level
        d = cumsum_stacked[cumsum_stacked[stack_name].eq(s)]
        sns.barplot(x=sample_name, y='count', hue=stack_name, palette=palette, data=d, **barplot_kws)
    return plt.gca()