我想预测一个多维数组使用长期短期记忆(LSTM)网络,同时对感兴趣的表面形状施加限制。
我想通过将输出的某些元素(表面区域)设置为与其他元素之间的功能关系(简单的缩放条件)来实现这一点。
是否可以在Keras中为输出(其参数是其他输出节点)设置这样的自定义激活函数?如果没有,是否还有其他接口允许这样做?你有手册的来源吗?
发布于 2022-07-11 13:37:38
科拉斯- GitHub团队回答了关于如何定制激活函数的问题。
还有一个带有自定义激活函数的有密码的问题。
这几页可能对你有帮助!
附加评论
对于这个问题,这些页面是不够的,所以我在下面添加评论;
也许PyTorch比Keras更适合定制。我试着编写这样一个网络,尽管它非常简单,基于PyTorch教程和"用自定义激活函数扩展PyTorch“。
我做了一个自定义的激活函数,其中输出向量的第1元素(从0开始计数)等于第0元素的两倍。训练使用了一个非常简单的一层网络。训练结束后,我检查了条件是否满足。
import torch
import matplotlib.pyplot as plt
# Define the custom activation function
# reference: https://towardsdatascience.com/extending-pytorch-with-custom-activation-functions-2d8b065ef2fa
def silu(input):
input[:,1] = input[:,0] * 2
return input
class SiLU(torch.nn.Module):
def __init__(self):
super().__init__() # init the base class
def forward(self, input):
return silu(input) # simply apply already implemented SiLU
# Training
# reference: https://pytorch.org/tutorials/beginner/pytorch_with_examples.html
k = 10
x = torch.rand([k,3])
y = x * 2
model = torch.nn.Sequential(
torch.nn.Linear(3, 3),
SiLU() # custom activation function
)
loss_fn = torch.nn.MSELoss(reduction='sum')
learning_rate = 1e-3
for t in range(2000):
y_pred = model(x)
loss = loss_fn(y_pred, y)
if t % 100 == 99:
print(t, loss.item())
model.zero_grad()
loss.backward()
with torch.no_grad():
for param in model.parameters():
param -= learning_rate * param.grad
# check the behaviour
yy = model(x) # predicted
print('ground truth')
print(y)
print('predicted')
print(yy)
# examples for the first five data
colorlist = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00']
plt.figure()
for i in range(5):
plt.plot(y[i,:].detach().numpy(), linestyle = "solid", label = "ground truth_" + str(i), color=colorlist[i])
plt.plot(yy[i,:].detach().numpy(), linestyle = "dotted", label = "predicted_" + str(i), color=colorlist[i])
plt.legend()
# check if the custom activation works correctly
plt.figure()
plt.plot(yy[:,0].detach().numpy()*2, label = '0th * 2')
plt.plot(yy[:,1].detach().numpy(), label = '1th')
plt.legend()
print(yy[:,0]*2)
print(yy[:,1])https://stackoverflow.com/questions/72937188
复制相似问题