在本教程中,探索OpenMythos的实现,这是对Claude Mythos架构的理论重构,通过迭代计算而非增加参数规模来实现更深层的推理。构建并分析使用GQA和MLA注意力机制的模型,通过KV缓存比较检查内存效率,并通过循环更新的谱属性验证稳定性。然后在结构化的奇偶校验任务上训练模型,并研究在推理时增加循环深度如何在不重新训练的情况下提升性能。在此过程中,还通过ACT暂停机制检查自适应计算,并监控MoE层中的专家利用率,从而对这种新兴架构提供全面的实践理解。
import subprocess, sys
try:
import open_mythos # noqa: F401
except ImportError:
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q",
"open-mythos"])
import math, time, copy
from collections import Counter, defaultdict
import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
import matplotlib.pyplot as plt
from open_mythos.main import (
OpenMythos, MythosConfig,
ACTHalting, MoEFFN,
)
torch.manual_seed(0); np.random.seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"▸ device = {device} | torch = {torch.__version__}")安装并导入所有必需的依赖项,初始化运行OpenMythos的环境。为GQA和MLA注意力机制构建配置,并实例化各自的模型。比较它们的参数量,以理解架构差异如何影响模型规模。
def make_config(attn_type: str, *, dim=128, n_heads=4, n_experts=4,
max_loops=8, seq_len=128, vocab=256):
base = dict(
vocab_size=vocab, dim=dim, n_heads=n_heads,
max_seq_len=seq_len, max_loop_iters=max_loops,
prelude_layers=1, coda_layers=1,
n_experts=n_experts, n_shared_experts=1,
n_experts_per_tok=2, expert_dim=dim // 2,
lora_rank=8, attn_type=attn_type,
)
if attn_type == "gqa":
return MythosConfig(**base, n_kv_heads=2)
return MythosConfig(
**base, n_kv_heads=n_heads,
kv_lora_rank=32, q_lora_rank=64,
qk_rope_head_dim=16, qk_nope_head_dim=16, v_head_dim=16,
)
cfg_gqa = make_config("gqa")
cfg_mla = make_config("mla")
m_gqa = OpenMythos(cfg_gqa).to(device)
m_mla = OpenMythos(cfg_mla).to(device)
print("\n─── 第1部分 ─ 模型规模 ──────────────────────────────")
print(f"GQA 参数量 : {sum(p.numel() for p in m_gqa.parameters()):>10,}")
print(f"MLA 参数量 : {sum(p.numel() for p in m_mla.parameters()):>10,}")def cache_bytes(kv: dict) -> int:
total = 0
for entry in kv.values():
for t in entry.values():
total += t.element_size() * t.numel()
return total
x = torch.randint(0, 256, (1, 64), device=device)
ck_gqa, ck_mla = {}, {}
with torch.no_grad():
m_gqa(x, n_loops=4, kv_cache=ck_gqa)
m_mla(x, n_loops=4, kv_cache=ck_mla)
gqa_kb = cache_bytes(ck_gqa) / 1024
mla_kb = cache_bytes(ck_mla) / 1024
print("\n─── 第2部分 ─ KV缓存占用 (1×64 tokens, 4 loops) ─")
print(f"GQA 缓存 : {gqa_kb:6.2f} KB ({len(ck_gqa)} 层键)")
print(f"MLA 缓存 : {mla_kb:6.2f} KB ({len(ck_mla)} 层键)")
print(f"比例 : MLA 约小 {gqa_kb / max(mla_kb, 1e-9):.2f} 倍")
def show_stability(model, tag):
A = model.recurrent.injection.get_A()
print(f"{tag:3s} ρ(A): min={A.min():.4f} max={A.max():.4f} "
f"mean={A.mean():.4f} stable={bool((A < 1).all() and (A > 0).all())}")
print("\n─── 第3部分 ─ 初始化时的谱半径 ──────────────────")
show_stability(m_gqa, "GQA")
show_stability(m_mla, "MLA")
opt = torch.optim.Adam(m_mla.parameters(), lr=1.0)
for _ in range(30):
loss = m_mla(torch.randint(0, 256, (2, 16), device=device),
n_loops=2).square().mean()
opt.zero_grad(); loss.backward(); opt.step()
show_stability(m_mla, "MLA 经过极端训练后 (lr=1.0, 30步)")计算并比较前向传播过程中GQA和MLA注意力类型的KV缓存内存占用。通过分析矩阵A的谱半径来检查循环组件的稳定性。用极端训练条件对模型进行压力测试,以确认稳定性得以保持。
VOCAB = 64
SEQ_LEN = 24
def make_batch(batch=64, seq_len=SEQ_LEN):
x = torch.randint(1, 3, (batch, seq_len), device=device)
bits = x - 1
parity = bits.cumsum(dim=1) % 2
y = parity + 1
return x, y
cfg = MythosConfig(
vocab_size=VOCAB, dim=64, n_heads=4, n_kv_heads=2,
max_seq_len=SEQ_LEN + 4, max_loop_iters=16,
prelude_layers=1, coda_layers=1,
n_experts=4, n_shared_experts=1, n_experts_per_tok=2,
expert_dim=32, lora_rank=4, attn_type="gqa",
act_threshold=0.99,
)
model = OpenMythos(cfg).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
T_TRAIN = 3
print("\n─── 第5部分 ─ 训练 (T_train = 3) ───────────────────")
print(f"参数量: {sum(p.numel() for p in model.parameters()):,}")
losses = []
t0 = time.time()
for step in range(600):
x, y = make_batch(64)
logits = model(x, n_loops=T_TRAIN)
loss = F.cross_entropy(logits.reshape(-1, VOCAB), y.reshape(-1))
opt.zero_grad(); loss.backward()
opt.step()
losses.append(loss.item())
if step % 100 == 0 or step == 599:
with torch.no_grad():
acc = (logits.argmax(-1) == y).float().mean().item()
print(f"step {step:3d} loss={loss.item():.4f} acc@T3={acc:.3f}")
print(f"训练耗时: {time.time() - t0:.1f}秒")定义一个累积奇偶校验任务,用于在结构化的序列问题上训练模型。使用固定的循环深度初始化OpenMythos模型,并通过交叉熵损失进行训练。在整个训练过程中,监控损失和准确率,以评估模型在受限深度下的学习效果。
model.eval()
T_sweep = [1, 2, 3, 4, 6, 8, 10, 12, 14, 16]
accs = []
with torch.no_grad():
x_eval, y_eval = make_batch(512)
for T in T_sweep:
logits = model(x_eval, n_loops=T)
accs.append((logits.argmax(-1) == y_eval).float().mean().item())
print("\n─── 第6部分 ─ 深度外推 (T_train=3) ──────────")
for T, a in zip(T_sweep, accs):
bar = "█" * int(a * 40)
marker = " ← 训练深度" if T == T_TRAIN else ""
print(f"T={T:2d} acc={a:.3f} {bar}{marker}")
halt_trace: list[torch.Tensor] = []
orig_halt = model.recurrent.act.forward
def halt_hook(self, h):
p = orig_halt(h)
halt_trace.append(p.detach().cpu())
return p
model.recurrent.act.forward = halt_hook.__get__(model.recurrent.act, ACTHalting)
with torch.no_grad():
x_h, _ = make_batch(1)
_ = model(x_h, n_loops=16)
model.recurrent.act.forward = orig_halt
halts = torch.stack(halt_trace, dim=0)[:, 0].numpy()
print(f"\n─── 第7部分 ─ ACT暂停矩阵 (循环次数 × 位置) ───")
print(f"形状: {halts.shape} | "
f"每轮平均暂停概率: "
f"{', '.join(f'{v:.2f}' for v in halts.mean(1))}")通过改变推理循环次数来评估训练好的模型,以研究深度外推。观察增加循环深度如何在不重新训练模型的情况下提高准确率。同时,对ACT机制进行插桩,以捕获每个序列位置和迭代的暂停概率。
expert_hits = Counter()
orig_moe = model.recurrent.block.ffn.forward
def moe_hook(self, x):
flat = x.view(-1, x.shape[-1])
logits = self.router(flat) + self.router_bias
scores = F.softmax(logits, dim=-1)
_, idx = scores.topk(self.topk, dim=-1)
for e in idx.flatten().tolist():
expert_hits[e] += 1
return orig_moe(x)
model.recurrent.block.ffn.forward = moe_hook.__get__(
model.recurrent.block.ffn, MoEFFN)
with torch.no_grad():
x_m, _ = make_batch(32)
_ = model(x_m, n_loops=T_TRAIN)
model.recurrent.block.ffn.forward = orig_moe
print("\n─── 第8部分 ─ MoE专家利用率 ───────────────────")
total = sum(expert_hits.values())
for eid in range(cfg.n_experts):
share = expert_hits.get(eid, 0) / max(total, 1)
print(f"专家 {eid}: {share*100:5.2f}% 的 top-k 槽位")
prompt = torch.tensor([[1, 2, 1, 1, 2, 2, 1, 2]], device=device)
print("\n─── 第9部分 ─ 生成 ───────────────────────────────")
print(f"提示 (奇偶模式): {prompt.tolist()[0]}")
for T_gen in [1, 4, 12]:
with torch.no_grad():
out = model.generate(prompt, max_new_tokens=8,
n_loops=T_gen, temperature=0.1, top_k=2)
print(f"T_gen={T_gen:2d} → {out.tolist()[0]}")
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
axes[0].plot(losses)
axes[0].set_title("训练损失 (奇偶校验任务)")
axes[0].set_xlabel("步数"); axes[0].set_ylabel("交叉熵")
axes[0].grid(alpha=0.3)
axes[1].plot(T_sweep, accs, "o-", linewidth=2, markersize=8)
axes[1].axvline(T_TRAIN, color="red", linestyle="--",
label=f"T_train = {T_TRAIN}")
axes[1].set_title("深度外推: 准确率 vs 推理循环数")
axes[1].set_xlabel("推理时的 n_loops"); axes[1].set_ylabel("准确率")
axes[1].legend(); axes[1].grid(alpha=0.3); axes[1].set_ylim(0, 1.05)
im = axes[2].imshow(halts, aspect="auto", cmap="viridis",
vmin=0, vmax=halts.max())
axes[2].set_title("ACT 暂停概率\n(循环 t × 位置)")
axes[2].set_xlabel("位置"); axes[2].set_ylabel("循环迭代 t")
plt.colorbar(im, ax=axes[2], fraction=0.046, pad=0.04)
plt.tight_layout()
plt.savefig("openmythos_tutorial.png", dpi=120, bbox_inches="tight")
plt.show()分析MoE层中的专家利用率,追踪token如何被路由到不同的专家。然后,在不同的循环深度下生成序列,观察其对输出结果的影响。最后,通过图表可视化训练损失、深度外推性能和ACT暂停行为。
总之,展示了OpenMythos如何有效利用循环计算实现深度外推,使模型仅通过增加推理时的循环次数就能提高准确率。观察到即使在极端训练条件下循环机制仍保持稳定,且与GQA相比,MLA注意力显著减少了KV缓存的内存使用。同时了解了ACT如何实现跨序列位置的动态计算,以及MoE路由如何将工作负载分配给不同的专家。总体而言,确立了这种架构为计算自适应推理提供了一个有前景的方向——通过增加推理计算量来换取更好的性能,而无需修改模型参数。
在此查看完整代码和笔记本。另外,欢迎在Twitter上关注,别忘了加入130k+的ML SubReddit并订阅新闻通讯。FINISHED
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。