首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何用角角建立三维输入/三维输出卷积模型?

如何用角角建立三维输入/三维输出卷积模型?
EN

Stack Overflow用户
提问于 2018-12-03 20:44:26
回答 1查看 2.4K关注 0票数 0

我有个小问题我解决不了。

我想实现CNN模型,完全连接MLP到我的蛋白质数据库,其中有2589个蛋白质。每个蛋白质有1287行和69列作为输入,1287行和8列作为输出。实际上有1287x1输出,但我在模型中使用了类标签的热编码来使用交叉熵损失。

还有我想要

如果把输入**的三维矩阵** X_train =(2589,1287,69)和y_train = (2589,1287,8)输出考虑为图像,则输出也是矩阵。

下面是我的克拉斯代码:

代码语言:javascript
复制
model = Sequential()
model.add(Conv2D(64, kernel_size=3, activation="relu", input_shape=(X_train.shape[1],X_train.shape[2])))
model.add(Conv2D(32, kernel_size=3, activation="relu"))
model.add(Flatten())
model.add(Dense((8), activation="softmax"))

但是我遇到了关于稠密层的错误:

代码语言:javascript
复制
ValueError: Error when checking target: expected dense_1 to have 2 dimensions, but got array with shape (2589, 1287, 8)

好的,我知道稠密应该是正整数单位(用Keras解释)。)。但是我如何实现矩阵输出到我的模型呢?

我试过了:

代码语言:javascript
复制
model.add(Dense((1287,8), activation="softmax"))

但我找不到任何解决办法。

非常感谢。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-12-03 21:07:50

Conv2D层需要(batch_size, height, width, channels)的输入形状。这意味着每个样本都是一个3D数组。

您的实际输入是(2589, 1287, 8),这意味着每个示例都是形状(1289, 8) --一个2D形状。因此,您应该使用Conv1D而不是Conv2D

其次,您需要(2589, 1287, 8)的输出。因为每个样本都是2D形状,所以输入Flatten()是没有意义的-- Flatten()会将每个样本的形状减少到1D,并且您希望每个样本都是2D的。

最后,取决于您的Conv层的填充,形状可能会根据kernel_size发生变化。由于您希望保留1287的中间维度,所以使用padding='same'来保持大小不变。

代码语言:javascript
复制
from keras.models import Sequential
from keras.layers import Conv1D, Flatten, Dense
import numpy as np

X_train = np.random.rand(2589, 1287, 69)
y_train = np.random.rand(2589, 1287, 8)


model = Sequential()
model.add(Conv1D(64, 
                 kernel_size=3, 
                 activation="relu", 
                 padding='same',
                 input_shape=(X_train.shape[1],X_train.shape[2])))
model.add(Conv1D(32, 
                 kernel_size=3, 
                 activation="relu",
                 padding='same'))
model.add(Dense((8), activation="softmax"))

model.summary()
model.compile(loss='categorical_crossentropy', optimizer='adam')
model.fit(X_train, y_train)
票数 5
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/53601593

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档