首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在Tensorflow中打印重量?

如何在Tensorflow中打印重量?
EN

Stack Overflow用户
提问于 2018-03-16 07:42:37
回答 5查看 21.8K关注 0票数 8

我试图得到我的hidden_layer2的权重矩阵并打印出来。似乎我可以得到重量矩阵,但我不能打印它。

当使用tf.Print(w, [w])时,它不会输出任何内容。在使用print(tf.Print(w,[w])时,它至少会打印关于张量的信息:

代码语言:javascript
复制
Tensor("hidden_layer2_2/Print:0", shape=(3, 2), dtype=float32)

我还尝试在with-Statement,之外使用tf.Print(),结果也是一样。

完整的代码在这里,我只是在一个前馈NN:https://pastebin.com/KiQUBqK4中处理随机数据。

“我的守则”的一部分是:

代码语言:javascript
复制
hidden_layer2 = tf.layers.dense(
        inputs=hidden_layer1,
        units=2,
        activation=tf.nn.relu,
        name="hidden_layer2")


with tf.variable_scope("hidden_layer2", reuse=True):
        w = tf.get_variable("kernel")
        tf.Print(w, [w])
        # Also tried tf.Print(hidden_layer2, [w])
EN

回答 5

Stack Overflow用户

发布于 2020-01-06 08:00:50

更新为TENSORFLOW 2.X

从TensorFlow 2.0 (>= 2.0)开始,由于Session对象已被删除,推荐的高级后端是Keras,因此获取权重的方法如下:

代码语言:javascript
复制
from tensorflow.keras.applications import MobileNetV2

model = MobileNetV2(input_shape=[128, 128, 3], include_top=False) #or whatever model
print(model.layers[0].get_weights()[0])
票数 8
EN

Stack Overflow用户

发布于 2018-03-16 08:30:22

试着做这个,

代码语言:javascript
复制
w = tf.get_variable("kernel")
print(w.eval())
票数 0
EN

Stack Overflow用户

发布于 2019-10-29 14:00:05

我采取了不同的方法。首先,我列出所有可培训的变量,并使用所需变量的索引,并使用当前会话运行它。守则附后如下:

代码语言:javascript
复制
variables = tf.trainable_variables()

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

print("Weight matrix: {0}".format(sess.run(variables[0])))  # For me the variable at the 0th index was the one I required
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/49315432

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档