首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >运行Keras h5模型

运行Keras h5模型
EN

Stack Overflow用户
提问于 2021-02-13 08:59:29
回答 1查看 147关注 0票数 0

我正在尝试运行这个h5模型发现的here of ALASKA2图像隐写分析竞赛。

我想用下面的代码预测一个RGB图像c1.bmp的标签:

代码语言:javascript
复制
import efficientnet.tfkeras as efn
import tensorflow as tf
from tensorflow import keras
import numpy as np


def decode_image(filename, image_size=(512, 512)):
    bits = tf.io.read_file(filename)
    image = tf.image.decode_bmp(bits, channels=3)
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.image.resize(image, image_size)

    return image


img = decode_image('imgs/c1.bmp')
model = keras.models.load_model("model.h5")
print(model.predict(img, verbose=1))

但是,运行此代码会导致以下错误:

代码语言:javascript
复制
File "alaska.py", line 20, in <module>
    print(model.predict(img, verbose=1))
  File "Python38\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1629, in predict
    tmp_batch_outputs = self.predict_function(iterator)
  File "Python38\lib\site-packages\tensorflow\python\eager\def_function.py", line 828, in __call__
    result = self._call(*args, **kwds)
  File "Python38\lib\site-packages\tensorflow\python\eager\def_function.py", line 871, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "Python38\lib\site-packages\tensorflow\python\eager\def_function.py", line 725, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "Python38\lib\site-packages\tensorflow\python\eager\function.py", line 2969, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "Python38\lib\site-packages\tensorflow\python\eager\function.py", line 3361, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "Python38\lib\site-packages\tensorflow\python\eager\function.py", line 3196, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "Python38\lib\site-packages\tensorflow\python\framework\func_graph.py", line 990, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "Python38\lib\site-packages\tensorflow\python\eager\def_function.py", line 634, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "Python38\lib\site-packages\tensorflow\python\framework\func_graph.py", line 977, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:

    Python38\lib\site-packages\tensorflow\python\keras\engine\training.py:1478 predict_function  *
        return step_function(self, iterator)
    Python38\lib\site-packages\tensorflow\python\keras\engine\training.py:1468 step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    Python38\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:1259 run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    Python38\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2730 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
   Python38\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:3417 _call_for_each_replica
        return fn(*args, **kwargs)
    Python38\lib\site-packages\tensorflow\python\keras\engine\training.py:1461 run_step  **
        outputs = model.predict_step(data)
    Python38\lib\site-packages\tensorflow\python\keras\engine\training.py:1434 predict_step
        return self(x, training=False)
    Python38\lib\site-packages\tensorflow\python\keras\engine\base_layer.py:998 __call__
        input_spec.assert_input_compatibility(self.input_spec, inputs, self.name)
    Python38\lib\site-packages\tensorflow\python\keras\engine\input_spec.py:271 assert_input_compatibility
        raise ValueError('Input ' + str(input_index) +

    ValueError: Input 0 is incompatible with layer sequential: expected shape=(None, 512, 512, 3), found shape=(32, 512, 3)

我有Python 3.8.7和tensorflow 2.4.1,并在Windows 8中使用Pycharm。

这个错误意味着什么,我如何解决它?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-02-14 02:10:29

您忘记添加批次维度。只需将以下转换添加到ddecode_image函数:

代码语言:javascript
复制
image = tf.expand_dims(image, axis=0) 
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66180960

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档