首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在matplotlib中绘制动画矩阵

如何在matplotlib中绘制动画矩阵
EN

Stack Overflow用户
提问于 2018-08-29 08:22:51
回答 2查看 3.2K关注 0票数 9

我需要一步一步地可视化地做一些数值计算算法,如下图所示:(gif)

Font

如何使用matplotlib制作此动画?有没有办法在视觉上呈现这些过渡?例如矩阵的变换,求和,转置,使用循环和它表示转换等。我的目标不是使用图形,而是使用相同的矩阵表示。这是为了便于理解算法。

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2018-08-31 08:02:55

由于使用imshow可以轻松地绘制矩阵,因此可以使用imshow绘图创建这样的表,并根据当前动画步长调整数据。

代码语言:javascript
复制
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import matplotlib.animation

#####################
# Array preparation
#####################

#input array
a = np.random.randint(50,150, size=(5,5))
# kernel
kernel = np.array([[ 0,-1, 0], [-1, 5,-1], [ 0,-1, 0]])

# visualization array (2 bigger in each direction)
va = np.zeros((a.shape[0]+2, a.shape[1]+2), dtype=int)
va[1:-1,1:-1] = a

#output array
res = np.zeros_like(a)

#colorarray
va_color = np.zeros((a.shape[0]+2, a.shape[1]+2)) 
va_color[1:-1,1:-1] = 0.5

#####################
# Create inital plot
#####################
fig = plt.figure(figsize=(8,4))

def add_axes_inches(fig, rect):
    w,h = fig.get_size_inches()
    return fig.add_axes([rect[0]/w, rect[1]/h, rect[2]/w, rect[3]/h])

axwidth = 3.
cellsize = axwidth/va.shape[1]
axheight = cellsize*va.shape[0]

ax_va  = add_axes_inches(fig, [cellsize, cellsize, axwidth, axheight])
ax_kernel  = add_axes_inches(fig, [cellsize*2+axwidth,
                                   (2+res.shape[0])*cellsize-kernel.shape[0]*cellsize,
                                   kernel.shape[1]*cellsize,  
                                   kernel.shape[0]*cellsize])
ax_res = add_axes_inches(fig, [cellsize*3+axwidth+kernel.shape[1]*cellsize,
                               2*cellsize, 
                               res.shape[1]*cellsize,  
                               res.shape[0]*cellsize])
ax_kernel.set_title("Kernel", size=12)

im_va = ax_va.imshow(va_color, vmin=0., vmax=1.3, cmap="Blues")
for i in range(va.shape[0]):
    for j in range(va.shape[1]):
        ax_va.text(j,i, va[i,j], va="center", ha="center")

ax_kernel.imshow(np.zeros_like(kernel), vmin=-1, vmax=1, cmap="Pastel1")
for i in range(kernel.shape[0]):
    for j in range(kernel.shape[1]):
        ax_kernel.text(j,i, kernel[i,j], va="center", ha="center")


im_res = ax_res.imshow(res, vmin=0, vmax=1.3, cmap="Greens")
res_texts = []
for i in range(res.shape[0]):
    row = []
    for j in range(res.shape[1]):
        row.append(ax_res.text(j,i, "", va="center", ha="center"))
    res_texts.append(row)    


for ax  in [ax_va, ax_kernel, ax_res]:
    ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
    ax.yaxis.set_major_locator(mticker.IndexLocator(1,0))
    ax.xaxis.set_major_locator(mticker.IndexLocator(1,0))
    ax.grid(color="k")

###############
# Animation
###############
def init():
    for row in res_texts:
        for text in row:
            text.set_text("")

def animate(ij):
    i,j=ij
    o = kernel.shape[1]//2
    # calculate result
    res_ij = (kernel*va[1+i-o:1+i+o+1, 1+j-o:1+j+o+1]).sum()
    res_texts[i][j].set_text(res_ij)
    # make colors
    c = va_color.copy()
    c[1+i-o:1+i+o+1, 1+j-o:1+j+o+1] = 1.
    im_va.set_array(c)

    r = res.copy()
    r[i,j] = 1
    im_res.set_array(r)

i,j = np.indices(res.shape)
ani = matplotlib.animation.FuncAnimation(fig, animate, init_func=init, 
                                         frames=zip(i.flat, j.flat), interval=400)
ani.save("algo.gif", writer="imagemagick")
plt.show()

票数 4
EN

Stack Overflow用户

发布于 2018-08-30 07:02:30

此示例在Jupyter notebook中设置内联动画。我想可能也有一种方法可以导出为gif,但到目前为止我还没有研究过。

不管怎样,首先要做的就是摆好桌子。为了编写render_mpl_table代码,我从Export a Pandas dataframe as a table image那里借了很多钱。

这个问题的(改编)版本是:

代码语言:javascript
复制
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML
import six

width = 8
data = pd.DataFrame([[0]*width,
                     [0, *np.random.randint(95,105,size=width-2), 0],
                     [0, *np.random.randint(95,105,size=width-2), 0],
                     [0, *np.random.randint(95,105,size=width-2), 0]])

def render_mpl_table(data, col_width=3.0, row_height=0.625, font_size=14,
                     row_color="w", edge_color="black", bbox=[0, 0, 1, 1],
                     ax=None, col_labels=data.columns, 
                     highlight_color="mediumpurple",
                     highlights=[], **kwargs):
    if ax is None:
        size = (np.array(data.shape[::-1]) + np.array([0, 1])) *
                np.array([col_width, row_height])
        fig, ax = plt.subplots(figsize=size)
        ax.axis('off')

    mpl_table = ax.table(cellText=data.values, bbox=bbox, colLabels=col_labels,
                         **kwargs)

    mpl_table.auto_set_font_size(False)
    mpl_table.set_fontsize(font_size)

    for k, cell in six.iteritems(mpl_table._cells):
        cell.set_edgecolor(edge_color)
        if k in highlights:
            cell.set_facecolor(highlight_color)
        elif data.iat[k] > 0:
            cell.set_facecolor("lightblue")
        else:
            cell.set_facecolor(row_color)
    return fig, ax, mpl_table

fig, ax, mpl_table = render_mpl_table(data, col_width=2.0, col_labels=None,
                                  highlights=[(0,2),(0,3),(1,2),(1,3)])

在这种情况下,要以不同颜色突出显示的单元格由指定行和列的元组数组提供。

对于动画,我们需要设置一个函数来绘制具有不同高亮显示的表格:

代码语言:javascript
复制
def update_table(i, *args, **kwargs):
    r = i//(width-1)
    c = i%(width-1)
    highlights=[(r,c),(r,c+1),(r+1,c),(r+1,c+1)]
    for k, cell in six.iteritems(mpl_table._cells):
        cell.set_edgecolor("black")
        if k in highlights:
            cell.set_facecolor("mediumpurple")
        elif data.iat[k] > 0:
            cell.set_facecolor("lightblue")
        else:
            cell.set_facecolor("white")
    return (mpl_table,)

这将强制更新表中所有单元格的颜色。基于当前帧计算highlights数组。在本例中,表格的宽度和高度是硬编码的,但根据输入数据的形状进行更改并不是很难。

我们基于现有的图和更新函数创建一个动画:

代码语言:javascript
复制
a = animation.FuncAnimation(fig, update_table, (width-1)*3,
                               interval=750, blit=True)

最后,我们将其内联显示在我们的笔记本中:

代码语言:javascript
复制
HTML(a.to_jshtml())

我把这些放在github上的笔记本里,见https://github.com/gurudave/so_examples/blob/master/mpl_animation.ipynb

希望这足以让你朝着正确的方向前进!

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/52067833

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档