使用Tensorflow2/Keras,我想在训练期间根据一些更新规则修改模型组件的权重。为此,我使用了get_weights()和set_weights()方法。我试着按如下方式实现:
class CAD_model(keras.Model):
def __init__(self, online_encoder, target_encoder, predictor, **kwargs):
super(CAD_model, self).__init__(**kwargs)
self.online_encoder = online_encoder
self.target_encoder = target_encoder
self.predictor = predictor
def call(self, x):
z = self.target_encoder(x)
return z
def compile(self, optimizer):
super(CAD_model, self).compile()
self.opt = optimizer
def compute_loss(self, x1, x2):
online_encoder = self.online_encoder
target_encoder = self.target_encoder
y = online_encoder(x1)
z1 = self.predictor(y)
# Stop gradient
z2 = tf.stop_gradient(target_encoder(x2))
loss = tf.reduce_mean((z1 - z2)**2)
return loss
def update_ema(self, decay=0.999):
online_vars = self.online_encoder.get_weights()
target_vars = self.target_encoder.get_weights()
ema_vars = [decay * var1 + (1 - decay) * var2 for var1, var2 in zip(target_vars, online_vars)]
self.target_encoder.set_weights(ema_vars)
def train_step(self, data):
x1, x2 = data
with tf.GradientTape() as tape:
loss = self.compute_loss(x1, x2)
grads = tape.gradient(loss, self.trainable_weights)
self.opt.apply_gradients(zip(grads, self.trainable_weights))
self.update_ema()
return {
"loss": loss,
}当运行CAD_model.fit时,我得到:
RuntimeError: Cannot get value inside Tensorflow graph function.其属于get_weights和set_weights操作。如何正确地提取和分配权重,以便在图中执行操作?
发布于 2021-03-24 16:13:58
train_step函数:将移出模型类。
下面的示例可能会对您有所帮助
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
class seq_model(keras.Model):
def __init__(self):
super().__init__()
self.dense1 = layers.Dense(64, activation="relu", name="dense_1")
self.dense2 = layers.Dense(64, activation="relu", name="dense_2")
self.classifier = layers.Dense(784, activation="softmax", name="predictions")
self.mse_loss_fn = tf.keras.losses.MeanSquaredError()
def call(self, inputs):
x = self.dense1(inputs)
x = self.dense2(x)
outputs = self.classifier(x)
return outputs
def compile(self, optimizer):
super().compile()
self.opt = optimizer
return self
def train_step(self, data):
x_batch_train = data
with tf.GradientTape() as tape:
reconstructed = self(x_batch_train)
loss = self.mse_loss_fn(x_batch_train, reconstructed)
loss += sum(self.losses)
grads = tape.gradient(loss, self.trainable_weights)
self.opt.apply_gradients(zip(grads, self.trainable_weights))
# This line would give error:
# RuntimeError: Cannot get value inside Tensorflow graph function.
# online_vars = self.dense1.get_weights()
# print(online_vars)
return {
"loss": loss,
}
model = seq_model().compile(tf.keras.optimizers.Adam(learning_rate=1e-3))
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype("float32") / 255
history = model.fit(
x_train,
batch_size=64,
epochs=2,
)
# workaround
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
mse_loss_fn = tf.keras.losses.MeanSquaredError()
loss_metric = tf.keras.metrics.Mean()
train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)
epochs = 2
for epoch in range(epochs):
print("Start of epoch %d" % (epoch,))
for step, x_batch_train in enumerate(train_dataset):
with tf.GradientTape() as tape:
reconstructed = model(x_batch_train)
loss = mse_loss_fn(x_batch_train, reconstructed)
loss += sum(model.losses) # Add KLD regularization loss
grads = tape.gradient(loss, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
# this line runs well
online_vars = model.dense1.get_weights()
print(online_vars)
loss_metric(loss)
if step % 100 == 0:
print("step %d: mean loss = %.4f" % (step, loss_metric.result()))https://stackoverflow.com/questions/66744107
复制相似问题