首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow 2变量不可训练

Tensorflow 2变量不可训练
EN

Stack Overflow用户
提问于 2020-02-04 05:18:19
回答 1查看 1.2K关注 0票数 1

我已经用tf2创建了一个简单的模型,它将输入'a‘乘以变量'b’(初始化为1),然后返回输出'c‘。然后我试着在简单的数据集a=1,c=5上训练它,我希望它学习b=5。

代码语言:javascript
复制
import tensorflow as tf
from tensorflow.keras.models import Model

a = Input(shape=(1,))
b = tf.Variable(1., trainable=True)
c = a*b
model = Model(a,c)

loss = tf.keras.losses.MeanAbsoluteError()
model.compile(optimizer='adam', loss=loss)

model.fit([1.],[5.],batch_size=1, epochs=1)

然而,tf2并不认为变量'b‘是可训练的。摘要显示没有可训练的参数。

代码语言:javascript
复制
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 1)]               0         
_________________________________________________________________
tf_op_layer_mul (TensorFlowO [(None, 1)]               0         
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________

为什么变量'b‘不是训练?

EN

回答 1

Stack Overflow用户

发布于 2020-02-04 14:20:29

Keras模型是Layer类的包装器。您必须将此变量包装为keras层,以便在模型中将其显示为可训练参数。

您可以为此创建一个小型自定义层,如下所示:

代码语言:javascript
复制
class MyLayer(tf.keras.layers.Layer):
  def __init__(self):
    super(MyLayer, self).__init__()

    #your variable goes here
    self.variable = tf.Variable(1., trainable=True, dtype=tf.float64)

  def call(self, inputs, **kwargs):

    # your mul operation goes here
    x = inputs * self.variable

    return x

这里的call方法会做乘法运算。我们可以像out模型中的任何其他层一样使用这一层。在这里,我创建了一个Sequential模型,并添加了乘法操作作为模型层。

代码语言:javascript
复制
model = tf.keras.models.Sequential()
mylayer_object = MyLayer()
model.add(mylayer_object)

loss = tf.keras.losses.MeanAbsoluteError()
model.compile("adam", loss)

model.fit([1.],[5.],batch_size=1, epochs=1)
model.summary()
'''
Train on 1 samples
1/1 [==============================] - 0s 426ms/sample - loss: 4.0000
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
my_layer (MyLayer)           multiple                  1         
=================================================================
Total params: 1
Trainable params: 1
Non-trainable params: 0
_________________________________________________________________
'''

在此之后,如果你能列出模型的可训练参数。

代码语言:javascript
复制
print(model.trainable_variables)
# [<tf.Variable 'Variable:0' shape=() dtype=float64, numpy=1.0009999968852092>]
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/60047291

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档