首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在StellarGraph PaddedGraphGenerator中提供特定的培训、验证和测试集-

如何在StellarGraph PaddedGraphGenerator中提供特定的培训、验证和测试集-
EN

Stack Overflow用户
提问于 2022-07-09 13:43:31
回答 1查看 73关注 0票数 -1

我试图使用StellarGraph库来训练一个图形卷积神经网络。我想运行这个例子https://stellargraph.readthedocs.io/en/stable/demos/graph-classification/gcn-supervised-graph-classification.html,但没有N折叠交叉验证,通过提供我自己的培训,验证和测试集。这是我正在使用的代码(取自这个post)

代码语言:javascript
复制
generator = PaddedGraphGenerator(graphs=graphs)

train_gen = generator.flow([x for x in range(0, len(graphs_train))],
                           targets=graphs_train_labels,
                           batch_size=35)

test_gen = generator.flow([x for x in range(len(graphs_train),len(graphs_train) + len(graphs_test))],
                          targets=graphs_test_labels,
                          batch_size=35)

# Stopping criterium
es = EarlyStopping(monitor="val_loss",
                   min_delta=0,
                   patience=20,
                   restore_best_weights=True)

# Model definition
gc_model = GCNSupervisedGraphClassification(layer_sizes=[64, 64],
                                            activations=["relu", "relu"],
                                            generator=generator,
                                            dropout=0.5)

x_inp, x_out = gc_model.in_out_tensors()
predictions = Dense(units=32, activation="relu")(x_out)
predictions = Dense(units=16, activation="relu")(predictions)
predictions = Dense(units=1, activation="sigmoid")(predictions)

# Creating Keras model and preparing it for training
model = Model(inputs=x_inp, outputs=predictions)
model.compile(optimizer=Adam(0.001), loss=binary_crossentropy, metrics=["acc"])

# GNN Training
history = model.fit(train_gen, epochs=10, validation_data=test_gen, verbose=1)
model.fit(x=graphs_train,
          y=graphs_train_labels,
          epochs=10,
          verbose=1,
          callbacks=[es])


# Calculate performance on the validation data
test_metrics = model.evaluate(valid_gen, verbose=1)
valid_acc = test_metrics[model.metrics_names.index("acc")]

print(f"Test Accuracy model = {valid_acc}")

但到了最后,我得到了这个错误

ValueError:未能找到能够处理输入的数据适配器:(包含类型{"}“的值),

我在这里错过了什么?是因为我创建图表的方式吗?在我的例子中,图是一个包含恒星图的列表。

EN

回答 1

Stack Overflow用户

发布于 2022-07-09 18:11:09

问题解决了。我在打电话

代码语言:javascript
复制
model.fit(x=graphs_train,
          y=graphs_train_labels,
          epochs=10,
          verbose=1,
          callbacks=[es])

在线后

代码语言:javascript
复制
history = model.fit(train_gen, epochs=10, validation_data=test_gen, verbose=1)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72921857

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档