首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >vae预测角化的NotImplementedError方法

vae预测角化的NotImplementedError方法
EN

Stack Overflow用户
提问于 2022-05-05 21:14:49
回答 1查看 356关注 0票数 0

所以我一直在用这个例子来表示卷积vae,这里是mnist:https://keras.io/examples/generative/vae/

代码语言:javascript
复制
vae.predict(mnist_digits)


NotImplementedError                       Traceback (most recent call last)

<ipython-input-8-2e6bf7edcacc> in <module>()
----> 1 vae.predict(mnist_digits)

1 frames

/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in autograph_handler(*args, **kwargs)
   1145           except Exception as e:  # pylint:disable=broad-except
   1146             if hasattr(e, "ag_error_metadata"):
-> 1147               raise e.ag_error_metadata.to_exception(e)
   1148             else:
   1149               raise

NotImplementedError: in user code:

    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1801, in predict_function  *
        return step_function(self, iterator)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1790, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1783, in run_step  **
        outputs = model.predict_step(data)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1751, in predict_step
        return self(x, training=False)
    File "/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py", line 67, in error_handler
        raise e.with_traceback(filtered_tb) from None
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 475, in call
        raise NotImplementedError('Unimplemented `tf.keras.Model.call()`: if you '

    NotImplementedError: Exception encountered when calling layer "vae" (type VAE).
    
    Unimplemented `tf.keras.Model.call()`: if you intend to create a `Model` with the Functional API, please provide `inputs` and `outputs` arguments. Otherwise, subclass `Model` with an overridden `call()` method.
    
    Call arguments received:
      • inputs=tf.Tensor(shape=(None, 28, 28, 1), dtype=float32)
      • training=False
      • mask=None

同样地,与

代码语言:javascript
复制
vae(mnist_digits)

我得到以下信息:

代码语言:javascript
复制
---------------------------------------------------------------------------

NotImplementedError                       Traceback (most recent call last)

<ipython-input-8-aa5f4fb52e20> in <module>()
----> 1 vae(mnist_digits)

1 frames

/usr/local/lib/python3.7/dist-packages/keras/engine/training.py in call(self, inputs, training, mask)
    473         a list of tensors if there are more than one outputs.
    474     """
--> 475     raise NotImplementedError('Unimplemented `tf.keras.Model.call()`: if you '
    476                               'intend to create a `Model` with the Functional '
    477                               'API, please provide `inputs` and `outputs` '

NotImplementedError: Exception encountered when calling layer "vae" (type VAE).

Unimplemented `tf.keras.Model.call()`: if you intend to create a `Model` with the Functional API, please provide `inputs` and `outputs` arguments. Otherwise, subclass `Model` with an overridden `call()` method.

Call arguments received:   • inputs=tf.Tensor(shape=(70000, 28, 28, 1), dtype=float32)   • training=None   • mask=None

我是否必须创建一个自定义预测函数来解决这个问题。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-05-05 21:38:21

predict方法不是在VAE类中,而是在encoderdecoder中,因为它们是您的Keras模型。

如果要查看不同的数字类是如何群集的,请使用

z, _, _ = vae.encoder.predict(data)

如果要从潜在空间中采样数字,请使用

decoded = vae.decoder.predict(z)

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72133736

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档