首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在TensorFlow2.0中冻结和导出TensorFlow模型

在TensorFlow2.0中冻结和导出TensorFlow模型
EN

Stack Overflow用户
提问于 2019-08-10 05:24:20
回答 1查看 387关注 0票数 3

我正在尝试将用TensorFlow1.13编写的现有代码(使用估计器)迁移到TensorFlow2.0,但我在尝试找到等效的API来冻结和输出图形以及输出.pb文件时遇到了问题。

在TensorFlow1.13中,estimator类有一个函数export_savedmodel,它接受一个模型路径和一个serving_input_receiver_fn。我在设置serving_input_receiver_fn时遇到了麻烦,因为它似乎接受占位符。但是,当迁移到TensorFlow2.0时,尽管存在相同的API,但由于将急切执行模型设置为默认值,占位符不能在急切执行模式下工作。

代码语言:javascript
复制
   def export(self):
        self.configure()
        a_shape = (None, None, None, self.IMG_CHANNELS)
        b_shape = tf.TensorShape((None, None, self.IMU_DATA_DIM))
        a = tf.compat.v1.placeholder(tf.float32, a_shape, name="a")
        b = tf.compat.v1.placeholder(tf.float32, b_shape, name='b')
        input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
            'a': a,
            'b':b
        })
        return self.modelPath, input_fn

RuntimeError: tf.placeholder()与急切执行不兼容。

因此,我想问,从现有的检查点文件冻结和导出模型以输出.pb文件的正确方法是什么?

EN

回答 1

Stack Overflow用户

发布于 2020-06-30 05:09:55

这里有一个tf.estimator.export.build_raw_serving_input_receiver_fn()的例子。可以使用TF2.x将其直接粘贴到笔记本中。希望这能有所帮助。

代码语言:javascript
复制
import tensorflow as tf

checkpoint_dir = "/some/location/to/store/the_model"

input_column = tf.feature_column.numeric_column("x")
# Use a LinearClassifier but this would also work with a custom Estimator
estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])

# Create a fake dataset with only one feature 'x' and an associated label
def input_fn():
    return tf.data.Dataset.from_tensor_slices(
        ({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)

# The thing is that we must not call raw_input_fn: would result in the error 
# "tf.placeholder() is not compatible with eager execution."
# Instead pass raw_input_fn directly to estimator.export_saved_model()

feature_to_tensor = {
    # pass some dummy tensor: this is just to get the shapes for the placeholder
    # that will be created by build_raw_serving_input_receiver_fn(). 
    # Adjust with the shape of 'x'.
    # 
    'x': tf.constant(0.),
}
raw_input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(feature_to_tensor, default_batch_size=None)
export_dir = estimator.export_saved_model(checkpoint_dir, raw_input_fn).decode()

然后可以检查导出的模型:

代码语言:javascript
复制
!saved_model_cli show --all --dir $export_dir

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['predict']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['x'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: Placeholder:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['all_class_ids'] tensor_info:
        dtype: DT_INT32
        shape: (-1, 2)
        name: head/predictions/Tile:0
    outputs['all_classes'] tensor_info:
        dtype: DT_STRING
        shape: (-1, 2)
        name: head/predictions/Tile_1:0
    outputs['class_ids'] tensor_info:
        dtype: DT_INT64
        shape: (-1, 1)
        name: head/predictions/ExpandDims:0
    outputs['classes'] tensor_info:
        dtype: DT_STRING
        shape: (-1, 1)
        name: head/predictions/str_classes:0
    outputs['logistic'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1)
        name: head/predictions/logistic:0
    outputs['logits'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1)
        name: linear/linear_model/linear/linear_model/linear/linear_model/weighted_sum:0
    outputs['probabilities'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 2)
        name: head/predictions/probabilities:0
  Method name is: tensorflow/serving/predict

导出的模型现在可以加载并用于另一个进程的推理:

代码语言:javascript
复制
import tensorflow as tf
imported = tf.saved_model.load(export_dir)
f = imported.signatures["predict"]
f(x=tf.constant([-2., 5., -3.]))

{'class_ids': <tf.Tensor: shape=(3, 1), dtype=int64, numpy=
 array([[1],
        [0],
        [1]], dtype=int64)>,
 'classes': <tf.Tensor: shape=(3, 1), dtype=string, numpy=
 array([[b'1'],
        [b'0'],
        [b'1']], dtype=object)>,
 'all_class_ids': <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
 array([[0, 1],
        [0, 1],
        [0, 1]])>,
...etc...
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57437270

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档