首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tf keras中图神经网络多输入模型的误差

tf keras中图神经网络多输入模型的误差
EN

Stack Overflow用户
提问于 2020-06-06 21:28:57
回答 1查看 98关注 0票数 2

我正在使用Spektral训练一个带有辅助输入层的Graph神经网络。我正在拼接这些层。该模型可以完美地编译。但是,当将数据拟合到模型中时,我得到了以下错误。

代码语言:javascript
复制
ValueError: No data provided for "input_10". Need data for each key in: ['input_10', 'input_12']

代码如下所示

代码语言:javascript
复制
X_in = Input(shape=(1375, 3))
A_in = Input(tensor=sp_matrix_to_sp_tensor(adj_mat))

Feat_input = Input(shape=(55,8))

Feat_layer = Bidirectional(LSTM(32, return_sequences=True,),name='lstm_input')(Feat_input)
Feat_layer = Dense(512,activation='relu')(Feat_layer)
Feat_layer = Flatten()(Feat_layer)

graph_conv = GraphConvSkip(64, activation='relu',kernel_regularizer=l2(l2_reg),name='graph_input')([X_in, A_in])
graph_conv = Dropout(0.5)(graph_conv)

graph_conv = ChebConv(32, activation='relu', kernel_regularizer=l2(l2_reg)([graph_conv,A_in])

graph_conv = Dropout(0.5)(graph_conv)

graph_conv = GraphConvSkip(64, activation='relu', kernel_regularizer=l2(l2_reg)([graph_conv,A_in])
graph_conv = Dropout(0.5)(graph_conv)
graph_conv = ChebConv(32, activation='relu', kernel_regularizer=l2(l2_reg))([graph_conv, A_in])

flatten = Flatten()(graph_conv)

concatenated = concatenate([flatten, Feat_layer])

fc = Dense(512, activation='relu')(concatenated)
fc = Dense(256, activation='relu')(FC)
output = Dense(n_out, activation='softmax')(FC)

model = Model(inputs={'graph_input':[X_in, A_in], 'lstm_input':Feat_input}, outputs=output)

optimizer = RMSprop(lr=learning_rate)

model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['acc'])
model.summary()
history = model.fit({'graph_input': [X_train], 'lstm_input': x_train_feat }, y_train, batch_size=28, epochs=250,steps_per_epoch=10)
EN

回答 1

Stack Overflow用户

发布于 2020-06-08 17:23:07

在这里,当您定义Model时,您已经为graph_input - X_inA_in定义了two inputs

代码语言:javascript
复制
model = Model(inputs={'graph_input':[X_in, A_in], 'lstm_input':Feat_input}, outputs=output)

但是在调用model.fit时,您只传递了graph_input .i.e的one inputX_train和另一个输入丢失。这就是它抛出错误的原因。

代码语言:javascript
复制
history = model.fit({'graph_input': [X_train], 'lstm_input': x_train_feat }, y_train, batch_size=28, epochs=250,steps_per_epoch=10)

请传递graph_inputSecond input,错误应该会被修复。

希望这能回答你的问题。快乐学习。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62232435

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档