Harmonic Loss Trains Interpretable AI Models
谐波损失训练可解释人工智能模型
https://arxiv.org/pdf/2502.01628v2

摘要 本文提出谐波损失(harmonic loss)作为一种替代性的监督信号,用于训练神经网络和大语言模型(LLMs)。谐波损失与标准交叉熵损失的主要区别在于:(a)以一种尺度不变(scale-invariant)的HarMax函数替代常规的SoftMax归一化;(b)通过欧氏距离而非点积计算logits。得益于其尺度不变性及设计上具有有限收敛点(可解释为类别中心)的特性,谐波损失能提升模型可解释性并加速收敛。我们首先在算法、视觉和语言任务数据集上验证谐波模型的性能。大量实验表明,采用谐波损失训练的模型相较于标准模型具有以下优势:(a)增强可解释性;(b)实现泛化所需训练数据更少;(c)减轻“顿悟”(grokking)现象。此外,我们对比了采用谐波损失训练的GPT-2模型与标准GPT-2模型,结果表明前者习得的表征更具可解释性。展望未来,我们认为谐波损失有望成为数据稀缺领域或对可解释性与可靠性要求极高的高风险应用中的有力工具,从而推动构建更稳健、高效的神经网络模型。
1 引言 随着机器学习模型日益强大,深入理解神经网络的行为变得愈发重要。其中尤为引人关注的一个特性是其泛化能力——大量实证研究表明,神经网络在训练中未曾显式接触过的未见数据上依然表现优异[1]。这种卓越能力源于网络通过训练习得可泛化的表征与算法。然而,当前模型在泛化方面面临三大关键挑战:
(1)可解释性缺失:神经网络常缺乏可解释性,这在医疗、金融和自动驾驶等高风险应用中尤为成问题。尽管已有诸多研究增进了我们对大语言模型内部机制的理解[2],但我们距离完全解释其输出仍有较大差距。归根结底,亟需设计本质上可解释(interpretable by design)的系统;否则,错误诊断、公平性保障以及用户对模型决策的信任都难以实现。
(2)数据效率低下:泛化往往依赖海量且多样化的训练数据。由此引出一个关键问题:模型能否以更少数据实现有效泛化?该问题在数据稀缺领域(如罕见病诊断或专门科学研究)尤为突出。以往提升泛化能力的方法包括高效数据采样[3]与改进训练流程以加快收敛[4],但这些工作多聚焦于优化既有训练程序,而非从模型设计层面根本性解决问题。
(3)延迟泛化(即“顿悟”,grokking):模型有时会出现所谓“顿悟”现象[5,6]——训练损失早已收敛,而测试损失却滞后许久才收敛。这种差距带来两个问题:(i)难以判定实现泛化的最佳训练终止时机;(ii)需额外耗费大量计算时间与资源,持续训练直至“顿悟”发生。
正如俗语所言:“魔鬼藏在SoftMax中。”我们认为,上述三大挑战部分源于交叉熵损失(用于分类任务)中SoftMax函数的广泛使用。因此,本文提出谐波损失作为替代方案。谐波损失具备两个理想的数学性质——(1)尺度不变性;(2)具有有限收敛点(可阐释为类别中心)——从而促进更快收敛并提升可解释性。通过全面实验,我们证实:采用谐波损失训练的模型能有效减轻顿悟现象、降低数据需求,并增强可解释性。进一步地,我们将谐波损失训练的GPT-2与标准GPT-2进行对比,发现前者习得的表征更具可解释性。
本文后续结构安排如下:第2节介绍谐波损失的基本原理,并阐明其在泛化与可解释性方面优于交叉熵损失的原因;第3节在算法数据集上开展一系列系统性实验,展示谐波模型所具备而标准模型缺失的诸多优良性质;第4节在MNIST手写数字分类这一视觉任务中验证谐波模型性能;第5节将分析扩展至大规模模型,证明谐波损失的优势在大模型尺度下依然成立;第6节提供消融实验;第7节综述相关文献;第8节总结全文。
2 谐波损失
我们首先回顾交叉熵损失,并介绍谐波损失,其可视化如图1(a)所示。记解嵌入矩阵为 W ∈ ℝᴺˣⱽ(N 为嵌入维度,V 为词汇表大小),以及解嵌入矩阵之前的表示(即未解嵌前的表示)为 x ∈ ℝᴺ。

交叉熵损失:Logits y 定义为矩阵-向量乘法,即 y = Wᵀx ∈ ℝⱽ(忽略偏置项),或写作 yᵢ = wᵢ · x,其中 wᵢ 是 W 的第 i 列。概率 p 可通过对 y 应用 SoftMax 函数获得,即:

假设真实类别标签为 c,则损失

。为简化符号,我们将线性层与交叉熵损失的组合称为交叉熵层。

其中 n(谐波指数)是一个控制概率分布重尾程度的超参数。若真实类别标签为 c,则损失 ℓ = -log p_c。为简化符号,我们将与谐波损失结合的层称为谐波层。由于两种损失的最后一部相同(ℓ = -log p),比较它们的值是有意义的。二者仅在从表征计算概率的方式上有所不同。
n 的合理取值为 n ~ √D,其中 D 表示底层数据的固有维度。在大语言模型(LLMs)中,D 可近似为 D ≈ d_embed,其中 d_embed 为嵌入维度。该近似源于考虑从 D 维高斯分布初始化的嵌入。两点间平方距离经维度数 D 归一化后,其量级为 1 ± O(1/√D)。为确保谐波距离 [1 ± O(1/√D)]^n 在缩放 D 时保持恒定,我们需要 n ~ √D,因为 lim_{x→∞} (1 + x⁻¹)^x = e。我们还在附录 E 中展示了指数对所学表征的经验影响。
玩具案例:为直观理解谐波损失相较于交叉熵损失的优势,我们考虑图1(b)(c)所示的两个二维玩具案例。在每个玩具案例中,我们使用 Adam 优化器训练交叉熵层和谐波层。
玩具案例1:x₁ = (1, 1) 和 x₂ = (-1, -1) 属于两个不同类别。谐波层产生更快的损失下降,因为谐波损失只需 dᵢ → 0(收敛点有限)即可使 pᵢ → 1。相比之下,交叉熵损失要求 yᵢ → ∞(收敛点无限)才能使 pᵢ → 1。谐波损失已产生一个趋于常数的 l₂ 权重范数,而交叉熵损失导致 l₂ 持续增大并趋向无穷。
玩具案例2:存在5个二维点,每点属于不同类别。特别地,红点 (0, 0) 被其余四点包围,即无法线性可分。交叉熵层在此任务上表现不佳,表现为高损失平台期。相比之下,谐波层可将损失降至机器精度。与案例1类似,谐波层具有趋于稳定的 l₂,而交叉熵层的 l₂ 则持续增长。我们还观察到,谐波层的权重对应于 x,比交叉熵层的权重更具可解释性。
谐波损失的优势:从这两个玩具案例中,我们理解了谐波损失的优势:(1)非线性可分性:在案例2中,即使红点非线性可分,仍能被正确分类。(2)快速收敛:收敛点有限这一特性既导致损失衰减更快,也带来稳定(不发散)的 l₂。(3)尺度不变性:谐波损失是尺度不变的,即 dᵢ → αdᵢ 会使 pᵢ(因而损失)保持不变;而 yᵢ → αyᵢ 会产生不同的交叉熵损失。(4)可解释性:权重向量对应类别中心。我们在附录 G 中给出了这些性质的形式化证明。
关于可解释性的说明:在缺乏真实表征的情况下,衡量可解释性本身具有挑战性。因此,我们在全文中提出两个原则性的可解释性指标:(1)压缩性:稀疏、低维的表征通过集中语义增强可解释性。我们通过 PCA 投影中的累积解释方差来衡量这一点。(2)几何性:在一般模型中,我们假设具有多个一维语义方向的平行四边形状单元能够实现组合推理;这支持向量运算如 man - woman = king - queen,并有助于忠实的特征归因。我们通过第5节中的平行四边形损失来衡量这一点。
3 算法实验 算法任务因其数学定义清晰,是评估可解释性的良好基准。然而,由于存在“顿悟”(延迟泛化)现象[5]以及多种可行算法共存等问题[7],在这些任务上训练神经网络并非易事。我们将证明:谐波模型所学习的表征更优、数据效率更高,且顿悟现象更轻微。
3.1 模型与数据集 模型:我们比较以下四种模型:
我们对MLP模型训练7000轮(epochs),对Transformer模型训练10000轮。对全部四种模型,均采用AdamW优化器,学习率设为 2 × 10⁻³,权重衰减为 10⁻²,并对嵌入层施加强度为0.01的L2正则化。
数据集:我们使用以下五个数据集训练上述四种模型,并对其性能及所学表征进行了分析:
3.2 表征忠实性
图2展示了MLP任务中各模型嵌入向量的前两个主成分的可视化图。所有任务的完整嵌入可视化结果见附录A。总体而言,采用谐波损失得到的表征比交叉熵损失对应的表征更清晰、更有条理。我们发现:在模加法任务中,表征几乎完美地呈现为圆形;在家谱树学习任务中,呈现出清晰的塔状结构;在置换合成任务中,则形成整齐的聚类。以下我们逐项分析各任务中的表征:

图3(a)进一步表明,谐波表征比标准模型的表征更紧凑,包含更少不可解释的成分。特别是,在上下文内学习任务中训练的谐波模型,仅使用前两个主成分即可达到100%的解释方差。

3.3 训练中的数据效率
图3(b)展示了在合成实验中,测试准确率随训练数据比例变化的曲线,反映了模型实现泛化所需的数据量。我们观察到,相较于对应的交叉熵模型,谐波模型泛化所需的训练数据量相当或显著更少。这一提升在上下文内学习任务中尤为突出:谐波模型几乎在训练伊始即实现泛化。
3.4 顿悟现象的缓解
“顿悟”(grokking)指延迟泛化的现象[5]:例如,模型在10³步内即可在训练集上达到完美准确率,却需10⁵步才能在测试集上泛化。“顿悟”是一种病理性现象,我们希望加以避免[8]。如图3(c)所示,总体而言,谐波损失有效减轻了顿悟现象。位于直线 y = x 上的数据点代表未出现顿悟的模型——其训练准确率与测试准确率同步提升。这种改善在模加法与置换合成任务中尤为显著:标准MLP表现出严重的顿悟现象,而谐波MLP的大部分数据点则明显更接近 y = x 线。
3.5 案例研究:模加法
本节以模加法任务为案例,深入分析为何谐波MLP相比标准MLP能促进更具可解释性的表征并实现更优泛化。如图4所示,若不施加权重衰减,标准MLP在模加法任务上往往无法泛化;仅当加入强权重衰减后才能实现泛化,但此时:(a) 如图4所示,仍出现显著的顿悟现象;(b) 尽管前两个主成分近似构成一个圆,其所解释的总方差比例却远低于100%,残留大量未解释方差。相比之下,谐波模型在模加法任务上可快速泛化且无顿悟现象;其嵌入表征形成一个完美圆形(见图4)。谐波MLP之所以能更优地构建圆形结构并提升泛化性能,可归因于谐波损失的固有性质(见第2节所述):

这种有限收敛点的存在,带来了三重优势:(a) 收敛速度更快;(b) 泛化性能更优;(c) 表征更具可解释性。
4 MNIST实验
在视觉任务中,卷积神经网络已通过展示“边缘检测器”、“车轮检测器”等特征被证明具有(至少一定程度的)可解释性[9]。本节中,我们表明:当使用全连接网络训练MNIST数据集时,谐波损失能够产生更具可解释性的网络。作为概念验证,我们对比了分别采用交叉熵损失与谐波损失训练的单层神经网络。输入图像首先被展平,随后通过一个784×10的线性层获得logits。模型训练时批量大小为64,学习率为0.001,共训练10轮,最终交叉熵损失模型达到92.50%的测试准确率,谐波损失模型达到92.49%的测试准确率。
图4显示,谐波模型的权重比标准模型的权重更具可解释性。与其核心原理一致,谐波模型的权重几乎完美地对齐于各类别的中心(即每个数字的典型图像)。此外,它对边缘像素赋予接近零的权重值;而交叉熵损失训练的模型则缺乏将无关背景权重精确压缩至零的内在激励机制。
5 GPT-2 实验
大量机制性可解释性研究致力于理解大语言模型。例如,探测(probing)与归因(attribution)方法是有效的后验分析工具。尽管这些方法取得了一定(部分)成功,但它们并非从一开始就构建可解释模型,而更像是在干草堆中寻找绣花针。我们主张:若能在预训练阶段就使语言模型本身具备更高可解释性,将更为理想。通过在训练中采用谐波损失,我们能够得到一个语言模型——它能够“生长”出类似晶体结构的表征,同时在性能上与标准模型(采用交叉熵损失训练)相当。
我们在 OpenWebText 上预训练了一个 GPT-2 Small 模型(1.28亿参数,基于 NanoGPT 实现)。其中嵌入矩阵与解嵌入矩阵权重共享(tied)。训练使用 8 块 V100 GPU,序列块长度(block size)为 1024,每批次处理 480 个块。优化器选用 Adam,参数设为 β₁ = 0.9,β₂ = 0.95。 针对谐波损失,依据第2节中关于谐波指数的讨论,我们取 n = √768 ≈ 28。 对于标准 GPT(谐波 GPT),我们采用线性学习率预热策略:在前 2k(1k)步将学习率升至最大值 6×10⁻⁴(6×10⁻³),随后在第 2k 至 10k 步之间采用余弦退火策略,最终学习率降至 3×10⁻⁵(3×10⁻⁴)。
如图5左上角所示,谐波 GPT 在初期收敛更快(部分归因于更大的初始学习率),最终在第 10k 步时达到与标准模型相近的性能。其最终验证损失分别为:标准模型 3.159,谐波模型 3.146。从训练损失曲线还可看出,谐波 GPT 的波动更小。这表明谐波损失在真实世界模型中同样有效。

为检验所学嵌入的可解释性,我们从文献[10]中选取了十二个功能向量任务。每个数据集包含大量具有特定关系的输入-输出词对。例如,“现在时–过去时”(present-past)数据集包含诸如 jump–jumped、fasten–fastened、win–won 等词对。
为构造平行四边形,我们可以从数据集中抽取两组不同的词对,得到类似 (jump, jumped, fasten, fastened) 的四元组,这些四元组预期应构成平行四边形。每个词被分词为若干词元;若得到多个词元,则取最后一个词元。我们将词元嵌入投影至前两个主成分空间。对于四元组 (i, j, m, n),其二维主成分嵌入为 (Eᵢ, Eⱼ, Eₘ, Eₙ);我们定义平行四边形损失

为:

其中,

是一个用于归一化损失的缩放因子(

)。我们获得了10000个四元组,通过计算它们的平行四边形损失来衡量平行四边形的质量。我们在图5右上角绘制了它们的累积分布函数:对于每一项任务,谐波GPT产生的平行四边形损失(即更好的平行四边形)低于标准GPT。我们在图5底部展示了在当前-过去任务中获得的平行四边形。平行四边形按质量从左到右降序排列。谐波GPT倾向于生成视觉上更吸引人的、更“矩形”的平行四边形,而标准GPT则生成扁平的“平行四边形”。关于内部表示的讨论包含在附录C中。
6 消融实验 谐波损失(Harmonic loss)对标准交叉熵损失做了两项主要修改:(i) 通过 ℓ₂ 距离计算 logits;(ii) 使用如公式 (2) 所示的 HarMax 函数。为分别考察这两项修改的独立贡献,我们进行了一系列有针对性的消融实验:每次仅替换其中一个组件,其余训练流程保持不变。具体而言,我们在上下文学习(in-context learning)任务与模加(modular addition)任务上,使用消融后的损失函数训练 MLP 模型。
结果如图 6 所示。在上下文学习任务中,我们观察到:仅引入 HarMax 或仅采用 ℓ₂ logits 均足以复现完整谐波损失的全部性能;相比之下,在模加任务中,HarMax 与 ℓ₂ logits 两者缺一不可,方能达到完整谐波损失的性能水平。尽管仅引入其中任一组件也能提升圆形表征(circular representation)的质量,但其解释方差(explained variance)仍显著低于 100%。总体而言,HarMax 与 ℓ₂ logits 在提升表征可解释性方面均发挥着关键作用。

7 相关工作
表征与机制可解释性:大量研究表明,大语言模型(LLMs)能够在空间 [11]、时间 [12] 与颜色 [13] 等领域形成概念性表征。此类表征的结构涵盖一维概念 [11, 14–16],以及多维表征,例如格点结构(lattices)[17–19] 和环形结构(circles)[20, 21]。尽管这些表征的结构与某些几何模式相关,但通常仍存在显著的未解释方差(unexplained variance),因此亟需提升神经网络表征的可解释性。
损失函数:已有研究指出,损失函数会影响模型学习数据表征的方式,并以独特方式影响其能力 [22–28]。关于机器学习中各类损失函数的全面综述,可参见 [29] 与 [30]。本文提出的谐波损失(harmonic loss)在标准监督学习中提供了一种替代性的监督信号,其方式为:(a) 以尺度不变的 HarMax 函数取代常规的 SoftMax 归一化;(b) 用欧氏距离而非点积来计算 logits。虽然该损失与对比损失(contrastive loss)有一定相似性——二者均通过欧氏距离这一度量以最大化不同类别之间的分离——但对比学习方法本身本质上并非监督式方法:其通常需额外附加一个交叉熵层以生成 logits,从而重新引入 SoftMax(及其固有缺陷)。我们在第 6 节中进一步表明,仅使用欧氏距离尚不足以完全复现谐波损失的全部能力。此外,在语言建模中,直接采用基于欧氏距离的监督学习方法相对未被充分探索,目前主要局限于简单任务(如句子情感分类)[31]。我们在附录 D 中提供了谐波损失与其他损失函数更全面的比较。
8 结论
本文提出了谐波损失(harmonic loss)作为训练神经网络与大语言模型(LLMs)时标准交叉熵损失的一种替代方案。我们发现,采用谐波损失训练的模型相较于标准模型表现更优,具体体现在:(a) 减少了“顿悟”(grokking)现象;(b) 实现泛化所需的训练数据更少;(c) 提升了表征的可解释性。我们还将采用谐波损失训练的 GPT-2 模型与标准 GPT-2 进行对比,结果表明:前者习得的表征具有更高的可解释性。未来仍需进一步研究,以探索本文发现对更大规模模型的可扩展性与适用性。
原文链接:https://arxiv.org/pdf/2502.01628v2