首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow2.0重用层

Tensorflow2.0重用层
EN

Stack Overflow用户
提问于 2019-10-11 22:17:51
回答 1查看 258关注 0票数 1
代码语言:javascript
复制
def head(self, input, num_anchors, name, flatten=False):
    out_channels = (self.num_classes + 4) * num_anchors
    conv = layers.Conv2D(256, 3, 1, 'same', activation='relu', name=name+'_conv1')(input)
    conv = layers.Conv2D(256, 3, 1, 'same', activation='relu', name=name+'_conv2')(conv)
    conv = layers.Conv2D(256, 3, 1, 'same', activation='relu', name=name+'_conv3')(conv)
    out = layers.Conv2D(out_channels, 3, 1, 'same', name=name+'output')(conv)
    if flatten is True:
        batch_size = tf.shape(out)[0]
        out = tf.reshape(out, [batch_size, -1, num_anchors, self.num_classes+4])
        out = tf.reshape(out, [batch_size, -1, self.num_classes+4])
    return out

我想知道如何在tensorflow1中将这些层重用为tf.variable_scope(scope resue=tf.AUTO_REUSE

在tensorflow1中

with tf.variable_scope('', resue=tf.AUTO_REUSE) as scope: all layers here could be auto reuse

EN

回答 1

Stack Overflow用户

发布于 2019-10-21 01:27:42

您可以通过拥有一个公共引用来重用这些层。我已经附上了下面的示例代码。我使用了一个名为common_layer的变量,它将在三个独立的模型(顺序模型和函数模型)中使用。训练第一个模型,然后从所有三个模型的common_layer中减去权重。它证明了在第一个模型的层中发生的变化会反映在其他模型的公共层中。

代码语言:javascript
复制
import tensorflow as tf
import numpy as np

common_layer = tf.keras.layers.Dense(100, name='common_layer')

model1 = tf.keras.models.Sequential([
    tf.keras.layers.Input((100)),
    common_layer,
    tf.keras.layers.Dense(1)
])

model2 = tf.keras.models.Sequential([
    tf.keras.layers.Input((100)),
    common_layer,
    tf.keras.layers.Dense(10)
])

input_layer = tf.keras.layers.Input((100))
output_layer = common_layer(input_layer)
output_layer = tf.keras.layers.Dense(20)(output_layer)
model3 = tf.keras.Model(inputs=[input_layer], outputs=[output_layer])

model1.compile('adam', loss='mse')
model1.fit(np.random.rand(128, 100), np.random.rand(128, 1))

weights1 = model1.get_weights()[0]
weights2 = model2.get_weights()[0]
weights3 = model3.get_weights()[0]
print(np.sum(weights1 - weights2))  # 0.0
print(np.sum(weights1 - weights3))  # 0.0
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58343191

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档