首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >随机交错一个tf.Dataset和另一个tf.Dataset

随机交错一个tf.Dataset和另一个tf.Dataset
EN

Stack Overflow用户
提问于 2019-01-03 15:16:17
回答 1查看 87关注 0票数 1

我有两个数据集:

代码语言:javascript
复制
main_ds = tf.data.Dataset.from_tensor_slices(list(range(1000, 1100)))
backgroud_ds = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4])

我想要一批交错的main_dsbackgroud_ds数据随机。例如,10大小的批处理应该如下所示:

代码语言:javascript
复制
[3, 1017, 1039, 3, 2, 1024, 4, 1, 1053, 4]

我尝试了以下几点:

代码语言:javascript
复制
def interlace_background(image, background):
    return  tf.cond(tf.random_uniform([]) < .5, lambda: image, lambda: background)

background_ds = background_ds.shuffle(10).repeat(-1)
background_it = background_ds.make_initializable_iterator()
background_next = background_it.get_next()

main_ds = main_ds.shuffle(10)\
                 .repeat(-1)\
                 .map(lambda x: interlace_background(x, background_next))\
                 .batch(10)
main_it = main_ds.make_initializable_iterator()
main_next = main_it.get_next()

但我在所有批次都有固定的背景:

代码语言:javascript
复制
batch 0: [   3 1006    3 1001    3 1005 1015 1000    3    3]
batch 1: [1007    3 1012 1018 1013    3 1008 1019    3    3]
batch 2: [1016    3 1025    3    3    3 1021    3    3 1035]
batch 3: [1038    3    3 1023 1020    3    3 1046 1034 1047]
batch 4: [   3    3 1039    3    3    3    3    3 1053    3]

为什么背景是固定的。上面的背景总是3),我该如何解决这个问题?

以下是完全可复制的代码:

代码语言:javascript
复制
import tensorflow as tf
import numpy as np

def interlace_background(image, background):
    return  tf.cond(tf.random_uniform([]) < .5, lambda: image, lambda: background)

main_ds = tf.data.Dataset.from_tensor_slices(list(range(1000, 1100)))
background_ds = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4])

background_ds = background_ds.shuffle(10).repeat(-1)
background_it = background_ds.make_initializable_iterator()
background_next = background_it.get_next()

main_ds = main_ds.shuffle(10)\
                 .repeat(-1)\
                 .map(lambda x: interlace_background(x, background_next))\
                 .batch(10)
main_it = main_ds.make_initializable_iterator()
main_next = main_it.get_next()

with tf.Session() as sess:
    sess.run(background_it.initializer)
    sess.run(main_it.initializer)
    for i in range(5):
        print('batch %i' % i, sess.run(main_next))
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-01-03 15:38:35

您可以使用Dataset.zip()Dataset.map()做同样的事情。

以下是代码:

代码语言:javascript
复制
import tensorflow as tf

def interlace_background(image, background):
    return tf.cond(tf.random_uniform([]) < .5, lambda: image, lambda: background)


main_ds = tf.data.Dataset.from_tensor_slices(list(range(1000, 1100))).shuffle(100)
background_ds = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4]).shuffle(4)

new_ds = tf.data.Dataset \
    .zip((main_ds, background_ds)) \
    .repeat(-1) \
    .map(lambda x, y: interlace_background(x, y)) \
    .batch(10)

iterator = new_ds.make_initializable_iterator()
next_item = iterator.get_next()

with tf.Session() as sess:
    sess.run(iterator.initializer)
    for i in range(5):
        print('batch %i' % i, sess.run(next_item))

输出:

代码语言:javascript
复制
batch 0 [1065    2    4    1    2    4    1 1036 1072 1020]
batch 1 [   4    3    2 1057    1    4    2 1077    3    1]
batch 2 [   3 1044 1042 1049 1029    1    3 1069 1018    3]
batch 3 [   2    4 1089 1094    2 1022 1041 1006    1    3]
batch 4 [1079    2    1    3 1023 1042    4 1018 1054    4]
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/54025069

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档