我在gcp平台上运行一个tensorflow模型。数据集很大,并不是所有东西都可以同时保存在内存中,因此我使用以下代码将数据读入tf.dataset中:
def read_dataset(filepattern):
def decode_csv(value_column):
cols = tf.io.decode_csv(value_column, record_defaults=[[0.0],[0],[0.0])
features=[cols[1],cols[2]]
label = cols[0]
return features, label
# Create list of files that match pattern
file_list = tf.io.gfile.glob(filepattern)
# Create dataset from file list
dataset = tf.data.TextLineDataset(file_list).map(decode_csv)
return dataset
training_data=read_dataset(<filepattern>)问题是我的数据中的第二列是绝对的,我需要使用一个热编码。如何做到这一点,既可以在函数decode_csv中,也可以在以后操作tf.dataset。
发布于 2019-07-30 21:37:14
你可以用热。假设第二列为cols[1],并且已将分类值转换为整数,则可以执行以下操作:
def decode_csv(value_column):
cols = tf.io.decode_csv(value_column, record_defaults=[[0.0],[0],[0.0]])
features=[cols[1], tf.one_hot(cols[2], nb_classes)]
label = cols[0]
return features, label注意:未测试。
https://stackoverflow.com/questions/57273291
复制相似问题