首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Theano的平均池化

Theano的平均池化
EN

Stack Overflow用户
提问于 2014-07-09 20:58:26
回答 1查看 2K关注 0票数 3

我正在尝试用Theano为神经网络实现另一个池化函数,expect已经存在的最大池,例如平均池。

对于已经实现了平均池化的this source,我的代码看起来像这样:

随机初始化只是为了测试:

代码语言:javascript
复制
invals = numpy.random.RandomState(1).rand(3,2,5,5) 

Theano标量和函数的定义:

代码语言:javascript
复制
pdim = T.scalar('pool dim', dtype='float32')
pool_inp = T.tensor4('pool input', dtype='float32')
pool_sum = TSN.images2neibs(pool_inp, (pdim, pdim))
pool_out = pool_sum.mean(axis=-1) 
pool_fun = theano.function([pool_inp, pdim], pool_out, name = 'pool_fun', allow_input_downcast=True)

TSN为theano.sandbox.neighbours

以及函数的调用:

代码语言:javascript
复制
pool_dim = 2
temp = pool_fun(invals, pool_dim)
temp.shape = (invals.shape[0], invals.shape[1], invals.shape[2]/pool_dim,
            invals.shape[3]/pool_dim)
print ('invals[1,0,:,:]=\n', invals[1,0,:,:])
print ('output[1,0,:,:]=\n',temp[1,0,:,:])

我得到了一个错误:

代码语言:javascript
复制
TypeError: neib_shape[0]=2, neib_step[0]=2 and ten4.shape[2]=5 not consistent
Apply node that caused the error: Images2Neibs{valid}(pool input, MakeVector.0, MakeVector.0)
Inputs shapes: [(3, 2, 5, 5), (2,), (2,)]
Inputs strides: [(200, 100, 20, 4), (4,), (4,)]
Inputs types: [TensorType(float32, 4D), TensorType(float32, vector), TensorType(float32, vector)]
Use the Theano flag 'exception_verbosity=high' for a debugprint of this apply node.

我真的不理解这个错误。我很高兴有任何建议,如何纠正这个错误或其他池化技术的例子,在Theano中编程。

谢谢!

编辑:通过忽略边框,它可以完美地工作

代码语言:javascript
复制
pool_sum = TSN.images2neibs(pool_inp, (pdim, pdim), mode='ignore_borders')

invals[1,0,:,:]=
[[ 0.01936696  0.67883553  0.21162812  0.26554666  0.49157316]
[ 0.05336255  0.57411761  0.14672857  0.58930554  0.69975836]
[ 0.10233443  0.41405599  0.69440016  0.41417927  0.04995346]
[ 0.53589641  0.66379465  0.51488911  0.94459476  0.58655504]
[ 0.90340192  0.1374747   0.13927635  0.80739129  0.39767684]]
output[1,0,:,:]=
[[ 0.33142066  0.30330223]
[ 0.42902038  0.64201581]]
EN

回答 1

Stack Overflow用户

发布于 2014-07-21 20:47:40

invals在最后两个维度中具有形状(5, 5),但是您希望将其放在(2, 2)子集上。只有当您忽略边框(即invals的最后一列和最后一行)时,这才有效。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/24654389

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档