接上篇,定义 SLM 的模型架构 代码如下
显示已折叠代码(172 行) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 import torch import torch.nn as nn import torch.nn.functional as F import math from dataclasses import dataclass import numpy as np from tqdm.auto import tqdm from contextlib import nullcontext import os class LayerNorm(nn.Module): """ - 与 `nn.LayerNorm(ndim, elementwise_affine=bias)` 等价,手搓是为了**可控是否带 bias**。 - 作用:把最后一维做标准化,稳定训练。 - 形状不变:`(B, T, C) → (B, T, C)`。 """ def __init__(self, ndim, bias): super().__init__() self.weight = nn.Parameter(torch.ones(ndim)) self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None def forward(self, x): return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5) class CausalSelfAttention(nn.Module): """ 自回归注意力 """ def __init__(self, config): super().__init__() assert config.n_embd % config.n_head == 0 self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) # 生成 Q,K,V self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) # 多头拼回后做投影 # 注意力/残差丢弃 self.attn_dropout = nn.Dropout(config.dropout) self.resid_dropout = nn.Dropout(config.dropout) self.n_head = config.n_head self.n_embd = config.n_embd self.flash = hasattr(F, 'scaled_dot_product_attention') if not self.flash: self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size)) def forward(self, x): B, T, C = x.size() q, k, v = self.c_attn(x).split(self.n_embd, dim=2) k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) if self.flash: y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True) else: att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf')) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) y = att @ v y = y.transpose(1, 2).contiguous().view(B, T, C) y = self.resid_dropout(self.c_proj(y)) return y class MLP(nn.Module): def __init__(self, config): super().__init__() self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) self.gelu = nn.GELU() self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) self.dropout = nn.Dropout(config.dropout) def forward(self, x): return self.dropout(self.c_proj(self.gelu(self.c_fc(x)))) class Block(nn.Module): def __init__(self, config): super().__init__() self.ln1 = LayerNorm(config.n_embd, config.bias) self.attn = CausalSelfAttention(config) self.ln2 = LayerNorm(config.n_embd, config.bias) self.mlp = MLP(config) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x @dataclass class GPTConfig: block_size: int vocab_size: int n_layer: int n_head: int n_embd: int dropout: float = 0.0 bias: bool = True class GPT(nn.Module): def __init__(self, config): super().__init__() self.config = config self.transformer = nn.ModuleDict(dict( wte=nn.Embedding(config.vocab_size, config.n_embd), wpe=nn.Embedding(config.block_size, config.n_embd), drop=nn.Dropout(config.dropout), h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), ln_f=LayerNorm(config.n_embd, config.bias), )) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.transformer.wte.weight = self.lm_head.weight # weight tying self.apply(self._init_weights) for pn, p in self.named_parameters(): if pn.endswith('c_proj.weight'): nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx, targets=None): device = idx.device b, t = idx.size() assert t <= self.config.block_size pos = torch.arange(0, t, dtype=torch.long, device=device) tok_emb = self.transformer.wte(idx) pos_emb = self.transformer.wpe(pos) x = self.transformer.drop(tok_emb + pos_emb) for block in self.transformer.h: x = block(x) x = self.transformer.ln_f(x) if targets is not None: logits = self.lm_head(x) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) return logits, loss else: logits = self.lm_head(x[:, [-1], :]) return logits, None @torch.no_grad() def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): """ Generate tokens given a conditioning sequence. idx: Tensor of shape (B, T) """ for _ in range(max_new_tokens): idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] logits, _ = self(idx_cond) logits = logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) idx = torch.cat((idx, idx_next), dim=1) return idx config = GPTConfig( vocab_size=50257, # 使用分词器的词汇表大小 block_size=128, # 或你训练时使用的任何上下文大小 n_layer=6, n_head=6, n_embd=384, dropout=0.1, bias=True ) model = GPT(config) CausalSelfAttention(自回归注意力) 1 2 3 4 assert n_embd % n_head == 0 self.c_attn = nn.Linear(C, 3C) # 生成 Q,K,V self.c_proj = nn.Linear(C, C) # 多头拼回后做投影 self.attn_dropout / resid_dropout # 注意力/残差丢弃 self.flash = hasattr(F, 'scaled_dot_product_attention') if not flash: self.register_buffer("bias", tril(ones(T,T)).view(1,1,T,T)) 将输入 x 线性变换得到 q/k/v,每个 shape (B, T, C),再 reshape 成 (B, n_head, T, head_dim)。 Flash/SDPA 路径:PyTorch 内置 scaled_dot_product_attention,设 is_causal=True 自动做下三角 mask,快而省显存。 回退路径:手动算 att = q @ k^T / sqrt(d),再用注册的下三角 bias 做 mask: bias 维度 (1,1,T,T),只保留 i≥j 的位置,保证只能看见过去。 最后把多头输出拼回 (B, T, C),投影并做残差丢弃。 形状小抄 入:x: (B, T, C) 出:y: (B, T, C) 注意 bias 的 T 用的是 config.block_size。若推理时序列长度 > block_size,回退路径会越界;但 Flash 路径不受此限。 解释 基础设定 x: (B, T, C) 你有一段输入,比如一句话里的字/词。 B = batch,大概就是“同时处理多少句话”。 T = 时间步长,就是句子有多少个字/词。 C = 每个字/词的向量维度,可以理解成“每个字有多少个特征”。 Q/K/V(查询、键、值) • 把输入向量 x 分三份: • Q (Query 查询):我要看别的词。 • K (Key 键):我能提供什么信息。 • V (Value 值):具体信息内容。 • 每个 shape 最开始都是 (B, T, C)。再 reshape 成 (B, n_head, T, head_dim):就像让很多小组(head)分头看,不同小组看问题的角度不同。 Flash 路径 vs 回退路径 • Flash / SDPA 路径:PyTorch 内置的 scaled_dot_product_attention,就像显卡加速版,自动帮你算“谁可以看谁”。只要设 is_causal=True,它会自动加下三角遮罩(mask),保证当前词不能偷看未来。又快又省内存。 • 回退路径(手工实现):如果没有显卡加速,就自己算: 做点积:att = q @ k^T / sqrt(d) → 代表“查询词对键的相关性”。 用 bias(下三角矩阵)遮住未来:bias 形状 (1, 1, T, T),保证第 i 个词只能看到自己和之前的词。 拼回与投影 • 各个头(小组)得到的信息会合并成 (B, T, C)。 • 再过一层线性变换(c_proj),把信息“压缩整理”回原来的维度。 • 最后做 dropout(随机丢弃部分连接),防止过拟合。 形象比喻 想象你在写作文,每个字要决定自己怎么写: • Q = 我在想:我要参考哪些前面的字? • K = 每个字举手说:我能告诉你些什么。 • V = 每个字手里的小抄内容。 • Attention = 根据 Q 和 K 的匹配程度,决定 V 的权重。 • Causal Mask(下三角遮罩) = 老师规定:写第 i 个字时,只能看前面写过的,不许看后面的。 • 多头 (Multi-head) = 你同时派出好几个“审稿小人”,从不同角度帮你挑参考内容,最后合并。 注意点 • bias 的大小是按照 最大序列长度 block_size 来建的。如果实际推理时句子更长,mask 不够用,就会报错。 • 但用 Flash 路径就不会有这个问题,因为它是动态生成的。 总结 CausalSelfAttention 就像写作文时的小抄机制。每个字只能看前面的字,不许看未来。多个小组(头)一起决定要参考谁,最后合并结果。
...