首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >多输入GAN返回错误ValueError:图形断开连接:

多输入GAN返回错误ValueError:图形断开连接:
EN

Stack Overflow用户
提问于 2021-02-22 10:27:58
回答 1查看 37关注 0票数 0

我整个周末都在试着解决这个问题。我希望有人能帮上忙。

我有一个模型,它可以接受一个普通的数组,并在GAN中处理它,它是有效的,但是一旦我把它改为多个输入,我就开始得到:

代码语言:javascript
复制
ValueError: Graph disconnected:

我的原始代码:

代码语言:javascript
复制
# Build stacked GAN model
gan_input = Input(shape=Xtrain.shape[1])
H = generator(gan_input)
gd_input=Concatenate()([gan_input,H])
gan_V = discriminator(gd_input)
GAN = Model(gan_input, [gan_V,H])
GAN.compile(loss=['categorical_crossentropy','mse'], optimizer=opt) #Complete GAN have both loss functions
GAN.summary()

然后我将其修改为多输入:

代码语言:javascript
复制
gan_dataframe_input = Input(shape=Xtrain[1][:-2].shape) #new testing
numpy_input = Input(shape=Xtrain[1][-1].shape)

gan_input = layers.concatenate([gan_dataframe_input, numpy_input])

print(gan_input)
print(mergedLayer)

H = generator([gan_dataframe_input,numpy_input]) <<--two shapes being imputed
gd_input=Concatenate()([gan_input,H])   <<--merged layer + above two shapes being imputed
gan_V = discriminator(gd_input) 
GAN = Model(gan_input, [gan_V,H])  <<--this line returns an error
GAN.compile(loss=['categorical_crossentropy','mse'], optimizer=opt) #Complete GAN have both loss functions
GAN.summary()

堆栈跟踪:

代码语言:javascript
复制
KerasTensor(type_spec=TensorSpec(shape=(None, 736), dtype=tf.float32, name=None), name='concatenate_28/concat:0', description="created by layer 'concatenate_28'")
KerasTensor(type_spec=TensorSpec(shape=(None, 736), dtype=tf.float32, name=None), name='concatenate_27/concat:0', description="created by layer 'concatenate_27'")
WARNING:tensorflow:Functional model inputs must come from `tf.keras.Input` (thus holding past layer metadata), they cannot be the output of a previous non-Input layer. Here, a tensor specified as input to "model_34" was not an Input tensor, it was generated by layer concatenate_28.
Note that input tensors are instantiated via `tensor = tf.keras.Input(shape)`.
The tensor that caused the issue was: concatenate_28/concat:0
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-94-ac83091846e6> in <module>()
     69 gd_input=Concatenate()([gan_input,H])
     70 gan_V = discriminator(gd_input)
---> 71 GAN = Model(gan_input, [gan_V,H])
     72 GAN.compile(loss=['categorical_crossentropy','mse'], optimizer=opt) #Complete GAN have both loss functions
     73 GAN.summary()

4 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in _map_graph_network(inputs, outputs)
    988                              'The following previous layers '
    989                              'were accessed without issue: ' +
--> 990                              str(layers_with_complete_input))
    991         for x in nest.flatten(node.outputs):
    992           computable_tensors.add(id(x))

ValueError: Graph disconnected: cannot obtain value for tensor KerasTensor(type_spec=TensorSpec(shape=(None, 659), dtype=tf.float32, name='input_71'), name='input_71', description="created by layer 'input_71'") at layer "concatenate_28". The following previous layers were accessed without issue: []

奇怪的是,在我在层上打印数据后,似乎数组中的项的数量没有对齐?(659,)是其中一个输入的大小,而另一个是(77,)。我不确定我到底做错了什么。有什么建议吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-02-22 12:42:48

在构建多输入/多输出模型时,必须将模型输入和输出作为数组进行编译和馈送,而不是像以前那样将它们连接在一起。此外,模型的输入必须始终为tf.keras.layers.Input。所以正确的代码应该是

代码语言:javascript
复制
gan_dataframe_input = Input(shape=Xtrain[1][:-2].shape) #new testing
numpy_input = Input(shape=Xtrain[1][-1].shape)

gan_input = layers.concatenate([gan_dataframe_input, numpy_input])

print(gan_input)
print(mergedLayer)

H = generator([gan_dataframe_input,numpy_input]) <<--two shapes being imputed
gd_input=Concatenate()([gan_input,H])   <<--merged layer + above two shapes being imputed
gan_V = discriminator(gd_input) 
GAN = Model([gan_dataframe_input, numpy_input ], [gan_V,H])  <<--this line is modified
GAN.compile(loss=['categorical_crossentropy','mse'], optimizer=opt) #Complete GAN have both loss functions
GAN.summary()
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66309302

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档