首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何按特定值过滤tf.data.Dataset?

如何按特定值过滤tf.data.Dataset?
EN

Stack Overflow用户
提问于 2018-02-16 19:27:08
回答 4查看 9.9K关注 0票数 16

我通过读取TFRecords创建一个数据集,我映射值,并希望过滤数据集以获得特定值,但由于结果是带有张量的字典,因此我无法获得张量的实际值,也无法使用tf.cond() / tf.equal检查它。我怎么才能做到这一点呢?

代码语言:javascript
复制
def mapping_func(serialized_example):
    feature = { 'label': tf.FixedLenFeature([1], tf.string) }
    features = tf.parse_single_example(serialized_example, features=feature)
    return features

def filter_func(features):
    # this doesn't work
    #result = features['label'] == 'some_label_value'
    # neither this
    result = tf.reshape(tf.equal(features['label'], 'some_label_value'), [])
    return result

def main():
    file_names = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
    dataset = tf.contrib.data.TFRecordDataset(file_names)
    dataset = dataset.map(mapping_func)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.filter(filter_func)
    dataset = dataset.repeat()
    iterator = dataset.make_one_shot_iterator()
    sample = iterator.get_next()
EN

回答 4

Stack Overflow用户

回答已采纳

发布于 2018-02-20 00:10:09

我在回答我自己的问题。我找到问题了!

我需要做的是像这样tf.unstack()标签:

代码语言:javascript
复制
label = tf.unstack(features['label'])
label = label[0]

在我把它给tf.equal()之前

代码语言:javascript
复制
result = tf.reshape(tf.equal(label, 'some_label_value'), [])

我认为问题是标签被定义为一个数组,其中一个元素的类型为string tf.FixedLenFeature([1], tf.string),所以为了获得第一个也是单个元素,我必须解压它(这会创建一个列表),然后获得索引为0的元素,如果我错了,请纠正我。

票数 5
EN

Stack Overflow用户

发布于 2018-05-08 02:57:15

我认为你一开始就不需要把label设为一维数组。

使用:

代码语言:javascript
复制
feature = {'label': tf.FixedLenFeature((), tf.string)}

您不需要在filter_func中拆分标签

票数 1
EN

Stack Overflow用户

发布于 2021-02-04 14:56:11

读取、过滤数据集非常容易,并且不需要对任何内容进行拆分。

要读取数据集,请执行以下操作:

代码语言:javascript
复制
print(my_dataset, '\n\n')
##let us print the first 3 records
for record in my_dataset.take(3):
    ##below could be large in case of image
    print(record)
    ##let us print a specific key
    print(record['key2'])

过滤同样简单:

代码语言:javascript
复制
my_filtereddataset = my_dataset.filter(_filtcond1)

在这里您可以随心所欲地定义_filtcond1。假设您的数据集中有一个'true‘'false’布尔标志,那么:

代码语言:javascript
复制
@tf.function
def _filtcond1(x):
    return x['key_bool'] == 1

或者甚至是lambda函数:

代码语言:javascript
复制
my_filtereddataset = my_dataset.filter(lambda x: x['key_int']>13)

如果您正在读取一个尚未创建的数据集,或者您不知道键(就像OPs的情况一样),您可以首先使用下面的命令来了解键和结构:

代码语言:javascript
复制
import json
from google.protobuf.json_format import MessageToJson

for raw_record in noidea_dataset.take(1):
    example = tf.train.Example()
    example.ParseFromString(raw_record.numpy())
    ##print(example) ##if image it will be toooolong
    m = json.loads(MessageToJson(example))
    print(m['features']['feature'].keys())

现在你可以继续过滤了

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/48825785

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档