首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >torch.nn.Embedding 中 max_norm 的作用

torch.nn.Embedding 中 max_norm 的作用

作者头像
AlphaHinex
发布2026-03-16 14:50:29
发布2026-03-16 14:50:29
860
举报
文章被收录于专栏:周拱壹卒周拱壹卒

https://docs.pytorch.org/docs/stable/generated/torch.nn.Embedding.html#torch.nn.Embedding

1. 类比 “词典”

nn.Embedding(num_embeddings, embedding_dim) 可以看成是一个查表词典

  • num_embeddings 行,每一行是一个 embedding_dim 维的向量。
  • 输入是索引(比如单词 ID、类别 ID),输出是对应行的向量。

max_norm 的作用:

给这个“词典”里的每一行向量设一个“最大长度”限制:

  • 这里的“长度”是向量的 p-范数,默认是 L2 范数(欧几里得长度)。
  • 当某一行被用到(即该 index 出现在输入中)时,如果它的范数大于 max_norm,就会被按比例缩小,使得它的范数刚好等于 max_norm

重要点:

  • 这个“缩小”发生在 forward 的时候,是原地修改 embedding.weight 的部分行
  • 只有被当前 batch 索引到的那些行才会被检查和可能被缩放。

2. 一个具体、可手算的小例子

2.1 定义一个简单的 Embedding

假设我们手动构造一个 embedding 权重(方便算):

代码语言:javascript
复制
num_embeddings = 5   # 总共有 5 行
embedding_dim = 3   # 每行是 3 维向量
max_norm = 1.5

我们设定当前(某一次训练时刻)的权重矩阵为:

也就是:

  • 第 0 行:[0.3367, 0.1288, 0.2345]
  • 第 1 行:[0.2303, -1.1229, -0.1863]
  • 第 2 行:[2.2082, -0.6380, 0.4617]
  • 第 3 行:[0.2674, 0.5349, 0.8094]
  • 第 4 行:[1.1103, -1.6898, -0.9890]

2.2 计算每一行的 L2 范数

对每一行向量 ,L2 范数定义为:

逐行计算:

向量 0

(约等于 0.43,小于 1.5,不会被改动)

向量 1

(小于 1.5,不会被改动)

向量 2(会被裁剪)

这个超过了 max_norm=1.5需要被缩小

向量 3

不改。

向量 4(会被裁剪)

也超过了 1.5,需要缩小。

小结:

  • 需要被处理的只有:行 2 和 行 4(它们的范数 > 1.5)

3. max_norm 的具体计算公式(重归一化)

对每一个需要被裁剪的向量 :

  1. 先算出当前范数:
  2. 如果 ,就按系数去缩小它:

这样就有:

3.1 对行 2 的计算

  • 原始向量:
  • 缩放系数:
  • 新向量(逐元素相乘):

数值约为:

再检查一下新范数:

3.2 对行 4 的计算

  • 原始向量:
  • 缩放系数:
  • 新向量:

新范数:

4. 重归一化后的权重矩阵

归一化之后,新的权重矩阵 变为:

新的各行范数:

  • 第 0 行:0.4300(未变)
  • 第 1 行:1.1613(未变)
  • 第 2 行:1.5000(从 2.3444 被缩到 1.5)
  • 第 3 行:1.0063(未变)
  • 第 4 行:1.5000(从 2.2508 被缩到 1.5)

5. 对权重的具体影响总结

只影响被访问到且范数超限的行

  • forward 时只拿到了某些 index(比如本 batch 里用到的词 ID),max_norm 只会根据这些 index 去检查并缩放对应行。
  • 其它没被访问到的行,这一轮 forward 不会去动它。

操作是 in-place 的

官方文档明确说明:当 max_norm 不为 None 时,Embedding.forward原地修改 weight

这意味着:

  • 缩放后的权重会被保留下来,用于之后的训练步骤。
  • 如果你在 forward 之前对 embedding.weight 做可微操作,需要先 .clone() 一份再用,否则会和 autograd 的 in-place 规则冲突。

不改变方向,只改变长度

  • 向量被按比例整体缩小:
  • 方向(单位向量)不变,只是“缩短”到指定长度。

起到正则化 / 稳定作用

  • 限制每个 embedding 行向量的最大范数,可以避免某些向量过大,防止梯度爆炸或某些词向量“过度主导”模型。

6. 一句话记住

nn.Embedding 里:

max_norm = “给每一行 embedding 向量设一个最大长度, 每次 forward 时,凡是被用到且长度超过这个上限的行,都会被按比例缩到这个长度,并且是直接改写权重矩阵的。”

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2026-01-10,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 周拱壹卒 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. 类比 “词典”
  • 2. 一个具体、可手算的小例子
    • 2.1 定义一个简单的 Embedding
    • 2.2 计算每一行的 L2 范数
      • 向量 0
      • 向量 1
      • 向量 2(会被裁剪)
      • 向量 3
      • 向量 4(会被裁剪)
  • 3. max_norm 的具体计算公式(重归一化)
    • 3.1 对行 2 的计算
    • 3.2 对行 4 的计算
  • 4. 重归一化后的权重矩阵
  • 5. 对权重的具体影响总结
    • 只影响被访问到且范数超限的行
    • 操作是 in-place 的
    • 不改变方向,只改变长度
    • 起到正则化 / 稳定作用
  • 6. 一句话记住
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档