首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何轻松地将PyTorch数据中心转换为tf.Dataset?

如何轻松地将PyTorch数据中心转换为tf.Dataset?
EN

Stack Overflow用户
提问于 2021-12-01 18:32:39
回答 1查看 556关注 0票数 0

如何将pytorch数据中心转换为tf.Dataset

我发现了这个片段:-

代码语言:javascript
复制
def convert_pytorch_dataloader_to_tf_dataset(dataloader, batch_size, shuffle=True):
    dataset = tf.data.Dataset.from_generator(
        lambda: dataloader,
        output_types=(tf.float32, tf.float32),
        output_shapes=(tf.TensorShape([256, 512]), tf.TensorShape([2,]))
    )
    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(dataloader.dataset))
    dataset = dataset.batch(batch_size)
    return dataset

但根本不起作用。

是否有一个内置的选项可以轻松地将dataloaders导出到tf.Dataset?我有一个非常复杂的数据处理程序,所以一个简单的解决方案应该确保没有bug :)

EN

回答 1

Stack Overflow用户

发布于 2022-07-19 12:22:55

对于h5py格式的数据,可以使用下面的脚本。name_x是h5py中的特性名称,name_y是标签的文件名。此方法具有内存效率高的特点,您可以逐批输入数据。

代码语言:javascript
复制
class Generator(object):

def __init__(self,open_directory,batch_size,name_x,name_y):

    self.open_directory = open_directory

    data_f = h5py.File(open_directory, "r")

    self.x = data_f[name_x]
    self.y = data_f[name_y]

    if len(self.x.shape) == 4:
        self.shape_x = (None, self.x.shape[1], self.x.shape[2], self.x.shape[3])

    if len(self.x.shape) == 3:
        self.shape_x = (None, self.x.shape[1], self.x.shape[2])

    if len(self.y.shape) == 4:
        self.shape_y = (None, self.y.shape[1], self.y.shape[2], self.y.shape[3])

    if len(self.y.shape) == 3:
        self.shape_y = (None, self.y.shape[1], self.y.shape[2])

    self.num_samples = self.x.shape[0]
    self.batch_size = batch_size
    self.epoch_size = self.num_samples//self.batch_size+1*(self.num_samples % self.batch_size != 0)

    self.pointer = 0
    self.sample_nums = np.arange(0, self.num_samples)
    np.random.shuffle(self.sample_nums)


def data_generator(self):

    for batch_num in range(self.epoch_size):

        x = []
        y = []

        for elem_num in range(self.batch_size):

            sample_num = self.sample_nums[self.pointer]

            x += [self.x[sample_num]]
            y += [self.y[sample_num]]

            self.pointer += 1

            if self.pointer == self.num_samples:
                self.pointer = 0
                np.random.shuffle(self.sample_nums)
                break

        x = np.array(x,
                     dtype=np.float32)
        y = np.array(y,
                     dtype=np.float32)

        yield x, y

def get_dataset(self):
    dataset = tf.data.Dataset.from_generator(self.data_generator,
                                             output_types=(tf.float32,
                                                           tf.float32),
                                             output_shapes=(tf.TensorShape(self.shape_x),
                                                            tf.TensorShape(self.shape_y)))
    dataset = dataset.prefetch(1)

    return dataset
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70189513

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档