我正在尝试绘制MNIST数据集中的10个样本。每一个数字中的一个。代码如下:
import sklearn
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
mnist = datasets.fetch_mldata('MNIST original')
y = mnist.target
X = mnist.data
for i in range(10):
im_idx = np.argwhere(y == i)[0]
print(im_idx)
plottable_image = np.reshape(X[im_idx], (28, 28))
plt.imshow(plottable_image, cmap='gray_r')
plt.subplot(2, 5, i + 1)
plt.plot()由于某种原因,绘图中跳过了零位数。
为什么?
发布于 2018-12-19 00:55:49
好的,我知道了。问题是您是在绘制imshow之后定义子图的。所以你的第一个子情节被第二个改写了。要使您的代码正常工作,只需交换两个命令的顺序,如下所示。另外,我不明白你为什么要在最后使用plt.plot()。
plt.subplot(2, 5, i + 1) # <-- You have put this command after imshow
plt.imshow(plottable_image, cmap='gray_r')这里是你所知道的另一个备选方案:
fig = plt.figure()
for i in range(10):
im_idx = np.argwhere(y == i)[0]
plottable_image = np.reshape(X[im_idx], (28, 28))
ax = fig.add_subplot(2, 5, i+1)
ax.imshow(plottable_image, cmap='gray_r')您还可以使用以下代码进一步缩短Scott的代码(如下所示):
fig, ax = plt.subplots(2,5)
for i, ax in enumerate(ax.flatten()):
im_idx = np.argwhere(y == i)[0]
plottable_image = np.reshape(X[im_idx], (28, 28))
ax.imshow(plottable_image, cmap='gray_r')

发布于 2018-12-19 00:54:01
试试这个:
import sklearn
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
mnist = datasets.fetch_mldata('MNIST original')
y = mnist.target
X = mnist.data
fig, ax = plt.subplots(2,5)
ax = ax.flatten()
for i in range(10):
im_idx = np.argwhere(y == i)[0]
print(im_idx)
plottable_image = np.reshape(X[im_idx], (28, 28))
ax[i].imshow(plottable_image, cmap='gray_r')输出:

https://stackoverflow.com/questions/53837545
复制相似问题