我目前正在使用生成器来使用tf.data.Dataset.from_generator生成我的培训和验证数据集。我有一个类方法来处理这个问题:
def build_dataset(self, batch_size=16, shuffle=16, validation=None):
train_dataset = tf.data.Dataset.from_generator(import_images(validation=validation), (tf.float32, tf.float32))
self.train_dataset = train_dataset.shuffle(shuffle).repeat(-1).batch(batch_size).prefetch(1)
if validation is not None:
val_dataset = tf.data.Dataset.from_generator(import_images(validation=validation), (tf.float32, tf.float32))
self.val_dataset = val_dataset.repeat(1).batch(batch_size).prefetch(1)问题是将(validation=validation)传递给我的import_images生成器创建了Tensorflow不想要的生成器对象,并给出了错误:
TypeError: `generator` must be callable.因为我必须传递validation来告诉我的生成器生成一个单独的培训和验证版本,所以我需要创建同一个生成器的两个版本。它也不允许我传递其他参数来控制训练和验证示例的百分比--这意味着生成器必须是静态的。有什么建议吗?
发布于 2020-08-19 00:15:43
我最近遇到了一个类似的问题,但我是一个初学者,所以不确定这是否会有帮助。
尝试在类中添加一个调用函数。
下面是引发TypeError: `generator` must be callable.的原始类
class DataGen:
def __init__(self, files, data_path):
self.i = 0
self.files=files
self.data_path=data_path
def __load__(self, files_name):
data_path = os.path.join(self.data_path, files_name)
arr_img, arr_mask = load_patch(data_path)
return arr_img, arr_mask
def getitem(self, index):
_img, _mask = self.__load__(self.files[index])
return _img, _mask
def __iter__(self):
return self
def __next__(self):
if self.i < len(self.files):
img_arr, mask_arr = self.getitem(self.i)
self.i += 1
else:
raise StopIteration()
return img_arr, mask_arr然后,我修改了代码如下,它对我起作用。
class DataGen:
def __init__(self, files, data_path):
self.i = 0
self.files=files
self.data_path=data_path
def __load__(self, files_name):
data_path = os.path.join(self.data_path, files_name)
arr_img, arr_mask = load_patch(data_path)
return arr_img, arr_mask
def getitem(self, index):
_img, _mask = self.__load__(self.files[index])
return _img, _mask
def __iter__(self):
return self
def __next__(self):
if self.i < len(self.files):
img_arr, mask_arr = self.getitem(self.i)
self.i += 1
else:
raise StopIteration()
return img_arr, mask_arr
def __call__(self):
self.i = 0
return selfhttps://stackoverflow.com/questions/63345896
复制相似问题