首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >NeuPy:输入形状问题

NeuPy:输入形状问题
EN

Stack Overflow用户
提问于 2017-11-18 18:58:23
回答 1查看 124关注 0票数 1

我想用neupy建立一个神经网络。因此,我构建了以下体系结构:

代码语言:javascript
复制
 network = layers.join(
                    layers.Input(10),

                    layers.Linear(500),
                    layers.Relu(),

                    layers.Linear(300),
                    layers.Relu(),

                    layers.Linear(10),
                    layers.Softmax(),
                )

我的数据被塑造成对折:

代码语言:javascript
复制
x_train.shape = (32589,10)
y_train.shape = (32589,1)

当我尝试使用以下方法训练这个网络时:

代码语言:javascript
复制
model.train(x_train, y_trian)

我得到了一个错误:

代码语言:javascript
复制
ValueError: Input dimension mis-match. (input[0].shape[1] = 10, input[1].shape[1] = 1)
Apply node that caused the error: Elemwise{sub,no_inplace}(SoftmaxWithBias.0, algo:network/var:network-output)
Toposort index: 26
Inputs types: [TensorType(float64, matrix), TensorType(float64, matrix)]
Inputs shapes: [(32589, 10), (32589, 1)]
Inputs strides: [(80, 8), (8, 8)]
Inputs values: ['not shown', 'not shown']
Outputs clients: [[Elemwise{Composite{((i0 * i1) / i2)}}(TensorConstant{(1, 1) of 2.0}, Elemwise{sub,no_inplace}.0, Elemwise{mul,no_inplace}.0), Elemwise{Sqr}[(0, 0)](Elemwise{sub,no_inplace}.0)]]

我如何编辑我的网络来映射这类数据?

非常感谢你!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-11-19 15:19:02

您的架构有10个输出,而不是1个。我假设您的y_train函数是一个0-1类标识符。如果是这样,则需要将结构更改为:

代码语言:javascript
复制
network = layers.join(
   layers.Input(10),

   layers.Linear(500),
   layers.Relu(),

   layers.Linear(300),
   layers.Relu(),

   layers.Linear(1),  # Single output
   layers.Sigmoid(),  # Sigmoid works better for 2-class classification
)

你可以让它变得更简单

代码语言:javascript
复制
network = layers.join(
   layers.Input(10),
   layers.Relu(500),
   layers.Relu(300),
   layers.Sigmoid(1),
)

它工作的原因是因为layers.Liner(10) > layers.Relu()layers.Relu(10)是一样的。您可以在官方文档中了解更多信息:http://neupy.com/docs/layers/basics.html#mutlilayer-perceptron-mlp

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/47369861

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档