首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Dataset.from_generator: TypeError:` `generator`‘必须是可调用的

Dataset.from_generator: TypeError:` `generator`‘必须是可调用的
EN

Stack Overflow用户
提问于 2020-08-10 18:21:29
回答 1查看 1.3K关注 0票数 1

我目前正在使用生成器来使用tf.data.Dataset.from_generator生成我的培训和验证数据集。我有一个类方法来处理这个问题:

代码语言:javascript
复制
def build_dataset(self, batch_size=16, shuffle=16, validation=None):
    
    train_dataset = tf.data.Dataset.from_generator(import_images(validation=validation), (tf.float32, tf.float32))
    self.train_dataset = train_dataset.shuffle(shuffle).repeat(-1).batch(batch_size).prefetch(1)
    
    if validation is not None:
        val_dataset = tf.data.Dataset.from_generator(import_images(validation=validation), (tf.float32, tf.float32))
        self.val_dataset = val_dataset.repeat(1).batch(batch_size).prefetch(1)

问题是将(validation=validation)传递给我的import_images生成器创建了Tensorflow不想要的生成器对象,并给出了错误:

代码语言:javascript
复制
TypeError: `generator` must be callable.

因为我必须传递validation来告诉我的生成器生成一个单独的培训和验证版本,所以我需要创建同一个生成器的两个版本。它也不允许我传递其他参数来控制训练和验证示例的百分比--这意味着生成器必须是静态的。有什么建议吗?

EN

回答 1

Stack Overflow用户

发布于 2020-08-19 00:15:43

我最近遇到了一个类似的问题,但我是一个初学者,所以不确定这是否会有帮助。

尝试在类中添加一个调用函数。

下面是引发TypeError: `generator` must be callable.的原始类

代码语言:javascript
复制
class DataGen:
  def __init__(self, files, data_path):
    self.i = 0
    self.files=files
    self.data_path=data_path
  
  def __load__(self, files_name):
    data_path = os.path.join(self.data_path, files_name)
    arr_img, arr_mask = load_patch(data_path)
    return arr_img, arr_mask

  def getitem(self, index):
    _img, _mask = self.__load__(self.files[index])
    return _img, _mask

  def __iter__(self):
    return self

  def __next__(self):
    if self.i < len(self.files):
      img_arr, mask_arr = self.getitem(self.i)
      self.i += 1
    else:
      raise StopIteration()
    return img_arr, mask_arr

然后,我修改了代码如下,它对我起作用。

代码语言:javascript
复制
class DataGen:
  def __init__(self, files, data_path):
    self.i = 0
    self.files=files
    self.data_path=data_path
  
  def __load__(self, files_name):
    data_path = os.path.join(self.data_path, files_name)
    arr_img, arr_mask = load_patch(data_path)
    return arr_img, arr_mask

  def getitem(self, index):
    _img, _mask = self.__load__(self.files[index])
    return _img, _mask

  def __iter__(self):
    return self

  def __next__(self):
    if self.i < len(self.files):
      img_arr, mask_arr = self.getitem(self.i)
      self.i += 1
    else:
      raise StopIteration()
    return img_arr, mask_arr
  
  def __call__(self):
    self.i = 0
    return self
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63345896

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档