我正在尝试将MNIST数据集加载到数组中。当我使用(X_train,y_train),(X_test,y_test)= mnist.load_data()时,我得到一个数组y_test(10000,),但我希望它的形状为(10000,1)。数组(10000,1 )和数组(10000,1)的区别是什么?如何将第一个数组转换为第二个数组?
发布于 2018-11-07 15:38:28
您的第一个具有形状(10000,)的数组是一个一维np.ndarray。由于shape数组的元组属性是一个元组,而长度为1的元组需要一个尾随逗号,因此形状是(10000,)而不是(10000) (应该是一个整数)。因此,目前您的数据如下所示:
import numpy as np
a = np.arange(5) # >>> array([0, 1, 2, 3, 4]
print(a.shape) # >>> (5,)您需要的是一个形状为(10000, 1)的二维数组。添加一个长度为1的维度不需要任何额外的数据,它基本上是一个“空”维度。要向现有数组添加维度,可以使用np.expand_dims()或np.reshape()。
使用np.expand_dims
import numpy as np
b = np.array(np.arange(5)) # >>> array([0, 1, 2, 3, 4])
b = np.expand_dims(b, axis=1) # >>> array([[0],[1],[2],[3],[4]])该函数是专门为数组添加空维而创建的。axis关键字指定新添加的维度将占据的位置。
使用np.reshape
import numpy as np
a = np.arange(5)
X_test_reshaped = np.reshape(a, shape=[-1, 1]) # >>> array([[0],[1],[2],[3],[4]])shape=[-1, 1]指定整形操作后新形状的外观。-1本身将被numpy在内部“拟合数据”的形状所取代。Reshape是一个比expand_dims更强大的功能,可以通过许多不同的方式使用。你可以在numpy文档中阅读更多关于它的其他用法。numpy.reshape()
https://stackoverflow.com/questions/53185055
复制相似问题