首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Theano HiddenLayer激活函数

Theano HiddenLayer激活函数
EN

Stack Overflow用户
提问于 2014-10-21 22:38:18
回答 5查看 9.1K关注 0票数 11

是否可以用校正线性单元(ReLU)代替tanh()sigmoid()作为隐层的激活函数?隐藏层的实现如下所示,就我在互联网上搜索的情况而言,ReLU没有在Theano内部实现。

代码语言:javascript
复制
class HiddenLayer(object):
  def __init__(self, rng, input, n_in, n_out, W=None, b=None, activation=T.tanh):
    pass
EN

回答 5

Stack Overflow用户

回答已采纳

发布于 2014-10-22 00:19:43

雷鲁在西亚诺很容易做到:

代码语言:javascript
复制
switch(x<0, 0, x)

要在您的情况下使用它,请创建一个python函数,该函数将实现relu并将其传递给激活:

代码语言:javascript
复制
def relu(x):
    return theano.tensor.switch(x<0, 0, x)
HiddenLayer(..., activation=relu)

有些人使用这个实现:x * (x > 0)

更新:较新的Theano版本有theano.tensor.nnet.relu(x)可用。

票数 17
EN

Stack Overflow用户

发布于 2015-02-28 00:21:37

更新:最新版本的theano有对ReLU:T.nnet.relu的本地支持,这应该比自定义解决方案更好。

我决定比较解决方案的速度,因为它对NNs非常重要。比较了函数本身的速度和梯度,在第一种情况下,switch是首选的,x* (x>0)的梯度更快。所有计算的梯度都是正确的。

代码语言:javascript
复制
def relu1(x):
    return T.switch(x<0, 0, x)

def relu2(x):
    return T.maximum(x, 0)

def relu3(x):
    return x * (x > 0)


z = numpy.random.normal(size=[1000, 1000])
for f in [relu1, relu2, relu3]:
    x = theano.tensor.matrix()
    fun = theano.function([x], f(x))
    %timeit fun(z)
    assert numpy.all(fun(z) == numpy.where(z > 0, z, 0))

Output: (time to compute ReLU function)
>100 loops, best of 3: 3.09 ms per loop
>100 loops, best of 3: 8.47 ms per loop
>100 loops, best of 3: 7.87 ms per loop

for f in [relu1, relu2, relu3]:
    x = theano.tensor.matrix()
    fun = theano.function([x], theano.grad(T.sum(f(x)), x))
    %timeit fun(z)
    assert numpy.all(fun(z) == (z > 0)

Output: time to compute gradient 
>100 loops, best of 3: 8.3 ms per loop
>100 loops, best of 3: 7.46 ms per loop
>100 loops, best of 3: 5.74 ms per loop

最后,让我们比较一下如何计算梯度(最快的方法)。

代码语言:javascript
复制
x = theano.tensor.matrix()
fun = theano.function([x], x > 0)
%timeit fun(z)
Output:
>100 loops, best of 3: 2.77 ms per loop

因此,theano生成梯度的非最优代码。IMHO,今天的切换版本应该是首选。

票数 8
EN

Stack Overflow用户

发布于 2014-10-23 18:52:43

我认为用这样的方式写它更准确:

代码语言:javascript
复制
x * (x > 0.) + 0. * (x < 0.)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/26497564

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档