如何在matplotlib中的tight_layout下调整独立颜色条的大小(高度和宽度)

时间:2019-01-31 06:24:29

标签: python-3.x matplotlib

我正在尝试在单个图形中的一行中绘制两个3D线框图和一个色条。为了在线框上应用颜色渐变,我遵循了此Stackoverflow Answer。因此,我不得不生成独立的颜色条。由于我使用的是tight_layout,因此彩条占用了整个高度,并且显示的宽度很大。

到目前为止,我找不到控制颜色条大小的任何解决方案。我尝试更改网格的width_ratio,但宽度保持不变。似乎无法在tight layout中调整高度。

我的绘图代码如下。在这方面的任何帮助,我将不胜感激。

def plot_signature(hh, hv, vh, vv,  wireframe=False):

    x_c, y_c, z_c = synthesize(hh=hh, hv=hv, vh=vh, vv=vv, channel=False)
    x_x, y_x, z_x = synthesize(hh=hh, hv=hv, vh=vh, vv=vv, channel=True)

    xticks = np.linspace(0, 180, 13)
    yticks = np.linspace(-45, 45, 7)
    zticks = np.linspace(0, 1, 11)

    xt_labels = np.core.defchararray.add(xticks.astype(int).astype(str), u"\u00b0")
    yt_labels = np.core.defchararray.add(yticks.astype(int).astype(str), u"\u00b0")

    plt.ion()

    fig = plt.figure(tight_layout=True)
    grid = fig.add_gridspec(nrows=1, ncols=3, width_ratios=[1, 10, 10])
    ax_cbar = fig.add_subplot(grid[0, 0])
    ax_cpol = fig.add_subplot(grid[0, 1], projection='3d')
    ax_xpol = fig.add_subplot(grid[0, 2], projection='3d')

    ax_cpol.set_xlim([0, 180])
    ax_cpol.set_ylim([-45, 45])
    ax_cpol.set_zlim([0, 1])

    ax_xpol.set_xlim([0, 180])
    ax_xpol.set_ylim([-45, 45])
    ax_xpol.set_zlim([0, 1])

    cfont = {'fontname':'CMU Serif'}
    color_map = cm.rainbow

    if wireframe:
        norm_c = plt.Normalize(z_c.min(), z_c.max())
        norm_x = plt.Normalize(z_x.min(), z_x.max())
        colors_c = color_map(norm_c(z_c))
        colors_x = color_map(norm_c(z_x))
        rcount_c, ccount_c, _ = colors_c.shape
        rcount_x, ccount_x, _ = colors_x.shape
        pfig_c = ax_cpol.plot_surface(x_c, y_c, z_c, rcount=rcount_c, ccount=ccount_c, facecolors=colors_c, shade=False)
        pfig_x = ax_xpol.plot_surface(x_x, y_x, z_x, rcount=rcount_x, ccount=ccount_x, facecolors=colors_x, shade=False)
        pfig_c.set_facecolor((0, 0, 0, 0))
        pfig_x.set_facecolor((0, 0, 0, 0))
    else:
        ax_cpol.plot_surface(x_c, y_c, z_c, cmap=color_map, linewidth=0, antialiased=True)
        ax_xpol.plot_surface(x_x, y_x, z_x, cmap=color_map, linewidth=0, antialiased=True)

    ax_cpol.set_xticks(xticks)
    ax_xpol.set_xticks(xticks)
    ax_cpol.set_xticklabels(xt_labels)
    ax_xpol.set_xticklabels(xt_labels)
    ax_cpol.set_yticks(yticks)
    ax_xpol.set_yticks(yticks)
    ax_cpol.set_yticklabels(yt_labels)
    ax_xpol.set_yticklabels(yt_labels)
    ax_cpol.set_xlabel('Orientation Angle ($\psi$)', labelpad=16, fontsize=16, **cfont)
    ax_xpol.set_xlabel('Orientation Angle ($\psi$)', labelpad=16, fontsize=16, **cfont)
    ax_cpol.set_ylabel('Ellipticity Angle ($\chi$)', labelpad=16, fontsize=16, **cfont)
    ax_xpol.set_ylabel('Ellipticity Angle ($\chi$)', labelpad=16, fontsize=16, **cfont)
    ax_cpol.set_zlabel("Relative Intensity", labelpad=16, fontsize=16, **cfont)
    ax_xpol.set_zlabel("Relative Intensity", labelpad=16, fontsize=16, **cfont)
    ax_cpol.set_zticks(zticks)
    ax_xpol.set_zticks(zticks)
    ax_cpol.set_title("Co-pol Signature", fontsize=20, pad=20, **cfont)
    ax_xpol.set_title("Cross-pol Signature", fontsize=20, pad=20, **cfont)

    cb_norm = mpl.colors.Normalize(vmin=0,vmax=1)
    cb_ticks = np.linspace(0.0, 1.0, 11)

    cb = mpl.colorbar.ColorbarBase(
        ax_cbar,cmap=color_map,
        norm=cb_norm,ticks=cb_ticks,
        label='Color Map',
        orientation='vertical'
    )

    cb_label = cb.ax.yaxis.label
    cb.ax.yaxis.labelpad = 10
    custom_font = mpl.font_manager.FontProperties(family='CMU Serif', size=12)
    cb_label.set_font_properties(custom_font)
    plt.show()

当前情节如下:

Current Plot

但是,我希望它类似于以下内容:

Expected Plot

2 个答案:

答案 0 :(得分:1)

当前您使用的是1 x 3网格。

enter image description here

考虑使用3 x 3网格,其中3D图覆盖所有三行,而颜色条仅位于第二行。

enter image description here

请注意,颜色条当然应该使用与曲面图相同的归一化。表面图应使用相同的归一化。简短:应该进行一次归一化,否则绘图会传达错误的信息。

最后,这里不需要使用ColorbarBase;更好地使用

sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])  # still needed for matplotlib <= 3.0
fig.colorbar(sm, cax=cax) 

这可能允许使用不同的选项来缩小颜色栏。

理想情况下,答案会把所有这些都变成代码,但是问题中没有可运行的代码可以轻松实现。

答案 1 :(得分:0)

ImportanceOfBeingErnes已经发布了正确的工作解决方案。我将展示他的答案的实现。请注意,这次我用颜色图实现了线框图,有点不同(Source)。

def plot_signature(hh, hv, vh, vv, wireframe=False):
    x_c, y_c, z_c = synthesize(
        hh=hh,
        hv=hv,
        vh=vh,
        vv=vv,
        channel=False
    )

    x_x, y_x, z_x = synthesize(
        hh=hh,
        hv=hv,
        vh=vh,
        vv=vv,
        channel=True
    )

    xticks = np.linspace(0, 180, 13)
    yticks = np.linspace(-45, 45, 7)
    zticks = np.linspace(0, 1, 11)

    xt_labels = np.core.defchararray.add(
        xticks.astype(int).astype(str),
        u"\u00b0"
    )

    yt_labels = np.core.defchararray.add(
        yticks.astype(int).astype(str),
        u"\u00b0"
    )

    plt.ion()
    fig = plt.figure(num='Polarimetric Signatures', tight_layout=True)

    gs_root = mpl.gridspec.GridSpec(
        nrows=1,
        ncols=3,
        width_ratios=[1, 10, 10]
    )

    cb_gs = mpl.gridspec.GridSpecFromSubplotSpec(
        nrows=3,
        ncols=3,
        height_ratios=[2, 5, 2],
        width_ratios=[1.1, 0.8, 1.1],
        subplot_spec=gs_root[0]
    )

    ax_cbar = fig.add_subplot(cb_gs[1, 1])
    ax_cpol = fig.add_subplot(gs_root[0, 1], projection='3d')
    ax_xpol = fig.add_subplot(gs_root[0, 2], projection='3d')

    ax_cpol.set_xlim([0, 180])
    ax_cpol.set_ylim([-45, 45])
    ax_cpol.set_zlim([0, 1])

    ax_xpol.set_xlim([0, 180])
    ax_xpol.set_ylim([-45, 45])
    ax_xpol.set_zlim([0, 1])

    cfont = {'fontname': 'CMU Serif'}
    color_map = cm.rainbow

    if wireframe:

        wire_c = ax_cpol.plot_wireframe(
            x_c,
            y_c,
            z_c,
            rstride=1,
            cstride=2
        )
        nx_c, ny_c, _ = np.shape(wire_c._segments3d)
        wire_c_x = np.array(wire_c._segments3d)[:, :, 0].ravel()
        wire_c_y = np.array(wire_c._segments3d)[:, :, 1].ravel()
        wire_c_z = np.array(wire_c._segments3d)[:, :, 2].ravel()
        wire_c.remove()
        wire_c_x1 = np.vstack([wire_c_x, np.roll(wire_c_x, 1)])
        wire_c_y1 = np.vstack([wire_c_y, np.roll(wire_c_y, 1)])
        wire_c_z1 = np.vstack([wire_c_z, np.roll(wire_c_z, 1)])
        to_delete = np.arange(0, (nx_c * ny_c), ny_c)
        wire_c_x1 = np.delete(wire_c_x1, to_delete, axis=1)
        wire_c_y1 = np.delete(wire_c_y1, to_delete, axis=1)
        wire_c_z1 = np.delete(wire_c_z1, to_delete, axis=1)
        scalars_c = np.delete(wire_c_z, to_delete)
        segs_c = [
            list(zip(x_c, y_c, z_c))
            for x_c, y_c, z_c in zip(wire_c_x1.T, wire_c_y1.T, wire_c_z1.T)
        ]
        my_wire_c = art3d.Line3DCollection(segs_c, cmap=color_map)
        my_wire_c.set_array(scalars_c)
        ax_cpol.add_collection(my_wire_c)

        wire_x = ax_xpol.plot_wireframe(
            x_x,
            y_x,
            z_x,
            rstride=1,
            cstride=2
        )
        nx_x, ny_x, _ = np.shape(wire_x._segments3d)
        wire_x_x = np.array(wire_x._segments3d)[:, :, 0].ravel()
        wire_x_y = np.array(wire_x._segments3d)[:, :, 1].ravel()
        wire_x_z = np.array(wire_x._segments3d)[:, :, 2].ravel()
        wire_x.remove()
        wire_x_x1 = np.vstack([wire_x_x, np.roll(wire_x_x, 1)])
        wire_x_y1 = np.vstack([wire_x_y, np.roll(wire_x_y, 1)])
        wire_x_z1 = np.vstack([wire_x_z, np.roll(wire_x_z, 1)])
        to_delete = np.arange(0, (nx_x * ny_x), ny_x)
        wire_x_x1 = np.delete(wire_x_x1, to_delete, axis=1)
        wire_x_y1 = np.delete(wire_x_y1, to_delete, axis=1)
        wire_x_z1 = np.delete(wire_x_z1, to_delete, axis=1)
        scalars_x = np.delete(wire_x_z, to_delete)
        segs_x = [
            list(zip(x_x, y_x, z_x))
            for x_x, y_x, z_x in zip(wire_x_x1.T, wire_x_y1.T, wire_x_z1.T)
        ]
        my_wire_x = art3d.Line3DCollection(segs_x, cmap=color_map)
        my_wire_x.set_array(scalars_x)
        ax_xpol.add_collection(my_wire_x)

        cb_norm = mpl.colors.Normalize(vmin=0, vmax=1)
        cb_ticks = np.linspace(0.0, 1.0, 11)
        cb = plt.colorbar(
            my_wire_c,
            ax_cbar, cmap=color_map,
            norm=cb_norm, ticks=cb_ticks,
            orientation='vertical',
            label="Color Map",
            ticklocation='left'
        )
        cb_label = cb.ax.yaxis.label
        cb.ax.yaxis.labelpad = 10
        custom_font = mpl.font_manager.FontProperties(
            family='CMU Serif',
            size=12
        )
        cb_label.set_font_properties(custom_font)

    else:
        ax_cpol.plot_surface(
            x_c,
            y_c,
            z_c,
            cmap=color_map,
            linewidth=0,
            antialiased=True
        )

        ax_xpol.plot_surface(
            x_x,
            y_x,
            z_x,
            cmap=color_map,
            linewidth=0,
            antialiased=True
        )

    ax_cpol.set_xticks(xticks)
    ax_xpol.set_xticks(xticks)
    ax_cpol.set_xticklabels(xt_labels)
    ax_xpol.set_xticklabels(xt_labels)
    ax_cpol.set_yticks(yticks)
    ax_xpol.set_yticks(yticks)
    ax_cpol.set_yticklabels(yt_labels)
    ax_xpol.set_yticklabels(yt_labels)
    ax_cpol.set_xlabel(
        'Orientation Angle ($\psi$)',
        labelpad=16,
        fontsize=16,
        **cfont
    )
    ax_xpol.set_xlabel(
        'Orientation Angle ($\psi$)',
        labelpad=16,
        fontsize=16,
        **cfont)
    ax_cpol.set_ylabel(
        'Ellipticity Angle ($\chi$)',
        labelpad=16,
        fontsize=16,
        **cfont
    )
    ax_xpol.set_ylabel(
        'Ellipticity Angle ($\chi$)',
        labelpad=16,
        fontsize=16,
        **cfont
    )
    ax_cpol.set_zlabel(
        "Relative Intensity",
        labelpad=16,
        fontsize=16,
        **cfont
    )
    ax_xpol.set_zlabel(
        "Relative Intensity",
        labelpad=16,
        fontsize=16,
        **cfont
    )
    ax_cpol.set_zticks(zticks)
    ax_xpol.set_zticks(zticks)
    ax_cpol.set_title(
        "Co-pol Signature",
        fontsize=20,
        pad=20,
        **cfont
    )
    ax_xpol.set_title(
        "Cross-pol Signature",
        fontsize=20,
        pad=20,
        **cfont
    )

    plt.show()

    return 0

电流输出如下所示:

Latest Plot

相关问题