首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >`Tensor`丢失传递时,需要输入`tape`

`Tensor`丢失传递时,需要输入`tape`
EN

Stack Overflow用户
提问于 2021-01-27 13:00:29
回答 1查看 2.9K关注 0票数 2

关于tf的一些问题。

代码语言:javascript
复制
import numpy as np
import tensorflow as tf
from tensorflow import keras

x_train = [1,2,3]
y_train = [1,2,3]

W = tf.Variable(tf.random.normal([1]), name = 'weight')
b = tf.Variable(tf.random.normal([1]), name = 'bias')
hypothesis = W*x_train+b

optimizer = tf.optimizers.SGD (learning_rate=0.01)

train = tf.keras.optimizers.Adam().minimize(cost, var_list=[W, b])

当我开始我的代码的最后一行时,出现了下面的错误。

代码语言:javascript
复制
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-52-cd6e22f66d09> in <module>()
----> 1 train = tf.keras.optimizers.Adam().minimize(cost, var_list=[W, b])

1 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py in _compute_gradients(self, loss, var_list, grad_loss, tape)
    530     # TODO(josh11b): Test that we handle weight decay in a reasonable way.
    531     if not callable(loss) and tape is None:
--> 532       raise ValueError("`tape` is required when a `Tensor` loss is passed.")
    533     tape = tape if tape is not None else backprop.GradientTape()
    534 

ValueError: `tape` is required when a `Tensor` loss is passed.

我知道它与tensorflow版本2相关,但不想修改为版本1。

需要tensorflow ver2的解决方案。谢谢。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-01-27 13:22:42

由于您没有提供成本函数,因此我添加了一个。以下是代码

代码语言:javascript
复制
import numpy as np
import tensorflow as tf
from tensorflow import keras

 
x_train = [1,2,3]
y_train = [1,2,3]

W = tf.Variable(tf.random.normal([1]), name = 'weight')
b = tf.Variable(tf.random.normal([1]), name = 'bias')
hypothesis = W*x_train+b

@tf.function
def cost():

    y_model = W*x_train+b
    error = tf.reduce_mean(tf.square(y_train- y_model))
    return error


optimizer = tf.optimizers.SGD (learning_rate=0.01)

train = tf.keras.optimizers.Adam().minimize(cost, var_list=[W, b])

tf.print(W)
tf.print(b)
票数 4
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65913108

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档