我有一些可变长度数据矩阵及其相关标签的例子,我想用它来训练一个LSTM网络。我知道,至少对于每一批,我应该填充数据样本(例如使用keras.preprocessing.sequence.pad_sequences),并且我成功地为网络提供了numpy数组,但我不知道如何使用TFRecord数据集。
我有一个典型的TFRecord文件读取代码,如下所示:
featuresDict = {'data': tf.FixedLenSequenceFeature([], dtype=tf.string),
'dataShape': tf.FixedLenSequenceFeature([], dtype=tf.int64),
'label': tf.FixedLenSequenceFeature([], dtype=tf.int64)
}
def parse_tfrecord(example):
context, features = tf.parse_single_sequence_example(example, sequence_features=featuresDict)
label = features['label']
data_shape = features['dataShape']
data = tf.decode_raw(features['data'], tf.int64)
data = tf.reshape(data, data_shape)
return label, data
def DataGenerator(fileName, numEpochs=None, batchSize=None):
dataset = tf.data.TFRecordDataset(fileName, compression_type='GZIP')
dataset = dataset.map(parse_tfrecord)
dataset = dataset.batch(batchSize)
dataset = dataset.repeat(numEpochs)
return dataset我可以解析每个示例并生成原始数据矩阵和标签。然后,DataGenerator函数定义数据集并设置该数据集的批处理和重复功能。然后创建一个DataGenerator对象并使用它来适应我的模型:
train_data = DataGenerator(fileName='train.gz', numEpochs=epochs, batchSize=batch_size)
model.fit(train_data, epochs=epochs, steps_per_epoch = train_steps, ...)在代码中可以将填充函数放在哪里?一般来说,如果我想使用dataset API进行批处理级别的预处理,我如何才能做到这一点?
发布于 2019-08-09 15:44:12
这样做的一种方法是,当您向写入时,在TFRecords预处理期间使用TFRecords填充序列。然后您可以使用与上面相同的代码。
但我建议衬垫,它的工作方式类似于Keras序列预处理。如果已知维数(padded_shapes为某个常量),则将序列填充到该常数中。否则,它们会被填充到最长的序列中。
https://stackoverflow.com/questions/56569794
复制相似问题