首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >keras模型中检查点不工作的负载权重

keras模型中检查点不工作的负载权重
EN

Stack Overflow用户
提问于 2019-09-18 13:23:21
回答 1查看 2.5K关注 0票数 8

这事我要疯了。

我使用tensorflow keras定义了一个顺序模型:

代码语言:javascript
复制
model = tf.keras.Sequential([tf.keras.layer.Dense(128,input_shape(784,),activation="relu"),
                             tf.keras.layer.Dense(10,activation="softmax"])
model.compile(optimizer="adam",loss="mse")
keras.experimental.export_saved_model(model,"keras_model")

我在C程序中用c_api.h训练所说的模型

C程序将权重保存在检查点文件中。

当尝试从检查点文件恢复python中的权重时,使用:

代码语言:javascript
复制
keras.experimental.load_from_saved_model("keras_model/")
#OR
model = tf.keras.Sequential([tf.keras.layer.Dense(128,input_shape(784,),activation="relu"),
                             tf.keras.layer.Dense(10,activation="softmax"])
model.load_weights("keras_model/variables/variables")
#OR
checkpoint = tf.train.Checkpoint(model=model)
status = checkpoint.restore("keras_model/variables/variables")

最后我得到了一个错误,没有恢复权重。

我能够恢复负重,并在我的C程序中继续训练

代码语言:javascript
复制
keras.experimental.load_from_saved_model("keras_model/")
WARNING: Logging before flag parsing goes to stderr.
W0918 15:18:04.350199 140418474760000 deprecation.py:323] From <ipython-input-2-06ea110fdc8e>:1: load_from_saved_model (from tensorflow.python.keras.saving.saved_model_experimental) is deprecated and will be removed in a future version.
Instructions for updating:
The experimental save and load functions have been  deprecated. Please switch to `tf.keras.models.load_model`.
2019-09-18 15:18:04.390271: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 1696040000 Hz
2019-09-18 15:18:04.390913: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x4bf4790 executing computations on platform Host. Devices:
2019-09-18 15:18:04.390961: I tensorflow/compiler/xla/service/service.cc:175]   StreamExecutor device (0): Host, Default Version
W0918 15:18:04.436281 140418474760000 deprecation.py:323] From /home/jregalado/Projects/tensorflow/env/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/util.py:1249: NameBasedSaverStatus.__init__ (from tensorflow.python.training.tracking.util) is deprecated and will be removed in a future version.
Instructions for updating:
Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future.
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-2-06ea110fdc8e> in <module>
----> 1 keras.experimental.load_from_saved_model("keras_model/")

~/Projects/tensorflow/env/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py in new_func(*args, **kwargs)
322               'in a future version' if date is None else ('after %s' % date),
323               instructions)
--> 324       return func(*args, **kwargs)
325     return tf_decorator.make_decorator(
326         func, new_func, 'deprecated',

~/Projects/tensorflow/env/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/saved_model_experimental.py in load_from_saved_model(saved_model_path, custom_objects)
425       compat.as_text(constants.VARIABLES_DIRECTORY),
426       compat.as_text(constants.VARIABLES_FILENAME))
--> 427   model.load_weights(checkpoint_prefix)
428   return model

~/Projects/tensorflow/env/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py in load_weights(self, filepath, by_name)
179         raise ValueError('Load weights is not yet supported with TPUStrategy '
180                          'with steps_per_run greater than 1.')
--> 181     return super(Model, self).load_weights(filepath, by_name)
182
183   @trackable.no_automatic_dependency_tracking

~/Projects/tensorflow/env/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in load_weights(self, filepath, by_name)
1372         # streaming restore for any variables created in the future.
1373         trackable_utils.streaming_restore(status=status, session=session)
-> 1374       status.assert_nontrivial_match()
1375       return status
1376     if h5py is None:

~/Projects/tensorflow/env/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/util.py in assert_nontrivial_match(self)
964     # assert_nontrivial_match and assert_consumed (and both are less
965     # useful since we don't touch Python objects or Python state).
--> 966     return self.assert_consumed()
967
968   def _gather_saveable_objects(self):

~/Projects/tensorflow/env/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/util.py in assert_consumed(self)
941       raise AssertionError(
942           "Some objects had attributes which were not restored:{}".format(
--> 943               "".join(unused_attribute_strings)))
944     for trackable in self._graph_view.list_objects():
945       # pylint: disable=protected-access

AssertionError: Some objects had attributes which were not restored:
<tf.Variable 'a/kernel:0' shape=(784, 128) dtype=float32, numpy=
array([[-0.03716458, -0.04911711, -0.01023878, ...,  0.0636776 ,
0.02892563, -0.05542086],
[-0.02324755, -0.07362694, -0.0399951 , ...,  0.0680329 ,
0.05201877, -0.05149256],
[ 0.00954343,  0.05673491,  0.05108347, ...,  0.01994208,
-0.01107961,  0.06192174],
...,
[ 0.07091486, -0.07734856, -0.04417738, ...,  0.01921409,
-0.01908814, -0.05070668],
[ 0.01353646, -0.05189713, -0.01391671, ..., -0.05795977,
0.04801518,  0.00801209],
[-0.05304915,  0.01870193,  0.05657425, ..., -0.06819408,
-0.00760372, -0.0106293 ]], dtype=float32)>: ['a/kernel']
<tf.Variable 'a/bias:0' shape=(128,) dtype=float32, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>: ['a/bias']
<tf.Variable 'b/kernel:0' shape=(128, 10) dtype=float32, numpy=
array([[-0.1759212 , -0.09282549, -0.11045764, ..., -0.13727605,
-0.02849793,  0.14510198],
[ 0.06857841, -0.01459177,  0.08369003, ...,  0.05089156,
-0.05319159, -0.08594933],
[-0.180914  , -0.18932283,  0.20551099, ..., -0.17210156,
-0.10069884,  0.06433241],
...,
[ 0.09097584, -0.03930017, -0.15125516, ...,  0.02359283,
-0.16158347, -0.13176063],
[-0.04145582, -0.03205152,  0.20097663, ..., -0.15124482,
0.16874255, -0.15434337],
[-0.13188484,  0.04145408,  0.05036192, ..., -0.10489662,
0.12316228,  0.08794598]], dtype=float32)>: ['b/kernel']
<tf.Variable 'b/bias:0' shape=(10,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>: ['b/bias']
EN

回答 1

Stack Overflow用户

发布于 2022-08-03 14:43:52

当我遇到这个问题时,我只使用Python,而不使用C/C++,但实际问题是,我将.index文件而不是茎传递给了.index函数:

错误:

代码语言:javascript
复制
model = make_some_model()
model.load_weights("output/20220801-pretrain_test/checkpoints/checkpoint_weights_e10.ckpt.index")

右侧:

代码语言:javascript
复制
model = make_some_model()
model.load_weights("output/20220801-pretrain_test/checkpoints/checkpoint_weights_e10.ckpt")
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57993832

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档