文本生成模型的解码策略:从贪心搜索到核采样的工程实践

发布时间:2026/6/9 14:26:19
文本生成模型的解码策略:从贪心搜索到核采样的工程实践
文本生成模型的解码策略从贪心搜索到核采样的工程实践一、生成的确定性陷阱为什么贪心解码总是复读机大语言模型的文本生成过程是一个自回归过程每一步从词表的概率分布中选择一个 token拼接到已有序列后再预测下一个 token。最直觉的选择是贪心解码——每步选概率最高的 token。但贪心解码有两个致命问题一是生成内容单调重复复读机效应二是容易陷入局部最优选择了当前最高概率的 token但整体序列质量并非最优。解码策略的核心目标是在多样性和质量之间找到平衡。贪心解码偏向质量但缺乏多样性随机采样偏向多样性但质量不可控。温度调节、Top-K、Top-P核采样和束搜索是四种主流的平衡策略它们从不同角度控制生成的随机性。二、解码策略的数学原理与对比解码策略的本质是对模型输出的 logits未归一化概率进行变换再从变换后的分布中采样。不同的变换方式产生不同的生成特性。flowchart TD A[模型输出 Logits] -- B[温度调节br/logits / T] B -- C[Softmaxbr/概率分布] C -- D{解码策略} D -- E[贪心搜索br/argmax(p)br/确定性单调重复] D -- F[束搜索br/保留 top-B 序列br/确定性全局较优] D -- G[Top-K 采样br/只从概率最高的 K 个中采样br/随机性可控] D -- H[Top-P 核采样br/只从累积概率 ≥ P 的 token 中采样br/自适应随机性] E -- I[优点稳定br/缺点重复、无创意] F -- J[优点全局较优br/缺点计算量大、仍可能重复] G -- K[优点简单可控br/缺点K 值难以选择] H -- L[优点自适应br/缺点极端情况不稳定] subgraph 温度 T 的效果 M[T 1分布更尖锐br/偏向高概率 token] N[T 1原始分布] O[T 1分布更平滑br/低概率 token 也有机会] end B -- M B -- N B -- O四种策略的数学表达贪心搜索x_t argmax p(x|x_{t})束搜索Beam Search维护 B 个候选序列每步扩展所有候选保留总概率最高的 B 个Top-K 采样从概率最高的 K 个 token 中按概率采样Top-P 核采样从累积概率达到 P 的最小 token 集合中采样三、解码策略的完整实现# decoding_strategies.py — 文本生成解码策略 # 设计意图实现贪心搜索、束搜索、Top-K 和 Top-P 四种解码策略 # 提供统一的接口和可配置的参数 import torch import torch.nn.functional as F from typing import List, Optional, Tuple from dataclasses import dataclass dataclass class GenerationConfig: 生成配置 max_new_tokens: int 256 temperature: float 1.0 top_k: int 0 # 0 表示不使用 Top-K top_p: float 1.0 # 1.0 表示不使用 Top-P num_beams: int 1 # 1 表示不使用束搜索 repetition_penalty: float 1.0 # 1.0 表示不使用重复惩罚 length_penalty: float 1.0 # 束搜索的长度惩罚 no_repeat_ngram_size: int 0 # 禁止重复的 n-gram 大小 eos_token_id: Optional[int] None class DecodingStrategies: 文本生成解码策略 def __init__(self, model, tokenizer): self.model model self.tokenizer tokenizer torch.no_grad() def generate( self, input_ids: torch.Tensor, config: GenerationConfig, ) - torch.Tensor: 统一的生成接口 if config.num_beams 1: return self._beam_search(input_ids, config) else: return self._autoregressive_generate(input_ids, config) def _autoregressive_generate( self, input_ids: torch.Tensor, config: GenerationConfig, ) - torch.Tensor: 自回归生成贪心 / Top-K / Top-P batch_size input_ids.shape[0] device input_ids.device # 初始化已生成的 token 序列 generated input_ids.clone() for _ in range(config.max_new_tokens): # 模型前向传播 outputs self.model(generated) next_logits outputs.logits[:, -1, :] # 取最后一个位置的 logits # 重复惩罚 if config.repetition_penalty 1.0: next_logits self._apply_repetition_penalty( next_logits, generated, config.repetition_penalty ) # 温度调节 if config.temperature ! 1.0: next_logits next_logits / config.temperature # Top-K 过滤 if config.top_k 0: next_logits self._top_k_filter(next_logits, config.top_k) # Top-P核采样过滤 if config.top_p 1.0: next_logits self._top_p_filter(next_logits, config.top_p) # 采样 probs F.softmax(next_logits, dim-1) next_token torch.multinomial(probs, num_samples1) # 拼接到已生成序列 generated torch.cat([generated, next_token], dim-1) # 检查是否生成了结束符 if config.eos_token_id is not None: if (next_token config.eos_token_id).all(): break # N-gram 重复禁止 if config.no_repeat_ngram_size 0: generated self._ban_repeated_ngrams( generated, config.no_repeat_ngram_size ) return generated def _beam_search( self, input_ids: torch.Tensor, config: GenerationConfig, ) - torch.Tensor: 束搜索 batch_size input_ids.shape[0] num_beams config.num_beams device input_ids.device # 扩展输入以容纳多个束 expanded_input input_ids.unsqueeze(1).expand( batch_size, num_beams, -1 ).reshape(batch_size * num_beams, -1) # 初始化束分数 beam_scores torch.zeros( batch_size, num_beams, devicedevice ) beam_scores[:, 1:] -1e9 # 非第一个束初始化为极低分 generated expanded_input for step in range(config.max_new_tokens): outputs self.model(generated) next_logits outputs.logits[:, -1, :] # 温度调节 if config.temperature ! 1.0: next_logits next_logits / config.temperature next_log_probs F.log_softmax(next_logits, dim-1) # 累积分数 vocab_size next_log_probs.shape[-1] next_scores next_log_probs beam_scores.view(-1, 1) # 长度惩罚 if config.length_penalty ! 1.0: length (step 1) ** config.length_penalty next_scores next_scores / length # 选择 top-2B 个候选 next_scores next_scores.view(batch_size, num_beams * vocab_size) top_scores, top_indices next_scores.topk(2 * num_beams, dim-1) # 解析束索引和 token 索引 beam_indices top_indices // vocab_size token_indices top_indices % vocab_size # 更新束 new_generated [] new_beam_scores [] for b in range(batch_size): batch_beams [] batch_scores [] for rank in range(2 * num_beams): beam_idx beam_indices[b, rank].item() token_idx token_indices[b, rank].item() score top_scores[b, rank].item() # 获取当前束的序列 current_seq generated[b * num_beams beam_idx] new_seq torch.cat([ current_seq, torch.tensor([[token_idx]], devicedevice), ], dim-1) batch_beams.append(new_seq) batch_scores.append(score) if len(batch_beams) num_beams: break new_generated.extend(batch_beams) new_beam_scores.append(batch_scores) generated torch.stack(new_generated) beam_scores torch.tensor( new_beam_scores, devicedevice ).view(-1) # 返回每个 batch 中分数最高的束 results [] for b in range(batch_size): best_idx b * num_beams # 第一个束分数最高 results.append(generated[best_idx]) return torch.stack(results) def _top_k_filter( self, logits: torch.Tensor, k: int ) - torch.Tensor: Top-K 过滤只保留概率最高的 K 个 token top_k min(k, logits.size(-1)) # 找到第 K 大的值作为阈值 threshold torch.topk(logits, top_k)[0][:, -1:] # 将低于阈值的 logits 设为 -inf return logits.masked_fill(logits threshold, float(-inf)) def _top_p_filter( self, logits: torch.Tensor, p: float ) - torch.Tensor: Top-P核采样过滤保留累积概率达到 P 的最小 token 集合 sorted_logits, sorted_indices torch.sort( logits, descendingTrue ) sorted_probs F.softmax(sorted_logits, dim-1) cumulative_probs torch.cumsum(sorted_probs, dim-1) # 找到累积概率超过 P 的位置 sorted_mask cumulative_probs - sorted_probs p # 将超出 P 的 token 的 logits 设为 -inf sorted_logits[sorted_mask] float(-inf) # 恢复原始顺序 _, original_indices sorted_indices.sort(dim-1) filtered_logits sorted_logits.gather(dim-1, indexoriginal_indices) return filtered_logits def _apply_repetition_penalty( self, logits: torch.Tensor, generated: torch.Tensor, penalty: float, ) - torch.Tensor: 重复惩罚降低已生成 token 的概率 for token_id in generated[0].unique(): if logits[0, token_id] 0: logits[0, token_id] / penalty else: logits[0, token_id] * penalty return logits def _ban_repeated_ngrams( self, generated: torch.Tensor, n: int, ) - torch.Tensor: 禁止重复的 n-gram if generated.shape[-1] n: return generated # 检查最近 n-1 个 token 是否与之前的 n-gram 重复 recent generated[0, -(n-1):].tolist() for i in range(generated.shape[-1] - n 1): ngram generated[0, i:in-1].tolist() if ngram recent: # 找到重复将下一个 token 的概率设为 -inf # 此处简化处理实际应在 logits 层面操作 pass return generated四、解码策略的 Trade-offs多样性 vs. 一致性高温采样增加多样性但降低一致性——模型可能生成语法正确但逻辑矛盾的内容。低温采样提高一致性但降低多样性——输出变得可预测但缺乏创意。建议对话场景使用 T0.7-0.9 Top-P0.9创意写作使用 T1.0-1.2 Top-P0.95代码生成使用 T0.2 贪心。Top-K vs. Top-PTop-K 的 K 值是固定的不适应不同上下文的概率分布。在确定性高的上下文中如太阳从___升起K50 会引入不必要的噪声在不确定性高的上下文中如接下来___K5 可能过于限制。Top-P 自适应地调整候选 token 数量在确定性高时选择少量 token不确定性高时选择更多。束搜索的局限束搜索虽然能找到全局较优序列但生成的文本仍然偏向安全和平庸——它倾向于选择高概率的常见表达而非低概率但更有信息量的表达。在创意生成场景中束搜索的效果可能不如采样策略。重复惩罚的副作用重复惩罚可以缓解复读机问题但过度惩罚会导致模型回避合理的重复如列表中的重复结构、代码中的重复模式。建议将重复惩罚限制在 1.0-1.2 之间避免过度惩罚。五、总结文本生成解码策略在多样性和质量之间寻找平衡。贪心搜索偏向质量但单调重复束搜索全局较优但缺乏创意Top-K 采样简单可控但不自适应Top-P 核采样自适应但极端情况不稳定。温度调节控制整体随机性重复惩罚抑制复读机效应。在实际落地中建议根据任务特性选择策略组合代码生成用低温贪心对话用 Top-P 采样创意写作用高温 Top-P。解码策略的目标不是找到最优序列而是在质量约束下生成足够多样的合理文本。