首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Keras get_weights和set_weights在训练期间:“无法在Tensorflow图形函数中获取值”

Keras get_weights和set_weights在训练期间:“无法在Tensorflow图形函数中获取值”
EN

Stack Overflow用户
提问于 2021-03-22 18:18:04
回答 1查看 216关注 0票数 0

使用Tensorflow2/Keras,我想在训练期间根据一些更新规则修改模型组件的权重。为此,我使用了get_weights()set_weights()方法。我试着按如下方式实现:

代码语言:javascript
复制
class CAD_model(keras.Model):
    
    def __init__(self, online_encoder, target_encoder, predictor, **kwargs):
        super(CAD_model, self).__init__(**kwargs)
        self.online_encoder = online_encoder
        self.target_encoder = target_encoder
        self.predictor = predictor
        
    def call(self, x):
        z = self.target_encoder(x)
        return z
    
    def compile(self, optimizer):
        super(CAD_model, self).compile()
        self.opt = optimizer
    
    def compute_loss(self, x1, x2):
        
        online_encoder = self.online_encoder
        target_encoder = self.target_encoder
        
        y = online_encoder(x1)
        z1 = self.predictor(y)
        # Stop gradient
        z2 = tf.stop_gradient(target_encoder(x2))        
        loss = tf.reduce_mean((z1 - z2)**2)
        return loss
    
    def update_ema(self, decay=0.999):
        online_vars = self.online_encoder.get_weights()
        target_vars = self.target_encoder.get_weights()
        ema_vars = [decay * var1 + (1 - decay) * var2 for var1, var2 in zip(target_vars, online_vars)]
        self.target_encoder.set_weights(ema_vars)

    def train_step(self, data):
        
        x1, x2 = data
            
        with tf.GradientTape() as tape:
            loss = self.compute_loss(x1, x2)        
        
        grads = tape.gradient(loss, self.trainable_weights)
        self.opt.apply_gradients(zip(grads, self.trainable_weights))
        
        self.update_ema()
        
        return {
            "loss": loss,
        }

当运行CAD_model.fit时,我得到:

代码语言:javascript
复制
RuntimeError: Cannot get value inside Tensorflow graph function.

其属于get_weightsset_weights操作。如何正确地提取和分配权重,以便在图中执行操作?

EN

回答 1

Stack Overflow用户

发布于 2021-03-24 16:13:58

train_step函数:将移出模型类。

下面的示例可能会对您有所帮助

代码语言:javascript
复制
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers


class seq_model(keras.Model):
    def __init__(self):
        super().__init__()
        self.dense1 = layers.Dense(64, activation="relu", name="dense_1")
        self.dense2 = layers.Dense(64, activation="relu", name="dense_2")
        self.classifier = layers.Dense(784, activation="softmax", name="predictions")
        self.mse_loss_fn = tf.keras.losses.MeanSquaredError()

    def call(self, inputs):
        x = self.dense1(inputs)
        x = self.dense2(x)
        outputs = self.classifier(x)
        return outputs

    def compile(self, optimizer):
        super().compile()
        self.opt = optimizer
        return self

    def train_step(self, data):
        x_batch_train = data
        with tf.GradientTape() as tape:
            reconstructed = self(x_batch_train)
            loss = self.mse_loss_fn(x_batch_train, reconstructed)
            loss += sum(self.losses)

        grads = tape.gradient(loss, self.trainable_weights)
        self.opt.apply_gradients(zip(grads, self.trainable_weights))

        # This line would give error:
        # RuntimeError: Cannot get value inside Tensorflow graph function.
        # online_vars = self.dense1.get_weights()
        # print(online_vars)

        return {
            "loss": loss,
        }


model = seq_model().compile(tf.keras.optimizers.Adam(learning_rate=1e-3))


(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype("float32") / 255

history = model.fit(
    x_train,
    batch_size=64,
    epochs=2,
)


# workaround

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
mse_loss_fn = tf.keras.losses.MeanSquaredError()

loss_metric = tf.keras.metrics.Mean()

train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

epochs = 2

for epoch in range(epochs):
    print("Start of epoch %d" % (epoch,))

    for step, x_batch_train in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            reconstructed = model(x_batch_train)
            loss = mse_loss_fn(x_batch_train, reconstructed)
            loss += sum(model.losses)  # Add KLD regularization loss

        grads = tape.gradient(loss, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))
        # this line runs well
        online_vars = model.dense1.get_weights()
        print(online_vars)

        loss_metric(loss)

        if step % 100 == 0:
            print("step %d: mean loss = %.4f" % (step, loss_metric.result()))
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66744107

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档