首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何找到tf.Dataset的len()

如何找到tf.Dataset的len()
EN

Stack Overflow用户
提问于 2020-12-29 01:13:24
回答 1查看 608关注 0票数 2

我已经开始使用tf.data.Dataset作为将数据加载到keras模型中的一种方式,因为它们看起来比keras的ImageDataGenerator快得多,而且比数组上的培训更节省内存。

一个人认为我无法理解我的想法,那就是我似乎找不到一种方法来访问数据集的len()。Keras的ImageDataGenerator有一个名为n的属性,用于此目的。这使得我的代码非常难看,因为我需要在片段的各个部分对长度进行硬编码(例如,了解一个时代有多少次迭代)。

有什么我能解决这个问题的主意吗?

一个示例脚本:

代码语言:javascript
复制
# Generator
def make_mnist_train_generator(batch_size):
    (x_train, y_train), (_,_) = tf.keras.mnist.load_data()

    x_train = x_train.reshape((-1, 28, 28, 1))
    x_train = x_train.astype(np.float32) / 225.

    y_train = tf.keras.utils.to_categorical(y_train, 10)

    ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    ds = ds.shuffle(buffer_size=len(x_train))
    ds = ds.repeat()
    ds = ds.batch(batch_size=batch_size)
    ds = ds.prefetch(buffer_size=1)

    return ds


model = ...  # create a tf.keras model

batch_size 256
gen = make_mnist_train_generator(batch_size)

# Training
model.fit(gen, epochs=50, steps_per_epoch=60000//batch_size+1)  # Hard coded size of generator
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-12-29 01:25:25

tl;dr

不幸的是,tf.data.Dataset是一个生成器,而且没有,没有找到其大小的固有方法。

但是..。

一般来说,当您使用.from_tensor_slices()时,您可以通过在这个方法中添加的参数来知道它的大小,在您的例子中是x_train。您唯一的问题是在函数中创建它。

要绕过这个问题,您可以做的一个简单的方法是自己添加一个__len__属性!我发现你能做到的最简单的方法是:

代码语言:javascript
复制
ds.__class__ = type(ds.__class__.__name__, (ds.__class__,), {'__len__': lambda self: len(x_train)})

在你的例子中,它看起来应该是这样的:

代码语言:javascript
复制
def make_mnist_train_generator(batch_size):
    (x_train, y_train), (_,_) = tf.keras.mnist.load_data()

    x_train = x_train.reshape((-1, 28, 28, 1))
    x_train = x_train.astype(np.float32) / 225.

    y_train = tf.keras.utils.to_categorical(y_train, 10)

    ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    ds = ds.shuffle(buffer_size=len(x_train))
    ds = ds.repeat()
    ds = ds.batch(batch_size=batch_size)
    ds = ds.prefetch(buffer_size=1)

    ds.__class__ = type(ds.__class__.__name__, (ds.__class__,), {'__len__': lambda self: len(x_train)})

    return ds


gen = make_mnist_train_generator(batch_size)

model.fit(gen, epochs=50, steps_per_epoch=len(gen)//batch_size+1)  # Hard coded size of generator

为什么要这么做?

我在过去做过这件事,它令人惊讶地有用。您希望您的生成器有一个len()的原因有很多。一些例子是:

如果希望将生成器放在单独的模块中,并导入生成器

  • ,则将其导入
  • ,如果生成器打算供不知道用于创建

的数据的其他人使用

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65486078

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档