首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >将onnx模型导入到tensorflow2.x 2.x?

将onnx模型导入到tensorflow2.x 2.x?
EN

Stack Overflow用户
提问于 2022-02-27 12:36:01
回答 1查看 1.2K关注 0票数 1

我使用tensorflow创建了一个修改后的lenet模型,如下所示:

代码语言:javascript
复制
img_height = img_width = 64
BS = 32

model = models.Sequential()
model.add(layers.InputLayer((img_height,img_width,1), batch_size=BS))
model.add(layers.Conv2D(filters=32, kernel_size=(3, 3), strides=(1, 1), batch_size=BS, activation='relu', padding="valid"))
model.add(layers.Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1), batch_size=BS, activation='relu', padding='valid'))
model.add(layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), batch_size=BS, padding='valid'))
model.add(layers.Dropout(0.25))
model.add(layers.Conv2D(filters=128, kernel_size=(1,1), strides=(1,1), batch_size=BS, activation='relu', padding='valid'))
model.add(layers.Dropout(0.5))
model.add(layers.Conv2D(filters=2, kernel_size=(1,1), strides=(1,1), batch_size=BS, activation='relu', padding='valid'))
model.add(layers.GlobalAveragePooling2D())
model.add(layers.Activation('softmax'))
model.summary()

完成培训后,我使用tf.keras.models.save_model保存模型:

代码语言:javascript
复制
num = time.time()
tf.keras.models.save_model(model,'./saved_models/' + str(num) + '/')

然后使用"tf2onnx“模块将该模型转换为onnx格式:

代码语言:javascript
复制
! python -m tf2onnx.convert --saved-model saved_models/1645088924.84102/ --output 1645088924.84102.onnx

我想要一个可以将相同模型检索到tensorflow2.x 2.x中的方法。我尝试使用"onnx_tf“将onnx模型转换为tensorflow .pb模型:

代码语言:javascript
复制
import onnx

from onnx_tf.backend import prepare

onnx_model = onnx.load("1645088924.84102.onnx")  # load onnx model
tf_rep = prepare(onnx_model)  # prepare tf representation

但是这个方法只生成一个.pb文件,但是tensorflow2.x 2.x中的load_model方法需要与.pb文件位于同一个目录中的另外两个文件夹,它们被命名为“变量”和“资产”。

如果有一种方法可以使.pb文件像有“资产”和“变量”文件夹那样工作,或者如果有一种方法可以从onnx生成完整的模型,那么最好采用这两种解决方案。

我使用的是jupyter集线器服务器,所有东西都在anaconda环境中。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-04-12 07:40:50

事实证明,最简单的方法是Tensorflow对原始帖子的注释中所建议的支持,即将.pb文件转换回.h5,然后重用该模型。对于推理,我们可以使用graph_def和concrete_function。

将.pb转换为.h5:How to convert .pb file to .h5. (Tensorflow model to keras)

用于推断:https://leimao.github.io/blog/Save-Load-Inference-From-TF2-Frozen-Graph/

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71284804

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档