DeepSeek-R1:冷启动下的强化学习之旅

在追求大语言模型(LLM)推理能力的道路上,DeepSeek 团队推出了 DeepSeek-R1-Zero,一个完全通过纯强化学习(RL)训练的模型,展现了令人惊叹的推理能力。然而,它的局限性(如可读性差和语言混合)促使团队进一步探索,最终开发出更强大的 DeepSeek-R1。本文将总结 DeepSeek-R1 的训练过程,重点介绍其“冷启动 + 强化学习”的创新 pipeline,带你走进这场技术旅程。

背景:从 DeepSeek-R1-Zero 到 DeepSeek-R1

DeepSeek-R1-Zero (具体细节可以参考笔者的另一篇博客DeepSeek-R1-Zero 的训练过程:pytorch代码实现,这里有GRPO算法的介绍,本文不在赘述)的成功证明了纯 RL 可以让模型自发发展出复杂的推理行为,例如反思和长链推理。然而,其输出往往不够用户友好,难以直接应用于实际场景。为了解决这些问题并进一步提升性能,DeepSeek-R1 引入了冷启动数据和多阶段训练策略,最终在推理任务上达到了与 OpenAI-o1-1217 相当的水平。下面是其训练过程的四个关键阶段。


DeepSeek-R1 训练过程

1. 冷启动(Cold Start)

  • 目标:避免纯 RL 初期的训练不稳定性,提升输出可读性。
  • 方法:从 DeepSeek-V3-Base 模型开始,收集数千个包含长链推理(Chain-of-Thought, CoT)的冷启动数据进行微调。这些数据通过以下方式生成:
    • 少样本提示:用带有长 CoT 的示例引导模型生成。
    • 直接提示:要求模型输出详细推理和验证过程。
    • 后处理:利用 DeepSeek-R1-Zero 的输出,经人工标注精炼。
  • 输出格式:设计为 [special_token] <reasoning_process> [special_token] <summary>,确保推理过程清晰,总结便于阅读。
  • 优势
    • 可读性:相比 DeepSeek-R1-Zero 的混乱输出,冷启动数据加入了人类偏好(如 Markdown 格式),显著提高了用户体验。
    • 潜力:为后续 RL 提供更好的起点,加速收敛。

2. 推理导向的强化学习(Reasoning-oriented RL)

  • 目标:增强模型在推理密集任务(如数学、编程、科学推理)上的能力。
  • 方法
    • 使用与 DeepSeek-R1-Zero 相同的 RL 框架(GRPO),基于冷启动微调后的模型继续训练。
    • 奖励设计
      • 准确性奖励:评估输出是否正确。
      • 语言一致性奖励:新增一项奖励,计算 CoT 中目标语言(例如英语)的比例,解决语言混合问题。
    • 最终奖励:直接将准确性和语言一致性奖励相加。
  • 效果:训练至收敛后,模型在推理任务上的性能显著提升,同时语言输出更加一致,但仍需进一步优化非推理能力。

3. 拒绝采样与监督微调(Rejection Sampling and SFT)

  • 目标:平衡推理与通用能力,提升模型在写作、问答等领域的表现。
  • 方法
    • 数据收集
      • 推理数据:从 RL 检查点通过拒绝采样生成推理轨迹,保留正确答案,扩展到约 60 万样本。包括规则奖励和部分生成式奖励模型(基于 DeepSeek-V3 判断)。
      • 非推理数据:复用 DeepSeek-V3 的 SFT 数据(如写作、事实问答),约 20 万样本。对于复杂问题,生成 CoT;简单问题直接回答。
    • 微调:用约 80 万样本对 DeepSeek-V3-Base 进行两轮监督微调(SFT)。
  • 改进:通过筛选掉语言混乱或格式不佳的输出,模型的通用性和可读性进一步增强。

4. 全场景强化学习(RL for All Scenarios)

  • 目标:优化模型的帮助性和无害性,同时保持推理能力。
  • 方法
    • 奖励组合
      • 推理任务:延续规则奖励(如准确性)。
      • 通用任务:引入基于 DeepSeek-V3 的奖励模型,评估复杂场景下的人类偏好。
      • 帮助性:仅评估最终总结的实用性,避免干扰推理过程。
      • 无害性:检查完整输出(包括 CoT 和总结),减少偏见或有害内容。
    • 训练数据:使用多样化的提示分布,覆盖推理和非推理场景。
  • 结果:最终得到的 DeepSeek-R1 在推理任务(如 AIME 2024 得分 79.8%)和通用任务(如 AlpacaEval 2.0 胜率 87.6%)上均表现出色,且更符合人类偏好。

成果与亮点

经过这四阶段的训练,DeepSeek-R1 实现了以下突破:

  • 推理能力:在数学(MATH-500 得分 97.3%)、编程(Codeforces Elo 2029)等任务上与 OpenAI-o1-1217 匹敌。
  • 通用性:在知识问答(MMLU 90.8%)、长上下文理解等任务中大幅超越 DeepSeek-V3。
  • 用户友好性:通过冷启动和多轮微调,输出的可读性和一致性显著提升。

与 DeepSeek-R1-Zero 相比,冷启动策略不仅解决了可读性问题,还通过迭代训练加速了性能提升,最终打造出一个更强大、更实用的模型。


总结

DeepSeek-R1 的训练过程是一个从“冷启动”到“全场景优化”的完整 pipeline:

  1. 冷启动奠定基础,提供高质量初始数据。
  2. 推理 RL专注提升核心推理能力。
  3. 拒绝采样 + SFT扩展通用性。
  4. 全场景 RL兼顾帮助性和安全性。

这种多阶段方法巧妙结合了监督学习和强化学习的优势,不仅验证了 RL 在推理任务中的潜力,还为构建用户友好的 LLM 提供了宝贵经验。未来,DeepSeek 团队计划进一步优化语言混合、软件工程任务等领域,让这一模型更上一层楼。

如果你对大模型的推理能力提升感兴趣,不妨关注 DeepSeek-R1 的开源模型和后续研究,或许能从中找到灵感!

代码实现

以下是基于原论文中 2.3 DeepSeek-R1: Reinforcement Learning with Cold Start 部分的 DeepSeek-R1 训练流程的代码实现。由于 DeepSeek-R1 的训练涉及多个复杂阶段(冷启动、推理导向 RL、拒绝采样与 SFT、全场景 RL),完整的工业级实现需要大量基础设施支持(例如分布式训练、预训练模型、数据管道等)。这里提供一个简化的 PyTorch 代码框架,涵盖每个阶段的核心逻辑,并尽量贴近文档描述,同时保持可读性和可运行性。


前提假设

  1. 模型:假设使用一个简单的 Transformer 模型代替 DeepSeek-V3-Base。
  2. 数据:用简单的数学问题示例替代大规模数据集。
  3. 简化:省略分布式训练、复杂 tokenization 等细节,聚焦流程。

代码实现

import torch
import torch.nn as nn
import torch.optim as optim
import re
from typing import List, Tuple
from copy import deepcopy

# 简化的 Transformer 模型(代替 DeepSeek-V3-Base)
class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size=1000, d_model=256, n_heads=4, n_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.transformer = nn.Transformer(d_model, n_heads, n_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        x = self.embedding(x)
        x = self.transformer(x, x)  # 简化:自注意力
        return self.fc_out(x)

# 奖励计算函数
def compute_reward(output: str, ground_truth: str = None, target_lang: str = "en") -> float:
    # 格式奖励
    think_pattern = r"<think>.*?</think>"
    answer_pattern = r"<answer>.*?</answer>"
    format_reward = 1.0 if (re.search(think_pattern, output) and re.search(answer_pattern, output)) else 0.0
    
    # 准确性奖励
    accuracy_reward = 0.0
    if ground_truth:
        answer_match = re.search(r"<answer>(.*?)</answer>", output)
        if answer_match and answer_match.group(1).strip() == ground_truth:
            accuracy_reward = 1.0
    
    # 语言一致性奖励(简单模拟:检查是否全为英文)
    lang_reward = 1.0 if all(c.isascii() for c in output if c.isalpha()) else 0.5
    
    return format_reward + accuracy_reward + lang_reward

# GRPO 损失函数
def compute_grpo_loss(
    policy: SimpleTransformer,
    old_policy: SimpleTransformer,
    ref_policy: SimpleTransformer,
    states: torch.Tensor,
    outputs: List[str],
    rewards: List[float],
    epsilon: float = 0.2,
    beta: float = 0.01
) -> torch.Tensor:
    G = len(outputs)
    logits = policy(states)  # [batch, seq_len, vocab_size]
    old_logits = old_policy(states).detach()
    ref_logits = ref_policy(states).detach()

    # 简化为平均 log_probs(实际需序列化)
    log_probs = torch.softmax(logits, dim=-1).log().mean(dim=1)
    old_log_probs = torch.softmax(old_logits, dim=-1).log().mean(dim=1)
    ref_log_probs = torch.softmax(ref_logits, dim=-1).log().mean(dim=1)

    ratios = torch.exp(log_probs - old_log_probs)
    rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
    advantages = (rewards_tensor - rewards_tensor.mean()) / (rewards_tensor.std() + 1e-6)

    surr1 = ratios * advantages
    surr2 = torch.clamp(ratios, 1 - epsilon, 1 + epsilon) * advantages
    clipped_loss = torch.min(surr1, surr2).mean()

    kl_div = (ref_log_probs - log_probs).mean()
    return - (clipped_loss - beta * kl_div)

# 阶段 1:冷启动微调
def cold_start_sft(model: SimpleTransformer, data: List[Tuple[str, str]], epochs: int = 2):
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    
    # 模拟冷启动数据(问题 + CoT + 答案)
    cold_start_data = [
        (q, f"<think>{q}的推理过程:一步步计算。</think><answer>{a}</answer>")
        for q, a in data
    ]
    
    for epoch in range(epochs):
        total_loss = 0
        for question, target in cold_start_data:
            # 简化为 token IDs(实际需 tokenizer)
            input_ids = torch.randint(0, 1000, (1, 10))  # 占位符
            target_ids = torch.randint(0, 1000, (1, 10))
            
            optimizer.zero_grad()
            logits = model(input_ids)
            loss = criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Cold Start Epoch {epoch}, Loss: {total_loss / len(data)}")
    return model

# 阶段 2:推理导向 RL
def reasoning_rl(model: SimpleTransformer, data: List[Tuple[str, str]], epochs: int = 10, group_size: int = 4):
    ref_policy = deepcopy(model)  # 参考策略
    old_policy = deepcopy(model)  # 旧策略
    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    
    for epoch in range(epochs):
        for question, ground_truth in data:
            state = torch.randint(0, 1000, (1, 10))  # 模拟状态
            outputs = []
            for _ in range(group_size):
                with torch.no_grad():
                    logits = old_policy(state)
                    # 模拟生成(实际需解码)
                    output = f"<think>{question}推理</think><answer>{ground_truth}</answer>"
                    outputs.append(output)
            
            rewards = [compute_reward(o, ground_truth) for o in outputs]
            optimizer.zero_grad()
            loss = compute_grpo_loss(model, old_policy, ref_policy, state, outputs, rewards)
            loss.backward()
            optimizer.step()
            old_policy.load_state_dict(model.state_dict())
        print(f"Reasoning RL Epoch {epoch}, Loss: {loss.item()}")
    return model

# 阶段 3:拒绝采样与 SFT
def rejection_sampling_sft(model: SimpleTransformer, data: List[Tuple[str, str]], epochs: int = 2):
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    
    # 拒绝采样:生成多条输出,保留正确且格式良好的
    sft_data = []
    for question, ground_truth in data:
        candidates = []
        for _ in range(5):  # 采样 5 次
            with torch.no_grad():
                state = torch.randint(0, 1000, (1, 10))
                logits = model(state)
                output = f"<think>{question}推理</think><answer>{ground_truth}</answer>"  # 模拟
                reward = compute_reward(output, ground_truth)
                candidates.append((output, reward))
        # 选择最佳输出
        best_output = max(candidates, key=lambda x: x[1])[0]
        sft_data.append((question, best_output))
    
    # SFT 训练
    for epoch in range(epochs):
        total_loss = 0
        for _, target in sft_data:
            input_ids = torch.randint(0, 1000, (1, 10))
            target_ids = torch.randint(0, 1000, (1, 10))
            optimizer.zero_grad()
            logits = model(input_ids)
            loss = criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"SFT Epoch {epoch}, Loss: {total_loss / len(sft_data)}")
    return model

# 阶段 4:全场景 RL
def full_scenario_rl(model: SimpleTransformer, data: List[Tuple[str, str]], epochs: int = 5, group_size: int = 4):
    ref_policy = deepcopy(model)
    old_policy = deepcopy(model)
    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    
    for epoch in range(epochs):
        for question, ground_truth in data:
            state = torch.randint(0, 1000, (1, 10))
            outputs = []
            for _ in range(group_size):
                with torch.no_grad():
                    logits = old_policy(state)
                    output = f"<think>{question}推理</think><answer>{ground_truth}</answer>"
                    outputs.append(output)
            
            # 结合推理和通用奖励(这里简化)
            rewards = [compute_reward(o, ground_truth) for o in outputs]
            optimizer.zero_grad()
            loss = compute_grpo_loss(model, old_policy, ref_policy, state, outputs, rewards)
            loss.backward()
            optimizer.step()
            old_policy.load_state_dict(model.state_dict())
        print(f"Full RL Epoch {epoch}, Loss: {loss.item()}")
    return model

# 主流程
def train_deepseek_r1(data: List[Tuple[str, str]]):
    # 初始化模型
    model = SimpleTransformer()
    
    # 阶段 1:冷启动
    print("Stage 1: Cold Start SFT")
    model = cold_start_sft(model, data)
    
    # 阶段 2:推理 RL
    print("Stage 2: Reasoning RL")
    model = reasoning_rl(model, data)
    
    # 阶段 3:拒绝采样与 SFT
    print("Stage 3: Rejection Sampling SFT")
    model = rejection_sampling_sft(model, data)
    
    # 阶段 4:全场景 RL
    print("Stage 4: Full Scenario RL")
    model = full_scenario_rl(model, data)
    
    return model

# 示例数据
data = [("2+2=?", "4"), ("3*3=?", "9")]
model = train_deepseek_r1(data)

代码说明

1. 模型(SimpleTransformer
  • 一个简化的 Transformer 模型,代替 DeepSeek-V3-Base,用于生成 token 序列。
2. 奖励函数(compute_reward
  • 准确性奖励:检查 <answer> 中的答案与 ground truth 是否匹配。
  • 格式奖励:验证 <think><answer> 标签是否存在。
  • 语言一致性奖励:简单模拟,检查输出是否全为 ASCII 字符(代表英语)。
  • 符合文档中基于规则、不用神经网络的设计。
3. GRPO 损失(compute_grpo_loss
  • 实现 GRPO 的公式,包含剪切和 KL 散度正则化。
  • 优势 ( A i A_i Ai ) 根据组内奖励统计计算。
4. 训练阶段
  • 冷启动 SFT
    • 用预定义的 CoT 数据微调模型,模拟文档中的格式 [special_token] <reasoning_process> [special_token] <summary>
  • 推理 RL
    • 使用 GRPO 和奖励函数优化推理能力,加入语言一致性奖励。
  • 拒绝采样与 SFT
    • 从 RL 模型生成多条输出,保留最佳样本(基于奖励),然后进行 SFT。
    • 模拟推理和非推理数据的混合(这里简化)。
  • 全场景 RL
    • 再次用 GRPO 训练,结合推理和通用奖励(原论文中提到帮助性和无害性,这里简化为单一奖励函数)。
5. 主流程(train_deepseek_r1
  • 按顺序执行四个阶段,返回最终模型。

简化与局限

  1. 数据处理:未实现真实的 tokenization 和序列生成,需结合 transformers 库(例如 model.generate())。
  2. 奖励模型:语言一致性奖励过于简单,实际需更复杂的自然语言处理。
  3. 规模:未体现大规模训练,需扩展至分布式环境。
  4. 通用任务:阶段 4 的帮助性和无害性奖励未详细实现,可添加外部奖励模型。

如何扩展

  • 真实模型:用 transformers 的预训练模型(如 LLaMA)替换 SimpleTransformer
  • 数据管道:接入真实数据集(如数学题、编程题),用 tokenizer 处理输入输出。
  • 奖励细化:实现更复杂的规则(如代码编译器)或外部评估。

针对 拒绝采样与 SFT全场景 RL 两部分的优化代码

以下是针对 拒绝采样与 SFT全场景 RL 两部分的优化代码实现,模拟推理和非推理数据的混合,并在全场景 RL 中结合推理、帮助性和无害性奖励,同时保持与文档描述一致。代码使用 PyTorch,并针对 DeepSeek-R1 的训练流程进行了优化。


优化代码

1. 拒绝采样与 SFT(Rejection Sampling and SFT)
  • 目标:从 RL 检查点生成多条输出,筛选最佳样本,混合推理和非推理数据进行监督微调(SFT)。
  • 优化点
    • 模拟推理数据(数学问题)和非推理数据(简单问答)的生成与筛选。
    • 使用奖励函数评估输出质量,保留正确且格式良好的样本。
import torch
import torch.nn as nn
import torch.optim as optim
import re
from typing import List, Tuple
from copy import deepcopy

# 简化的 Transformer 模型(假设已定义)
class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size=1000, d_model=256, n_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(n_layers)])
        self.fc_out = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        x = self.embedding(x)
        for layer in self.layers:
            x = torch.relu(layer(x))
        return self.fc_out(x)

# 奖励函数(用于筛选)
def compute_basic_reward(output: str, ground_truth: str = None) -> float:
    format_reward = 1.0 if (re.search(r"<think>.*?</think>", output) and re.search(r"<answer>.*?</answer>", output)) else 0.0
    accuracy_reward = 0.0
    if ground_truth:
        answer_match = re.search(r"<answer>(.*?)</answer>", output)
        if answer_match and answer_match.group(1).strip() == ground_truth:
            accuracy_reward = 1.0
    return format_reward + accuracy_reward

# 拒绝采样与 SFT
def rejection_sampling_sft(
    model: SimpleTransformer,
    data: List[Tuple[str, str, str]],  # (question, ground_truth, type: "reasoning" 或 "general")
    epochs: int = 2,
    num_samples: int = 5
) -> SimpleTransformer:
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    # 拒绝采样:生成并筛选数据
    sft_data = []
    for question, ground_truth, data_type in data:
        candidates = []
        for _ in range(num_samples):
            with torch.no_grad():
                state = torch.randint(0, 1000, (1, 10))  # 模拟输入
                logits = model(state)
                # 模拟生成输出
                if data_type == "reasoning":
                    output = f"<think>{question}的步骤:计算。</think><answer>{ground_truth}</answer>"
                else:  # general
                    output = f"<answer>{ground_truth}</answer>"  # 非推理数据无需 CoT
                reward = compute_basic_reward(output, ground_truth if data_type == "reasoning" else None)
                candidates.append((output, reward))
        
        # 选择最佳输出
        best_output = max(candidates, key=lambda x: x[1])[0]
        sft_data.append((question, best_output))

    # SFT 训练
    for epoch in range(epochs):
        total_loss = 0
        for question, target in sft_data:
            input_ids = torch.randint(0, 1000, (1, 10))  # 模拟 token IDs
            target_ids = torch.randint(0, 1000, (1, 10))
            optimizer.zero_grad()
            logits = model(input_ids)
            loss = criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"SFT Epoch {epoch}, Loss: {total_loss / len(sft_data)}")
    
    return model

# 测试示例
data = [
    ("2+2=?", "4", "reasoning"),           # 推理数据
    ("3*3=?", "9", "reasoning"),
    ("Hello?", "Hi!", "general"),          # 非推理数据
    ("What’s the weather?", "Sunny", "general")
]
model = SimpleTransformer()
model = rejection_sampling_sft(model, data)

说明

  • 数据混合data 包含推理数据(带 CoT 和 ground truth)和非推理数据(仅答案)。推理数据要求准确性和格式,非推理数据只要求格式。
  • 拒绝采样:为每个问题生成 num_samples 条输出,用 compute_basic_reward 筛选最佳样本。
  • SFT:用筛选后的数据进行监督微调,模拟文档中的 60 万推理 + 20 万非推理样本。

2. 全场景 RL(Full Scenario RL)
  • 目标:结合推理奖励和通用奖励(帮助性、无害性),使用 GRPO 进一步优化模型。
  • 优化点
    • 实现多维度奖励:推理准确性、帮助性(输出实用性)、无害性(避免有害内容)。
    • 使用 GRPO 框架训练。
# 综合奖励函数
def compute_full_reward(output: str, ground_truth: str = None, question: str = None) -> float:
    # 推理奖励(准确性 + 格式)
    format_reward = 1.0 if (re.search(r"<think>.*?</think>", output) and re.search(r"<answer>.*?</answer>", output)) else 0.0
    accuracy_reward = 0.0
    if ground_truth:
        answer_match = re.search(r"<answer>(.*?)</answer>", output)
        if answer_match and answer_match.group(1).strip() == ground_truth:
            accuracy_reward = 1.0
    
    # 帮助性奖励(简单模拟:检查答案是否为空)
    helpfulness_reward = 1.0 if re.search(r"<answer>.+?</answer>", output) else 0.0
    
    # 无害性奖励(简单模拟:检查是否有负面词)
    harmful_words = {"hate", "kill", "bad"}
    harmless_reward = 0.0 if any(word in output.lower() for word in harmful_words) else 1.0
    
    return format_reward + accuracy_reward + helpfulness_reward + harmless_reward

# GRPO 损失函数(已定义,这里重用)
def compute_grpo_loss(
    policy: SimpleTransformer,
    old_policy: SimpleTransformer,
    ref_policy: SimpleTransformer,
    states: torch.Tensor,
    outputs: List[str],
    rewards: List[float],
    epsilon: float = 0.2,
    beta: float = 0.01
) -> torch.Tensor:
    G = len(outputs)
    logits = policy(states)
    old_logits = old_policy(states).detach()
    ref_logits = ref_policy(states).detach()

    log_probs = torch.softmax(logits, dim=-1).log().mean(dim=1)
    old_log_probs = torch.softmax(old_logits, dim=-1).log().mean(dim=1)
    ref_log_probs = torch.softmax(ref_logits, dim=-1).log().mean(dim=1)

    ratios = torch.exp(log_probs - old_log_probs)
    rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
    advantages = (rewards_tensor - rewards_tensor.mean()) / (rewards_tensor.std() + 1e-6)

    surr1 = ratios * advantages
    surr2 = torch.clamp(ratios, 1 - epsilon, 1 + epsilon) * advantages
    clipped_loss = torch.min(surr1, surr2).mean()

    kl_div = (ref_log_probs - log_probs).mean()
    return - (clipped_loss - beta * kl_div)

# 全场景 RL
def full_scenario_rl(
    model: SimpleTransformer,
    data: List[Tuple[str, str, str]],  # (question, ground_truth, type)
    epochs: int = 5,
    group_size: int = 4
) -> SimpleTransformer:
    ref_policy = deepcopy(model)
    old_policy = deepcopy(model)
    optimizer = optim.Adam(model.parameters(), lr=1e-5)

    for epoch in range(epochs):
        total_loss = 0
        for question, ground_truth, data_type in data:
            state = torch.randint(0, 1000, (1, 10))
            outputs = []
            for _ in range(group_size):
                with torch.no_grad():
                    logits = old_policy(state)
                    if data_type == "reasoning":
                        output = f"<think>{question}的步骤</think><answer>{ground_truth}</answer>"
                    else:
                        output = f"<answer>{ground_truth}</answer>"
                    outputs.append(output)
            
            # 计算综合奖励
            rewards = [compute_full_reward(o, ground_truth if data_type == "reasoning" else None, question) 
                      for o in outputs]
            
            optimizer.zero_grad()
            loss = compute_grpo_loss(model, old_policy, ref_policy, state, outputs, rewards)
            loss.backward()
            optimizer.step()
            old_policy.load_state_dict(model.state_dict())
            total_loss += loss.item()
        print(f"Full RL Epoch {epoch}, Loss: {total_loss / len(data)}")
    
    return model

# 测试示例
data = [
    ("2+2=?", "4", "reasoning"),
    ("3*3=?", "9", "reasoning"),
    ("Hello?", "Hi!", "general"),
    ("What’s the weather?", "Sunny", "general")
]
model = SimpleTransformer()
model = full_scenario_rl(model, data)

说明

  • 综合奖励
    • 推理奖励:包括格式和准确性,与文档一致。
    • 帮助性奖励:检查 <answer> 是否非空,模拟实用性评估。
    • 无害性奖励:简单检查是否有负面词汇,模拟安全性。
  • 数据处理:根据 data_type 区分推理和非推理输出,推理任务带 CoT,非推理任务仅答案。
  • GRPO 训练:用综合奖励驱动策略优化,保留剪切和 KL 散度。

优化亮点

  1. 拒绝采样与 SFT
    • 混合推理和非推理数据,推理任务要求 CoT 和准确性,非推理任务简化格式。
    • 拒绝采样通过奖励筛选,确保高质量训练数据。
  2. 全场景 RL
    • 奖励函数扩展到推理、帮助性和无害性三个维度,符合文档中多目标优化的描述。
    • GRPO 框架保持一致,适应多样化任务。

注意事项

  • 简化:序列生成用静态字符串模拟,实际需用 transformersgenerate()
  • 奖励模型:帮助性和无害性奖励过于简单,工业实现可能需要外部 LLM(如 DeepSeek-V3)评分。
  • 扩展性:支持更大规模数据需批量处理和分布式训练。

针对拒绝采样(Rejection Sampling)中选择最佳输出的逻辑

针对拒绝采样(Rejection Sampling)中选择最佳输出的逻辑,尤其是 best_output = max(candidates, key=lambda x: x[1])[0] 这行代码的作用和合理性合适吗?详细解释这段代码的逻辑,为什么使用 max,以及它是否完全符合 DeepSeek-R1 的设计意图。


当前代码的逻辑

代码片段回顾
candidates = []
for _ in range(num_samples):
    with torch.no_grad():
        state = torch.randint(0, 1000, (1, 10))  # 模拟输入
        logits = model(state)
        # 模拟生成输出
        if data_type == "reasoning":
            output = f"<think>{question}的步骤:计算。</think><answer>{ground_truth}</answer>"
        else:  # general
            output = f"<answer>{ground_truth}</answer>"
        reward = compute_basic_reward(output, ground_truth if data_type == "reasoning" else None)
        candidates.append((output, reward))

best_output = max(candidates, key=lambda x: x[1])[0]
  • 生成候选:代码为每个问题生成 num_samples 个输出(这里是 5 次),每个输出是一个 (output, reward) 元组,output 是生成的文本,reward 是根据 compute_basic_reward 计算的分数。
  • 选择最佳输出
    • max(candidates, key=lambda x: x[1]):从 candidates 列表中选择 reward(索引 1)最大的元组。
    • [0]:从选中的元组中提取 output(索引 0),即最佳输出文本。
逻辑是什么?
  • 目标:拒绝采样的目的是从多个候选输出中挑选“最好”的一个,用于后续的监督微调(SFT)。这里的“最好”由 reward 定义,包含格式奖励和准确性奖励(对于推理任务)。
  • max 的作用max 函数基于 key=lambda x: x[1](即 reward)比较所有候选,选出得分最高的输出。换句话说,它假设更高的 reward 表示更高质量的输出。
  • 假设
    • 对于推理任务(data_type == "reasoning"),reward 越高意味着输出更可能正确(匹配 ground_truth)且格式正确。
    • 对于非推理任务(data_type == "general"),reward 越高意味着格式正确(仅要求 <answer> 标签)。
示例

假设 question = "2+2=?", ground_truth = "4", num_samples = 3,生成以下候选:

  1. ("<think>2+2的步骤:计算。</think><answer>4</answer>", 2.0)(格式正确 + 答案正确)
  2. ("<think>2+2的步骤。</think><answer>5</answer>", 1.0)(格式正确 + 答案错误)
  3. ("2+2=4", 0.0)(无格式)

运行 max(candidates, key=lambda x: x[1])

  • 比较 reward:2.0 > 1.0 > 0.0
  • 返回 ( "<think>2+2的步骤:计算。</think><answer>4</answer>", 2.0 )
  • [0] 提取输出:"<think>2+2的步骤:计算。</think><answer>4</answer>"

最终,best_output 是 reward 最高的输出。


这里用 max 可以吗?

从代码的实现来看,max 是合理的,但是否完全符合 DeepSeek-R1 的设计意图,需要结合文档(原paper)进一步分析。

文档中的描述

文档第 2.3.3 节(Rejection Sampling and Supervised Fine-Tuning)提到:

  • “We curate reasoning prompts and generate reasoning trajectories by performing rejection sampling from the checkpoint from the above RL training.”
  • “For each prompt, we sample multiple responses and retain only the correct ones.”
  • “In total, we collect about 600k reasoning related training samples.”

关键点:

  1. 筛选标准:文档明确要求保留“正确”的输出(correct ones),而不是简单地取 reward 最高。
  2. 数据量:实际中生成了大量样本(60 万推理 + 20 万非推理),筛选后保留高质量部分。
  3. 奖励评估:推理数据基于规则奖励(准确性)和生成式奖励模型(DeepSeek-V3 判断),非推理数据也有特定标准。
当前代码的问题
  • 仅用 max
    • 当前逻辑只取 reward 最高的输出,未严格验证“正确性”。例如,如果所有候选都不正确(reward < 2.0),仍会选一个次优输出。
    • 文档要求“retain only the correct ones”,暗示可能需要一个阈值(threshold),而不是单纯取最大值。
  • 模拟生成局限
    • 当前代码中,output 是直接用 ground_truth 构造的(模拟),而不是模型真实生成。这导致 reward 评估不反映模型实际能力,筛选逻辑过于理想化。
  • 非推理数据
    • 对于 general 类型,仅检查格式,未考虑帮助性或语义质量,与文档中“增强写作、问答等能力”的目标不完全匹配。
改进建议
  • 加入阈值:只保留 reward 达到一定标准的输出(例如 >= 2.0 表示格式和准确性都满足),而不是盲目取最大值。
  • 真实生成:用模型生成输出,而不是预设 ground_truth,以反映 RL 检查点的真实能力。
  • 多标准筛选:结合准确性、格式和语义质量(例如用外部模型评估),尤其是非推理任务。

优化后的代码

以下是改进后的拒绝采样逻辑:

# 改进的奖励函数(加入语义检查)
def compute_advanced_reward(output: str, ground_truth: str = None, data_type: str = "reasoning") -> float:
    format_reward = 1.0 if (re.search(r"<think>.*?</think>", output) and re.search(r"<answer>.*?</answer>", output)) else 0.0 if data_type == "reasoning" else (1.0 if re.search(r"<answer>.*?</answer>", output) else 0.0)
    accuracy_reward = 0.0
    if ground_truth and data_type == "reasoning":
        answer_match = re.search(r"<answer>(.*?)</answer>", output)
        if answer_match and answer_match.group(1).strip() == ground_truth:
            accuracy_reward = 1.0
    # 简单模拟语义质量(非推理任务)
    semantic_reward = 1.0 if ("<answer>" in output and len(output) > 10) else 0.5 if data_type == "general" else 0.0
    return format_reward + accuracy_reward + semantic_reward

# 优化后的拒绝采样
def improved_rejection_sampling(
    model: SimpleTransformer,
    data: List[Tuple[str, str, str]],
    num_samples: int = 5,
    threshold: float = 2.0  # 推理任务要求格式+准确性
) -> List[Tuple[str, str]]:
    sft_data = []
    for question, ground_truth, data_type in data:
        candidates = []
        for _ in range(num_samples):
            with torch.no_grad():
                state = torch.randint(0, 1000, (1, 10))
                logits = model(state)
                # 模拟生成(实际需模型解码)
                if data_type == "reasoning":
                    # 假设模型生成的答案可能是随机的
                    possible_answer = ground_truth if torch.rand(1) > 0.3 else str(int(ground_truth) + 1)
                    output = f"<think>{question}的步骤:计算。</think><answer>{possible_answer}</answer>"
                else:
                    output = f"<answer>{ground_truth} today!</answer>"
                reward = compute_advanced_reward(output, ground_truth, data_type)
                candidates.append((output, reward))
        
        # 筛选:保留 reward >= threshold 的正确输出
        valid_candidates = [c for c in candidates if c[1] >= threshold]
        if valid_candidates:
            best_output = max(valid_candidates, key=lambda x: x[1])[0]  # 从正确输出中选最佳
        else:
            best_output = candidates[0][0]  # 若无正确输出,退回第一个(次优)
        sft_data.append((question, best_output))
    
    return sft_data

# 测试
data = [
    ("2+2=?", "4", "reasoning"),
    ("Hello?", "Hi!", "general")
]
model = SimpleTransformer()
sft_data = improved_rejection_sampling(model, data)
print("SFT Data:", sft_data)
改进点
  • 阈值筛选
    • threshold = 2.0 确保推理任务至少满足格式和准确性(reward >= 2.0),非推理任务至少满足格式(reward >= 1.0)。
    • 用列表推导式 valid_candidates 过滤掉不合格输出。
  • 真实性模拟
    • 推理任务加入随机性(答案可能是错的),模拟 RL 检查点的生成能力。
  • 多维度奖励
    • 增加 semantic_reward,为非推理任务提供简单语义评估。
  • 次优回退
    • 如果没有满足阈值的输出,选择第一个候选,避免数据丢失。

Max 的合理性与局限

  • 合理性
    • 在当前代码中,max 简单高效,适用于 reward 明确反映输出质量的情况。
    • 如果所有候选都不完美,max 仍能选出相对较好的输出。
  • 局限性
    • 未严格遵循“仅保留正确输出”的要求,可能选出次优解。
    • 对非推理任务,reward 设计简单,max 可能不足以捕捉语义质量。
是否需要替代?
  • 如果严格按文档要求(只保留正确输出),可以用 filter + 阈值替代 max,例如:
    valid_candidates = [c for c in candidates if c[1] >= 2.0]  # 仅保留正确的
    best_output = valid_candidates[0][0] if valid_candidates else None  # 取第一个正确输出,或 None
    
    • 但这可能导致数据量不足(如果所有候选都不正确),需要更多采样。

总结

  • 当前逻辑max 选择 reward 最高的输出,简单有效,但不够严格。
  • 文档意图:要求“正确”输出,建议加入阈值筛选。
  • 优化版本:结合阈值和 max,更贴近 DeepSeek-R1 的拒绝采样过程。

后记

2025年3月3日14点15分于上海,在grok3大模型辅助下完成。

Logo

欢迎加入DeepSeek 技术社区。在这里,你可以找到志同道合的朋友,共同探索AI技术的奥秘。

更多推荐