首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Chainer中的渐变裁剪

Chainer中的渐变裁剪
EN

Stack Overflow用户
提问于 2019-09-23 14:26:34
回答 1查看 134关注 0票数 1

我能在Chainer中得到一个渐变裁剪函数吗?

我在Pytorch文档中找到了一些代码:https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html

在Chainer中有类似于替代函数的东西吗?我刚刚发现了chainer.optimizer_hooks.GradientClipping,但它很难使用。

提前谢谢。

EN

回答 1

Stack Overflow用户

发布于 2019-09-23 14:37:11

试试这个怎么样。我刚刚用Chainer风格重写了pyTorch函数。

代码语言:javascript
复制
import cupy
def clip_grad_norm(model, max_norm, norm_type=2):
    params  = list( filter(lambda p : p.grad is not None ,  model.params()) )
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    total_norm = 0.0
    for p in params:
        g = p.grad
        norm = cupy.linalg.norm(g)
        total_norm += norm**(norm_type)
    total_norm = total_norm **(1/norm_type)
    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1:
        for p in params:
            g = p.grad
            p.grad = g * clip_coef
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58056693

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档