首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在keras中处理数据集时的批处理

在keras中处理数据集时的批处理
EN

Stack Overflow用户
提问于 2019-06-12 20:19:10
回答 1查看 1K关注 0票数 0

我有一些可变长度数据矩阵及其相关标签的例子,我想用它来训练一个LSTM网络。我知道,至少对于每一批,我应该填充数据样本(例如使用keras.preprocessing.sequence.pad_sequences),并且我成功地为网络提供了numpy数组,但我不知道如何使用TFRecord数据集。

我有一个典型的TFRecord文件读取代码,如下所示:

代码语言:javascript
复制
featuresDict = {'data': tf.FixedLenSequenceFeature([], dtype=tf.string),
                'dataShape': tf.FixedLenSequenceFeature([], dtype=tf.int64),
                'label': tf.FixedLenSequenceFeature([], dtype=tf.int64)
               }

def parse_tfrecord(example):
    context, features = tf.parse_single_sequence_example(example, sequence_features=featuresDict)   
    label = features['label']
    data_shape = features['dataShape']
    data = tf.decode_raw(features['data'], tf.int64)
    data = tf.reshape(data, data_shape)
    return label, data

def DataGenerator(fileName, numEpochs=None, batchSize=None):    
  dataset = tf.data.TFRecordDataset(fileName, compression_type='GZIP')
  dataset = dataset.map(parse_tfrecord)
  dataset = dataset.batch(batchSize)
  dataset = dataset.repeat(numEpochs)
  return dataset

我可以解析每个示例并生成原始数据矩阵和标签。然后,DataGenerator函数定义数据集并设置该数据集的批处理和重复功能。然后创建一个DataGenerator对象并使用它来适应我的模型:

代码语言:javascript
复制
train_data = DataGenerator(fileName='train.gz', numEpochs=epochs, batchSize=batch_size)
model.fit(train_data, epochs=epochs, steps_per_epoch = train_steps, ...)

在代码中可以将填充函数放在哪里?一般来说,如果我想使用dataset API进行批处理级别的预处理,我如何才能做到这一点?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-08-09 15:44:12

这样做的一种方法是,当您向写入时,在TFRecords预处理期间使用TFRecords填充序列。然后您可以使用与上面相同的代码。

但我建议衬垫,它的工作方式类似于Keras序列预处理。如果已知维数(padded_shapes为某个常量),则将序列填充到该常数中。否则,它们会被填充到最长的序列中。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56569794

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档