首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >keras中的GNN结构形状

keras中的GNN结构形状
EN

Stack Overflow用户
提问于 2020-04-27 17:52:16
回答 1查看 136关注 0票数 0

我的数据集是X。X的形状是(423,320,3)。数据个数为423,数据长度为320。我使用python的spektral包。

代码语言:javascript
复制
X.shape # (423,320,3)

调整矩阵为A。A的形状为(423,423)

代码语言:javascript
复制
A.shape # (423,423)

我的y标签是y。y的形状是(320,1)

代码语言:javascript
复制
y.shape # (320,1)

我的模型如下所示。我觉得我的模型很简单。但它不起作用。

代码语言:javascript
复制
N = A.shape[0]
F = X.shape[-1]
n_classes = 1

X_in = Input(shape=(423,320,))
A_in = Input((N, ), sparse=True)

X_1 = GraphConv(16, 'relu')([X_in, A_in])
X_1 = Dropout(0.5)(X_1)
X_2 = GraphConv(n_classes, 'relu')([X_1, A_in])

model = Model(inputs=[X_in, A_in], outputs=X_2)
A = GraphConv.preprocess(A).astype('f4')

model.compile(optimizer='adam',
              loss='mean_squared_error',
              weighted_metrics=['accuracy'])
model.summary()

model.fit([X, A], y)

模型摘要如下

代码语言:javascript
复制
Model: "model_32"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_63 (InputLayer)           [(None, 423, 320)]   0                                            
__________________________________________________________________________________________________
input_64 (InputLayer)           [(None, None)]       0                                            
__________________________________________________________________________________________________
graph_conv_52 (GraphConv)       (None, 423, 16)      5136        input_63[0][0]                   
                                                                 input_64[0][0]                   
__________________________________________________________________________________________________
dropout_25 (Dropout)            (None, 423, 16)      0           graph_conv_52[0][0]              
__________________________________________________________________________________________________
graph_conv_53 (GraphConv)       (None, 423, 1)       17          dropout_25[0][0]                 
                                                                 input_64[0][0]                   
==================================================================================================
Total params: 5,153
Trainable params: 5,153
Non-trainable params: 0
__________________________________________________________________________________________________

错误在下面

代码语言:javascript
复制
ValueError                                Traceback (most recent call last)
<ipython-input-272-d00160881a92> in <module>
     18 model.summary()
     19 
---> 20 model.fit([X, A], y)

~/anaconda3/envs/tensor2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    817         max_queue_size=max_queue_size,
    818         workers=workers,
--> 819         use_multiprocessing=use_multiprocessing)
    820 
    821   def evaluate(self,

~/anaconda3/envs/tensor2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    233           max_queue_size=max_queue_size,
    234           workers=workers,
--> 235           use_multiprocessing=use_multiprocessing)
    236 
    237       total_samples = _get_total_number_of_samples(training_data_adapter)

~/anaconda3/envs/tensor2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py in _process_training_inputs(model, x, y, batch_size, epochs, sample_weights, class_weights, steps_per_epoch, validation_split, validation_data, validation_steps, shuffle, distribution_strategy, max_queue_size, workers, use_multiprocessing)
    591         max_queue_size=max_queue_size,
    592         workers=workers,
--> 593         use_multiprocessing=use_multiprocessing)
    594     val_adapter = None
    595     if validation_data:

~/anaconda3/envs/tensor2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py in _process_inputs(model, mode, x, y, batch_size, epochs, sample_weights, class_weights, shuffle, steps, distribution_strategy, max_queue_size, workers, use_multiprocessing)
    644     standardize_function = None
    645     x, y, sample_weights = standardize(
--> 646         x, y, sample_weight=sample_weights)
    647   elif adapter_cls is data_adapter.ListsOfScalarsDataAdapter:
    648     standardize_function = standardize

~/anaconda3/envs/tensor2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split, shuffle, extract_tensors_from_dataset)
   2381         is_dataset=is_dataset,
   2382         class_weight=class_weight,
-> 2383         batch_size=batch_size)
   2384 
   2385   def _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs,

~/anaconda3/envs/tensor2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py in _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs, is_dataset, class_weight, batch_size)
   2408           feed_input_shapes,
   2409           check_batch_axis=False,  # Don't enforce the batch size.
-> 2410           exception_prefix='input')
   2411 
   2412     # Get typespecs for the input data and sanitize it if necessary.

~/anaconda3/envs/tensor2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    580                              ': expected ' + names[i] + ' to have shape ' +
    581                              str(shape) + ' but got array with shape ' +
--> 582                              str(data_shape))
    583   return data
    584 

**ValueError: Error when checking input: expected input_63 to have shape (423, 320) but got array with shape (320, 3)**
EN

回答 1

Stack Overflow用户

发布于 2020-12-11 18:31:38

不确定您是否还需要帮助,但问题出在输入上。

您的节点要素X具有形状( 423,320,3),但您的数据仅表示一个包含423个节点的图。Spektral不支持多维节点属性,因此应将X重塑为(423,320 * 3):

代码语言:javascript
复制
X = X.reshape(423, 320 * 3)

此外,由于您使用的是model.fit(),因此应该将批处理大小设置为N或执行以下操作:

代码语言:javascript
复制
for epoch in range(epochs): 
    model.train_on_batch([X, A], y)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61456086

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档