DeepSeek-R1:冷启动下的强化学习之旅(代码实现)
在追求大语言模型(LLM)推理能力的道路上,DeepSeek 团队推出了 DeepSeek-R1-Zero,一个完全通过纯强化学习(RL)训练的模型,展现了令人惊叹的推理能力。然而,它的局限性(如可读性差和语言混合)促使团队进一步探索,最终开发出更强大的 DeepSeek-R1。本文将总结 DeepSeek-R1 的训练过程,重点介绍其“冷启动 + 强化学习”的创新 pipeline,带你走进这场
DeepSeek-R1:冷启动下的强化学习之旅
在追求大语言模型(LLM)推理能力的道路上,DeepSeek 团队推出了 DeepSeek-R1-Zero,一个完全通过纯强化学习(RL)训练的模型,展现了令人惊叹的推理能力。然而,它的局限性(如可读性差和语言混合)促使团队进一步探索,最终开发出更强大的 DeepSeek-R1。本文将总结 DeepSeek-R1 的训练过程,重点介绍其“冷启动 + 强化学习”的创新 pipeline,带你走进这场技术旅程。
背景:从 DeepSeek-R1-Zero 到 DeepSeek-R1
DeepSeek-R1-Zero (具体细节可以参考笔者的另一篇博客DeepSeek-R1-Zero 的训练过程:pytorch代码实现,这里有GRPO算法的介绍,本文不在赘述)的成功证明了纯 RL 可以让模型自发发展出复杂的推理行为,例如反思和长链推理。然而,其输出往往不够用户友好,难以直接应用于实际场景。为了解决这些问题并进一步提升性能,DeepSeek-R1 引入了冷启动数据和多阶段训练策略,最终在推理任务上达到了与 OpenAI-o1-1217 相当的水平。下面是其训练过程的四个关键阶段。
DeepSeek-R1 训练过程
1. 冷启动(Cold Start)
- 目标:避免纯 RL 初期的训练不稳定性,提升输出可读性。
- 方法:从 DeepSeek-V3-Base 模型开始,收集数千个包含长链推理(Chain-of-Thought, CoT)的冷启动数据进行微调。这些数据通过以下方式生成:
- 少样本提示:用带有长 CoT 的示例引导模型生成。
- 直接提示:要求模型输出详细推理和验证过程。
- 后处理:利用 DeepSeek-R1-Zero 的输出,经人工标注精炼。
- 输出格式:设计为
[special_token] <reasoning_process> [special_token] <summary>
,确保推理过程清晰,总结便于阅读。 - 优势:
- 可读性:相比 DeepSeek-R1-Zero 的混乱输出,冷启动数据加入了人类偏好(如 Markdown 格式),显著提高了用户体验。
- 潜力:为后续 RL 提供更好的起点,加速收敛。
2. 推理导向的强化学习(Reasoning-oriented RL)
- 目标:增强模型在推理密集任务(如数学、编程、科学推理)上的能力。
- 方法:
- 使用与 DeepSeek-R1-Zero 相同的 RL 框架(GRPO),基于冷启动微调后的模型继续训练。
- 奖励设计:
- 准确性奖励:评估输出是否正确。
- 语言一致性奖励:新增一项奖励,计算 CoT 中目标语言(例如英语)的比例,解决语言混合问题。
- 最终奖励:直接将准确性和语言一致性奖励相加。
- 效果:训练至收敛后,模型在推理任务上的性能显著提升,同时语言输出更加一致,但仍需进一步优化非推理能力。
3. 拒绝采样与监督微调(Rejection Sampling and SFT)
- 目标:平衡推理与通用能力,提升模型在写作、问答等领域的表现。
- 方法:
- 数据收集:
- 推理数据:从 RL 检查点通过拒绝采样生成推理轨迹,保留正确答案,扩展到约 60 万样本。包括规则奖励和部分生成式奖励模型(基于 DeepSeek-V3 判断)。
- 非推理数据:复用 DeepSeek-V3 的 SFT 数据(如写作、事实问答),约 20 万样本。对于复杂问题,生成 CoT;简单问题直接回答。
- 微调:用约 80 万样本对 DeepSeek-V3-Base 进行两轮监督微调(SFT)。
- 数据收集:
- 改进:通过筛选掉语言混乱或格式不佳的输出,模型的通用性和可读性进一步增强。
4. 全场景强化学习(RL for All Scenarios)
- 目标:优化模型的帮助性和无害性,同时保持推理能力。
- 方法:
- 奖励组合:
- 推理任务:延续规则奖励(如准确性)。
- 通用任务:引入基于 DeepSeek-V3 的奖励模型,评估复杂场景下的人类偏好。
- 帮助性:仅评估最终总结的实用性,避免干扰推理过程。
- 无害性:检查完整输出(包括 CoT 和总结),减少偏见或有害内容。
- 训练数据:使用多样化的提示分布,覆盖推理和非推理场景。
- 奖励组合:
- 结果:最终得到的 DeepSeek-R1 在推理任务(如 AIME 2024 得分 79.8%)和通用任务(如 AlpacaEval 2.0 胜率 87.6%)上均表现出色,且更符合人类偏好。
成果与亮点
经过这四阶段的训练,DeepSeek-R1 实现了以下突破:
- 推理能力:在数学(MATH-500 得分 97.3%)、编程(Codeforces Elo 2029)等任务上与 OpenAI-o1-1217 匹敌。
- 通用性:在知识问答(MMLU 90.8%)、长上下文理解等任务中大幅超越 DeepSeek-V3。
- 用户友好性:通过冷启动和多轮微调,输出的可读性和一致性显著提升。
与 DeepSeek-R1-Zero 相比,冷启动策略不仅解决了可读性问题,还通过迭代训练加速了性能提升,最终打造出一个更强大、更实用的模型。
总结
DeepSeek-R1 的训练过程是一个从“冷启动”到“全场景优化”的完整 pipeline:
- 冷启动奠定基础,提供高质量初始数据。
- 推理 RL专注提升核心推理能力。
- 拒绝采样 + SFT扩展通用性。
- 全场景 RL兼顾帮助性和安全性。
这种多阶段方法巧妙结合了监督学习和强化学习的优势,不仅验证了 RL 在推理任务中的潜力,还为构建用户友好的 LLM 提供了宝贵经验。未来,DeepSeek 团队计划进一步优化语言混合、软件工程任务等领域,让这一模型更上一层楼。
如果你对大模型的推理能力提升感兴趣,不妨关注 DeepSeek-R1 的开源模型和后续研究,或许能从中找到灵感!
代码实现
以下是基于原论文中 2.3 DeepSeek-R1: Reinforcement Learning with Cold Start 部分的 DeepSeek-R1 训练流程的代码实现。由于 DeepSeek-R1 的训练涉及多个复杂阶段(冷启动、推理导向 RL、拒绝采样与 SFT、全场景 RL),完整的工业级实现需要大量基础设施支持(例如分布式训练、预训练模型、数据管道等)。这里提供一个简化的 PyTorch 代码框架,涵盖每个阶段的核心逻辑,并尽量贴近文档描述,同时保持可读性和可运行性。
前提假设
- 模型:假设使用一个简单的 Transformer 模型代替 DeepSeek-V3-Base。
- 数据:用简单的数学问题示例替代大规模数据集。
- 简化:省略分布式训练、复杂 tokenization 等细节,聚焦流程。
代码实现
import torch
import torch.nn as nn
import torch.optim as optim
import re
from typing import List, Tuple
from copy import deepcopy
# 简化的 Transformer 模型(代替 DeepSeek-V3-Base)
class SimpleTransformer(nn.Module):
def __init__(self, vocab_size=1000, d_model=256, n_heads=4, n_layers=2):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.transformer = nn.Transformer(d_model, n_heads, n_layers)
self.fc_out = nn.Linear(d_model, vocab_size)
def forward(self, x):
x = self.embedding(x)
x = self.transformer(x, x) # 简化:自注意力
return self.fc_out(x)
# 奖励计算函数
def compute_reward(output: str, ground_truth: str = None, target_lang: str = "en") -> float:
# 格式奖励
think_pattern = r"<think>.*?</think>"
answer_pattern = r"<answer>.*?</answer>"
format_reward = 1.0 if (re.search(think_pattern, output) and re.search(answer_pattern, output)) else 0.0
# 准确性奖励
accuracy_reward = 0.0
if ground_truth:
answer_match = re.search(r"<answer>(.*?)</answer>", output)
if answer_match and answer_match.group(1).strip() == ground_truth:
accuracy_reward = 1.0
# 语言一致性奖励(简单模拟:检查是否全为英文)
lang_reward = 1.0 if all(c.isascii() for c in output if c.isalpha()) else 0.5
return format_reward + accuracy_reward + lang_reward
# GRPO 损失函数
def compute_grpo_loss(
policy: SimpleTransformer,
old_policy: SimpleTransformer,
ref_policy: SimpleTransformer,
states: torch.Tensor,
outputs: List[str],
rewards: List[float],
epsilon: float = 0.2,
beta: float = 0.01
) -> torch.Tensor:
G = len(outputs)
logits = policy(states) # [batch, seq_len, vocab_size]
old_logits = old_policy(states).detach()
ref_logits = ref_policy(states).detach()
# 简化为平均 log_probs(实际需序列化)
log_probs = torch.softmax(logits, dim=-1).log().mean(dim=1)
old_log_probs = torch.softmax(old_logits, dim=-1).log().mean(dim=1)
ref_log_probs = torch.softmax(ref_logits, dim=-1).log().mean(dim=1)
ratios = torch.exp(log_probs - old_log_probs)
rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
advantages = (rewards_tensor - rewards_tensor.mean()) / (rewards_tensor.std() + 1e-6)
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1 - epsilon, 1 + epsilon) * advantages
clipped_loss = torch.min(surr1, surr2).mean()
kl_div = (ref_log_probs - log_probs).mean()
return - (clipped_loss - beta * kl_div)
# 阶段 1:冷启动微调
def cold_start_sft(model: SimpleTransformer, data: List[Tuple[str, str]], epochs: int = 2):
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
# 模拟冷启动数据(问题 + CoT + 答案)
cold_start_data = [
(q, f"<think>{q}的推理过程:一步步计算。</think><answer>{a}</answer>")
for q, a in data
]
for epoch in range(epochs):
total_loss = 0
for question, target in cold_start_data:
# 简化为 token IDs(实际需 tokenizer)
input_ids = torch.randint(0, 1000, (1, 10)) # 占位符
target_ids = torch.randint(0, 1000, (1, 10))
optimizer.zero_grad()
logits = model(input_ids)
loss = criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1))
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Cold Start Epoch {epoch}, Loss: {total_loss / len(data)}")
return model
# 阶段 2:推理导向 RL
def reasoning_rl(model: SimpleTransformer, data: List[Tuple[str, str]], epochs: int = 10, group_size: int = 4):
ref_policy = deepcopy(model) # 参考策略
old_policy = deepcopy(model) # 旧策略
optimizer = optim.Adam(model.parameters(), lr=1e-5)
for epoch in range(epochs):
for question, ground_truth in data:
state = torch.randint(0, 1000, (1, 10)) # 模拟状态
outputs = []
for _ in range(group_size):
with torch.no_grad():
logits = old_policy(state)
# 模拟生成(实际需解码)
output = f"<think>{question}推理</think><answer>{ground_truth}</answer>"
outputs.append(output)
rewards = [compute_reward(o, ground_truth) for o in outputs]
optimizer.zero_grad()
loss = compute_grpo_loss(model, old_policy, ref_policy, state, outputs, rewards)
loss.backward()
optimizer.step()
old_policy.load_state_dict(model.state_dict())
print(f"Reasoning RL Epoch {epoch}, Loss: {loss.item()}")
return model
# 阶段 3:拒绝采样与 SFT
def rejection_sampling_sft(model: SimpleTransformer, data: List[Tuple[str, str]], epochs: int = 2):
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
# 拒绝采样:生成多条输出,保留正确且格式良好的
sft_data = []
for question, ground_truth in data:
candidates = []
for _ in range(5): # 采样 5 次
with torch.no_grad():
state = torch.randint(0, 1000, (1, 10))
logits = model(state)
output = f"<think>{question}推理</think><answer>{ground_truth}</answer>" # 模拟
reward = compute_reward(output, ground_truth)
candidates.append((output, reward))
# 选择最佳输出
best_output = max(candidates, key=lambda x: x[1])[0]
sft_data.append((question, best_output))
# SFT 训练
for epoch in range(epochs):
total_loss = 0
for _, target in sft_data:
input_ids = torch.randint(0, 1000, (1, 10))
target_ids = torch.randint(0, 1000, (1, 10))
optimizer.zero_grad()
logits = model(input_ids)
loss = criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1))
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"SFT Epoch {epoch}, Loss: {total_loss / len(sft_data)}")
return model
# 阶段 4:全场景 RL
def full_scenario_rl(model: SimpleTransformer, data: List[Tuple[str, str]], epochs: int = 5, group_size: int = 4):
ref_policy = deepcopy(model)
old_policy = deepcopy(model)
optimizer = optim.Adam(model.parameters(), lr=1e-5)
for epoch in range(epochs):
for question, ground_truth in data:
state = torch.randint(0, 1000, (1, 10))
outputs = []
for _ in range(group_size):
with torch.no_grad():
logits = old_policy(state)
output = f"<think>{question}推理</think><answer>{ground_truth}</answer>"
outputs.append(output)
# 结合推理和通用奖励(这里简化)
rewards = [compute_reward(o, ground_truth) for o in outputs]
optimizer.zero_grad()
loss = compute_grpo_loss(model, old_policy, ref_policy, state, outputs, rewards)
loss.backward()
optimizer.step()
old_policy.load_state_dict(model.state_dict())
print(f"Full RL Epoch {epoch}, Loss: {loss.item()}")
return model
# 主流程
def train_deepseek_r1(data: List[Tuple[str, str]]):
# 初始化模型
model = SimpleTransformer()
# 阶段 1:冷启动
print("Stage 1: Cold Start SFT")
model = cold_start_sft(model, data)
# 阶段 2:推理 RL
print("Stage 2: Reasoning RL")
model = reasoning_rl(model, data)
# 阶段 3:拒绝采样与 SFT
print("Stage 3: Rejection Sampling SFT")
model = rejection_sampling_sft(model, data)
# 阶段 4:全场景 RL
print("Stage 4: Full Scenario RL")
model = full_scenario_rl(model, data)
return model
# 示例数据
data = [("2+2=?", "4"), ("3*3=?", "9")]
model = train_deepseek_r1(data)
代码说明
1. 模型(SimpleTransformer
)
- 一个简化的 Transformer 模型,代替 DeepSeek-V3-Base,用于生成 token 序列。
2. 奖励函数(compute_reward
)
- 准确性奖励:检查
<answer>
中的答案与 ground truth 是否匹配。 - 格式奖励:验证
<think>
和<answer>
标签是否存在。 - 语言一致性奖励:简单模拟,检查输出是否全为 ASCII 字符(代表英语)。
- 符合文档中基于规则、不用神经网络的设计。
3. GRPO 损失(compute_grpo_loss
)
- 实现 GRPO 的公式,包含剪切和 KL 散度正则化。
- 优势 ( A i A_i Ai ) 根据组内奖励统计计算。
4. 训练阶段
- 冷启动 SFT:
- 用预定义的 CoT 数据微调模型,模拟文档中的格式
[special_token] <reasoning_process> [special_token] <summary>
。
- 用预定义的 CoT 数据微调模型,模拟文档中的格式
- 推理 RL:
- 使用 GRPO 和奖励函数优化推理能力,加入语言一致性奖励。
- 拒绝采样与 SFT:
- 从 RL 模型生成多条输出,保留最佳样本(基于奖励),然后进行 SFT。
- 模拟推理和非推理数据的混合(这里简化)。
- 全场景 RL:
- 再次用 GRPO 训练,结合推理和通用奖励(原论文中提到帮助性和无害性,这里简化为单一奖励函数)。
5. 主流程(train_deepseek_r1
)
- 按顺序执行四个阶段,返回最终模型。
简化与局限
- 数据处理:未实现真实的 tokenization 和序列生成,需结合 transformers 库(例如
model.generate()
)。 - 奖励模型:语言一致性奖励过于简单,实际需更复杂的自然语言处理。
- 规模:未体现大规模训练,需扩展至分布式环境。
- 通用任务:阶段 4 的帮助性和无害性奖励未详细实现,可添加外部奖励模型。
如何扩展
- 真实模型:用
transformers
的预训练模型(如 LLaMA)替换SimpleTransformer
。 - 数据管道:接入真实数据集(如数学题、编程题),用 tokenizer 处理输入输出。
- 奖励细化:实现更复杂的规则(如代码编译器)或外部评估。
针对 拒绝采样与 SFT 和 全场景 RL 两部分的优化代码
以下是针对 拒绝采样与 SFT 和 全场景 RL 两部分的优化代码实现,模拟推理和非推理数据的混合,并在全场景 RL 中结合推理、帮助性和无害性奖励,同时保持与文档描述一致。代码使用 PyTorch,并针对 DeepSeek-R1 的训练流程进行了优化。
优化代码
1. 拒绝采样与 SFT(Rejection Sampling and SFT)
- 目标:从 RL 检查点生成多条输出,筛选最佳样本,混合推理和非推理数据进行监督微调(SFT)。
- 优化点:
- 模拟推理数据(数学问题)和非推理数据(简单问答)的生成与筛选。
- 使用奖励函数评估输出质量,保留正确且格式良好的样本。
import torch
import torch.nn as nn
import torch.optim as optim
import re
from typing import List, Tuple
from copy import deepcopy
# 简化的 Transformer 模型(假设已定义)
class SimpleTransformer(nn.Module):
def __init__(self, vocab_size=1000, d_model=256, n_layers=2):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(n_layers)])
self.fc_out = nn.Linear(d_model, vocab_size)
def forward(self, x):
x = self.embedding(x)
for layer in self.layers:
x = torch.relu(layer(x))
return self.fc_out(x)
# 奖励函数(用于筛选)
def compute_basic_reward(output: str, ground_truth: str = None) -> float:
format_reward = 1.0 if (re.search(r"<think>.*?</think>", output) and re.search(r"<answer>.*?</answer>", output)) else 0.0
accuracy_reward = 0.0
if ground_truth:
answer_match = re.search(r"<answer>(.*?)</answer>", output)
if answer_match and answer_match.group(1).strip() == ground_truth:
accuracy_reward = 1.0
return format_reward + accuracy_reward
# 拒绝采样与 SFT
def rejection_sampling_sft(
model: SimpleTransformer,
data: List[Tuple[str, str, str]], # (question, ground_truth, type: "reasoning" 或 "general")
epochs: int = 2,
num_samples: int = 5
) -> SimpleTransformer:
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
# 拒绝采样:生成并筛选数据
sft_data = []
for question, ground_truth, data_type in data:
candidates = []
for _ in range(num_samples):
with torch.no_grad():
state = torch.randint(0, 1000, (1, 10)) # 模拟输入
logits = model(state)
# 模拟生成输出
if data_type == "reasoning":
output = f"<think>{question}的步骤:计算。</think><answer>{ground_truth}</answer>"
else: # general
output = f"<answer>{ground_truth}</answer>" # 非推理数据无需 CoT
reward = compute_basic_reward(output, ground_truth if data_type == "reasoning" else None)
candidates.append((output, reward))
# 选择最佳输出
best_output = max(candidates, key=lambda x: x[1])[0]
sft_data.append((question, best_output))
# SFT 训练
for epoch in range(epochs):
total_loss = 0
for question, target in sft_data:
input_ids = torch.randint(0, 1000, (1, 10)) # 模拟 token IDs
target_ids = torch.randint(0, 1000, (1, 10))
optimizer.zero_grad()
logits = model(input_ids)
loss = criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1))
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"SFT Epoch {epoch}, Loss: {total_loss / len(sft_data)}")
return model
# 测试示例
data = [
("2+2=?", "4", "reasoning"), # 推理数据
("3*3=?", "9", "reasoning"),
("Hello?", "Hi!", "general"), # 非推理数据
("What’s the weather?", "Sunny", "general")
]
model = SimpleTransformer()
model = rejection_sampling_sft(model, data)
说明:
- 数据混合:
data
包含推理数据(带 CoT 和 ground truth)和非推理数据(仅答案)。推理数据要求准确性和格式,非推理数据只要求格式。 - 拒绝采样:为每个问题生成
num_samples
条输出,用compute_basic_reward
筛选最佳样本。 - SFT:用筛选后的数据进行监督微调,模拟文档中的 60 万推理 + 20 万非推理样本。
2. 全场景 RL(Full Scenario RL)
- 目标:结合推理奖励和通用奖励(帮助性、无害性),使用 GRPO 进一步优化模型。
- 优化点:
- 实现多维度奖励:推理准确性、帮助性(输出实用性)、无害性(避免有害内容)。
- 使用 GRPO 框架训练。
# 综合奖励函数
def compute_full_reward(output: str, ground_truth: str = None, question: str = None) -> float:
# 推理奖励(准确性 + 格式)
format_reward = 1.0 if (re.search(r"<think>.*?</think>", output) and re.search(r"<answer>.*?</answer>", output)) else 0.0
accuracy_reward = 0.0
if ground_truth:
answer_match = re.search(r"<answer>(.*?)</answer>", output)
if answer_match and answer_match.group(1).strip() == ground_truth:
accuracy_reward = 1.0
# 帮助性奖励(简单模拟:检查答案是否为空)
helpfulness_reward = 1.0 if re.search(r"<answer>.+?</answer>", output) else 0.0
# 无害性奖励(简单模拟:检查是否有负面词)
harmful_words = {"hate", "kill", "bad"}
harmless_reward = 0.0 if any(word in output.lower() for word in harmful_words) else 1.0
return format_reward + accuracy_reward + helpfulness_reward + harmless_reward
# GRPO 损失函数(已定义,这里重用)
def compute_grpo_loss(
policy: SimpleTransformer,
old_policy: SimpleTransformer,
ref_policy: SimpleTransformer,
states: torch.Tensor,
outputs: List[str],
rewards: List[float],
epsilon: float = 0.2,
beta: float = 0.01
) -> torch.Tensor:
G = len(outputs)
logits = policy(states)
old_logits = old_policy(states).detach()
ref_logits = ref_policy(states).detach()
log_probs = torch.softmax(logits, dim=-1).log().mean(dim=1)
old_log_probs = torch.softmax(old_logits, dim=-1).log().mean(dim=1)
ref_log_probs = torch.softmax(ref_logits, dim=-1).log().mean(dim=1)
ratios = torch.exp(log_probs - old_log_probs)
rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
advantages = (rewards_tensor - rewards_tensor.mean()) / (rewards_tensor.std() + 1e-6)
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1 - epsilon, 1 + epsilon) * advantages
clipped_loss = torch.min(surr1, surr2).mean()
kl_div = (ref_log_probs - log_probs).mean()
return - (clipped_loss - beta * kl_div)
# 全场景 RL
def full_scenario_rl(
model: SimpleTransformer,
data: List[Tuple[str, str, str]], # (question, ground_truth, type)
epochs: int = 5,
group_size: int = 4
) -> SimpleTransformer:
ref_policy = deepcopy(model)
old_policy = deepcopy(model)
optimizer = optim.Adam(model.parameters(), lr=1e-5)
for epoch in range(epochs):
total_loss = 0
for question, ground_truth, data_type in data:
state = torch.randint(0, 1000, (1, 10))
outputs = []
for _ in range(group_size):
with torch.no_grad():
logits = old_policy(state)
if data_type == "reasoning":
output = f"<think>{question}的步骤</think><answer>{ground_truth}</answer>"
else:
output = f"<answer>{ground_truth}</answer>"
outputs.append(output)
# 计算综合奖励
rewards = [compute_full_reward(o, ground_truth if data_type == "reasoning" else None, question)
for o in outputs]
optimizer.zero_grad()
loss = compute_grpo_loss(model, old_policy, ref_policy, state, outputs, rewards)
loss.backward()
optimizer.step()
old_policy.load_state_dict(model.state_dict())
total_loss += loss.item()
print(f"Full RL Epoch {epoch}, Loss: {total_loss / len(data)}")
return model
# 测试示例
data = [
("2+2=?", "4", "reasoning"),
("3*3=?", "9", "reasoning"),
("Hello?", "Hi!", "general"),
("What’s the weather?", "Sunny", "general")
]
model = SimpleTransformer()
model = full_scenario_rl(model, data)
说明:
- 综合奖励:
- 推理奖励:包括格式和准确性,与文档一致。
- 帮助性奖励:检查
<answer>
是否非空,模拟实用性评估。 - 无害性奖励:简单检查是否有负面词汇,模拟安全性。
- 数据处理:根据
data_type
区分推理和非推理输出,推理任务带 CoT,非推理任务仅答案。 - GRPO 训练:用综合奖励驱动策略优化,保留剪切和 KL 散度。
优化亮点
- 拒绝采样与 SFT:
- 混合推理和非推理数据,推理任务要求 CoT 和准确性,非推理任务简化格式。
- 拒绝采样通过奖励筛选,确保高质量训练数据。
- 全场景 RL:
- 奖励函数扩展到推理、帮助性和无害性三个维度,符合文档中多目标优化的描述。
- GRPO 框架保持一致,适应多样化任务。
注意事项
- 简化:序列生成用静态字符串模拟,实际需用
transformers
的generate()
。 - 奖励模型:帮助性和无害性奖励过于简单,工业实现可能需要外部 LLM(如 DeepSeek-V3)评分。
- 扩展性:支持更大规模数据需批量处理和分布式训练。
针对拒绝采样(Rejection Sampling)中选择最佳输出的逻辑
针对拒绝采样(Rejection Sampling)中选择最佳输出的逻辑,尤其是 best_output = max(candidates, key=lambda x: x[1])[0]
这行代码的作用和合理性合适吗?详细解释这段代码的逻辑,为什么使用 max
,以及它是否完全符合 DeepSeek-R1 的设计意图。
当前代码的逻辑
代码片段回顾
candidates = []
for _ in range(num_samples):
with torch.no_grad():
state = torch.randint(0, 1000, (1, 10)) # 模拟输入
logits = model(state)
# 模拟生成输出
if data_type == "reasoning":
output = f"<think>{question}的步骤:计算。</think><answer>{ground_truth}</answer>"
else: # general
output = f"<answer>{ground_truth}</answer>"
reward = compute_basic_reward(output, ground_truth if data_type == "reasoning" else None)
candidates.append((output, reward))
best_output = max(candidates, key=lambda x: x[1])[0]
- 生成候选:代码为每个问题生成
num_samples
个输出(这里是 5 次),每个输出是一个(output, reward)
元组,output
是生成的文本,reward
是根据compute_basic_reward
计算的分数。 - 选择最佳输出:
max(candidates, key=lambda x: x[1])
:从candidates
列表中选择reward
(索引 1)最大的元组。[0]
:从选中的元组中提取output
(索引 0),即最佳输出文本。
逻辑是什么?
- 目标:拒绝采样的目的是从多个候选输出中挑选“最好”的一个,用于后续的监督微调(SFT)。这里的“最好”由
reward
定义,包含格式奖励和准确性奖励(对于推理任务)。 - max 的作用:
max
函数基于key=lambda x: x[1]
(即reward
)比较所有候选,选出得分最高的输出。换句话说,它假设更高的 reward 表示更高质量的输出。 - 假设:
- 对于推理任务(
data_type == "reasoning"
),reward 越高意味着输出更可能正确(匹配ground_truth
)且格式正确。 - 对于非推理任务(
data_type == "general"
),reward 越高意味着格式正确(仅要求<answer>
标签)。
- 对于推理任务(
示例
假设 question = "2+2=?"
, ground_truth = "4"
, num_samples = 3
,生成以下候选:
("<think>2+2的步骤:计算。</think><answer>4</answer>", 2.0)
(格式正确 + 答案正确)("<think>2+2的步骤。</think><answer>5</answer>", 1.0)
(格式正确 + 答案错误)("2+2=4", 0.0)
(无格式)
运行 max(candidates, key=lambda x: x[1])
:
- 比较 reward:2.0 > 1.0 > 0.0
- 返回
( "<think>2+2的步骤:计算。</think><answer>4</answer>", 2.0 )
[0]
提取输出:"<think>2+2的步骤:计算。</think><answer>4</answer>"
最终,best_output
是 reward 最高的输出。
这里用 max
可以吗?
从代码的实现来看,max
是合理的,但是否完全符合 DeepSeek-R1 的设计意图,需要结合文档(原paper)进一步分析。
文档中的描述
文档第 2.3.3 节(Rejection Sampling and Supervised Fine-Tuning)提到:
- “We curate reasoning prompts and generate reasoning trajectories by performing rejection sampling from the checkpoint from the above RL training.”
- “For each prompt, we sample multiple responses and retain only the correct ones.”
- “In total, we collect about 600k reasoning related training samples.”
关键点:
- 筛选标准:文档明确要求保留“正确”的输出(correct ones),而不是简单地取 reward 最高。
- 数据量:实际中生成了大量样本(60 万推理 + 20 万非推理),筛选后保留高质量部分。
- 奖励评估:推理数据基于规则奖励(准确性)和生成式奖励模型(DeepSeek-V3 判断),非推理数据也有特定标准。
当前代码的问题
- 仅用
max
:- 当前逻辑只取 reward 最高的输出,未严格验证“正确性”。例如,如果所有候选都不正确(reward < 2.0),仍会选一个次优输出。
- 文档要求“retain only the correct ones”,暗示可能需要一个阈值(threshold),而不是单纯取最大值。
- 模拟生成局限:
- 当前代码中,
output
是直接用ground_truth
构造的(模拟),而不是模型真实生成。这导致 reward 评估不反映模型实际能力,筛选逻辑过于理想化。
- 当前代码中,
- 非推理数据:
- 对于
general
类型,仅检查格式,未考虑帮助性或语义质量,与文档中“增强写作、问答等能力”的目标不完全匹配。
- 对于
改进建议
- 加入阈值:只保留 reward 达到一定标准的输出(例如 >= 2.0 表示格式和准确性都满足),而不是盲目取最大值。
- 真实生成:用模型生成输出,而不是预设
ground_truth
,以反映 RL 检查点的真实能力。 - 多标准筛选:结合准确性、格式和语义质量(例如用外部模型评估),尤其是非推理任务。
优化后的代码
以下是改进后的拒绝采样逻辑:
# 改进的奖励函数(加入语义检查)
def compute_advanced_reward(output: str, ground_truth: str = None, data_type: str = "reasoning") -> float:
format_reward = 1.0 if (re.search(r"<think>.*?</think>", output) and re.search(r"<answer>.*?</answer>", output)) else 0.0 if data_type == "reasoning" else (1.0 if re.search(r"<answer>.*?</answer>", output) else 0.0)
accuracy_reward = 0.0
if ground_truth and data_type == "reasoning":
answer_match = re.search(r"<answer>(.*?)</answer>", output)
if answer_match and answer_match.group(1).strip() == ground_truth:
accuracy_reward = 1.0
# 简单模拟语义质量(非推理任务)
semantic_reward = 1.0 if ("<answer>" in output and len(output) > 10) else 0.5 if data_type == "general" else 0.0
return format_reward + accuracy_reward + semantic_reward
# 优化后的拒绝采样
def improved_rejection_sampling(
model: SimpleTransformer,
data: List[Tuple[str, str, str]],
num_samples: int = 5,
threshold: float = 2.0 # 推理任务要求格式+准确性
) -> List[Tuple[str, str]]:
sft_data = []
for question, ground_truth, data_type in data:
candidates = []
for _ in range(num_samples):
with torch.no_grad():
state = torch.randint(0, 1000, (1, 10))
logits = model(state)
# 模拟生成(实际需模型解码)
if data_type == "reasoning":
# 假设模型生成的答案可能是随机的
possible_answer = ground_truth if torch.rand(1) > 0.3 else str(int(ground_truth) + 1)
output = f"<think>{question}的步骤:计算。</think><answer>{possible_answer}</answer>"
else:
output = f"<answer>{ground_truth} today!</answer>"
reward = compute_advanced_reward(output, ground_truth, data_type)
candidates.append((output, reward))
# 筛选:保留 reward >= threshold 的正确输出
valid_candidates = [c for c in candidates if c[1] >= threshold]
if valid_candidates:
best_output = max(valid_candidates, key=lambda x: x[1])[0] # 从正确输出中选最佳
else:
best_output = candidates[0][0] # 若无正确输出,退回第一个(次优)
sft_data.append((question, best_output))
return sft_data
# 测试
data = [
("2+2=?", "4", "reasoning"),
("Hello?", "Hi!", "general")
]
model = SimpleTransformer()
sft_data = improved_rejection_sampling(model, data)
print("SFT Data:", sft_data)
改进点
- 阈值筛选:
threshold = 2.0
确保推理任务至少满足格式和准确性(reward >= 2.0),非推理任务至少满足格式(reward >= 1.0)。- 用列表推导式
valid_candidates
过滤掉不合格输出。
- 真实性模拟:
- 推理任务加入随机性(答案可能是错的),模拟 RL 检查点的生成能力。
- 多维度奖励:
- 增加
semantic_reward
,为非推理任务提供简单语义评估。
- 增加
- 次优回退:
- 如果没有满足阈值的输出,选择第一个候选,避免数据丢失。
Max 的合理性与局限
- 合理性:
- 在当前代码中,
max
简单高效,适用于 reward 明确反映输出质量的情况。 - 如果所有候选都不完美,
max
仍能选出相对较好的输出。
- 在当前代码中,
- 局限性:
- 未严格遵循“仅保留正确输出”的要求,可能选出次优解。
- 对非推理任务,reward 设计简单,
max
可能不足以捕捉语义质量。
是否需要替代?
- 如果严格按文档要求(只保留正确输出),可以用
filter
+ 阈值替代max
,例如:valid_candidates = [c for c in candidates if c[1] >= 2.0] # 仅保留正确的 best_output = valid_candidates[0][0] if valid_candidates else None # 取第一个正确输出,或 None
- 但这可能导致数据量不足(如果所有候选都不正确),需要更多采样。
总结
- 当前逻辑:
max
选择 reward 最高的输出,简单有效,但不够严格。 - 文档意图:要求“正确”输出,建议加入阈值筛选。
- 优化版本:结合阈值和
max
,更贴近 DeepSeek-R1 的拒绝采样过程。
后记
2025年3月3日14点15分于上海,在grok3大模型辅助下完成。
更多推荐
所有评论(0)