首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >误差加载Tensorflow角模型(.h5)

误差加载Tensorflow角模型(.h5)
EN

Stack Overflow用户
提问于 2021-02-23 07:20:58
回答 2查看 6.2K关注 0票数 0

我训练了一个数字图像并制作了一个模型文件。

相应的调味汁如下。

代码语言:javascript
复制
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical 
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense 
from tensorflow.keras.models import load_model
import cv2
from sklearn.model_selection import train_test_split
from tensorflow.keras.layers import Flatten, Convolution2D, MaxPooling2D
from tensorflow.keras.layers import Dropout, Activation, Dense
from tensorflow.keras.layers import Conv2D

os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'

path = '/Users/animalman/Documents/test/cnn/'

trainPath = os.listdir(path+'train')
testPath = os.listdir(path+'test')

categories = ["5"]
length = len(categories)

width = 28
height = 28
label = [1 for i in range(length)]
X = []
Y = []
for idx, categorie in enumerate(categories):
    label = [0 for i in range(length)]
    label[idx] = 1
    
    fileDir = path + 'train' + '/' + categorie + '/'
    for t, dir, f in os.walk(fileDir):
        for filename in f:
            print(fileDir + filename)
            img = cv2.imread(fileDir + filename)
            img = cv2.resize(img, None, fx=width/img.shape[0],fy=height/img.shape[1])
            X.append(img)
            Y.append(label)
X = np.array(X)
Y = np.array(Y)
X_train, X_test,Y_train, Y_test = train_test_split(X,Y)
xy = (X_train, X_test, Y_train, Y_test)


X_train = X_train.astype("float") / 256
X_test  = X_test.astype("float")  / 256

model = Sequential()
model.add(Conv2D(16, (3, 3), input_shape=X_train.shape[1:], padding='same'))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))

model.add(Conv2D(64, (3, 3)))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Flatten()) 
model.add(Dense(512))  
model.add(Activation('relu'))
model.add(Dropout(0.5))

model.add(Dense(length))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy',
    optimizer='rmsprop',
    metrics=['accuracy'])


hdf5_file = "./7obj-model.h5"
if os.path.exists(hdf5_file):
    model.load_weights(hdf5_file)
else:
    model.fit(X_train, Y_train, batch_size=32, epochs=1)
    model.save_weights(hdf5_file)

然后我把保存下来的模型文件带来了。

代码语言:javascript
复制
loaded_model = tf.keras.models.load_model(hdf5_file)

但是,这是发生错误的地方。理由是什么呢?

追溯(最近一次调用):loaded_model = tf.keras.models.load_model(hdf5_file)文件第206行、load_model返回hdf5_format.load_model_from_hdf5(filepath,custom_objects )中的第112行:文件“/Users/动物/文档/test/Tra.py”文件"/Users/animalman/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/keras/saving/hdf5_format.py",第181行,在load_model_from_hdf5提起ValueError(‘在配置文件中找不到模型’)。ValueError:在配置文件中找不到模型。

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2021-02-23 09:32:51

正如这个职位中提到的,您的h5文件只包含权重。您需要将模型架构保存在json文件中,然后使用model_from_json加载模型配置,因此可以使用load_weights加载权重。

另一个选择可能是简单地保存模型(体系结构+权重),方法是将最后一行替换为

代码语言:javascript
复制
model.save("model.h5")

然后加载,您可以使用

代码语言:javascript
复制
model = load_model('model.h5')
票数 0
EN

Stack Overflow用户

发布于 2021-12-02 21:33:33

添加到@Oscar响应中,对于较小和简单的模型,'h5'格式就足够了,但是对于具有custom_layers或自定义度量的复杂模型(函数和子类),最好以'tf'格式保存(也称为SavedModel格式)。

有关Keras网页的更详细指南,请在这里查看

Keras SavedModel格式限制: SavedModel为生成图层调用函数的图形所做的跟踪允许SavedModel比H5更可移植,但它也有缺点。 可能比H5更慢更大。无法序列化从掩码参数生成的操作(即,如果使用layer (.,mask=mask_value)调用一个层,则掩码参数不会保存到SavedModel)。不将重写的train_step()保存在子类模型中。使用掩码或具有自定义训练循环的自定义对象仍然可以从SavedModel中保存和加载,除非它们必须重写get_config()/from_config(),而且在加载时必须将类传递给custom_objects参数。 H5限制: 通过model.add_loss()和model.add_metric()添加的外部损失和度量没有保存(与SavedModel不同)。如果您在模型上有这样的损失&度量,并且希望恢复培训,那么您需要在加载模型之后自己添加这些损失。请注意,这不适用于通过self.add_loss() & self.add_metric()在层内创建的损失/度量。只要层被加载,这些损失&度量就会保留下来,因为它们是层调用方法的一部分。自定义对象(如自定义层)的计算图不包含在保存的文件中。加载时,Keras将需要访问这些对象的Python类/函数,以便重建模型。请参阅自定义对象。不支持预处理层。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66328719

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档