首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Pix2PiX gan教程的自定义数据集

Pix2PiX gan教程的自定义数据集
EN

Stack Overflow用户
提问于 2022-08-27 20:15:17
回答 1查看 111关注 0票数 1

我想使用这个Pix2Pix教程,但是使用我自己的数据集。我的数据集位于我的Google中,在一个名为facePictues的文件夹中。在该文件夹中,我有151个子文件夹,其名称为1-151。在这些文件夹中有两个图像:一个是.jpg,它是一个模型脸的真实图像,另一个图像是一个.png,它是该模型的脸的漫画书样式。

我需要从Google加载我的数据集,并将数据集分成两个不同的数据集张量,一个用于训练和测试Pix2Pix模型,然后将真实的人脸图像与绘图人脸标签图像分开。

因此,我可以希望遵循本教程。

这是我的代码:

代码语言:javascript
复制
all_image_paths = glob.glob(str(data_dir/'*/*'))
all_image_paths = np.array([pathlib.Path(path)for path in all_image_paths])
print("Total images: ", len(all_image_paths))
train_dataset, test_dataset = train_test_split(all_image_paths, test_size=0.2)
print("Train dataset size: ", len(train_dataset))
print("Test dataset size: ", len(test_dataset))

def load_image(dataset):
    input_image, target_image = imageio.imread(dataset), imageio.imread(dataset)

    input_image = np.array(PIL.Image.fromarray(input_image).resize((256, 256)))

    target_image = np.array(PIL.Image.fromarray(target_image).resize((256, 256)))

    input_image = (input_image - 127.5) / 127.5

    target_image = (target_image - 127.5) / 127.5

    input_image = np.expand_dims(input_image, axis=0)
    target_image = np.expand_dims(target_image, axis=0)
    return input_image, target_image

input_image, target_image = load_image(train_dataset)
display.display(input_image)
display.display(target_image)

我得到了这个错误:

代码语言:javascript
复制
---------------------------------------------------------------------------
 OSError                                   Traceback (most recent call last)
<ipython-input-13-ca46c378edd4> in <module>
 ----> 1 input_image, target_image = load_image(all_image_paths)
  2 display.display(input_image)
  3 display.display(target_image)

  4 frames
   /usr/local/lib/python3.7/dist-packages/imageio/core/request.py in _parse_uri(self, uri)
  220             if len(uri_r) > 60:
  221                 uri_r = uri_r[:57] + "..."
--> 222             raise IOError("Cannot understand given URI: %s." % uri_r)
223 
224         # Check if this is supported

OSError: Cannot understand given URI: array([PosixPath('gdrive/My 
Drive/DrawingDataSet/FacePict....

代码更新:

代码语言:javascript
复制
  def load_image(image_file):
  input_image, target_image = imageio.imread(image_file), 
  imageio.imread(image_file)
  input_image = 
  np.array(PIL.Image.fromarray(input_image).resize((256, 256)))
  target_image = 
  np.array(PIL.Image.fromarray(target_image).resize((256, 256)))
  input_image = (input_image - 127.5) / 127.5
  target_image = (target_image - 127.5) / 127.5
  input_image = np.expand_dims(input_image, axis=0)
  target_image = np.expand_dims(target_image, axis=0)
  return input_image, target_image

   real_image, drawing_image = load_image(train_dataset[0])
   RealDataSet = tf.constant(real_image);
   print(RealDataSet)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-08-27 20:40:07

如果您想获得路径列表,您必须这样做:

代码语言:javascript
复制
[pathlib.Path(path) for path in all_image_paths]

而不是:

代码语言:javascript
复制
np.array([pathlib.Path(path)for path in all_image_paths])

使用您的代码,您基本上是在创建一个字符串列表的numpy数组,这正是您的错误所抱怨的。

此外,还应该将单个图像传递给load_image函数以使其工作,例如,数据集的第一个图像:

代码语言:javascript
复制
input_image, target_image = load_image(train_dataset[0])
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73514131

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档