首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何使用TensorFlow Dataset API进行GANs培训?

如何使用TensorFlow Dataset API进行GANs培训?
EN

Stack Overflow用户
提问于 2018-08-22 16:28:54
回答 1查看 586关注 0票数 1

我在训练甘斯模特。为了加载数据集,我使用TensorFlow的dataset API。

代码语言:javascript
复制
# train_dataset has image and label. z_train dataset has noise (z).
train_dataset = tf.data.TFRecordDataset(train_file)
z_train = tf.data.Dataset.from_tensor_slices(tf.random_uniform([total_training_samples, seq_length,  z_dim],
                                                                 minval=0, maxval=1, dtype=tf.float32))

train_dataset = tf.data.Dataset.zip((train_dataset, z_train))

创建迭代器:

代码语言:javascript
复制
iter = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)

使用迭代器:

代码语言:javascript
复制
(img, label), z = iter.get_next()
train_init_op = iter.make_initializer(train_dataset)

在培训班上训练海关总署时:

培训鉴别器:

代码语言:javascript
复制
_, disc_loss = sess.run([disc_optim, disc_loss])

然后训练生成器:

代码语言:javascript
复制
_, gen_loss = sess.run([gen_optim, gen_loss])

这是陷阱。由于我使用标签作为条件(CGAN),在同一批处理中使用判别器和生成器图,使用两个sess.run生成两组不同的标签

代码语言:javascript
复制
for epoch in range(num_of_epochs):
    sess.run([tf.global_variables_initializer(), train_init_op.initializer])
    for batch in range(num_of_batches):
        _, disc_loss = sess.run([disc_optim, disc_loss])
        _, gen_loss = sess.run([gen_optim, gen_loss])

由于我必须在生成器的会话运行中提供与在鉴别器的会话运行中相同的批标签,所以如何防止Dataset API在一个批处理的同一个循环中生成两个不同的批?

注意:我使用的是TensorFlow v1.9

提前谢谢。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-11-06 18:40:22

您可以为同一数据集创建两个迭代器。如果需要对数据集进行洗牌,甚至可以将种子指定为张量。见下面的例子。

代码语言:javascript
复制
import tensorflow as tf

seed_ts = tf.placeholder(tf.int64)
ds = tf.data.Dataset.from_tensor_slices([1,2,3,4,5]).shuffle(5, seed=seed_ts, reshuffle_each_iteration=True)
it1 = ds.make_initializable_iterator()
it2 = ds.make_initializable_iterator()

input1 = it1.get_next()
input2 = it2.get_next()

with tf.Session() as sess:
    for ep in range(10):
        sess.run(it1.initializer, feed_dict={seed_ts: ep})
        sess.run(it2.initializer, feed_dict={seed_ts: ep})

        print("Epoch" + str(ep))
        for i in range(5):
            x = sess.run(input1)
            y = sess.run(input2)
            print([x, y])
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/51971324

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档