首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >导出用于TensorFlow服务的textsum模型时,获取错误'str‘对象时没有属性'dtype’

导出用于TensorFlow服务的textsum模型时,获取错误'str‘对象时没有属性'dtype’
EN

Stack Overflow用户
提问于 2017-10-21 11:14:24
回答 1查看 3.5K关注 0票数 1

我目前正在尝试获得一个TF文本和模型,使用预测签名导出。我让_Decode返回测试中传递的项目字符串的结果,然后将其传递给buildTensorInfo。这实际上是一个返回的字符串.

现在,当我运行textsum_export.py逻辑来导出模型时,它已经到了构建TensorInfo对象的地步--但是使用下面的跟踪就会出现错误。我知道预测签名通常与图像一起使用。这就是问题所在吗?因为我正在处理字符串,所以不能将它用于Textsum模型吗?

错误是:

代码语言:javascript
复制
Traceback (most recent call last):
  File "export_textsum.py", line 129, in Export
    tensor_info_outputs = tf.saved_model.utils.build_tensor_info(res)
  File "/usr/local/lib/python2.7/site-packages/tensorflow/python/saved_model/utils_impl.py", line 37, in build_tensor_info
    dtype_enum = dtypes.as_dtype(tensor.dtype).as_datatype_enum
AttributeError: 'str' object has no attribute 'dtype'

输出模型的TF会话如下:

代码语言:javascript
复制
with tf.Session(config = config) as sess:

                # Restore variables from training checkpoints.
                ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                    print('Successfully loaded model from %s at step=%s.' %
                        (ckpt.model_checkpoint_path, global_step))
                    res = decoder._Decode(saver, sess)

                    print("Decoder value {}".format(type(res)))
                else:
                    print('No checkpoint file found at %s' % FLAGS.checkpoint_dir)
                    return

                # Export model
                export_path = os.path.join(FLAGS.export_dir,str(FLAGS.export_version))
                print('Exporting trained model to %s' % export_path)


                #-------------------------------------------

                tensor_info_inputs = tf.saved_model.utils.build_tensor_info(serialized_tf_example)
                tensor_info_outputs = tf.saved_model.utils.build_tensor_info(res)

                prediction_signature = (
                    tf.saved_model.signature_def_utils.build_signature_def(
                        inputs={ tf.saved_model.signature_constants.PREDICT_INPUTS: tensor_info_inputs},
                        outputs={tf.saved_model.signature_constants.PREDICT_OUTPUTS:tensor_info_outputs},
                        method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
                        ))

                #----------------------------------

                legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
                builder = saved_model_builder.SavedModelBuilder(export_path)

                builder.add_meta_graph_and_variables(
                    sess=sess, 
                    tags=[tf.saved_model.tag_constants.SERVING],
                    signature_def_map={
                        'predict':prediction_signature,
                    },
                    legacy_init_op=legacy_init_op)
                builder.save()

                print('Successfully exported model to %s' % export_path)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-10-22 05:54:14

使用张量预测签名工作,如果res是'str‘类型python变量,则res_tensor将为dtype tf.string。

代码语言:javascript
复制
res_tensor = tf.convert_to_tensor(res) 
票数 4
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/46862662

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档