首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tensorflow2.0是否还有'trainable‘参数?

tensorflow2.0是否还有'trainable‘参数?
EN

Stack Overflow用户
提问于 2019-08-26 22:44:46
回答 1查看 604关注 0票数 1

在tensorboard中,我找不到像tensorflow1.X一样更新参数的梯度运算。

并且在keras api中找不到'trainable‘参数。

如果tf2.0仍然可以在tensorboard中显示渐变操作,我如何将其添加到我的tensorboard中。

ps.my tensorflow版本为2.0-rc0。

下面是我向tensorboard文件添加内容的代码。

代码语言:javascript
复制
logdir = "testlogs"
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
.....
model.fit(x=train_x, y=train_y,
      batch_size=256,
      epochs=6,
      shuffle=True,
      callbacks=[tensorboard_callback])
EN

回答 1

Stack Overflow用户

发布于 2019-08-30 14:36:14

tensorflow2.0是否还有参数'trainable'?

model中,决定哪些变量是可训练的是组成keras的各个层的责任。开箱即用的层有很多,但为了说明一些可训练变量的使用,这里有一个简单的密集层实现

代码语言:javascript
复制
class MyLayer(tf.keras.layers.Layer):
  def __init__(self, units=8, input_dim=8):
    super(MyLayer,self).__init__()
    self.w = tf.Variable(initial_value=tf.random_normal_initializer()(shape=(input_dim, units)),
                              trainable=True)
    self.b = tf.Variable(initial_value=tf.zeros_initializer()(shape=(units,)),
                           trainable=True)

  def call(self, inputs):
    return tf.matmul(inputs, self.w) + self.b

例如,您可以在如下所示的keras模型中使用:

代码语言:javascript
复制
my_layer = MyLayer(units=8,input_dim=2)
my_model = tf.keras.models.Sequential([
    my_layer
])
my_model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.binary_crossentropy)

当然,在实践中最好使用开箱即用的tf.keras.layers.Dense,这只是为了说明可训练的变量my_layer.wmy_layer.b

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57660163

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档