首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >循环深度Transformer与MoE路由编码教程

循环深度Transformer与MoE路由编码教程

原创
作者头像
用户11764306
发布2026-04-26 12:22:20
发布2026-04-26 12:22:20
240
举报

在本教程中,探索OpenMythos的实现,这是对Claude Mythos架构的理论重构,通过迭代计算而非增加参数规模来实现更深层的推理。构建并分析使用GQA和MLA注意力机制的模型,通过KV缓存比较检查内存效率,并通过循环更新的谱属性验证稳定性。然后在结构化的奇偶校验任务上训练模型,并研究在推理时增加循环深度如何在不重新训练的情况下提升性能。在此过程中,还通过ACT暂停机制检查自适应计算,并监控MoE层中的专家利用率,从而对这种新兴架构提供全面的实践理解。

代码语言:python
复制
import subprocess, sys
try:
   import open_mythos  # noqa: F401
except ImportError:
   subprocess.check_call([sys.executable, "-m", "pip", "install", "-q",
                          "open-mythos"])

import math, time, copy
from collections import Counter, defaultdict

import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
import matplotlib.pyplot as plt

from open_mythos.main import (
   OpenMythos, MythosConfig,
   ACTHalting, MoEFFN,
)

torch.manual_seed(0); np.random.seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"▸ device = {device}   |   torch = {torch.__version__}")

安装并导入所有必需的依赖项,初始化运行OpenMythos的环境。为GQA和MLA注意力机制构建配置,并实例化各自的模型。比较它们的参数量,以理解架构差异如何影响模型规模。

代码语言:python
复制
def make_config(attn_type: str, *, dim=128, n_heads=4, n_experts=4,
               max_loops=8, seq_len=128, vocab=256):
   base = dict(
       vocab_size=vocab, dim=dim, n_heads=n_heads,
       max_seq_len=seq_len, max_loop_iters=max_loops,
       prelude_layers=1, coda_layers=1,
       n_experts=n_experts, n_shared_experts=1,
       n_experts_per_tok=2, expert_dim=dim // 2,
       lora_rank=8, attn_type=attn_type,
   )
   if attn_type == "gqa":
       return MythosConfig(**base, n_kv_heads=2)
   return MythosConfig(
       **base, n_kv_heads=n_heads,
       kv_lora_rank=32, q_lora_rank=64,
       qk_rope_head_dim=16, qk_nope_head_dim=16, v_head_dim=16,
   )

cfg_gqa = make_config("gqa")
cfg_mla = make_config("mla")
m_gqa = OpenMythos(cfg_gqa).to(device)
m_mla = OpenMythos(cfg_mla).to(device)

print("\n─── 第1部分 ─ 模型规模 ──────────────────────────────")
print(f"GQA 参数量 : {sum(p.numel() for p in m_gqa.parameters()):>10,}")
print(f"MLA 参数量 : {sum(p.numel() for p in m_mla.parameters()):>10,}")
代码语言:python
复制
def cache_bytes(kv: dict) -> int:
   total = 0
   for entry in kv.values():
       for t in entry.values():
           total += t.element_size() * t.numel()
   return total

x = torch.randint(0, 256, (1, 64), device=device)
ck_gqa, ck_mla = {}, {}
with torch.no_grad():
   m_gqa(x, n_loops=4, kv_cache=ck_gqa)
   m_mla(x, n_loops=4, kv_cache=ck_mla)

gqa_kb = cache_bytes(ck_gqa) / 1024
mla_kb = cache_bytes(ck_mla) / 1024
print("\n─── 第2部分 ─ KV缓存占用 (1×64 tokens, 4 loops) ─")
print(f"GQA 缓存 : {gqa_kb:6.2f} KB   ({len(ck_gqa)} 层键)")
print(f"MLA 缓存 : {mla_kb:6.2f} KB   ({len(ck_mla)} 层键)")
print(f"比例      : MLA 约小 {gqa_kb / max(mla_kb, 1e-9):.2f} 倍")

def show_stability(model, tag):
   A = model.recurrent.injection.get_A()
   print(f"{tag:3s}  ρ(A): min={A.min():.4f}  max={A.max():.4f}  "
         f"mean={A.mean():.4f}  stable={bool((A < 1).all() and (A > 0).all())}")

print("\n─── 第3部分 ─ 初始化时的谱半径 ──────────────────")
show_stability(m_gqa, "GQA")
show_stability(m_mla, "MLA")

opt = torch.optim.Adam(m_mla.parameters(), lr=1.0)
for _ in range(30):
   loss = m_mla(torch.randint(0, 256, (2, 16), device=device),
                n_loops=2).square().mean()
   opt.zero_grad(); loss.backward(); opt.step()
show_stability(m_mla, "MLA 经过极端训练后 (lr=1.0, 30步)")

计算并比较前向传播过程中GQA和MLA注意力类型的KV缓存内存占用。通过分析矩阵A的谱半径来检查循环组件的稳定性。用极端训练条件对模型进行压力测试,以确认稳定性得以保持。

代码语言:python
复制
VOCAB = 64
SEQ_LEN = 24

def make_batch(batch=64, seq_len=SEQ_LEN):
   x = torch.randint(1, 3, (batch, seq_len), device=device)
   bits = x - 1
   parity = bits.cumsum(dim=1) % 2
   y = parity + 1
   return x, y

cfg = MythosConfig(
   vocab_size=VOCAB, dim=64, n_heads=4, n_kv_heads=2,
   max_seq_len=SEQ_LEN + 4, max_loop_iters=16,
   prelude_layers=1, coda_layers=1,
   n_experts=4, n_shared_experts=1, n_experts_per_tok=2,
   expert_dim=32, lora_rank=4, attn_type="gqa",
   act_threshold=0.99,
)
model = OpenMythos(cfg).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
T_TRAIN = 3

print("\n─── 第5部分 ─ 训练 (T_train = 3) ───────────────────")
print(f"参数量: {sum(p.numel() for p in model.parameters()):,}")
losses = []
t0 = time.time()
for step in range(600):
   x, y = make_batch(64)
   logits = model(x, n_loops=T_TRAIN)
   loss = F.cross_entropy(logits.reshape(-1, VOCAB), y.reshape(-1))
   opt.zero_grad(); loss.backward()
   opt.step()
   losses.append(loss.item())
   if step % 100 == 0 or step == 599:
       with torch.no_grad():
           acc = (logits.argmax(-1) == y).float().mean().item()
       print(f"step {step:3d}   loss={loss.item():.4f}   acc@T3={acc:.3f}")
print(f"训练耗时: {time.time() - t0:.1f}秒")

定义一个累积奇偶校验任务,用于在结构化的序列问题上训练模型。使用固定的循环深度初始化OpenMythos模型,并通过交叉熵损失进行训练。在整个训练过程中,监控损失和准确率,以评估模型在受限深度下的学习效果。

代码语言:python
复制
model.eval()
T_sweep = [1, 2, 3, 4, 6, 8, 10, 12, 14, 16]
accs = []
with torch.no_grad():
   x_eval, y_eval = make_batch(512)
   for T in T_sweep:
       logits = model(x_eval, n_loops=T)
       accs.append((logits.argmax(-1) == y_eval).float().mean().item())

print("\n─── 第6部分 ─ 深度外推 (T_train=3) ──────────")
for T, a in zip(T_sweep, accs):
   bar = "█" * int(a * 40)
   marker = "  ← 训练深度" if T == T_TRAIN else ""
   print(f"T={T:2d}  acc={a:.3f}  {bar}{marker}")

halt_trace: list[torch.Tensor] = []
orig_halt = model.recurrent.act.forward

def halt_hook(self, h):
   p = orig_halt(h)
   halt_trace.append(p.detach().cpu())
   return p
model.recurrent.act.forward = halt_hook.__get__(model.recurrent.act, ACTHalting)

with torch.no_grad():
   x_h, _ = make_batch(1)
   _ = model(x_h, n_loops=16)

model.recurrent.act.forward = orig_halt

halts = torch.stack(halt_trace, dim=0)[:, 0].numpy()
print(f"\n─── 第7部分 ─ ACT暂停矩阵 (循环次数 × 位置) ───")
print(f"形状: {halts.shape}  |  "
     f"每轮平均暂停概率: "
     f"{', '.join(f'{v:.2f}' for v in halts.mean(1))}")

通过改变推理循环次数来评估训练好的模型,以研究深度外推。观察增加循环深度如何在不重新训练模型的情况下提高准确率。同时,对ACT机制进行插桩,以捕获每个序列位置和迭代的暂停概率。

代码语言:python
复制
expert_hits = Counter()
orig_moe = model.recurrent.block.ffn.forward

def moe_hook(self, x):
   flat = x.view(-1, x.shape[-1])
   logits = self.router(flat) + self.router_bias
   scores = F.softmax(logits, dim=-1)
   _, idx = scores.topk(self.topk, dim=-1)
   for e in idx.flatten().tolist():
       expert_hits[e] += 1
   return orig_moe(x)

model.recurrent.block.ffn.forward = moe_hook.__get__(
   model.recurrent.block.ffn, MoEFFN)

with torch.no_grad():
   x_m, _ = make_batch(32)
   _ = model(x_m, n_loops=T_TRAIN)

model.recurrent.block.ffn.forward = orig_moe

print("\n─── 第8部分 ─ MoE专家利用率 ───────────────────")
total = sum(expert_hits.values())
for eid in range(cfg.n_experts):
   share = expert_hits.get(eid, 0) / max(total, 1)
   print(f"专家 {eid}: {share*100:5.2f}% 的 top-k 槽位")

prompt = torch.tensor([[1, 2, 1, 1, 2, 2, 1, 2]], device=device)
print("\n─── 第9部分 ─ 生成 ───────────────────────────────")
print(f"提示 (奇偶模式): {prompt.tolist()[0]}")
for T_gen in [1, 4, 12]:
   with torch.no_grad():
       out = model.generate(prompt, max_new_tokens=8,
                            n_loops=T_gen, temperature=0.1, top_k=2)
   print(f"T_gen={T_gen:2d}  → {out.tolist()[0]}")

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(losses)
axes[0].set_title("训练损失 (奇偶校验任务)")
axes[0].set_xlabel("步数"); axes[0].set_ylabel("交叉熵")
axes[0].grid(alpha=0.3)

axes[1].plot(T_sweep, accs, "o-", linewidth=2, markersize=8)
axes[1].axvline(T_TRAIN, color="red", linestyle="--",
               label=f"T_train = {T_TRAIN}")
axes[1].set_title("深度外推: 准确率 vs 推理循环数")
axes[1].set_xlabel("推理时的 n_loops"); axes[1].set_ylabel("准确率")
axes[1].legend(); axes[1].grid(alpha=0.3); axes[1].set_ylim(0, 1.05)

im = axes[2].imshow(halts, aspect="auto", cmap="viridis",
                   vmin=0, vmax=halts.max())
axes[2].set_title("ACT 暂停概率\n(循环 t × 位置)")
axes[2].set_xlabel("位置"); axes[2].set_ylabel("循环迭代 t")
plt.colorbar(im, ax=axes[2], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.savefig("openmythos_tutorial.png", dpi=120, bbox_inches="tight")
plt.show()

分析MoE层中的专家利用率,追踪token如何被路由到不同的专家。然后,在不同的循环深度下生成序列,观察其对输出结果的影响。最后,通过图表可视化训练损失、深度外推性能和ACT暂停行为。

总之,展示了OpenMythos如何有效利用循环计算实现深度外推,使模型仅通过增加推理时的循环次数就能提高准确率。观察到即使在极端训练条件下循环机制仍保持稳定,且与GQA相比,MLA注意力显著减少了KV缓存的内存使用。同时了解了ACT如何实现跨序列位置的动态计算,以及MoE路由如何将工作负载分配给不同的专家。总体而言,确立了这种架构为计算自适应推理提供了一个有前景的方向——通过增加推理计算量来换取更好的性能,而无需修改模型参数。

在此查看完整代码和笔记本。另外,欢迎在Twitter上关注,别忘了加入130k+的ML SubReddit并订阅新闻通讯。FINISHED

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档