首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用TensorFlow数据集API的历元计数器

使用TensorFlow数据集API的历元计数器
EN

Stack Overflow用户
提问于 2017-11-21 10:27:47
回答 4查看 3.7K关注 0票数 10

我正在将TensorFlow代码从旧的队列接口更改为新的数据集API。在我的旧代码中,我通过每次在队列中访问和处理新的输入张量时递增一个tf.Variable来跟踪时间计数。我希望使用新的Dataset API进行这一划时代的计数,但我在使它工作时遇到了一些困难。

因为我在预处理阶段产生了可变数量的数据项,所以在训练循环中增加一个(Python)计数器并不是简单的事情--我需要计算与队列或数据集的输入有关的历元计数。

我模仿了以前使用旧队列系统时的情况,下面是Dataset API (简化示例)的结果:

代码语言:javascript
复制
with tf.Graph().as_default():

    data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data")
    input_tensors = (data,)

    epoch_counter = tf.Variable(initial_value=0.0, dtype=tf.float32,
                                trainable=False)

    def pre_processing_func(data_):
        data_size = tf.constant(0.1, dtype=tf.float32)
        epoch_counter_op = tf.assign_add(epoch_counter, data_size)
        with tf.control_dependencies([epoch_counter_op]):
            # normally I would do data-augmentation here
            results = (tf.expand_dims(data_, axis=0),)
            return tf.data.Dataset.from_tensor_slices(results)

    dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors)
    dataset = dataset_source.flat_map(pre_processing_func)
    dataset = dataset.repeat()
    # ... do something with 'dataset' and print
    # the value of 'epoch_counter' every once a while

然而,这是行不通的。它以一条神秘的错误消息崩溃:

代码语言:javascript
复制
 TypeError: In op 'AssignAdd', input types ([tf.float32, tf.float32])
 are not compatible with expected types ([tf.float32_ref, tf.float32])

更仔细的检查表明,epoch_counter变量可能根本无法在pre_processing_func中访问。它生活在不同的图形中吗?

知道如何修复上面的例子吗?或者如何通过其他方法获得时间计数器(带有小数点,例如0.4或2.9)?

EN

回答 4

Stack Overflow用户

回答已采纳

发布于 2017-11-21 16:02:40

TL;DR:将epoch_counter的定义替换为:

代码语言:javascript
复制
epoch_counter = tf.get_variable("epoch_counter", initializer=0.0,
                                trainable=False, use_resource=True)

在TensorFlow转换中使用tf.data.Dataset变量有一些限制。原则上的限制是,所有变量都必须是“资源变量”,而不是旧的“引用变量”;不幸的是,tf.Variable仍然为了向后兼容性的原因创建“引用变量”。

一般来说,如果可能的话,我不建议在tf.data管道中使用变量。例如,您可能可以使用Dataset.range()定义一个划时代计数器,然后执行如下操作:

代码语言:javascript
复制
epoch_counter = tf.data.Dataset.range(NUM_EPOCHS)
dataset = epoch_counter.flat_map(lambda i: tf.data.Dataset.zip(
    (pre_processing_func(data), tf.data.Dataset.from_tensors(i).repeat()))

上面的片段将一个划时代计数器作为第二个组件附加到每个值。

票数 8
EN

Stack Overflow用户

发布于 2018-09-14 19:20:11

要添加到@mrry的伟大答案,如果您想停留在tf.data管道内,并且希望跟踪每个时代的迭代,您可以在下面尝试我的解决方案。如果您有非单位批次大小,我想您将不得不添加行data = data.batch(bs)

代码语言:javascript
复制
import tensorflow as tf
import itertools

def step_counter(): 
    for i in itertools.count(): yield i

num_examples = 3
num_epochs = 2
num_iters = num_examples * num_epochs

features = tf.data.Dataset.range(num_examples)
labels = tf.data.Dataset.range(num_examples)
data = tf.data.Dataset.zip((features, labels))
data = data.shuffle(num_examples)

step = tf.data.Dataset.from_generator(step_counter, tf.int32)
data = tf.data.Dataset.zip((data, step))

epoch = tf.data.Dataset.range(num_epochs)
data = epoch.flat_map(
    lambda i: tf.data.Dataset.zip(
        (data, tf.data.Dataset.from_tensors(i).repeat())))

data = data.repeat(num_epochs)
it = data.make_one_shot_iterator()
example = it.get_next()

with tf.Session() as sess:
    for _ in range(num_iters):
        ((x, y), st), ep = sess.run(example)
        print(f'step {st} \t epoch {ep} \t x {x} \t y {y}')

指纹:

代码语言:javascript
复制
step 0   epoch 0     x 2     y 2
step 1   epoch 0     x 0     y 0
step 2   epoch 0     x 1     y 1
step 0   epoch 1     x 2     y 2
step 1   epoch 1     x 0     y 0
step 2   epoch 1     x 1     y 1
票数 1
EN

Stack Overflow用户

发布于 2019-02-20 19:35:25

data = data.repeat(num_epochs)的结果是重复已经为num_epochs重复的数据集(也是划时代计数器)。用for _ in range(num_iters):代替for _ in range(num_iters+1):,可以很容易地得到。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/47410778

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档