创建tfrecord的代码:
def convert(self):
with tf.python_io.TFRecordWriter(self.tfrecord_out) as writer:
example = self._convert_image()
writer.write(example.SerializeToString())
def _convert_image(self):
for (path, label) in zip(self.image_paths, self.labels):
label = int(label)
# Read image data in terms of bytes
with open(path, 'rb') as fid:
png_bytes = fid.read()
example = tf.train.Example(features=tf.train.Features(feature={
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[png_bytes]))
}))
return example我的问题是,当我从文件中读取时,图像无法正确解码:
def parse(self, serialized):
features = \
{
'image': tf.FixedLenFeature([], tf.string)
}
parsed_example = tf.parse_single_example(serialized=serialized,
features=features)
image_raw = parsed_example['image']
image = tf.image.decode_png(contents=image_raw, channels=3, dtype=tf.uint8)
image = tf.cast(image, tf.float32)
return image`有人知道这是为什么吗?

发布于 2019-01-02 02:38:43
找到了解决方案,希望我的愚蠢错误能帮助其他人。
当将张量重塑为4维的张量[batch_size, height, width, channels]时,我切换了宽度和高度。
正确的整形代码是:
x_reshaped = session.run(tf.reshape(tensor=decoded_png_uint8, shape=[batch_size, height, width, channels], name="x_reshaped"))但是我有shape=[batch_size, width, height, channels]。啊,好吧。每一天都是上学的日子。

https://stackoverflow.com/questions/53981111
复制相似问题