我经常遇到的一件事是,我的模型将包含具有NaN值的矩阵。有没有一种通用的Flux方法,我可以把我的矩阵传递到和检测这些NaN的?我知道朱莉娅有一个内置的isnan()函数,可以在某些情况下使用,但我不确定是否有一个特定的通量版本?
发布于 2021-07-05 22:43:48
不,没有特定的功能。在大多数情况下,使用any(isnan, A)可能是您想要做的事情。一个与流量相关的“增强”将是使用训练循环回调来停止训练,如果检测到NaNs。
# assumes (x, y) is your training data
# and loss(x, y, mode) will compute the loss of model on (x, y)
cb = () -> isnan(loss(x, y, model)) && Flux.stop()
# basic train loop
# assuming opt is your optimizer
Flux.train!((x, y) -> loss(x, y, model), params(model), [(x, y)], opt; cb = cb)上面的例子是基本思想,您可以扩展到检查NaN的不同数组。例如,你可以
cb = () -> any(params(m)) do p
any(isnan, p)
end && Flux.stop()检查任何参数是否为NaN。
https://stackoverflow.com/questions/68248145
复制相似问题