首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >FP8端到端精度强化学习训练技术

FP8端到端精度强化学习训练技术

原创
作者头像
用户11764306
发布2026-04-26 18:10:51
发布2026-04-26 18:10:51
130
举报

随着大语言模型从简单文本生成向复杂推理过渡,强化学习(RL)发挥着核心作用。像分组相对策略优化(GRPO)这样的算法推动了这一转变,使推理级模型能够通过迭代反馈不断改进。与标准的监督微调不同,RL训练循环分为两个不同的高强度阶段:具有严格延迟要求的生成阶段和需要高吞吐量的训练阶段。

为了使这些工作负载可行,研究人员和工程师正转向使用低精度数据类型(如FP8)来提高训练和面向吞吐量的生成性能。此外,在某些生成受限于GPU内存带宽的场景中,使用低精度参数可以减少每参数字节数,从而提升性能。

本文深入探讨了低精度RL的系统性挑战,以及NVIDIA NeMo RL(NVIDIA NeMo框架内的一个开源库)如何在保持精度的同时加速RL工作负载。

RL中线性层的FP8实现

我们的方案使用了DeepSeek-V3技术报告中引入的块状量化FP8。表1给出了线性投影层中张量格式的详细信息。

张量

数据类型

量化粒度

缩放因子类型

权重

FP8 (E4M3)

128, 128

FP32(块状)

输入激活

FP8 (E4M3)

1, 128

FP32(块状)

输出梯度

FP8 (E4M3)

1, 128

FP32(块状)

表1. 线性投影层中的张量格式

使用此方案,线性层可以用FP8数学计算,其峰值吞吐量是BF16数学的2倍。其他模块,包括注意力机制、归一化层、非线性函数和输出投影,则使用BF16数学计算。

RL中数值不一致的挑战

RL流程通常使用独立的引擎:vLLM用于策略采样,NVIDIA Megatron Core用于训练。每个引擎都使用独特的自定义NVIDIA CUDA内核来最大化性能。这不可避免地引入了数值差异,由于额外的量化和反量化逻辑,这些差异在低精度下会累积放大。我们将此数值差异量化为令牌乘法概率误差:

令牌乘法概率误差 = (1/n) * Σⁿᵢ₌₁ exp(|log_train_fwk_i – log_probs_inference_fwk_i|)

完美对齐的得分为1,我们发现不使用任何额外技术时,“可接受”的值通常小于1.03-1.05。

线性层中的端到端FP8减少数值不一致

在开发FP8方案的过程中,我们实验了三种方案:

  • 基线方案:生成和训练均使用BF16。
  • 候选方案1:仅在生成阶段应用FP8,而策略模型训练使用BF16。
  • 最终方案(端到端FP8):在生成和训练引擎中均使用FP8。

我们观察到,与仅在生成阶段使用FP8的候选方案1相比,最终方案在生成和训练之间表现出更低的数值不一致。注意,基线方案的数值不一致始终最低。图1显示了三种方案的令牌乘法概率误差指标。

图1. 三种方案中的令牌乘法概率误差

使用重要性采样减轻数值不一致

重要性采样用于纠正生成数据的模型(分布)与正在训练的模型(分布)之间的分布不匹配。它是一个逐令牌的权重,乘以损失函数。关于重要性采样的详细理论背景,请参阅GRPO文档。

实验表明:

  • 对于候选方案1(FP8生成 + BF16训练),重要性采样可以缩小与BF16 RL的精度差距,但无法完全消除。
  • 对于最终方案(端到端FP8),重要性采样完全消除了与BF16训练的精度差距。图2显示了不同方案在训练期间的验证准确率。

图2. 在Llama 3.1 8B Instruct模型和数学数据集上进行GRPO训练的验证准确率

FP8线性层端到端的结果

我们在密集模型和混合专家模型上评估了端到端的FP8方案,测量了与BF16基线相比的验证准确率和训练吞吐量。

密集模型上的FP8端到端:Llama 3.1 8B Instruct

表2显示了在Llama 3.1 8B Instruct模型的GRPO训练中,FP8端到端方案与BF16方案在训练4000步后的准确率。

精度

BF16

仅FP8生成

FP8端到端

验证准确率

0.616

0.586

0.613

表2:不同精度配置下Llama3 8B验证准确率结果

在加速方面,与BF16相比,FP8方案实现了持续超过15%的吞吐量提升。图3显示了两种方案在1000步训练中的GRPO训练吞吐量(每GPU每秒处理的令牌数)。

图3. 两种方案的吞吐量(每GPU每秒令牌数)(蓝色:BF16,粉色:FP8端到端)

尽管理论上FP8比BF16快2倍,但实践中较低,因为只有线性层受益于更快的数学吞吐量,而注意力和逐元素层保持不变。线性层之前添加的额外量化内核会带来一些开销。15%-25%的加速与我们对vLLM的独立测试结果相符。通过进一步的优化(如在vLLM中融合量化内核),我们预计加速比可以进一步提高到1.25倍。

MoE模型上的FP8端到端:Qwen3-30B

在混合专家模型上进行了类似实验,Qwen3-30B的结果显示出匹配的准确率曲线。FP8实现了与BF16相似的准确率。加速效果正在研究中。

图4. Qwen3-30B GRPO在OpenMathInstruct-2数据集上的准确率曲线,使用8个H100节点。蓝色为BF16,粉色为FP8端到端

将FP8扩展到KV缓存和注意力机制

对于Transformer模型,线性层并非唯一的瓶颈。在具有长输出序列长度的RL工作流中,KV缓存的增长和注意力计算通常主导端到端的策略采样时间,同时还会使内存带宽饱和并减慢令牌生成速度。这促使我们探索在RL循环中将FP8用于KV缓存和注意力机制。使用了按张量缩放FP8。

在RL环境中为KV缓存实现FP8具有独特的挑战性,因为策略权重在每一步都会发生变化。与静态推理(只需校准一次)不同,RL需要动态处理量化缩放因子。

NeMo RL采用以下方法解决此问题:

  • 重新校准:在每个训练步骤结束时,训练器使用更新后的策略权重重新查询、键、值(QKV)的缩放因子。
  • 数据选择:此校准使用训练数据(提示和生成的响应)执行,以确保缩放因子反映当前的分布。
  • 同步:然后将新计算的缩放因子同步到推理引擎(vLLM),用于下一轮的策略采样。

图5. 使用FP8 KV缓存的RL工作流程

此设计确保策略采样引擎始终使用源自最新策略状态的最优量化缩放因子,从而最大限度地减少精度下降。校准开销极小,约占总步骤时间的2-3%。

张量

数据类型

缩放因子类型

QKV注意力激活

FP8 (E4M3)

FP32(张量级)

存储的KV缓存

FP8 (E4M3)

FP32(张量级)

表3:注意力激活和存储的KV缓存的张量格式

KV缓存和注意力机制FP8的结果总结

我们在Qwen3-8B-Base模型上使用GRPO算法运行了结果,在策略采样中应用FP8,训练中使用BF16。虽然同时量化KV缓存和注意力机制时,由于复合误差,不匹配的KL散度略高,但我们的方案减轻了不稳定性。通过启用令牌级别的截断重要性采样,线性层+KV缓存+注意力机制均使用FP8的方案实现了与BF16基线以及线性层使用FP8(W8A8)方案对齐的验证准确率。

图6. Qwen3-8B-Base的训练准确率曲线

为KV缓存和注意力操作启用FP8,比仅线性层使用W8A8配置在策略采样阶段额外带来了约30%的加速,相比BF16基线总体加速约48%。这些提升在响应长度较长时尤为显著,因为此时注意力计算占整体工作负载的比例更大。QKV缩放因子重新校准过程约占总步骤时间的2-3%,相对于所实现的显著加速而言,这是一项较小的成本。

图7. Qwen3-8B-Base模型的策略采样性能曲线

尝试使用NVIDIA NeMo RL进行端到端FP8训练

要为生成后端和训练后端中的线性层启用FP8,以下配置映射显示了每个调优参数如何传递给训练和生成后端。

图8. 在某机构NeMo RL中为线性层启用FP8

要为KV缓存和注意力机制启用FP8,需要配置策略的vllm_cfg中的kv_cache_dtype参数,这将自动处理训练器端的QKV缩放因子重新校准以及与vLLM后端的同步。

代码语言:yaml
复制
policy:
  generation:
    vllm_cfg:
      precision: fp8       # 为线性层启用FP8
      kv_cache_dtype: fp8  # 为KV缓存启用FP8

生成和训练的高级FP8配置选项

到目前为止,我们已经介绍了线性层和KV缓存+注意力层的FP8实现。高级用户可以尝试方案的各种变体。以下是一些功能示例:

  • 在生成期间将前N层和/或后M层Transformer层保持在BF16精度(N, M为整数)
代码语言:yaml
复制
policy:
  generation:
    vllm_cfg:
      num_first_layers_in_bf16: N # 将N替换为整数
      num_last_layers_in_bf16: M  # 将M替换为整数
  • 配置生成和/或训练使用2的幂缩放因子类型而非FP32
代码语言:yaml
复制
policy:
  generation:
    vllm_cfg:
      pow2_weight_scaling_factors: true
      pow2_activation_scaling_factors: true
  megatron_cfg:
    env_vars:
      NVTE_FP8_BLOCK_SCALING_FP32_SCALES: "0"

开发者可以使用为Megatron Core后端预定义的FP8方案变体,而不是默认的块状量化FP8方案(如表1所示)。详情请参阅参数文档字符串。

代码语言:yaml
复制
policy:
  megatron_cfg:
    fp8_cfg:
      fp8: "e4m3"
      fp8_recipe: "blockwise"

开始使用

用户可以参阅NeMo RL GitHub仓库中的llama-3.1-8b和moonlight-16b方案开始使用。

致谢

此项工作是跨团队协作的成果。感谢Jimmy Zhang、Victor Cui、Zhiyu Li和Lark Zhang在FP8方案开发、实验以及集成到NeMo RL中所做的贡献。FINISHED

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • RL中线性层的FP8实现
  • RL中数值不一致的挑战
  • 线性层中的端到端FP8减少数值不一致
  • 使用重要性采样减轻数值不一致
  • FP8线性层端到端的结果
  • 将FP8扩展到KV缓存和注意力机制
  • KV缓存和注意力机制FP8的结果总结
  • 尝试使用NVIDIA NeMo RL进行端到端FP8训练
  • 生成和训练的高级FP8配置选项
  • 开始使用
  • 致谢
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档