首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何反转tensorflow中的tf.image.per_image_standardization()函数?

如何反转tensorflow中的tf.image.per_image_standardization()函数?
EN

Stack Overflow用户
提问于 2018-04-26 02:00:19
回答 2查看 1.1K关注 0票数 2

Tensorflow中的tf.image.per_image_standardization()转换每个图像的均值和单位方差为零。因此,当我们想要显示图像数组时,当我们训练深度学习model.But时,这将导致一个非爆炸梯度,我们如何在Tensorflow中恢复这个z-score归一化步骤?

EN

回答 2

Stack Overflow用户

发布于 2018-04-26 04:29:21

通过“显示图像数组”,我想你的意思是在tensorboard中显示它。如果是这种情况,那么你不需要做任何事情,tensorboard可以处理已经标准化的图像。如果您希望原始值用于任何其他目的,为什么不在标准化变量之前使用它,例如:

代码语言:javascript
复制
img = tf.placeholder(...)
img_std = tf.image.per_image_standardization(img)

您可以以任何您认为合适的方式使用imgimg_std

如果你有一个对上面没有覆盖的标准化图像进行去正规化的用例,那么你需要自己计算平均值和标准差,然后乘以标准差,再加上平均值。请注意,tf.image.per_image_standardization使用文档中定义的adjusted_stddev

代码语言:javascript
复制
adjusted_stddev = max(stddev, 1.0/sqrt(image.NumElements()))
票数 1
EN

Stack Overflow用户

发布于 2018-04-26 04:32:49

tf.image.per_image_standardization()层将创建一些内部变量,您可以使用它们来恢复原始数据。请注意,这是未记录的行为,不能保证保持不变。不过,现在,您可以使用以下代码(经过测试)来参考如何获取相关张量并恢复原始数据:

代码语言:javascript
复制
import tensorflow as tf
import numpy as np

img_size = 3
a = tf.placeholder( shape = ( img_size, img_size, 1 ), dtype = tf.float32 )
b = tf.image.per_image_standardization( a )

with tf.Session() as sess:
    tensors, tensor_names = [], []
    for l in sess.graph.get_operations():
        tensors.append( sess.graph.get_tensor_by_name( l.name + ":0" ) )
        tensor_names.append( l.name )

    #mean_t = sess.graph.get_tensor_by_name( "per_image_standardization/Mean:0" )
    #variance_t = sess.graph.get_tensor_by_name( "per_image_standardization/Sqrt:0" )

    foobar = np.reshape( np.array( range( img_size * img_size ), dtype = np.float32 ), ( img_size, img_size, 1 ) )
    res =  sess.run( tensors, feed_dict = { a : foobar } )
    #for i in xrange( len( res ) ):
    #    print( i, tensor_names[ i ] + ":" )
    #    print( res[ i ] )
    #    print()

    mean = res[ 6 ] # "per_image_standardization/Mean:0"
    variance = res[ 13 ] # "per_image_standardization/Sqrt:0"
    standardized = res[ 18 ] # "per_image_standardization:0"
    original = standardized * variance + mean
    print( original )

您可以取消对mean_tvariance_t行的注释,以按名称获取对相关张量的引用。(需要重写sess.run()部件。)您可以取消对以for i in xrange(...开头的四行代码的注释(无需重写),以便为您的启迪打印所有可用的已创建张量。:)

上面的代码按原样输出:

[[0.

2.] [3.

5.] [6.

8.]]

这正是提供给网络的数据。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/50028639

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档