首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >自定义激活函数依赖于Keras中的其他输出节点

自定义激活函数依赖于Keras中的其他输出节点
EN

Stack Overflow用户
提问于 2022-07-11 10:33:58
回答 1查看 92关注 0票数 0

我想预测一个多维数组使用长期短期记忆(LSTM)网络,同时对感兴趣的表面形状施加限制。

我想通过将输出的某些元素(表面区域)设置为与其他元素之间的功能关系(简单的缩放条件)来实现这一点。

是否可以在Keras中为输出(其参数是其他输出节点)设置这样的自定义激活函数?如果没有,是否还有其他接口允许这样做?你有手册的来源吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-07-11 13:37:38

科拉斯- GitHub团队回答了关于如何定制激活函数的问题。

还有一个带有自定义激活函数的有密码的问题

这几页可能对你有帮助!

附加评论

对于这个问题,这些页面是不够的,所以我在下面添加评论;

也许PyTorch比Keras更适合定制。我试着编写这样一个网络,尽管它非常简单,基于PyTorch教程和"用自定义激活函数扩展PyTorch“。

我做了一个自定义的激活函数,其中输出向量的第1元素(从0开始计数)等于第0元素的两倍。训练使用了一个非常简单的一层网络。训练结束后,我检查了条件是否满足。

代码语言:javascript
复制
import torch
import matplotlib.pyplot as plt

# Define the custom activation function
# reference: https://towardsdatascience.com/extending-pytorch-with-custom-activation-functions-2d8b065ef2fa
def silu(input):
    input[:,1] = input[:,0] * 2
    return input 

class SiLU(torch.nn.Module):
    def __init__(self):
        super().__init__() # init the base class

    def forward(self, input):
        return silu(input) # simply apply already implemented SiLU


# Training
# reference: https://pytorch.org/tutorials/beginner/pytorch_with_examples.html
k = 10
x = torch.rand([k,3])
y = x * 2
model = torch.nn.Sequential(
    torch.nn.Linear(3, 3),
    SiLU()  # custom activation function
)

loss_fn = torch.nn.MSELoss(reduction='sum')
learning_rate = 1e-3
for t in range(2000):
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    if t % 100 == 99:
        print(t, loss.item())

    model.zero_grad()
    loss.backward()

    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

# check the behaviour
yy = model(x)  # predicted
print('ground truth')
print(y)
print('predicted')
print(yy)


# examples for the first five data
colorlist = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00']
plt.figure()
for i in range(5):
  plt.plot(y[i,:].detach().numpy(), linestyle = "solid", label = "ground truth_" + str(i), color=colorlist[i])
  plt.plot(yy[i,:].detach().numpy(), linestyle = "dotted", label = "predicted_" + str(i), color=colorlist[i])
  plt.legend()

# check if the custom activation works correctly
plt.figure()
plt.plot(yy[:,0].detach().numpy()*2, label = '0th * 2')
plt.plot(yy[:,1].detach().numpy(), label = '1th')
plt.legend()

print(yy[:,0]*2)
print(yy[:,1])
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72937188

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档