首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >TFLITE的Tensorflow模型

TFLITE的Tensorflow模型
EN

Stack Overflow用户
提问于 2020-04-01 11:35:25
回答 1查看 1.2K关注 0票数 0

我有这个代码来构建一个语义搜索引擎,使用来自tensorflow中心的经过预先训练的通用编码器。我无法改信于特利特。我已将模型保存到我的目录中。

进口模型:

代码语言:javascript
复制
module_path ="/content/drive/My Drive/4"
%time model = hub.load(module_path)
#print ("module %s loaded" % module_url)

#Create function for using modeltraining
def embed(input):
    return model(input)

培训数据模型:

代码语言:javascript
复制
## training the model
Model_USE= embed(data)

保存模型:

代码语言:javascript
复制
exported = tf.train.Checkpoint(v=tf.Variable(Model_USE))
exported.f = tf.function(
    lambda  x: exported.v * x,
    input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
export_dir = "/content/drive/My Drive/"
tf.saved_model.save(exported,export_dir)

保存很好,但是当我转换到tflite时,它会产生错误。

转换代码:

代码语言:javascript
复制
converter = tf.lite.TFLiteConverter.from_saved_model(export_dir)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
                                       tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)

错误:

代码语言:javascript
复制
as_list() is not defined on an unknown TensorShape.
EN

回答 1

Stack Overflow用户

发布于 2020-04-03 16:35:52

首先,您需要添加一个数据生成器,以便为转换器提供有代表性的输入。就像这样:

代码语言:javascript
复制
def representative_data_gen():
  for input_value in dataset.take(100):
    yield [input_value]

input value必须是形状为(1, your_iput_shape)的形状,就好像它的批处理形状为1一样。它必须作为一个列表产生;是强制性的。

您还应该声明需要哪种类型的优化,例如:

代码语言:javascript
复制
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]

然而,我也遇到了不同选择的转换器取决于网络结构的问题,在这种情况下,我不知道。因此,为了使转换器正常运行,我只需要做:

代码语言:javascript
复制
converter = lite.TFLiteConverter.from_keras_model(model)
converter.experimental_new_converter = True
converter.optimizations = [lite.Optimize.DEFAULT]
tfmodel = converter.convert()

converter.experimental_new_converter = True用于转换https://github.com/tensorflow/tensorflow/issues/34813中的RNN时出现的问题

编辑

正如这里所解释的:ValueError: None is only supported in the 1st dimension. Tensor 'flatbuffer_data' has invalid shape '[None, None, 1, 512]' TFLite只允许数据的第一个维度为None,即批处理。所有其他尺寸都必须固定。尝试将它们填充,例如,tf.keras.preprocessing.sequence.pad_sequences

然后屏蔽网络中的序列,如:tensorflow.org/guide/keras/masking_and_paddingEmbeddingMasking层所描述的那样。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/60969947

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档