跳到主要内容

LLM 数学推导(详细版)

文档定位:把 LLM 里最常见、最容易在面试中被追问的数学公式单独抽出来,并按“详细推导版”来写。每个模块都按四件事组织:

  1. 公式是什么
  2. 在 LLM 中起什么作用
  3. 数学推导或直觉推导
  4. PyTorch 代码展示

这份文档默认不是背诵提纲,而是完整展开版。重点是把:

  • 公式从哪里来
  • 为什么这样设计
  • 梯度和复杂度怎么推
  • PyTorch 里对应哪段实现

全部放在一起。


目录


为什么要单独学 LLM 数学

很多面试会问你:

  • 注意力公式为什么要除以 sqrt(d_k)
  • FFN 为什么通常是参数量大头?
  • LayerNorm 和 RMSNorm 到底差在哪?
  • Adam 里的一阶、二阶矩估计是什么?
  • 反向传播里梯度是怎么一路传回去的?
  • RoPE 为什么能编码相对位置?
  • MoE 为什么“算力省了,通信炸了”?

如果你只会背模块名,不会写公式,也不会说梯度和复杂度,面试官通常会继续深挖。所以这份文档的重点不是“数学炫技”,而是把 能直接帮助你解释 LLM 行为的公式 讲清楚。


1. Embedding 与输出投影

公式

给定词表大小 V、隐藏维度 d_model,Embedding 矩阵记作:

E ∈ R^(V × d_model)

如果输入 token id 为 t,它对应的向量就是:

x = E[t]

模型最后输出 logits 时,通常会做:

logits = h W_out + b

其中:

  • h ∈ R^(d_model) 是最后一层 hidden state
  • W_out ∈ R^(d_model × V)

很多 LLM 会做 weight tying,即:

W_out = E^T

在 LLM 中的作用

  • Embedding 负责把离散 token id 映射到连续向量空间
  • 输出投影负责把 hidden state 映射回词表概率空间
  • weight tying 可以减少参数量,并让输入输出语义空间更一致

详细推导

Embedding 本质上不是“查字典”之外的更复杂操作,它就是一个 one-hot 向量和矩阵相乘:

如果 e_t ∈ R^V 是 token t 的 one-hot 表示,那么:

x = e_t^T E

因为 one-hot 只有第 t 个位置是 1,所以结果恰好就是 E 的第 t 行。

输出层也是类似的线性分类器。对每个词表项 i

logit_i = h · W_out[:, i]

也就是说,logit 本质上是“当前隐藏状态和每个词向量方向的匹配程度”。

如果使用 weight tying:

W_out = E^T

那么:

logit_i = h · E[i]

这说明输出层其实是在问:当前 hidden state h 和词表中第 i 个词向量是否对齐。
从优化角度看,这会把输入表征空间和输出分类空间绑在一起,通常能减少参数量,也常能带来更稳定的训练。

PyTorch 代码

import torch
import torch.nn as nn

vocab_size = 32000
d_model = 4096

embedding = nn.Embedding(vocab_size, d_model)
lm_head = nn.Linear(d_model, vocab_size, bias=False)

# weight tying
lm_head.weight = embedding.weight

token_ids = torch.tensor([[1, 42, 256]])
x = embedding(token_ids) # (batch, seq, d_model)
h = x[:, -1, :] # 取最后一个 token 的 hidden state
logits = lm_head(h) # (batch, vocab_size)

2. Positional Encoding 与 RoPE

公式

Transformer 本身对输入顺序不敏感,所以需要额外引入位置信息。

经典正弦位置编码:

PE(pos, 2i) = sin(pos / 10000^(2i / d_model))

PE(pos, 2i+1) = cos(pos / 10000^(2i / d_model))

RoPE(Rotary Positional Embedding)则把位置编码写成二维平面的旋转:

对一对特征 (x_1, x_2),位置 m 的旋转结果为:

[x_1', x_2'] = [x_1 cos θ_m - x_2 sin θ_m, x_1 sin θ_m + x_2 cos θ_m]

在 LLM 中的作用

  • 让模型知道“哪个 token 在前、哪个 token 在后”
  • RoPE 特别适合 decoder-only LLM,因为它天然把相对位置信息编码进 attention 点积里

详细推导

正弦位置编码的关键性质是:

PE(pos + k) 可以由 PE(pos) 线性组合出来

这让模型更容易从绝对位置中恢复相对位移关系。

更具体一点,利用三角恒等式:

sin(a + b) = sin a cos b + cos a sin b

cos(a + b) = cos a cos b - sin a sin b

可以看到 PE(pos + k) 能由 PE(pos) 和只依赖偏移量 k 的系数组合出来。
这就是“绝对位置编码里隐含相对位移信息”的来源。

RoPE 的核心更直接。若对 Query 和 Key 都做相同频率的旋转:

q_m = R_m q

k_n = R_n k

那么它们的点积满足:

q_m^T k_n = q^T R_(n-m) k

也就是说,attention 分数只和相对位置 (n - m) 有关,而不是绝对位置本身。这就是为什么 RoPE 特别适合长上下文和自回归建模。

若把二维旋转矩阵写出来:

R_m = [[cos θ_m, -sin θ_m], [sin θ_m, cos θ_m]]

那么:

q_m^T k_n = q^T R_m^T R_n k

因为旋转矩阵满足:

R_m^T = R_(-m)

所以:

R_m^T R_n = R_(n-m)

这一步就是 RoPE 最关键的数学结论:相对位置被自然编码进了 Query-Key 点积里。

PyTorch 代码

import torch

def apply_rope(x, cos, sin):
# x: (batch, heads, seq, dim)
x_even = x[..., ::2]
x_odd = x[..., 1::2]

x_rot_even = x_even * cos - x_odd * sin
x_rot_odd = x_even * sin + x_odd * cos

out = torch.stack([x_rot_even, x_rot_odd], dim=-1)
return out.flatten(-2)

3. Softmax 与 Cross Entropy

公式

给定 logits z_i,softmax 定义为:

p_i = exp(z_i) / Σ_j exp(z_j)

若正确类别为 y,cross entropy loss 为:

L = -log p_y

代入 softmax 可得:

L = -z_y + log Σ_j exp(z_j)

在 LLM 中的作用

  • softmax 把 logits 变成词表上的概率分布
  • cross entropy 是语言模型训练中最常见的目标函数

详细推导

cross entropy 的一个极重要结果是它对 logits 的梯度非常简洁:

∂L / ∂z_i = p_i - 1(i = y)

其中 1(i = y) 是 one-hot 标签。

这个结果意味着:

  • 对正确类别,梯度是 p_y - 1
  • 对错误类别,梯度是 p_i

也就是说,模型会自动“压低错误类别,抬高正确类别”。

这个梯度之所以重要,是因为它让 softmax + cross entropy 的反向传播既稳定又高效。

下面把这个梯度推出来。

先写:

p_i = exp(z_i) / Σ_j exp(z_j)

loss 为:

L = -log p_y = -z_y + log Σ_j exp(z_j)

对任意 z_i 求导:

∂L / ∂z_i = ∂(-z_y) / ∂z_i + ∂ log Σ_j exp(z_j) / ∂z_i

第一项:

  • i = y 时是 -1
  • i != y 时是 0

也就是:

∂(-z_y) / ∂z_i = -1(i = y)

第二项:

∂ log Σ_j exp(z_j) / ∂z_i = exp(z_i) / Σ_j exp(z_j) = p_i

所以最终得到:

∂L / ∂z_i = p_i - 1(i = y)

这也是为什么在实现里经常把 softmax + nll_loss 融合成一个 kernel:
前向和反向公式都非常规整,数值稳定版本也容易统一处理。

PyTorch 代码

import torch
import torch.nn.functional as F

logits = torch.randn(2, 5, requires_grad=True)
targets = torch.tensor([1, 3])

loss = F.cross_entropy(logits, targets)
loss.backward()

print(loss.item())
print(logits.grad) # 梯度大致对应 p - y_one_hot

4. Scaled Dot-Product Attention

公式

注意力的核心公式:

Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V

其中:

  • Q ∈ R^(n × d_k)
  • K ∈ R^(n × d_k)
  • V ∈ R^(n × d_v)

在 LLM 中的作用

  • 让当前 token 能看到历史 token 的信息
  • 是 Transformer 能建模长距离依赖的核心

详细推导

未缩放时,QK^T 的每个元素是长度为 d_k 的点积。

如果 q_i, k_i 独立且方差为 1,那么:

Var(q · k) = d_k

也就是说,点积的方差会随着维度线性增长。d_k 一大,softmax 输入就容易过大,导致分布过尖、梯度变差。

所以要除以:

sqrt(d_k)

使得点积的量级更稳定。

把这件事写得更严格一些。

设:

  • q = (q_1, ..., q_d)
  • k = (k_1, ..., k_d)

并假设每个分量独立、零均值、单位方差:

E[q_i] = E[k_i] = 0

Var(q_i) = Var(k_i) = 1

则点积:

s = q · k = Σ_i q_i k_i

因为独立且零均值:

E[s] = 0

又因为:

Var(q_i k_i) = E[q_i^2] E[k_i^2] = 1

所以:

Var(s) = Σ_i Var(q_i k_i) = d_k

这意味着 s 的典型量级会随着 sqrt(d_k) 增长。
如果直接把这样的 s 丢进 softmax,当 d_k 很大时,softmax 会快速饱和,注意力分布接近 one-hot,梯度变得很差。

将其缩放成:

s' = s / sqrt(d_k)

后就有:

Var(s') = 1

因此分数分布在不同 head_dim 下更稳定。

PyTorch 代码

import math
import torch

def scaled_dot_product_attention(q, k, v, mask=None):
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
if mask is not None:
scores = scores.masked_fill(~mask, float("-inf"))
probs = torch.softmax(scores, dim=-1)
out = torch.matmul(probs, v)
return out, probs

5. Multi-Head Attention 与 GQA

公式

多头注意力:

head_i = Attention(Q W_i^Q, K W_i^K, V W_i^V)

MHA(Q, K, V) = Concat(head_1, ..., head_h) W^O

GQA(Grouped Query Attention)中,多个 Query 头共享较少的 KV 头。

在 LLM 中的作用

  • MHA 让模型从不同子空间学习不同关系
  • GQA / MQA 主要是为了减少 KV Cache 体积,降低 decode 带宽压力

详细推导

MHA 的本质不是“把 attention 做很多遍”这么简单,而是:

  • 每个头的投影矩阵不同
  • 所以每个头在不同的表示子空间中做匹配

GQA 的关键计算收益来自 KV cache 大小:

若 attention 头数为 h_q,KV 头数为 h_kv,则 KV cache 大小大致与 h_kv 成正比。

所以当 h_kv << h_q 时,decode 的访存开销会显著下降。

更具体地,如果每层 KV cache 大小近似为:

KV_bytes_per_layer ≈ seq_len * h_kv * d_head * 2 * bytes_per_elem

那么:

  • 对 MHA,h_kv = h_q
  • 对 MQA,h_kv = 1
  • 对 GQA,1 < h_kv < h_q

因此 GQA / MQA 的收益不是来自“attention 算法公式变了”,而是来自 K/V 存储和读取量变小了
这也是为什么它们对 decode 更重要,而不是对 prefill 一样重要。

PyTorch 代码

import torch
import torch.nn as nn

class GQAProjection(nn.Module):
def __init__(self, d_model, n_heads, n_kv_heads):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = d_model // n_heads

self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, n_kv_heads * self.head_dim)
self.v_proj = nn.Linear(d_model, n_kv_heads * self.head_dim)

def forward(self, x):
b, s, _ = x.shape
q = self.q_proj(x).view(b, s, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(b, s, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(b, s, self.n_kv_heads, self.head_dim).transpose(1, 2)
return q, k, v

6. FFN 与 SwiGLU

公式

标准 FFN:

FFN(x) = W_2 σ(W_1 x + b_1) + b_2

常见激活函数可以是 ReLU、GELU。

SwiGLU 常写成:

SwiGLU(x) = (x W_a) ⊙ swish(x W_b)

其中:

swish(t) = t * sigmoid(t)

在 LLM 中的作用

  • FFN 负责逐 token 的非线性变换
  • 在很多 LLM 中,FFN 参数量通常比 attention 还大
  • SwiGLU 往往比普通 ReLU / GELU FFN 表达能力更强

详细推导

FFN 可以理解为“逐 token 的 MLP”。attention 负责信息混合,FFN 负责在每个 token 的特征维度上做非线性变换。

SwiGLU 的关键是门控:

  • 一路做特征变换
  • 一路做门控权重
  • 两路逐元素相乘

这比单一路径激活函数更灵活,因此很多现代 LLM 采用 SwiGLU / GeGLU / GLU 变体。

参数量为什么 FFN 常是大头?

设隐藏维度为 d,中间维度为 d_ff
标准 FFN 两层线性层参数量近似为:

params_FFN ≈ d * d_ff + d_ff * d = 2 d d_ff

如果 d_ff = 4d,则:

params_FFN ≈ 8 d^2

而 attention 的 Q/K/V/O 四个投影总参数量大致为:

params_attn ≈ 4 d^2

所以在很多 Transformer 里,FFN 参数量约为 attention 的两倍量级。
这也是“FFN 往往是参数量大头”的数学来源。

PyTorch 代码

import torch
import torch.nn as nn
import torch.nn.functional as F

class SwiGLUFFN(nn.Module):
def __init__(self, d_model, hidden_dim):
super().__init__()
self.w_a = nn.Linear(d_model, hidden_dim, bias=False)
self.w_b = nn.Linear(d_model, hidden_dim, bias=False)
self.w_out = nn.Linear(hidden_dim, d_model, bias=False)

def forward(self, x):
gate = F.silu(self.w_b(x))
value = self.w_a(x)
return self.w_out(value * gate)

7. LayerNorm 与 RMSNorm

公式

LayerNorm:

μ = (1 / d) Σ_i x_i

σ^2 = (1 / d) Σ_i (x_i - μ)^2

LN(x)_i = γ_i * (x_i - μ) / sqrt(σ^2 + ε) + β_i

RMSNorm:

RMS(x) = sqrt((1 / d) Σ_i x_i^2 + ε)

RMSNorm(x)_i = γ_i * x_i / RMS(x)

在 LLM 中的作用

  • 归一化有助于稳定训练
  • LayerNorm 更完整,RMSNorm 更轻量
  • 很多现代 LLM 更偏向 RMSNorm

详细推导

LayerNorm 做了两件事:

  1. 减均值
  2. 除标准差

RMSNorm 只保留第二步的“按尺度归一化”思想,不减均值。

在很多 LLM 场景里,减均值并不是最关键的,真正重要的是控制激活尺度,所以 RMSNorm 常能用更简单的计算达到相近效果。

还可以把 LayerNorm 的方差写开:

σ^2 = E[x^2] - (E[x])^2

而 RMSNorm 用的是:

RMS(x)^2 = E[x^2]

也就是说,RMSNorm 省掉的是 (E[x])^2 这部分居中操作。
如果模型训练过程中,真正更关键的是“防止激活尺度失控”,而不是“强制零均值”,那么 RMSNorm 就可能足够用了。

从实现角度看,RMSNorm:

  • 少一次均值减法
  • 少一次方差中心化
  • kernel 更简单

这也是它在 LLM 中流行的工程原因。

PyTorch 代码

import torch
import torch.nn as nn

class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def forward(self, x):
rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return x * rms * self.weight

8. Residual Connection

公式

残差连接最常见的形式:

y = x + F(x)

在 LLM 中的作用

  • 帮助深层网络训练
  • 让梯度更容易传播
  • 避免层数一深就退化

数学推导 / 直觉

看梯度:

∂y / ∂x = I + ∂F(x) / ∂x

即使 ∂F/∂x 很小,恒等映射 I 仍然给梯度保留了一条“直通路径”。这就是为什么残差连接对深层 Transformer 非常关键。

PyTorch 代码

def residual_block(x, sublayer):
return x + sublayer(x)

9. Transformer Block 的前向传播

公式

以 Pre-Norm Transformer Block 为例:

h_1 = x + Attention(Norm_1(x))

h_2 = h_1 + FFN(Norm_2(h_1))

在 LLM 中的作用

  • 这是现代 decoder-only LLM 最常见的基本骨架
  • 把 attention、FFN、norm、residual 串成完整前向路径

详细推导

Transformer block 的关键不是某一个单独公式,而是:

  • norm 控尺度
  • attention 做 token 间信息混合
  • FFN 做 token 内非线性变换
  • residual 保证深层稳定性

如果写成 decoder-only pre-norm block 的更完整形式:

u = Norm_1(x)

a = Attention(u)

h = x + a

v = Norm_2(h)

f = FFN(v)

y = h + f

这样拆开看更清楚:

  • Norm_1 / Norm_2 负责把输入尺度拉回稳定区间
  • Attention 负责 token 间信息交换
  • FFN 负责逐 token 特征变换
  • 两次 residual 确保梯度路径不断裂

PyTorch 代码

class SimpleTransformerBlock(nn.Module):
def __init__(self, dim, attn, ffn, norm_cls=nn.LayerNorm):
super().__init__()
self.norm1 = norm_cls(dim)
self.norm2 = norm_cls(dim)
self.attn = attn
self.ffn = ffn

def forward(self, x, mask=None):
x = x + self.attn(self.norm1(x), mask=mask)
x = x + self.ffn(self.norm2(x))
return x

10. 反向传播与链式法则

公式

链式法则:

z = f(y)y = g(x),则

∂z / ∂x = (∂z / ∂y) (∂y / ∂x)

以线性层为例:

y = xW + b

若上游梯度为 G = ∂L / ∂y,则:

∂L / ∂W = x^T G

∂L / ∂x = G W^T

在 LLM 中的作用

  • backward 就是把 loss 的梯度沿计算图一层层传回去
  • 训练的所有参数更新都依赖这些梯度

详细推导

这部分最容易被问的不是“会不会推一大页矩阵求导”,而是:

  1. 你知不知道梯度是怎么传回去的
  2. 你能不能说清哪些量需要缓存给 backward
  3. 你是否理解为什么反向传播显存开销大

例如线性层:

y_j = Σ_i x_i W_ij + b_j

所以:

∂L / ∂W_ij = x_i * ∂L / ∂y_j

写成矩阵形式就是:

∂L / ∂W = x^T G

如果把 batch 维一起考虑,设:

  • X ∈ R^(B × d_in)
  • W ∈ R^(d_in × d_out)
  • Y = XW
  • G = ∂L / ∂Y ∈ R^(B × d_out)

则:

∂L / ∂W = X^T G

∂L / ∂X = G W^T

这两个式子非常重要,因为它解释了两件事:

  1. backward 需要拿到前向输入 X,所以前向中间结果要缓存
  2. 梯度计算本身也是矩阵乘法,所以 backward 往往同样是高代价算子

这也是为什么训练时显存不仅存参数,还要存激活。

PyTorch 代码

import torch

x = torch.randn(2, 4, requires_grad=True)
linear = torch.nn.Linear(4, 3)
target = torch.randn(2, 3)

out = linear(x)
loss = ((out - target) ** 2).mean()
loss.backward()

print(linear.weight.grad.shape) # (3, 4)
print(x.grad.shape) # (2, 4)

11. Adam 与 AdamW

公式

给定梯度 g_t

一阶矩估计:

m_t = β_1 m_(t-1) + (1 - β_1) g_t

二阶矩估计:

v_t = β_2 v_(t-1) + (1 - β_2) g_t^2

偏差修正:

m̂_t = m_t / (1 - β_1^t)

v̂_t = v_t / (1 - β_2^t)

参数更新:

θ_t = θ_(t-1) - α * m̂_t / (sqrt(v̂_t) + ε)

AdamW 会把 weight decay 和梯度更新解耦:

θ_t = θ_(t-1) - α * m̂_t / (sqrt(v̂_t) + ε) - α λ θ_(t-1)

在 LLM 中的作用

  • Adam / AdamW 是训练 LLM 最常见的优化器之一
  • 它能对不同参数维度自适应调整学习率
  • AdamW 的 decoupled weight decay 在大模型训练中更常用

详细推导

Adam 里的两个统计量:

  • m_t 类似“梯度的滑动平均”,让更新方向更平滑
  • v_t 类似“梯度平方的滑动平均”,用来估计每个参数方向上的梯度尺度

所以 Adam 的更新可以理解为:

“方向上参考一阶动量,步长上参考历史波动做自适应缩放”

偏差修正则是因为 m_t, v_t 在训练初期从 0 开始,会系统性偏小。

把偏差修正写开更清楚。

由于初始时 m_0 = 0,递推展开:

m_t = (1 - β_1) Σ_(i=1)^t β_1^(t-i) g_i

如果梯度在统计意义上近似平稳,E[g_i] = μ,则:

E[m_t] = (1 - β_1) Σ_(i=1)^t β_1^(t-i) μ = (1 - β_1^t) μ

所以 m_t 比真实均值 μ 少了一个因子 `(1 - β_1^t)``,训练初期会偏小。
因此要除以这个因子,得到:

m̂_t = m_t / (1 - β_1^t)

v_t 的偏差修正同理。

为什么 AdamW 比 Adam 更常见?

因为传统 Adam 如果把 L2 正则直接混进梯度,会和自适应学习率缩放耦合在一起;AdamW 则把 weight decay 单独拆开:

θ <- θ - α * update - α λ θ

这样正则项不会被二阶矩缩放干扰,通常更符合“参数衰减”的原始意图。

PyTorch 代码

import torch
import torch.nn as nn

model = nn.Linear(16, 8)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)

x = torch.randn(4, 16)
y = torch.randn(4, 8)

pred = model(x)
loss = ((pred - y) ** 2).mean()

optimizer.zero_grad()
loss.backward()
optimizer.step()

12. MoE Router 与 Top-k Gating

公式

设输入 token hidden state 为 x,router logits 为:

r = W_r x

router 概率:

p = softmax(r)

若采用 top-k routing,则只保留概率最大的 k 个 expert。

在 LLM 中的作用

  • 决定每个 token 该送到哪些 expert
  • 让模型在参数总量很大时,单次前向仍只激活少数 expert

详细推导

MoE 的核心不是“参数很多”,而是“激活稀疏”:

  • 总参数量可以很大
  • 但每个 token 只走少数 expert

这样单 token 的理论计算量更低,但代价是 token 分发会引入额外通信和负载均衡问题。

若一共有 N 个 expert,每个 token 只选 k 个 expert,则每个 token 的 FFN 计算量大致从:

O(N * FFN_cost) 的“全激活”

变成:

O(k * FFN_cost)

其中通常 k << N

但系统侧的新问题是:

  1. router 必须先算出每个 token 去哪
  2. token 要被重新分发到对应 expert
  3. 不同 expert 负载可能严重不均

所以 MoE 节省的是 理论算力,未必直接节省 端到端时间

PyTorch 代码

import torch
import torch.nn.functional as F

def topk_router(x, router_weight, k=2):
# x: (batch, seq, dim)
logits = torch.matmul(x, router_weight) # (batch, seq, n_experts)
probs = F.softmax(logits, dim=-1)
topk_probs, topk_idx = torch.topk(probs, k=k, dim=-1)
return topk_probs, topk_idx

13. 参数量、FLOPs 与 KV Cache 估算

公式

参数量

一个线性层 W ∈ R^(d_in × d_out) 的参数量约为:

params = d_in * d_out

Attention 计算复杂度

attention 分数矩阵 QK^T 的复杂度大致为:

O(n^2 d)

其中:

  • n 是序列长度
  • d 是 head_dim 或隐藏维度相关量

FFN 复杂度

FFN 通常是两个大矩阵乘法,复杂度近似:

O(n d d_ff)

KV Cache 大小

KV_bytes ≈ batch * seq_len * num_layers * num_kv_heads * head_dim * 2 * bytes_per_elem

在 LLM 中的作用

  • 帮你判断哪一部分是参数大头
  • 帮你估算长上下文显存压力
  • 帮你解释为什么 prefill 和 decode 的瓶颈不一样

详细推导

attention 的 n^2 来自哪里?

因为长度为 n 的序列中,每个 token 都要和另外 n 个 token 做匹配,构成一个 n × n 的注意力分数矩阵。

为什么 decode 常常更 memory-bound?

因为 decode 每一步新增的计算很少,但要不断读取庞大的权重和 KV Cache。也就是说,它不是“算不完”,而是“搬不动”。

再把 prefill 和 decode 的差别写得更数学一点:

  • prefill 处理整段长度为 n 的序列,attention 会形成 n × n 分数矩阵
  • decode 每一步只新增 1 个 token,对已有 n 个历史 token 做注意力

所以:

  • prefill 的 attention 计算更像大矩阵乘法,算力利用率高
  • decode 的 attention 更像“读很多历史 KV,再做较小计算”,更容易受带宽限制

KV cache 公式:

KV_bytes ≈ batch * seq_len * num_layers * num_kv_heads * head_dim * 2 * bytes_per_elem

中:

  • batch * seq_len 决定 token 总数
  • num_layers 表示每层都要存 K 和 V
  • num_kv_heads * head_dim 是每个 token 每层的 KV 向量大小
  • 2 对应 K 和 V 两份缓存

这也是为什么长上下文时,KV cache 会迅速成为显存主压力之一。

PyTorch 代码

def estimate_kv_cache_bytes(
batch,
seq_len,
num_layers,
num_kv_heads,
head_dim,
bytes_per_elem=2,
):
return (
batch
* seq_len
* num_layers
* num_kv_heads
* head_dim
* 2
* bytes_per_elem
)

bytes_used = estimate_kv_cache_bytes(
batch=8,
seq_len=4096,
num_layers=32,
num_kv_heads=8,
head_dim=128,
)

print(bytes_used / (1024 ** 3), "GB")

14. 访存量、激活参数量与 DeepSeek 类结构手算

公式

标准 MHA / GQA 的 decode 侧 KV 读取

对单个生成步、单个样本、单层来说,历史 KV 的读取量可近似写成:

KV_read_bytes_per_step_per_layer ≈ seq_len * 2 * num_kv_heads * head_dim * bytes_per_elem

把层数和 batch 带上:

KV_read_bytes_per_step ≈ batch * num_layers * seq_len * 2 * num_kv_heads * head_dim * bytes_per_elem

MLA 的缓存量

如果系统缓存的是更紧凑的 latent 表示,而不是完整 K/V,那么单层单 token 的缓存量更接近:

MLA_cache_bytes_per_token_per_layer ≈ d_latent * bytes_per_elem

相比标准 KV:

MHA_cache_bytes_per_token_per_layer ≈ 2 * num_kv_heads * head_dim * bytes_per_elem

Linear Attention 的状态量

若使用特征映射后的线性 attention,并维护前缀状态:

S_t = Σ_{i<=t} phi(k_i) v_i^T

z_t = Σ_{i<=t} phi(k_i)

则每层缓存状态大小近似为:

Linear_state_bytes_per_layer ≈ num_heads * (d_phi * d_v + d_phi) * bytes_per_elem

它随序列长度增长得更慢,甚至可以做到对 seq_len 不敏感。

MoE 的总参数量与激活参数量

若一个 MoE 层有 N 个 expert,每个 expert 参数量为 P_expert,共享部分参数量为 P_shared,router 参数量为 P_router,则:

P_total = P_shared + P_router + N * P_expert

若每个 token 只激活 k 个 expert,则单 token 激活参数量近似:

P_active = P_shared + P_router + k * P_expert

DeepSeek 类 MLA + MoE 结构的单步 decode 读取

把它写成最通用的面试模板:

Bytes_decode_step ≈ Bytes_dense_weight + Bytes_active_expert_weight + Bytes_cache_read + Bytes_comm

其中:

  • Bytes_dense_weight:主干 dense 部分权重读取
  • Bytes_active_expert_weight:本步被激活 expert 的权重读取
  • Bytes_cache_read:历史 attention 状态读取,MHA / GQA / MLA / linear attention 形式不同
  • Bytes_comm:多卡 MoE token 分发或聚合通信

在 LLM 中的作用

  • 帮你区分“模型总参数很大”和“单步真正读了多少”不是一回事
  • 帮你解释为什么 MoE 看起来参数很大,但单 token 激活的只是其中一小部分
  • 帮你解释为什么 MLA / GQA / linear attention 的价值主要体现在缓存和带宽,而不是只看理论 FLOPs
  • 帮你在面试里用统一框架手算 DeepSeek-V3 这类结构,而不是死记一个公开数字

详细推导

很多人手算时最容易把三件事混在一起:

  1. 总参数量
  2. 单 token 激活参数量
  3. 单步 decode 访存量

这三者必须先拆开。

先看标准 MHA。decode 第 t 步时,新 token 需要和前 t 个历史 token 的 K/V 做注意力,所以单层至少要把历史 KV 各读一遍。于是有:

KV_read_bytes_per_step_per_layer ≈ t * 2 * num_kv_heads * head_dim * bytes_per_elem

t 很大时,这个量会线性增长。也就是说,长上下文里 decode 慢,很多时候不是算不动,而是历史 KV 越读越多。

GQA / MQA 为什么有效?因为它们直接把 num_kv_heads 变小了。公式里别的量不变,只有 num_kv_heads 缩小,所以缓存和读取量会按比例下降。

MLA 再进一步。它的关键不是“把 attention 换成另一个公式”,而是让缓存更接近一个低维 latent 表示。于是每个 token、每层缓存的不是完整 K/V,而是更小的 latent。手算时最重要的是比较:

2 * num_kv_heads * head_dim

d_latent

谁更大。只要 d_latent 显著更小,长上下文缓存和读取压力就会明显下降。

Linear attention 又是另一个方向。它不再显式保存所有历史 token 的 K/V,而是把历史压缩成前缀状态 S_tz_t。所以它的状态量更像:

num_heads * (d_phi * d_v + d_phi)

而不再是:

seq_len * 2 * num_kv_heads * head_dim

这也是为什么 linear attention 在超长序列题里经常被拿来和标准 attention 对比。

再看 MoE。很多人一看到 expert 数量大,就直接说“推理一定更慢”,这是不对的。因为总参数量:

P_total = P_shared + P_router + N * P_expert

但单 token 激活的通常只有:

P_active = P_shared + P_router + k * P_expert

其中 k << N
所以 MoE 的问题不只是“算了多少”,而是“这些被选中的 expert 在哪里、要不要跨卡取、通信是不是成为瓶颈”。

把这些合起来,DeepSeek-V3 这类 MLA + MoE 结构在面试里最稳的手算法就是:

  1. 先估主干 dense 部分参数和每步权重读取
  2. 再估 MLA 相比标准 MHA/GQA 少掉多少缓存和读取
  3. 再估 MoE 的总 expert 参数量,以及每步只激活 k 个 expert 时的读取量
  4. 最后补一个系统项:如果 expert 分布在多卡,还要加上 AllToAll 或类似 token 分发通信

这样你的答案就会从“背模型名词”变成“能把结构翻译成带宽、显存和通信”。

PyTorch 代码

def estimate_mha_kv_read_bytes(
batch,
seq_len,
num_layers,
num_kv_heads,
head_dim,
bytes_per_elem=2,
):
return (
batch
* num_layers
* seq_len
* 2
* num_kv_heads
* head_dim
* bytes_per_elem
)


def estimate_mla_cache_bytes(
batch,
seq_len,
num_layers,
d_latent,
bytes_per_elem=2,
):
return batch * seq_len * num_layers * d_latent * bytes_per_elem


def estimate_linear_attention_state_bytes(
batch,
num_layers,
num_heads,
d_phi,
d_v,
bytes_per_elem=2,
):
return (
batch
* num_layers
* num_heads
* (d_phi * d_v + d_phi)
* bytes_per_elem
)


def estimate_moe_params(p_shared, p_router, p_expert, n_experts, top_k):
p_total = p_shared + p_router + n_experts * p_expert
p_active = p_shared + p_router + top_k * p_expert
return p_total, p_active


def estimate_deepseek_like_decode_bytes(
dense_weight_bytes,
active_expert_weight_bytes,
cache_read_bytes,
comm_bytes=0,
):
return (
dense_weight_bytes
+ active_expert_weight_bytes
+ cache_read_bytes
+ comm_bytes
)


mha_kv = estimate_mha_kv_read_bytes(
batch=1,
seq_len=8192,
num_layers=32,
num_kv_heads=8,
head_dim=128,
)

mla_cache = estimate_mla_cache_bytes(
batch=1,
seq_len=8192,
num_layers=32,
d_latent=512,
)

moe_total, moe_active = estimate_moe_params(
p_shared=400_000_000,
p_router=10_000_000,
p_expert=120_000_000,
n_experts=64,
top_k=2,
)

print("MHA decode KV read (GB):", mha_kv / (1024 ** 3))
print("MLA cache (GB):", mla_cache / (1024 ** 3))
print("MoE total params:", moe_total)
print("MoE active params per token:", moe_active)

15. Transformer 架构高频面试题

Transformer 模型结构总览

答:

三大架构变体:

架构代表模型注意力类型主要任务
Encoder-onlyBERT, RoBERTa双向全注意力文本分类、NER
Decoder-onlyGPT, LLaMA, Qwen因果单向注意力文本生成、对话
Encoder-DecoderT5, BART编码双向 + 解码因果 + 交叉注意力翻译、摘要

现代 LLM(Decoder-only)的标准结构:

Input Token IDs

[Embedding Layer] ← token embedding(通常 weight tying 共享 LM Head)

┌─────────────────────────────────┐
│ × N layers: │
│ RMSNorm → MHA → + Residual │ ← Pre-Norm + 残差
│ RMSNorm → FFN → + Residual │ ← Pre-Norm + 残差
└─────────────────────────────────┘

[RMSNorm] ← 最终归一化

[LM Head (Linear)] ← 投影到词表大小,输出 logits

Softmax → Token Probabilities

说一下 Decoder 的因果注意力,QKV 分别来自哪

答:

因果注意力(Causal Attention): 每个 token 只能看到自己和之前的 token,不能看到未来的 token。

因果掩码示例(4 个 token):
t1 t2 t3 t4
t1 [ 1 0 0 0 ] ← t1 只能看 t1
t2 [ 1 1 0 0 ] ← t2 能看 t1, t2
t3 [ 1 1 1 0 ] ← t3 能看 t1, t2, t3
t4 [ 1 1 1 1 ] ← t4 能看所有

实现:在 softmax 前将未来位置设为 -inf
scores = QK^T / √d_k
scores = scores.masked_fill(mask == 0, -inf)
attn = softmax(scores)

QKV 的来源:

架构Q 来自K 来自V 来自
Decoder self-attention当前层输入当前层输入当前层输入
Cross-attention (Enc-Dec)Decoder 层输入Encoder 输出Encoder 输出
Decoder-only LLM只有 self-attention,QKV 全来自同一输入

Transformer 介绍下 QKV 的作用

答:

Q(Query) = "我在找什么" K(Key) = "我有什么信息" V(Value) = "我的实际内容"

类比信息检索:
Q = 搜索关键词
K = 文档标题/索引
V = 文档正文

Attention(Q, K, V) = softmax(QK^T / √d_k) · V

Step 1: QK^T → 计算每个 token 与其他 token 的相关性(注意力分数)
Step 2: softmax → 归一化为概率分布
Step 3: × V → 用概率加权聚合各 token 的信息

为什么要分离 QKV 而不直接用 X?

  • 分离后,Q/K/V 可以学到不同的线性变换
  • Q 学"查什么",K 学"提供什么索引",V 学"给出什么内容"
  • 这比 X·X^T 更灵活(否则只能计算输入之间的余弦相似度)

推导多头注意力计算复杂度

答:

单头注意力:

Q, K, V ∈ R^{n × d_h}  (n=序列长度, d_h=head_dim)

QK^T: (n × d_h) × (d_h × n) = O(n² d_h) ← 注意力分数
softmax: O(n²) ← 逐行归一化
× V: (n × n) × (n × d_h) = O(n² d_h) ← 加权聚合

单头总计: O(n² d_h)

多头注意力(h 个头,d_h = d/h):

h 个头并行: h × O(n² · d/h) = O(n² d)          ← 总注意力计算

QKV 投影: 3 × (n × d) × (d × d) = O(3nd²) ← 线性投影
Output 投影: (n × d) × (d × d) = O(nd²) ← 输出投影

MHA 总复杂度: O(n²d + nd²)

瓶颈分析:

  • 当 n >> d(长序列):O(n²d) 主导 → attention 是瓶颈
  • 当 d >> n(短序列大模型):O(nd²) 主导 → 投影是瓶颈
  • LLM 实际中:prefill 时 n 较大 → attention 瓶颈;decode 时 n=1 → 投影瓶颈

为什么 Transformer 使用多头注意力

答:

  1. 多子空间表示:不同 head 可以关注不同类型的关系

    Head 1: 关注语法结构(主谓宾)
    Head 2: 关注语义相似性
    Head 3: 关注位置距离
    Head 4: 关注共指关系
  2. 训练更稳定:多个 head 的 attention 分布更分散,不容易出现某些 token 被完全忽略

  3. 并行计算:多个小矩阵乘法比一个大矩阵乘法更适合 GPU 并行

  4. 参数效率:多头的参数量和单头相同(d² 总共),但表达能力更强

  5. 经验验证:论文实验证明 h=8 显著优于 h=1(同参数量)

类比: 多头注意力相当于"注意力的集成学习"——多个专家从不同角度看问题,最后综合判断。

注意力机制除了 MHA、GQA,还知道哪些

答:

机制核心思想KV Cache典型模型
MHA每个 Q head 有独立 KV headn_h × d_h × 2GPT-3, LLaMA-1
MQA所有 Q head 共享 1 个 KV headd_h × 2PaLM, Falcon
GQA每组 Q head 共享 1 个 KV headn_kv × d_h × 2LLaMA-2/3
MLAKV 低秩压缩到 latentd_c + d_ropeDeepSeek-V2/R1
Linear Attention用 kernel 替代 softmax,O(n)固定状态矩阵RWKV, Mamba
Sliding Window只关注局部窗口内 token窗口大小 × 层数Mistral
Sparse Attention稀疏注意力模式(local + global)取决于稀疏度BigBird, Longformer
Cross AttentionQ 来自一个序列,KV 来自另一个-T5, 多模态模型

注意力机制类型(MHA MQA GQA)各自优缺点

答:

维度MHAMQAGQA
Q headshhh
KV headsh1g (1 < g < h)
KV Cache最大最小中间
模型质量最好有损失接近 MHA
推理吞吐最低最高中高
典型配置32Q/32KV32Q/1KV32Q/8KV

选择指南:

  • 追求质量 → MHA(参数充足时)
  • 追求推理效率 → MQA(可接受质量损失)
  • 平衡方案 → GQA(目前最主流,LLaMA-2/3、Qwen-2+ 都用 GQA)

Transformer 底层原理,为啥能替代 RNN

答:

维度RNN/LSTMTransformer
并行性必须顺序处理(t 依赖 t-1)完全并行(所有 token 同时处理)
长距离依赖O(n) 路径,梯度易消失O(1) 路径(任意 token 对直接交互)
训练效率无法利用 GPU 并行矩阵乘法天然适合 GPU
表达能力受限于固定大小隐状态Attention 可关注任意 token
计算复杂度O(n × d²)O(n² × d)(长序列时更贵)

Transformer 能替代 RNN 的根本原因:

  1. Self-Attention 实现全局信息聚合:任意两个 token 距离为 1(vs RNN 的 O(n))
  2. 完全并行化:训练速度快数十倍
  3. Scaling Law 表现更好:模型越大、数据越多,Transformer 的优势越明显
  4. KV Cache 使推理可行:虽然训练是 O(n²),但推理通过缓存实现增量计算

FFN 层是干嘛的,为什么先升维再降维

答:

FFN 的作用: 对每个 token 独立做非线性变换(MHA 负责 token 间交互,FFN 负责单 token 的特征变换)。

标准 FFN:    FFN(x) = W₂ · ReLU(W₁ · x + b₁) + b₂
SwiGLU FFN: FFN(x) = W₂ · (SiLU(W_gate · x) ⊙ W_up · x)

W₁/W_up: R^{d → 4d} ← 升维
W₂: R^{4d → d} ← 降维
(SwiGLU 中 4d 变为 8d/3 × 2 ≈ 5.3d)

为什么先升维再降维:

  1. 信息瓶颈理论:升维到更高维空间,非线性变换可以学习更复杂的模式;降维压缩回原始维度,迫使网络保留最重要的信息

  2. 类比

    d 维 → 4d 维:展开(把信息铺开,更容易做分离和变换)
    激活函数:非线性筛选(保留有用信息,抑制无用信息)
    4d 维 → d 维:压缩(把处理后的信息压回标准维度)
  3. 参数量分配:FFN 的参数量约占 Transformer 的 2/3(2 × d × 4d = 8d²),是模型"记忆知识"的主要载体

梯度消失、梯度爆炸的根本原因

答:

根本原因:链式法则中的连乘。

反向传播中,梯度通过链式法则传播:
∂L/∂w₁ = ∂L/∂hₙ × ∂hₙ/∂hₙ₋₁ × ... × ∂h₂/∂h₁ × ∂h₁/∂w₁
= ∂L/∂hₙ × Π_{i=1}^{n-1} ∂hᵢ₊₁/∂hᵢ × ∂h₁/∂w₁

如果 |∂hᵢ₊₁/∂hᵢ| < 1(每层梯度 < 1):
n 层连乘 → 梯度指数衰减 → 梯度消失

如果 |∂hᵢ₊₁/∂hᵢ| > 1(每层梯度 > 1):
n 层连乘 → 梯度指数增长 → 梯度爆炸

具体触发场景:

问题触发原因
梯度消失Sigmoid/Tanh 在饱和区梯度 → 0;深层网络连乘;权重初始化过小
梯度爆炸权重矩阵特征值 > 1;学习率过大;RNN 长序列

解决方案:

  • 残差连接(+1 保底)
  • LayerNorm/RMSNorm(稳定分布)
  • 合理初始化(Xavier/He)
  • 梯度裁剪(Gradient Clipping)
  • 使用 ReLU 族激活函数(正区梯度恒为 1)

ResNet 和 Transformer 中残差连接的作用

答:

数学原理:

无残差:y = F(x)        → dy/dx = dF/dx(可能很小)
有残差:y = F(x) + x → dy/dx = dF/dx + 1(至少为 1)

"+1" 保证了梯度至少有一条直通路径,不会消失

在 Transformer 中的三重作用:

  1. 梯度高速公路:梯度可以跳过中间层直接传到底层,解决深度网络的梯度消失

  2. 学习"增量"而非"全量"

    F(x) 学习的是"需要修改的部分"(delta/residual)
    而不是从头学完整变换
    → 优化更容易(学小的 delta 比学完整映射简单)
  3. 支持更深的网络:LLaMA-65B 有 80 层,没有残差连接根本训不动

Pre-Norm 变体(现代 LLM 标准):

y = x + Layer(Norm(x))    ← Pre-Norm:先归一化再变换
vs
y = Norm(x + Layer(x)) ← Post-Norm:先变换再归一化

Pre-Norm 训练更稳定(梯度更平滑),Post-Norm 理论表达力更强

大模型位置编码方式、RoPE 相比传统正余弦编码的区别

答:

常见位置编码方法:

方法原理外推能力额外参数代表模型
正弦位置编码固定 sin/cos 函数原始 Transformer
可学习绝对编码学习每个位置的 embedding有(max_len × d)GPT-2
相对位置编码(T5)学习相对距离的 bias一般T5
RoPE旋转矩阵编码位置LLaMA, Qwen
ALiBi基于距离的线性衰减BLOOM

RoPE vs 传统正弦编码:

维度正弦位置编码RoPE
编码方式加到 token embedding 上对 Q/K 做旋转
位置信息绝对位置相对位置(内积自然编码距离)
外推能力差(超过训练长度性能骤降)较好(配合 NTK/YaRN 可大幅外推)
与 KV Cache兼容天然兼容(K 已编码位置,可直接缓存)
前缀共享不同绝对位置导致 KV 不同相同前缀 KV 可复用

为什么用 RoPE:

  1. 相对位置编码不依赖绝对位置 → 更好的泛化
  2. 无额外参数 → 简洁
  3. 与高效 attention(FlashAttention)和 KV Cache 完美兼容
  4. 通过 NTK-aware/YaRN 缩放可扩展到更长上下文

为什么要用 LN,不用 BN

答:

维度Batch Normalization (BN)Layer Normalization (LN)
归一化维度跨 batch,对每个特征跨特征,对每个样本
依赖 batch✅ 需要一定 batch size❌ 独立于 batch
变长序列❌ 不同位置统计量不同✅ 每个 token 独立归一化
推理时 batch=1需要 running statistics直接计算,无问题
自回归生成不适用(batch=1 逐 token)完美适用

LN 在 LLM 中必须用的原因:

  1. 自回归生成时 batch=1:BN 退化(统计量不可靠)
  2. 序列长度动态变化:BN 需要固定维度的统计量
  3. 分布式训练:BN 需要跨 GPU 同步统计量,通信开销大

现代 LLM 实际用 RMSNorm(简化版 LN):

LN:      y = (x - μ) / √(σ² + ε) × γ + β    ← 减均值 + 除标准差
RMSNorm: y = x / √(mean(x²) + ε) × γ ← 只除 RMS,更快

PreNorm 和 PostNorm 区别

答:

PostNorm(原始 Transformer):
x → [MHA] → + x → [LayerNorm] → [FFN] → + → [LayerNorm]
↑__________________| ↑___________________|

PreNorm(GPT/LLaMA/现代 LLM):
x → [LayerNorm] → [MHA] → + x → [LayerNorm] → [FFN] → +
↑______| ↑______|
维度PostNormPreNorm
梯度稳定性差(深层训练困难)好(残差路径不经过 Norm)
训练难度高(需要 warmup 等技巧)低(开箱即用)
理论表达力更强(Norm 在残差之后)略弱
实际表现浅层模型可能更好深层模型显著更好
谁在用原始 Transformer, BERTGPT, LLaMA, Qwen, DeepSeek

面试关键点: 现代 LLM 全部使用 PreNorm(+ RMSNorm),因为训练稳定性远比理论表达力重要。

LLM 中常用的激活函数

答:

激活函数公式特点使用模型
ReLUmax(0, x)简单,但有 dying neuron 问题早期模型
GELUx · Φ(x)平滑版 ReLU,概率解释BERT, GPT-2/3
SiLU/Swishx · σ(x)平滑,自门控LLaMA
SwiGLU(SiLU(W_gate·x) ⊙ W_up·x)门控机制 + SiLULLaMA, Qwen, DeepSeek

为什么不用 Sigmoid/Tanh:

  • Sigmoid:输出在 (0,1),梯度在饱和区趋近 0 → 梯度消失
  • Tanh:输出在 (-1,1),同样有饱和区问题
  • 两者的 exp 计算比 ReLU/SiLU 更昂贵

为什么 SwiGLU 成为主流:

SwiGLU(x) = SiLU(x · W_gate) ⊙ (x · W_up)
= 门控信号 × 信息通道

门控机制让网络学会"选择性地激活",比纯激活函数更灵活
PaLM 论文实验证明 SwiGLU > GELU > ReLU(同参数量下)

双向 attention、因果 attention 和 prefix-attention 的区别

答:

双向 Attention(BERT):
Mask: [1 1 1 1] 所有 token 都能看到所有 token
[1 1 1 1]
[1 1 1 1]
[1 1 1 1]

因果 Attention(GPT/LLaMA):
Mask: [1 0 0 0] 每个 token 只能看到自己和之前的
[1 1 0 0]
[1 1 1 0]
[1 1 1 1]

Prefix Attention(T5 Decoder / Prefix-tuning):
Mask: [1 1 1 0 0] 前缀部分双向 attention
[1 1 1 0 0] 生成部分因果 attention
[1 1 1 0 0]
[1 1 1 1 0] ← 生成 token 能看前缀 + 已生成
[1 1 1 1 1]
prefix gen
类型上下文适用场景代表模型
双向全局理解任务(分类、匹配)BERT, RoBERTa
因果仅历史自回归生成GPT, LLaMA
Prefix前缀双向 + 生成因果条件生成、few-shotT5, UniLM

什么是旋转位置编码(RoPE),解决了什么问题

答:

RoPE(Rotary Position Embedding) 通过对 Q/K 向量施加旋转来编码位置信息。

核心数学:

将 d 维向量视为 d/2 个二维向量,对每对施加旋转:

[q_{2i} ] [cos(mθ_i) -sin(mθ_i)] [q_{2i} ]
[q_{2i+1}] = [sin(mθ_i) cos(mθ_i)] [q_{2i+1}]

其中 m 是位置,θ_i = 10000^{-2i/d} 是频率

关键性质:
<RoPE(q, m), RoPE(k, n)> = f(q, k, m-n)
内积只依赖相对位置 (m-n),不依赖绝对位置!

解决的问题:

  1. 相对位置编码:无需显式学习 position bias,旋转自然编码相对距离
  2. 长度外推:比绝对编码更好的外推能力(配合 NTK/YaRN 可从 4K 扩到 128K+)
  3. KV Cache 友好:K 在生成时已包含位置信息,缓存后不需重新编码
  4. 无额外参数:不增加模型参数量
  5. 高效实现:旋转操作可融合到 Q/K 投影中

推荐使用方式

如果你是为面试准备这份文档,建议按下面顺序读:

  1. 先读 3, 4, 6, 7, 11
  2. 再读 2, 5, 9, 10
  3. 最后补 12, 13, 15

因为:

  • softmax / attention / FFN / norm / Adam 是最高频
  • RoPE / GQA / backward 是常见追问
  • MoE / FLOPs / KV cache 是资深面试官加深题
  • Transformer 架构题 是基础必答题

和现有文档的关系