首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Stellargraph中的不相容形状

Stellargraph中的不相容形状
EN

Stack Overflow用户
提问于 2022-06-03 09:42:06
回答 1查看 71关注 0票数 0

我正在尝试使用Stellargraph库来实现GCN模型的一个小原型。我已经准备好了我的StellarGraph图形对象,我正在试图解决一个多类多标签分类问题。这意味着我试图预测多个列(确切地说是19列),每个列被编码为0或1。

以下是我所做的:

代码语言:javascript
复制
from sklearn.model_selection import train_test_split
from stellargraph.mapper import FullBatchNodeGenerator

train_subjects, test_subjects = train_test_split(nodelist, test_size = .25)
generator = FullBatchNodeGenerator(graph, method="gcn")
代码语言:javascript
复制
from stellargraph.layer import GCN

train_gen = generator.flow(train_subjects['ID'], train_subjects.drop(['ID'], axis = 1))
gcn = GCN(layer_sizes=[16, 16], activations=["relu", "relu"], generator=generator, dropout=0.5)
代码语言:javascript
复制
from tensorflow.keras import layers, optimizers, losses, metrics, Model

x_inp, x_out = gcn.in_out_tensors()
predictions = layers.Dense(units = 1, activation="sigmoid")(x_out)
代码语言:javascript
复制
from tensorflow.keras.metrics import Precision as Precision
​
model = Model(inputs=x_inp, outputs=predictions)
model.compile(
    optimizer=optimizers.Adam(learning_rate = 0.01),
    loss=losses.categorical_crossentropy,
    metrics= [Precision()])

val_gen = generator.flow(test_subjects['ID'], test_subjects.drop(['ID'], axis = 1))
代码语言:javascript
复制
from tensorflow.keras.callbacks import EarlyStopping

es_callback = EarlyStopping(monitor="val_precision", patience=200, restore_best_weights=True)

history = model.fit(
    train_gen,
    epochs=200,
    validation_data=val_gen,
    verbose=2,
    shuffle=False,  
    callbacks=[es_callback])

我有271045个边& 16354个节点,包括12265个训练节点。我得到的问题是来自Keras的形状不匹配。它规定如下。我怀疑这是由于插入了多个列作为目标列。我只使用了一个列(类)来尝试这个模型&它工作得很好。

代码语言:javascript
复制
InvalidArgumentError:  Incompatible shapes: [1,12265] vs. [1,233035]
     [[node LogicalAnd_1 (defined at tmp/ipykernel_52/2745570431.py:7) ]] [Op:__inference_train_function_1405]

值得一提的是,233035 = 12265 (列车节点数)乘以19 (类数)。你知道这里出了什么问题吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-06-04 06:37:45

我解决了问题。

这是一个新手的错误,我用一个单元来初始化密集的分类层,而不是19个(类的数量)。

我只需要修正这句话:

代码语言:javascript
复制
predictions = layers.Dense(units = 19, activation="sigmoid")(x_out)

祝您今天愉快!

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72487723

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档