我想使用这个Pix2Pix教程,但是使用我自己的数据集。我的数据集位于我的Google中,在一个名为facePictues的文件夹中。在该文件夹中,我有151个子文件夹,其名称为1-151。在这些文件夹中有两个图像:一个是.jpg,它是一个模型脸的真实图像,另一个图像是一个.png,它是该模型的脸的漫画书样式。
我需要从Google加载我的数据集,并将数据集分成两个不同的数据集张量,一个用于训练和测试Pix2Pix模型,然后将真实的人脸图像与绘图人脸标签图像分开。
因此,我可以希望遵循本教程。
这是我的代码:
all_image_paths = glob.glob(str(data_dir/'*/*'))
all_image_paths = np.array([pathlib.Path(path)for path in all_image_paths])
print("Total images: ", len(all_image_paths))
train_dataset, test_dataset = train_test_split(all_image_paths, test_size=0.2)
print("Train dataset size: ", len(train_dataset))
print("Test dataset size: ", len(test_dataset))
def load_image(dataset):
input_image, target_image = imageio.imread(dataset), imageio.imread(dataset)
input_image = np.array(PIL.Image.fromarray(input_image).resize((256, 256)))
target_image = np.array(PIL.Image.fromarray(target_image).resize((256, 256)))
input_image = (input_image - 127.5) / 127.5
target_image = (target_image - 127.5) / 127.5
input_image = np.expand_dims(input_image, axis=0)
target_image = np.expand_dims(target_image, axis=0)
return input_image, target_image
input_image, target_image = load_image(train_dataset)
display.display(input_image)
display.display(target_image)我得到了这个错误:
---------------------------------------------------------------------------
OSError Traceback (most recent call last)
<ipython-input-13-ca46c378edd4> in <module>
----> 1 input_image, target_image = load_image(all_image_paths)
2 display.display(input_image)
3 display.display(target_image)
4 frames
/usr/local/lib/python3.7/dist-packages/imageio/core/request.py in _parse_uri(self, uri)
220 if len(uri_r) > 60:
221 uri_r = uri_r[:57] + "..."
--> 222 raise IOError("Cannot understand given URI: %s." % uri_r)
223
224 # Check if this is supported
OSError: Cannot understand given URI: array([PosixPath('gdrive/My
Drive/DrawingDataSet/FacePict....代码更新:
def load_image(image_file):
input_image, target_image = imageio.imread(image_file),
imageio.imread(image_file)
input_image =
np.array(PIL.Image.fromarray(input_image).resize((256, 256)))
target_image =
np.array(PIL.Image.fromarray(target_image).resize((256, 256)))
input_image = (input_image - 127.5) / 127.5
target_image = (target_image - 127.5) / 127.5
input_image = np.expand_dims(input_image, axis=0)
target_image = np.expand_dims(target_image, axis=0)
return input_image, target_image
real_image, drawing_image = load_image(train_dataset[0])
RealDataSet = tf.constant(real_image);
print(RealDataSet)发布于 2022-08-27 20:40:07
如果您想获得路径列表,您必须这样做:
[pathlib.Path(path) for path in all_image_paths]而不是:
np.array([pathlib.Path(path)for path in all_image_paths])使用您的代码,您基本上是在创建一个字符串列表的numpy数组,这正是您的错误所抱怨的。
此外,还应该将单个图像传递给load_image函数以使其工作,例如,数据集的第一个图像:
input_image, target_image = load_image(train_dataset[0])https://stackoverflow.com/questions/73514131
复制相似问题