首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何保存keras recommenderNet模型?

如何保存keras recommenderNet模型?
EN

Stack Overflow用户
提问于 2020-07-17 04:45:32
回答 3查看 1K关注 0票数 0

从嵌入RecommenderNet模型构建模型后,如何保存它,链接到doc is 电影/

代码语言:javascript
复制
class RecommenderNet(keras.Model):
    def __init__(self, num_users, num_movies, embedding_size, **kwargs):
        super(RecommenderNet, self).__init__(**kwargs)
        self.num_users = num_users
        self.num_movies = num_movies
        self.embedding_size = embedding_size
        self.user_embedding = layers.Embedding(
            num_users,
            embedding_size,
            embeddings_initializer="he_normal",
            embeddings_regularizer=keras.regularizers.l2(1e-6),
        )
        self.user_bias = layers.Embedding(num_users, 1)
        self.movie_embedding = layers.Embedding(
            num_movies,
            embedding_size,
            embeddings_initializer="he_normal",
            embeddings_regularizer=keras.regularizers.l2(1e-6),
        )
        self.movie_bias = layers.Embedding(num_movies, 1)

    def call(self, inputs):
        user_vector = self.user_embedding(inputs[:, 0])
        user_bias = self.user_bias(inputs[:, 0])
        movie_vector = self.movie_embedding(inputs[:, 1])
        movie_bias = self.movie_bias(inputs[:, 1])
        dot_user_movie = tf.tensordot(user_vector, movie_vector, 2)
        # Add all the components (including bias)
        x = dot_user_movie + user_bias + movie_bias
        # The sigmoid activation forces the rating to between 0 and 1
        return tf.nn.sigmoid(x)


model = RecommenderNet(num_users, num_movies, EMBEDDING_SIZE)
model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(), optimizer=keras.optimizers.Adam(lr=0.001)
)
history = model.fit(
    x=x_train,
    y=y_train,
    batch_size=64,
    epochs=5,
    verbose=1,
    validation_data=(x_val, y_val),
)

试过这些

代码语言:javascript
复制
model.save('model.h5py')
tf.keras.models.save_model(model, overwrite=True, include_optimizer=True, save_format='h5')

双掷

代码语言:javascript
复制
NotImplementedError: Saving the model to HDF5 format requires the model to be a Functional model or 
a Sequential model.It does not work for subclassed models, because such models are defined via the body of 
a Python method, which isn't safely serializable. Consider saving to the Tensorflow SavedModel 
format (by setting save_format="tf") or using `save_weights`.

模型类型为main.RecommenderNet

EN

回答 3

Stack Overflow用户

发布于 2020-07-17 04:50:53

因此,正如错误所述,您可以使用save_format="tf",因为您的模型不是FunctionalSequential

代码语言:javascript
复制
model.save('model.py')
tf.keras.models.save_model(model, overwrite=True, include_optimizer=True, save_format='tf')

另外,正如在文档中所看到的,您可以使用:

代码语言:javascript
复制
keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
票数 1
EN

Stack Overflow用户

发布于 2020-07-20 10:46:09

实际上,用

代码语言:javascript
复制
keras.models.load_model('path_to_my_model')

对我来说不起作用了,我们必须从构建的模型中提取save_weights

代码语言:javascript
复制
model.save_weights('model_weights', save_format='tf')

然后,我们必须为子类模型创建一个新的实例,然后用已建模型的一个记录和train_on_batch来编译和load_weights

代码语言:javascript
复制
loaded_model = RecommenderNet(num_users, num_movies, EMBEDDING_SIZE)
loaded_model.compile(loss=tf.keras.losses.BinaryCrossentropy(), optimizer=keras.optimizers.Adam(lr=0.001))
loaded_model.train_on_batch(x_train[:1], y_train[:1])
loaded_model.load_weights('model_weights')

这在TensorFlow==2.2.0中非常完美。

票数 1
EN

Stack Overflow用户

发布于 2020-07-20 01:14:14

我将tensorflow 1.14.0升级到2.2.0,save_model工作了

代码语言:javascript
复制
tf.keras.models.save_model(model,'./saved_model')

但是用load_model,加载后的模型,同时进行预测

代码语言:javascript
复制
loaded_model = tf.keras.models.load_model('./saved_model')

我得到了值错误

代码语言:javascript
复制
rates = loaded_model.predict(user_product_array).flatten()

    ValueError: Python inputs incompatible with input_signature:
      inputs: (
        Tensor("IteratorGetNext:0", shape=(None, 2), dtype=int32))
      input_signature: (
        TensorSpec(shape=(None, 2), dtype=tf.int64, name='input_1'))

我能知道这里的问题吗?

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62947266

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档