首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >TensorFlow:如何将相同的图像失真应用于多幅图像

TensorFlow:如何将相同的图像失真应用于多幅图像
EN

Stack Overflow用户
提问于 2016-05-15 10:41:57
回答 2查看 5.6K关注 0票数 9

Tensorflow CNN实例开始,我试图修改模型,使其有多个图像作为输入(这样输入就不只是3个输入通道,而是通过叠加图像来达到3的倍数)。为了增加输入,我尝试使用随机图像操作,例如TensorFlow提供的翻转、对比度和亮度。对所有输入图像应用相同的随机失真的当前解决方案是为这些操作使用固定的种子值:

代码语言:javascript
复制
def distort_image(image):
  flipped_image = tf.image.random_flip_left_right(image, seed=42)
  contrast_image = tf.image.random_contrast(flipped_image, lower=0.2, upper=1.8, seed=43)
  brightness_image = tf.image.random_brightness(contrast_image, max_delta=0.2, seed=44)
  return brightness_image

这种方法在图形构造时对每幅图像进行多次运算,因此我认为每幅图像都会使用相同的随机数序列,从而使图像输入序列具有相同的应用图像操作。

代码语言:javascript
复制
# ...

# distort images
distorted_prediction = distort_image(seq_record.prediction)
distorted_input = []
for i in xrange(INPUT_SEQ_LENGTH):
    distorted_input.append(distort_image(seq_record.input[i,:,:,:]))
stacked_distorted_input = tf.concat(2, distorted_input)

# Ensure that the random shuffling has good mixing properties.
min_queue_examples = int(num_examples_per_epoch *
                         MIN_FRACTION_EXAMPLES_IN_QUEUE)

# Generate a batch of sequences and prediction by building up a queue of examples.
return generate_sequence_batch(stacked_distorted_input, distorted_prediction, min_queue_examples, 
                               batch_size, shuffle=True)

从理论上讲,这很好。在做了一些测试之后,这似乎真的解决了我的问题。但是过了一会儿,我发现我有一个race-condition,,因为我使用了带有多个线程的CNN示例代码的输入管道(这是TensorFlow中建议的提高性能和减少运行时内存消耗的方法):

代码语言:javascript
复制
def generate_sequence_batch(sequence_in, prediction, min_queue_examples,
                        batch_size):
    num_preprocess_threads = 8 # <-- !!!
    sequence_batch, prediction_batch = tf.train.shuffle_batch(
        [sequence_in, prediction],
        batch_size=batch_size,
        num_threads=num_preprocess_threads,
        capacity=min_queue_examples + 3 * batch_size,
        min_after_dequeue=min_queue_examples)
return sequence_batch, prediction_batch

因为多个线程创建了我的示例,所以不再保证所有图像操作都按照正确的顺序执行(按照随机操作的正确顺序)。

我到了一个完全被困住的地步。有谁知道如何解决这个问题,以便将相同的图像失真应用于多幅图像?

我的一些想法:

  • 我想围绕这些图像失真方法做一些同步,但是我可以找到TensorFlow提供的任何东西。
  • 我试图生成一个随机数,例如,使用tf.random_uniform()自己生成随机亮度增量,并将此值用于tf.image.adjust_contrast()。但是TensorFlow随机生成器的结果总是一个张量,我还没有找到一种方法来使用这个张量作为tf.image.adjust_contrast()的一个参数,它期望它的contrast_factor参数有一个简单的float32。
  • 一种(部分)工作的解决方案是使用tf.concat()将所有图像组合到一个巨大的图像中,应用随机操作来改变对比度和亮度,然后再分割图像。但是这不适用于随机翻转,因为这将(至少在我的情况下)改变图像的顺序,并且无法检测tf.image.random_flip_left_right()是否执行了翻转操作,如果需要的话,这将需要修复图像的错误顺序。
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2016-07-15 19:24:00

下面是我通过查看tensorflow中的random_flip_up_down和random_flip_left_right代码得出的结果:

代码语言:javascript
复制
def image_distortions(image, distortions):
    distort_left_right_random = distortions[0]
    mirror = tf.less(tf.pack([1.0, distort_left_right_random, 1.0]), 0.5)
    image = tf.reverse(image, mirror)
    distort_up_down_random = distortions[1]
    mirror = tf.less(tf.pack([distort_up_down_random, 1.0, 1.0]), 0.5)
    image = tf.reverse(image, mirror)
    return image


distortions = tf.random_uniform([2], 0, 1.0, dtype=tf.float32)
image = image_distortions(image, distortions)
label = image_distortions(label, distortions)
票数 10
EN

Stack Overflow用户

发布于 2019-01-19 15:00:22

我会用tf.case做类似的事情。如果某些条件持有docs/python/tf/case,则可以指定返回的内容。

代码语言:javascript
复制
import tensorflow as tf

def distort(image, x):
    # flip vertically, horizontally, both, or do nothing
    image = tf.case({
        tf.equal(x,0): lambda: tf.reverse(image,[0]),
        tf.equal(x,1): lambda: tf.reverse(image,[1]),
        tf.equal(x,2): lambda: tf.reverse(image,[0,1]),
    }, default=lambda: image, exclusive=True)

    return image

def random_distortion(image):
    x = tf.random_uniform([1], 0, 4, dtype=tf.int32)
    return distort(image, x[0])

来检查它是否有效。

代码语言:javascript
复制
import numpy as np
import matplotlib.pyplot as plt

# create image
image = np.zeros((25,25))
image[:10,5:10] = 1.
# create subplots
fig, axes = plt.subplots(2,2)
for i in axes.flatten(): i.axis('off')

with tf.Session() as sess:
    for i in range(4):
        distorted_img = sess.run(distort(image, i))
        axes[i % 2][i // 2].imshow(distorted_img, cmap='gray')

    plt.show()
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/37237245

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档