首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tf.train.Checkpoint是否正在恢复?

tf.train.Checkpoint是否正在恢复?
EN

Stack Overflow用户
提问于 2021-04-04 14:40:48
回答 1查看 187关注 0票数 0

我正在colab上运行tensorflow 2.4。我试图使用tf.train.Checkpoint()保存模型,因为它包含了模型子类,但是在恢复之后,我看到它没有恢复模型的任何权重。

下面是几个片段:

代码语言:javascript
复制
### From tensorflow tutorial nmt_with_attention
class Encoder(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):
    ...
    self.gru = tf.keras.layers.GRU(self.enc_units,
                                   return_sequences=True,
                                   return_state=True,
                                   recurrent_initializer='glorot_uniform')

.
.
.

class NMT_Train(tf.keras.Model):
  def __init__(self, inp_vocab_size, tar_vocab_size, max_length_inp, max_length_tar, emb_dims, units, batch_size, source_tokenizer, target_tokenizer):
    super(NMT_Train, self).__init__()
    self.encoder = Encoder(inp_vocab_size, emb_dims, units, batch_size)
    ...

.
.
.

model = NMT_Train(INP_VOCAB, TAR_VOCAB, MAXLEN, MAXLEN, EMB_DIMS, UNITS, BATCH_SIZE, english_tokenizer, hindi_tokenizer)
model.compile(optimizer = tf.keras.optimizers.Adam(),
              loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits= True))
model.fit(dataset, epochs=2)

checkpoint = tf.train.Checkpoint(model = model)
manager = tf.train.CheckpointManager(checkpoint, './ckpts', max_to_keep=1)
manager.save()

model.encoder.gru.get_weights() ### get the output
##[array([[-0.0627057 ,  0.05900152,  0.06614069, ...

model.optimizer.get_weights() ### get the output
##[90, array([[ 6.6851695e-05, -4.6736805e-06, -2.3183979e-05, ...

当我后来修复它时,我没有得到任何gru重量:

代码语言:javascript
复制
model = NMT_Train(INP_VOCAB, TAR_VOCAB, MAXLEN, MAXLEN, EMB_DIMS, UNITS, BATCH_SIZE, english_tokenizer, hindi_tokenizer)
model.compile(optimizer = tf.keras.optimizers.Adam(),
              loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits= True))

checkpoint = tf.train.Checkpoint(model = model)
manager = tf.train.CheckpointManager(checkpoint, './ckpts', max_to_keep=1)

manager.restore_or_initialize()

model.encoder.gru.get_weights() ### empty list
## []

model.optimizer.get_weights() ### empty list
## []

我也尝试过checkpoint.restore(manager.latest_checkpoint),但是没有什么改变。

我做错什么了吗??或者建议其他方法来拯救模型,这样我就可以重新训练它,以适应更多的时代。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-04-04 14:48:34

您正在定义keras模型,那么为什么不使用keras模型契点呢?

来自Keras文档

代码语言:javascript
复制
model.compile(loss=..., optimizer=...,
              metrics=['accuracy'])

EPOCHS = 10
checkpoint_filepath = '/tmp/checkpoint'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])

# The model weights (that are considered the best) are loaded into the model.
model.load_weights(checkpoint_filepath)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66942378

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档