首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Keras卷积神经网络

Keras卷积神经网络
EN

Stack Overflow用户
提问于 2015-12-04 07:13:22
回答 2查看 1.1K关注 0票数 3

现在,我正在尝试构建一个基本的卷积神经网络,以便使用keras对mnist数据集进行简单分类。最后,我想把我自己的图像放进去,我只想先构建一个简单的网络,以确保我的结构正常工作。因此,我下载了mnist数据作为mnint.pkl.gz解压,并加载到元组,并最终颠簸数组。下面是我的代码:

代码语言:javascript
复制
import numpy as np
from keras.models import Sequential

from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.optimizers import SGD
from PIL import Image as IM
import theano
from sklearn.cross_validation import train_test_split
import cPickle
import gzip
f=gzip.open('mnist.pkl.gz')
data1,data2,data3=cPickle.load(f)
f.close()

X=data1[0]
Y=data1[1]

x=X[0:15000,:]
y=Y[0:15000]

X_train,X_test,y_train,y_test=train_test_split(x,y,test_size
=0.33,random_state=99)


model=Sequential()
model.add(Convolution2D(10,5,5,border_mode='valid', 
input_shape=   (1,28,28)))
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Flatten())
model.add(Dense(10))
model.add(Activation('softmax'))
sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mean_squared_error', optimizer=sgd)
model.fit(X_train,y_train, batch_size=10, nb_epoch=10)
score=model.evaluate(X_test,y_test,batch_size=10)
print(score)

我得到一个错误,如下:

代码语言:javascript
复制
'Wrong number of dimensions: expected 4, got 2 with shape 
(10,   784).')

我认为这意味着我需要将它放入theano 4d张量中,即具有(样本,通道,行,列),但我不知道如何做到这一点。此外,当我特别想在加载'.png‘文件后解决这个问题时,我打算把它们放入numpy矩阵中,但看起来好像行不通。谁能告诉我如何将图像转换为theano4d张量,以便在此代码中使用?谢谢

EN

回答 2

Stack Overflow用户

发布于 2015-12-04 16:27:55

代码需要一个tensor4,这一点是正确的。传统的结构是(batch, channel, width, height)。在本例中,图像是单色的,所以channel=1看起来像是在使用批大小为10的图像,MNIST图像的宽度为28像素,高度为28像素。

您可以简单地将数据重塑为所需的格式。如果x的形状为(10,784),则x.reshape(10, 1, 28, 28)将具有所需的格式。

票数 4
EN

Stack Overflow用户

发布于 2016-09-25 17:21:47

代码需要一个四维Theano张量,而不是一个张量(在幕后执行所有Theano张量操作)。

您的输入、X_train和X_test需要重塑,如下所示:

代码语言:javascript
复制
X_train = X_train.reshape(-1, 1, 28, 28)
X_test = X_test.reshape(-1, 1, 28, 28)
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/34078063

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档