从零实现 Transformer:用 Python 手写 LLM 的核心架构

从零开始用 Python 和 PyTorch 实现 Transformer 架构,涵盖 Self-Attention、Multi-Head Attention、位置编码等核心组件,附完整可运行代码,帮你真正理解 ChatGPT 和 Claude 的底层原理。

数据结构与算法 2026-06-05 15 分钟

2026 年,几乎所有主流 AI 应用——从 ChatGPT 到 Claude,从 GitHub Copilot 到 Gemini——都基于同一个架构:Transformer(变换器)。根据 Stanford HAI 2026 年 AI Index 报告,基于 Transformer 的大语言模型(LLM)在全球市场规模已超过 500 亿美元。但对大多数开发者来说,Transformer 仍然是一个「黑盒」——知道它能做什么,却不知道它为什么能做。本文将带你从零开始,用 Python 和 PyTorch 手写一个完整的 Transformer,帮你真正理解 LLM 的底层原理。

🧠 一、Transformer 架构总览

📌 从 RNN 到 Transformer 的范式跃迁

在 Transformer 出现之前,序列建模的主流方案是 RNN(循环神经网络)和 LSTM(长短期记忆网络)。它们的核心问题是串行计算——必须逐个处理 token,无法利用 GPU 的并行能力。一个 1024 token 的序列,RNN 需要串行执行 1024 步,而 Transformer 只需一步矩阵运算。

Transformer 的核心创新是 Self-Attention(自注意力机制)——让序列中的每个 token 都能直接「看到」所有其他 token,无需经过中间步骤传递信息。这不仅解决了长距离依赖问题,还实现了完全并行计算。

以下是三种架构的对比:

特性 RNN LSTM Transformer
并行计算 ❌ 逐步计算 ❌ 逐步计算 ✅ 完全并行
长距离依赖 ❌ 梯度消失严重 ⚠️ 有限改善 ✅ 直接连接
训练速度 慢(O(n) 串行) 慢(O(n) 串行) 快(O(1) 并行)
内存占用 高(O(n²) 注意力矩阵)
典型应用 早期机器翻译 语音识别、时序预测 GPT、Claude、BERT

关键结论: Transformer 的并行计算能力是它超越 RNN/LSTM 的根本原因。GPT-4 和 Claude 等现代 LLM 能处理数万甚至数十万 token 的上下文窗口,正是得益于 Transformer 架构的这一特性。

🔑 核心组件一览

一个完整的 Transformer 由以下核心组件构成:

  • Token Embedding(词嵌入):将 token ID 转换为稠密向量
  • Positional Encoding(位置编码):为模型注入位置信息
  • Self-Attention(自注意力):让 token 之间建立关联
  • Multi-Head Attention(多头注意力):从多个角度理解输入
  • Feed-Forward Network(前馈网络):非线性变换
  • Layer Normalization(层归一化):稳定训练过程
  • Residual Connection(残差连接):缓解梯度消失

让我们逐一实现这些组件。

🔧 二、用 Python 实现核心组件

🔤 Token Embedding(词嵌入)

Embedding 层的作用是将离散的 token ID 映射到连续的向量空间。简单来说,它就是一个可学习的查找表(Lookup Table)。

# Token Embedding:将 token ID 转换为稠密向量
import torch
import torch.nn as nn
import math

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, d_model: int):
        """
        Args:
            vocab_size: 词表大小(如 GPT-2 的 50257)
            d_model: 模型维度(如 768)
        """
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch_size, seq_len) -> (batch_size, seq_len, d_model)
        return self.embedding(x) * math.sqrt(self.d_model)

💡 提示: 为什么乘以 √d_model?这是原论文 “Attention Is All You Need” 中的设计——缩放 embedding 的幅度,使其与位置编码的幅度匹配,避免位置编码在训练初期被淹没。

⚡ Self-Attention(自注意力机制)

Self-Attention 是 Transformer 最核心的创新。它的思想可以用一句话概括:对于序列中的每个 token,计算它与所有其他 token 的相关性,然后用这些相关性对所有 token 的信息做加权求和。

具体实现分三步:

  1. 用三个线性变换生成 Q(Query)、K(Key)、V(Value)
  2. 计算 Q 和 K 的点积得到注意力分数
  3. 用 Softmax 归一化后,对 V 做加权求和
# Self-Attention:让每个 token 都能"看到"序列中的所有其他 token
class SelfAttention(nn.Module):
    def __init__(self, d_model: int):
        super().__init__()
        self.d_model = d_model
        # 三个线性变换:生成 Q、K、V
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch_size, seq_len, d_model)
        Q = self.W_q(x)  # (batch, seq, d_model)
        K = self.W_k(x)  # (batch, seq, d_model)
        V = self.W_v(x)  # (batch, seq, d_model)

        # 计算注意力分数:Q @ K^T / √d_k
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_model)
        # scores: (batch, seq, seq) — 每个 token 对所有 token 的注意力权重

        # Softmax 归一化
        attention_weights = torch.softmax(scores, dim=-1)

        # 加权求和
        output = torch.matmul(attention_weights, V)
        return output  # (batch, seq, d_model)

⚠️ 警告: 不要忘记除以 √d_k!当 d_model 较大时(如 768 或 4096),Q 和 K 的点积值会非常大,导致 Softmax 进入饱和区,梯度趋近于零,训练直接崩溃。这个缩放操作是 Transformer 能稳定训练的关键。

🔀 Multi-Head Attention(多头注意力)

单头注意力只能从一个「角度」理解 token 之间的关系。Multi-Head Attention 的思想是:把注意力分成多个「头」,每个头独立学习不同的关系模式,最后拼接起来。

比如,一个头可能学到了语法关系(主语-谓语),另一个头学到了语义关系(同义词),还有一个头学到了位置关系(相邻 token)。

# Multi-Head Attention:从多个角度理解 token 之间的关系
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        assert d_model % n_heads == 0, "d_model 必须能被 n_heads 整除"
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # 每个头的维度

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, _ = x.size()

        # 线性变换并拆分为多个头
        # (batch, seq, d_model) -> (batch, seq, n_heads, d_k) -> (batch, n_heads, seq, d_k)
        Q = self.W_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)

        # 缩放点积注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attention_weights = torch.softmax(scores, dim=-1)
        context = torch.matmul(attention_weights, V)

        # 拼接所有头:(batch, n_heads, seq, d_k) -> (batch, seq, d_model)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

        # 最终线性变换
        return self.W_o(context)

📌 记住: Multi-Head Attention 不是让模型「看得更多」,而是让模型「从不同角度看」。就像人类阅读一段文字时,会同时关注语法结构、语义含义和上下文线索——每个「头」负责捕捉一种模式。

📐 Positional Encoding(位置编码)

Self-Attention 的一个「缺陷」是:它对 token 的顺序完全无感。打乱输入序列的顺序,输出不会改变。但语言是有序的——「狗咬人」和「人咬狗」含义完全不同。

Positional Encoding 的解决方案是:为每个位置生成一个固定的向量,加到 embedding 上,让模型能够区分不同位置的 token。

# Positional Encoding:为 Transformer 注入位置信息
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        # 预计算位置编码矩阵
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        # 不同频率的正弦/余弦函数
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数维度用 sin
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数维度用 cos
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch, seq_len, d_model)
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

💡 提示: 为什么用正弦/余弦函数?因为它们有一个优雅的性质:位置 pos+k 的编码可以表示为位置 pos 编码的线性函数,这意味着模型可以轻松学习相对位置关系。现代 LLM(如 LLaMA)已改用 RoPE(旋转位置编码),但正弦编码仍是理解位置编码的最佳起点。

🔄 Feed-Forward Network 与 Layer Normalization

Feed-Forward Network(FFN,前馈网络) 是每个 Transformer Block 中的另一个关键组件。它由两个线性变换和一个激活函数组成,作用是对每个 token 的表示做非线性变换。

Layer Normalization(层归一化) 用于稳定训练过程,防止中间层的值分布漂移。

# Feed-Forward Network:对每个 token 做非线性变换
class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),  # GELU 比 ReLU 更平滑,现代 Transformer 都用它
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


# Layer Normalization:稳定训练过程
class LayerNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

⚠️ 警告: 注意 Layer Normalization 的计算维度!它应该在 d_model 维度上做归一化(dim=-1),而不是在 seq_len 维度上。搞错维度会导致模型完全无法训练,而且这个 Bug 不会报错,只会输出垃圾结果。

🚀 三、组装、训练与推理

🏗️ Transformer Block

现在把所有组件组装成一个完整的 Transformer Block。每个 Block 包含一个 Multi-Head Attention 和一个 FFN,中间用 LayerNorm 和残差连接。

# Transformer Block:组装所有核心组件
class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, n_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Pre-Norm 架构(比 Post-Norm 更稳定,GPT 系列采用)
        attn_output = self.attention(self.norm1(x))
        x = x + self.dropout(attn_output)  # 残差连接
        ff_output = self.ff(self.norm2(x))
        x = x + self.dropout(ff_output)    # 残差连接
        return x


# 完整的 Mini Transformer 模型
class MiniTransformer(nn.Module):
    def __init__(self, vocab_size: int, d_model: int = 128, n_heads: int = 4,
                 n_layers: int = 4, d_ff: int = 512, max_len: int = 256,
                 dropout: float = 0.1):
        super().__init__()
        self.embedding = TokenEmbedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.output = nn.Linear(d_model, vocab_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.embedding(x)
        x = self.pos_encoding(x)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return self.output(x)  # (batch, seq_len, vocab_size)

🎯 训练微型 Transformer

让我们用一段完整的代码来训练这个微型 Transformer。我们用一个简单的任务:学习预测序列中的下一个 token。

# 训练一个微型 Transformer
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

# 超参数
VOCAB_SIZE = 256    # 字符级词表
D_MODEL = 128       # 模型维度
N_HEADS = 4         # 注意力头数
N_LAYERS = 4        # Transformer 层数
D_FF = 512          # FFN 隐层维度
SEQ_LEN = 64        # 序列长度
BATCH_SIZE = 32
EPOCHS = 50
LR = 3e-4

# 生成训练数据:简单的字符序列(模拟)
def generate_data(n_samples=1000, seq_len=SEQ_LEN, vocab_size=VOCAB_SIZE):
    """生成随机序列数据,任务是预测下一个 token"""
    data = torch.randint(0, vocab_size, (n_samples, seq_len + 1))
    inputs = data[:, :-1]   # (n_samples, seq_len)
    targets = data[:, 1:]   # (n_samples, seq_len) — 右移一位
    return inputs, targets

# 准备数据
inputs, targets = generate_data()
dataset = TensorDataset(inputs, targets)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# 初始化模型
model = MiniTransformer(
    vocab_size=VOCAB_SIZE, d_model=D_MODEL, n_heads=N_HEADS,
    n_layers=N_LAYERS, d_ff=D_FF
)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

# 训练循环
model.train()
for epoch in range(EPOCHS):
    total_loss = 0
    for batch_inputs, batch_targets in dataloader:
        logits = model(batch_inputs)
        # logits: (batch, seq, vocab) -> (batch*seq, vocab)
        # targets: (batch, seq) -> (batch*seq,)
        loss = criterion(logits.reshape(-1, VOCAB_SIZE), batch_targets.reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        total_loss += loss.item()

    if (epoch + 1) % 10 == 0:
        avg_loss = total_loss / len(dataloader)
        perplexity = math.exp(avg_loss)
        print(f"Epoch {epoch+1:3d} | Loss: {avg_loss:.4f} | Perplexity: {perplexity:.2f}")

✨ 文本生成

训练完成后,我们可以用自回归的方式生成文本——每次预测一个 token,将其加入输入序列,然后继续预测。

# 自回归文本生成
@torch.no_grad()
def generate(model: nn.Module, start_tokens: torch.Tensor,
             max_new_tokens: int = 50, temperature: float = 0.8) -> torch.Tensor:
    """
    Args:
        model: 训练好的 Transformer
        start_tokens: 起始 token 序列 (1, seq_len)
        max_new_tokens: 最多生成多少个新 token
        temperature: 温度参数,越高越随机,越低越确定
    """
    model.eval()
    tokens = start_tokens

    for _ in range(max_new_tokens):
        # 截断到模型支持的最大长度
        logits = model(tokens[:, -SEQ_LEN:])
        # 取最后一个位置的 logits
        next_logits = logits[:, -1, :] / temperature
        # 采样
        probs = torch.softmax(next_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        # 拼接到序列末尾
        tokens = torch.cat([tokens, next_token], dim=1)

    return tokens

# 使用示例
start = torch.randint(0, VOCAB_SIZE, (1, 8))  # 随机起始 token
generated = generate(model, start, max_new_tokens=32)
print(f"Generated sequence length: {generated.shape[1]}")

💡 提示: temperature 参数是控制生成多样性的关键。temperature=1.0 是标准采样;temperature=0.1 接近贪心搜索(总是选概率最高的 token);temperature=2.0 则非常随机。实际应用中,0.7-0.9 通常是最佳范围。

⚠️ 四、避坑指南与最佳实践

🚨 常见陷阱

在实现 Transformer 的过程中,以下是最容易踩的坑:

  • 忘记缩放注意力分数:不除以 √d_k,Softmax 会饱和,梯度消失,模型无法训练
  • d_model 不能被 n_heads 整除:会导致 view 操作报错或产生错误结果
  • 位置编码维度不匹配max_len 小于实际序列长度会导致索引越界
  • 训练时忘记 model.train()、推理时忘记 model.eval():Dropout 和 LayerNorm 行为不同
  • 忘记梯度裁剪:Transformer 训练容易出现梯度爆炸,clip_grad_norm_ 是标配
  • Softmax 维度错误:应该在 dim=-1(token 维度)上归一化,不是 dim=1

✅ 最佳实践

  • 从简单开始:先实现单头注意力,验证正确性后再扩展到多头
  • 用小模型验证:先用 d_model=64, n_layers=2 快速迭代,确认无 Bug 后再放大
  • 可视化注意力权重:用 matplotlib 画出注意力矩阵,帮助理解模型行为
  • 使用 Pre-Norm 架构:比 Post-Norm 更稳定,GPT 系列和 LLaMA 都采用 Pre-Norm
  • 使用 AdamW 优化器:比 Adam 更适合 Transformer,权重衰减是正则化的关键
  • 使用学习率预热(Warmup):前几千步线性增加学习率,避免训练初期不稳定

⚠️ 警告: 本文实现的是最基础的 Transformer。真实世界的 LLM 还包含大量工程优化:Flash Attention(将注意力计算的内存复杂度从 O(n²) 降到 O(n))、KV Cache(推理加速)、GQA/MQA(减少 KV 头数以节省内存)等。但理解了本文的基础,这些优化都只是「锦上添花」。

📊 五、从微型 Transformer 到真实 LLM

我们实现的 MiniTransformer 只有约 200 万参数。真实的 LLM 有多大?以下是对比:

模型 参数量 层数 d_model 注意力头数 上下文窗口
MiniTransformer(本文) ~2M 4 128 4 256
GPT-2 Small 117M 12 768 12 1024
LLaMA-3 8B 8B 32 4096 32 8192
GPT-4 (推测) ~1.8T ~120 ~12288 ~96 128K
Claude 3.5 Sonnet 未公开 未公开 未公开 未公开 200K

关键结论: 从 2M 参数到 1.8T 参数,架构的核心原理没有改变——都是 Transformer。规模的提升带来了涌现能力(Emergent Abilities),但底层的 Self-Attention、Multi-Head Attention、残差连接这些组件,和我们在本文中实现的完全一致。

💡 总结与进阶建议

通过从零实现 Transformer,我们深入理解了以下核心概念:

  • Self-Attention 是 Transformer 的灵魂,它让每个 token 能直接「看到」所有其他 token
  • Multi-Head Attention 让模型从多个角度理解输入,捕捉不同层面的关系
  • Positional Encoding 弥补了 Transformer 缺乏位置信息的缺陷
  • 残差连接 + LayerNorm 是稳定深层网络训练的关键技术
  • Pre-Norm 架构 比 Post-Norm 更稳定,是现代 LLM 的标配

进阶学习建议:

  1. 📖 阅读原论文 “Attention Is All You Need”(Vaswani et al., 2017)
  2. 🔬 尝试在真实文本数据(如 Shakespeare)上训练,观察生成效果
  3. 🚀 探索 Decoder-Only 架构(GPT 系列采用)与 Encoder-Decoder 架构的区别
  4. ⚡ 学习 Flash Attention、KV Cache 等工程优化技术
  5. 🛠️ 尝试用 Hugging Face Transformers 库加载真实模型,对比你的实现

📌 记住: 理解 Transformer 不仅是为了做 AI 研究——作为开发者,当你知道 LLM 的注意力机制如何工作,你就能更好地设计 Prompt、理解上下文窗口限制、优化 API 调用成本。底层原理的理解,是高效使用 AI 工具的基础。

📚 相关文章