我正在做一个“新闻分类”项目。其中,模型必须将给定的文本分类为业务、娱乐、政治、体育和技术。。
我正在使用TensorFlow==2.7.0上的谷歌Colab。我训练了7种不同的模特。之后,对其进行训练和预测。与所有模型相比,Conv1d表现最好。保存了model_2.save('saved_model/my_model').It的表现最好的模型直到现在都做得很好。
但是,当我想使用代码加载保存的模型时
然后,loaded_model = tf.keras.models.load_model('saved_model/my_model')将得到以下异常:
TypeError Traceback (most recent call last)
<ipython-input-129-c92edaf0db7f> in <module>()
----> 1 load_model = tf.keras.models.load_model('saved_model/my_model')
2 # load_model.preditct(val_sentences)
1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/tensor_util.py in make_tensor_proto(values, dtype, shape, verify_shape, allow_broadcast)
547 str_values = [compat.as_bytes(x) for x in proto_values]
548 except TypeError:
--> 549 raise TypeError(f"Failed to convert elements of {values} to Tensor. "
550 "Consider casting elements to a supported type. See "
551 "https://www.tensorflow.org/api_docs/python/tf/dtypes "
TypeError: Exception encountered when calling layer "conv1d" (type Conv1D).
Failed to convert elements of tf.RaggedTensor(values=tf.RaggedTensor(values=Tensor("Placeholder:0", shape=(None, 128), dtype=float32), row_splits=Tensor("Placeholder_1:0", shape=(None,), dtype=int64)), row_splits=Tensor("conv1d/Conv1D/RaggedExpandDims/RaggedFromUniformRowLength/RowPartitionFromUniformRowLength/mul:0", shape=(None,), dtype=int64)) to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.
Call arguments received:
• inputs=tf.RaggedTensor(values=Tensor("Placeholder:0", shape=(None, 128), dtype=float32), row_splits=Tensor("Placeholder_1:0", shape=(None,), dtype=int64))发布于 2022-03-18 17:12:06
看起来您需要在文件名上添加.h5。您需要将.h5添加到文件名的末尾。而不是:
model_2.save('saved_model/my_model')尝试
model_2.save('saved_model/my_model.h5')
#Notice the .h5 -------------------^您还需要将loaded_model = tf.keras.models.load_model('saved_model/my_model')更改为loaded_model = tf.keras.models.load_model('saved_model/my_model.h5')
发布于 2022-07-18 13:46:09
它用tensorflow 2.9修好了。更新tf版本将有助于提供更多信息:https://github.com/tensorflow/tensorflow/commit/d3b6494e82085a397bad2260a43ea6769aa20f5d
https://stackoverflow.com/questions/70119317
复制相似问题