一、DeepGEMM概述

DeepGEMM 是一个专为 NVIDIA Hopper 架构设计的高效 FP8 矩阵乘法库,支持普通和混合专家模型(MoE)分组矩阵乘法,通过简洁的实现和即时编译技术,实现了高性能和易用性。

官方开源代码链接:https://github.com/deepseek-ai/DeepGEMM

1. 核心亮点

  • FP8 低精度支持:DeepGEMM 最大的特色在于从架构上优先设计为 FP8 服务。传统GEMM库主要优化FP16和FP32,而DeepGEMM针对FP8的特殊性进行了优化设计。
  • 极致性能与极简核心实现:DeepGEMM在NVIDIA Hopper GPU上实现了高达1350+ FP8 TFLOPS的计算性能,同时其核心代码仅有约300行
  • JIT 即时编译:DeepGEMM 不是预先编译好所有可能配置的内核,而是利用 JIT 在运行时生成最佳内核。例如,根据矩阵大小、FP8尺度等参数,JIT 会即时优化指令顺序和寄存器分配。

DeepGEMM 是一个用于高效执行 FP8 精度通用矩阵乘法(GEMM)的库,支持细粒度缩放,相关设计思路来源于 DeepSeek-V3。它同时支持普通矩阵乘法和混合专家模型(MoE)分组矩阵乘法。该库使用 CUDA 编写,安装时无需编译,通过轻量级的即时编译(JIT)模块在运行时动态编译所有内核。
目前,DeepGEMM 仅支持 NVIDIA Hopper 架构的张量核心。为了解决 FP8 张量核心在累加时的精度问题,它采用了 CUDA 核心的双级累加(提升)策略。虽然它借鉴了 CUTLASS 和 CuTe 的部分概念,但并未过度依赖它们的模板或代数结构。相反,该库的设计注重简洁性,其核心内核函数仅包含约 300 行代码。这使得它成为学习 Hopper 架构下 FP8 矩阵乘法及优化技术的清晰且易于上手的资源。

2. 性能表现

在 H800 SXM5 GPU 上使用 NVCC 12.8 测试 DeepSeek-V3/R1 推理中可能使用的所有形状(包括预填充和解码阶段,但不包括张量并行)。所有加速指标均与基于 CUTLASS 3.6 的内部优化实现进行比较。

(1)普通 GEMM(适用于稠密模型)

M N K Computation Memory bandwidth Speedup
64 2112 7168 206 TFLOPS 1688 GB/s 2.7x
64 24576 1536 289 TFLOPS 2455 GB/s 1.7x
64 32768 512 219 TFLOPS 2143 GB/s 1.8x
64 7168 16384 336 TFLOPS 2668 GB/s 1.4x
64 4096 7168 287 TFLOPS 2320 GB/s 1.4x
64 7168 2048 295 TFLOPS 2470 GB/s 1.7x
128 2112 7168 352 TFLOPS 1509 GB/s 2.4x
128 24576 1536 535 TFLOPS 2448 GB/s 1.6x
128 32768 512 358 TFLOPS 2103 GB/s 1.5x
128 7168 16384 645 TFLOPS 2604 GB/s 1.4x
128 4096 7168 533 TFLOPS 2221 GB/s 2.0x
128 7168 2048 510 TFLOPS 2277 GB/s 1.7x
4096 2112 7168 1058 TFLOPS 527 GB/s 1.1x
4096 24576 1536 990 TFLOPS 786 GB/s 1.0x
4096 32768 512 590 TFLOPS 1232 GB/s 1.0x
4096 7168 16384 1358 TFLOPS 343 GB/s 1.2x
4096 4096 7168 1304 TFLOPS 500 GB/s 1.1x
4096 7168 2048 1025 TFLOPS 697 GB/s 1.1x

(2)分组 GEMM(适用于 MoE 模型,连续布局)

#Groups M per group N K Computation Memory bandwidth Speedup
4 8192 4096 7168 1297 TFLOPS 418 GB/s 1.2x
4 8192 7168 2048 1099 TFLOPS 681 GB/s 1.2x
8 4096 4096 7168 1288 TFLOPS 494 GB/s 1.2x
8 4096 7168 2048 1093 TFLOPS 743 GB/s 1.1x

(3)分组 GEMM(适用于 MoE 模型,掩码布局)

#Groups M per group N K Computation Memory bandwidth Speedup
1 1024 4096 7168 1233 TFLOPS 924 GB/s 1.2x
1 1024 7168 2048 925 TFLOPS 968 GB/s 1.2x
2 512 4096 7168 1040 TFLOPS 1288 GB/s 1.2x
2 512 7168 2048 916 TFLOPS 1405 GB/s 1.2x
4 256 4096 7168 932 TFLOPS 2064 GB/s 1.1x
4 256 7168 2048 815 TFLOPS 2047 GB/s 1.2x

二、通用矩阵乘法机制简述

在深度学习中,矩阵乘法是核心计算任务,比如神经网络的全连接层和Transformer的自注意力机制都需要大量的矩阵乘法运算。随着模型规模的增大,高效的矩阵乘法变得至关重要,因为它能显著缩短模型训练和推理的时间。

在这里插入图片描述

通用矩阵乘法 (General Matrix Multiplication,GEMM) 是各种模型和计算中的核心部分,同时也是评估计算硬件性能 (FLOPS) 的标准技术。

GEMM 的定义为:

C ← α A B + β C \boldsymbol{C} \leftarrow \alpha \boldsymbol{A} \boldsymbol{B} + \beta \boldsymbol{C} CαAB+βC

即将矩阵 A \boldsymbol{A} A B \boldsymbol{B} B 进行矩阵相乘,并将结果缩放 α \alpha α 倍,然后与缩放 β \beta β 倍的矩阵 C \boldsymbol{C} C 相加,并将最终结果存入 C \boldsymbol{C} C 中。

接下来分析计算复杂度,假设 A \boldsymbol{A} A 的形状是 M × K M × K M×K B \boldsymbol{B} B 的形状是 K × N K × N K×N ,则 C \boldsymbol{C} C 形状是 M × N M × N M×N 。其中主要的部分是 A B \boldsymbol{A} \boldsymbol{B} AB 矩阵相乘,根据矩阵乘法的定义

A B = ( a 1 , 1 ⋯ a 1 , K ⋮ ⋱ ⋮ a M , 1 ⋯ a M , K ) ( b 1 , 1 ⋯ b 1 , N ⋮ ⋱ ⋮ b K , 1 ⋯ b K , N ) = ( ∑ k = 1 K a 1 , k b k , 1 ⋯ ∑ k = 1 K a 1 , k b k , N ⋮ ⋱ ⋮ ∑ k = 1 N a M , k b k , 1 ⋯ ∑ k = 1 K a M , k b k , N ) \boldsymbol{A} \boldsymbol{B}=\left(\begin{array}{ccc} a_{1,1} & \cdots & a_{1, K} \\ \vdots & \ddots & \vdots \\ a_{M, 1} & \cdots & a_{M, K} \end{array}\right) \left(\begin{array}{ccc} b_{1,1} & \cdots & b_{1, N} \\ \vdots & \ddots & \vdots \\ b_{K, 1} & \cdots & b_{K, N} \end{array}\right) =\left(\begin{array}{ccc} \sum_{k=1}^{K} a_{1, k} b_{k, 1} & \cdots & \sum_{k=1}^{K} a_{1, k} b_{k, N} \\ \vdots & \ddots & \vdots \\ \sum_{k=1}^{N} a_{M, k} b_{k, 1} & \cdots & \sum_{k=1}^{K} a_{M, k} b_{k, N} \end{array}\right) AB= a1,1aM,1a1,KaM,K b1,1bK,1b1,NbK,N = k=1Ka1,kbk,1k=1NaM,kbk,1k=1Ka1,kbk,Nk=1KaM,kbk,N

其中第 i 行第 j 列元素 ∑ k = 1 K a i , k b k , j \sum_{k=1}^{K} a_{i, k} b_{k, j} k=1Kai,kbk,j ,即每个元素的计算需要 K K K 次乘法和 K − 1 K-1 K1 次加法,即计算 A B \boldsymbol{A} \boldsymbol{B} AB 共需要执行 ( 2 K − 1 ) M N (2K-1)MN (2K1)MN 次浮点数运算。

另外 A B \boldsymbol{A} \boldsymbol{B} AB C \boldsymbol{C} C 的放缩都需要 M N MN MN 次浮点运算,那么总的浮点运算次数则为 ( 2 K + 1 ) M N (2K+1)MN (2K+1)MN ,由于 K ≫ 1 K\gg 1 K1 ,因此通常浮点运算次数近似等于 2 K M N 2KMN 2KMN

在这里插入图片描述

深入阅读:https://zhuanlan.zhihu.com/p/657632577

三、DeepGEMM矩阵乘法解析

1. 传统矩阵乘法 vs. DeepGEMM算法

特性 传统矩阵乘法 DeepGEMM算法
计算方式 逐元素顺序计算,无并行加速 利用GPU并行计算,分块处理,使用张量核心等优化
计算速度 慢,尤其在大规模矩阵时效率低下 极快,能在短时间内完成大规模矩阵乘法
内存使用 内存占用高,可能面临内存不足问题 内存管理高效,避免内存瓶颈
适用场景 适用于小型或简单矩阵运算场景 专为大规模矩阵运算设计,如深度学习模型训练和推理

2. 接口函数

DeepGEMM 提供了多个接口函数,包括:

  • 常规稠密 GEMMgemm_fp8_fp8_bf16_nt 是标准的矩阵乘法,适用于简单的矩阵计算场景。
  • 分组 GEMM(连续布局)m_grouped_gemm_fp8_fp8_bf16_nt_contiguous 引入了分组机制,允许对输入数据进行分组处理,适用于需要对输入数据进行分组计算的场景。
  • 分组 GEMM(掩码布局)m_grouped_gemm_fp8_fp8_bf16_nt_masked 在分组机制的基础上增加了掩码机制,允许动态裁剪每个分组的计算行数,适用于需要灵活处理分组数据的场景。

(1)gemm_fp8_fp8_bf16_nt

  • 输入形状:

    • LHS[m, k],FP8 类型。
    • LHS 缩放因子[m, ⌈k / 128⌉],FP32 类型。
    • RHS[n, k],FP8 类型。
    • RHS 缩放因子[⌈n / 128⌉, ⌈k / 128⌉],FP32 类型。
    • 输出[m, n],BF16 类型。
  • 原理:

    • 执行标准的矩阵乘法操作,LHS 的每一行与 RHS 的每一列进行点积计算,结果存储在输出矩阵中。
    • 缩放因子用于在计算过程中调整数值范围,以避免溢出或下溢。
    • LHS 和 RHS 的缩放因子分别以 1x128 和 128x128 的形式提供,用于在计算过程中动态调整数值范围。

(2)m_grouped_gemm_fp8_fp8_bf16_nt_contiguous

  • 输入形状:

    • LHS[m_sum, k],FP8 类型。
    • LHS 缩放因子[m_sum, ⌈k / 128⌉],FP32 类型。
    • RHS[num_groups, n, k],FP8 类型。
    • RHS 缩放因子[num_groups, ⌈n / 128⌉, ⌈k / 128⌉],FP32 类型。
    • 输出[m_sum, n],BF16 类型。
    • m_indices[m_sum],int 类型,记录每行 LHS 所属的分组。
  • 分组机制:

    • 分组机制:输入矩阵 LHS 的每一行被分配到不同的分组,分组信息由 m_indices 提供。
    • 分组大小对齐:分组大小对齐到 get_m_alignment_for_contiguous_layout()(通常是 128)。
    • 动态分组m_indices[i] 表示 LHS 的第 i 行属于第 m_indices[i] 个分组,计算时会与 RHS[m_indices[i]] 进行乘法操作。
  • 原理:

    • 每个分组的 RHS 矩阵独立参与计算,LHS 的每一行根据 m_indices 的指示选择对应的 RHS 分组进行乘法操作。
    • 这种机制允许对不同分组的输入数据进行独立计算,适用于需要对输入数据进行分组处理的场景。
    • 分组大小对齐到固定值(如 128),以优化内存访问和计算效率。

(3)m_grouped_gemm_fp8_fp8_bf16_nt_masked

  • 输入形状:

    • LHS[num_groups, m_max, k],FP8 类型。
    • LHS 缩放因子[num_groups, m_max, ⌈k / 128⌉],FP32 类型。
    • RHS[num_groups, n, k],FP8 类型。
    • RHS 缩放因子[num_groups, ⌈n / 128⌉, ⌈k / 128⌉],FP32 类型。
    • 输出[num_groups, m_max, n],BF16 类型。
    • masked_m[num_groups],int 类型,记录每个分组实际参与计算的行数。
    • expected_m:一个标量,表示每个分组的期望行数(用于性能优化)。
  • 分组机制:

    • 分组机制:输入矩阵 LHS 和输出矩阵 out 按分组维度组织,每个分组的大小为 m_max
    • 掩码机制masked_m[i] 表示第 i 个分组实际参与计算的行数,允许每个分组的计算行数小于 m_max
    • 分组独立转置:与 m_grouped_gemm_fp8_fp8_bf16_nt_contiguous 不同,masked 版本要求每个分组独立转置,以优化内存访问。
  • 原理:

    • 每个分组的 LHS 矩阵根据 masked_m 的指示选择实际参与计算的行数,与对应的 RHS 分组进行乘法操作。
    • 掩码机制允许灵活处理不同分组的实际计算行数,避免不必要的计算开销。
    • expected_m 提供了一个性能优化的提示值,帮助内核更好地分配计算资源。
    • 这种机制适用于需要对分组数据进行动态裁剪的场景,例如在某些分组中只有部分行需要计算。

(4)综合对比

特性 gemm_fp8_fp8_bf16_nt m_grouped_gemm_fp8_fp8_bf16_nt_contiguous m_grouped_gemm_fp8_fp8_bf16_nt_masked
输入形状 [m, k][n, k] [m_sum, k][num_groups, n, k] [num_groups, m_max, k][num_groups, n, k]
分组机制 无分组 动态分组,m_indices 指示行所属分组 动态分组,masked_m 指示每个分组的实际行数
对齐要求 无特殊对齐 分组大小对齐到 get_m_alignment_for_contiguous_layout() 每个分组独立转置,对齐到 m_max
适用场景 标准矩阵乘法 需要对输入数据进行分组处理 需要对分组数据进行动态裁剪
性能优化 无特殊优化 分组大小对齐优化 掩码机制优化,expected_m 提供性能提示
计算灵活性 固定矩阵形状 动态选择 RHS 分组 动态选择计算行数,灵活处理分组数据

3. 代码简析

以gemm_fp8_fp8_bf16_nt为例,简析代码:

function gemm_fp8_fp8_bf16_nt(lhs, rhs, out):
    # lhs: (FP8 matrix [m, k], FP32 scales [m, ceil(k/128)])
    # rhs: (FP8 matrix [n, k], FP32 scales [ceil(n/128), ceil(k/128)])
    # out: BF16 matrix [m, n]

    # Step 1: 输入提取
    lhs_matrix, lhs_scales = lhs
    rhs_matrix, rhs_scales = rhs
    m, k = shape(lhs_matrix)
    n, k_ = shape(rhs_matrix)
    m_, n_ = shape(out)

    # Step 2: 输入验证
    assert ...

    # Step 3: 数据预处理:确保 lhs_scales 是 TMA 对齐的格式。如果不符合要求,则对其进行转置和对齐。
    if not is_tma_aligned(lhs_scales):
        lhs_scales = transpose_and_align(lhs_scales)  # Ensure TMA-aligned format
    assert is_contiguous(rhs_scales)

    # Step 4: 自动调优: 动态选择最优的配置参数
    num_sms = get_number_of_sms()  # Get the number of SMs on the GPU
    num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(
        m, n, k, num_groups=1, num_sms=num_sms
    )

    # Step 5: JIT 编译和内核准备
    args = (lhs_matrix, lhs_scales, rhs_matrix, rhs_scales, out, m, current_cuda_stream(), num_sms, smem_size)
    runtime = jit_tuner.compile_and_tune(
        name='gemm_fp8_fp8_bf16_nt',
        keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
              'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast},
        includes=("deep_gemm/fp8_gemm.cuh"),
        template=generate_cuda_template(),
        args=args
    )

    # Step 6: 内核执行
    runtime(*args)

四、持久化Warp专业化 & Hopper TMA特性

1. 持久化Warp专业化

Warp 的概念是 NVIDIA GPU 架构中的一个重要特性,它使得开发者能够更细致地控制线程的执行,以优化并行计算的性能。
在 CUDA 编程模型中,Warp 是 GPU 上并行执行的最小单元。一个 Warp 包含 32 个线程,这些线程在同一个周期内执行相同的指令。这意味着,如果一个 Warp 中的所有线程都执行相同的操作,那么它们可以并行地在 GPU 上执行,从而提高计算效率。

遵循CUTLASS的设计,DeepGEMM中的内核是warp专业化的,能够重叠数据移动、张量核心MMA指令和CUDA核心提升。下图展示了这一过程的简化示意图:

在这里插入图片描述

  • TMA Warps:负责异步数据加载,通过 TMA 指令减少内存访问延迟。
  • Math Warps:执行实际的矩阵乘法计算(WGMMA 表示张量核心矩阵乘法)。
  • Promotion:在计算过程中进行数据的累加操作。

2. Hopper TMA特性

Tensor Memory Accelerator(TMA) 是NVIDIA Hopper架构中的一项新功能,它旨在加速GPU的内存访问。TMA通过使用全局内存(GMEM)和共享内存(SMEM)进行数据复制,从而提高GPU的整体性能。
在DeepGEMM中,TMA被用于以下操作:

  • 用于LHS、LHS缩放因子和RHS矩阵的TMA加载
  • 用于输出矩阵的TMA存储
  • TMA多播(仅限LHS矩阵)
  • TMA描述符预取

五、补充:JIT技术

1. JIT 技术的定义与原理

即时编译(JIT)是一种在程序运行时动态生成和优化代码的技术。与传统的编译方式(如提前编译,Ahead-Of-Time,AOT)不同,JIT 编译器不会在程序安装或部署时生成最终的可执行代码,而是在程序运行时根据实际的输入和运行环境动态生成优化后的代码。这种技术的核心思想是“延迟编译”,即在代码真正需要执行时才进行编译和优化,从而针对具体的运行场景生成最高效的目标代码。
在这里插入图片描述

2. JIT 技术的优势

JIT 技术在现代计算中具有广泛的应用,尤其是在高性能计算和动态语言运行环境中。它相比传统的静态编译方式具有以下显著优势:

  • 更高的性能:JIT 编译器可以根据运行时的具体输入和硬件环境动态生成优化代码。例如,它可以针对不同的矩阵大小、数据类型或硬件特性选择最适合的算法和优化策略,从而实现比静态编译更高的性能。
  • 灵活性与可扩展性:由于 JIT 编译器在运行时生成代码,因此它可以轻松适应不同的硬件架构和输入数据特征。这种灵活性使得 JIT 技术特别适合于需要处理多种输入场景和硬件环境的应用程序。
  • 减少编译时间与资源消耗:在传统的静态编译中,编译器需要考虑所有可能的输入情况并生成通用的代码。这往往导致复杂的编译过程和较长的编译时间。而 JIT 编译器只需在运行时针对当前输入生成代码,因此可以显著减少编译时间和资源消耗。

3. 在 DeepGEMM 中的应用

DeepGEMM 采用完全即时编译(JIT)设计,安装时无需编译。所有内核在运行时使用轻量级 JIT 实现进行编译。这种方法具有以下几个优点:

  • GEMM形状、块大小和流水线阶段数被视为编译时常量
    • 节省寄存器
    • 编译器可以进行更多优化
  • 自动选择块大小、warpgroup数量、最佳流水线阶段和TMA集群大小
    • 但没有自动调优,确定性地选择最佳方案
  • 完全展开MMA流水线,为编译器提供更多优化机会
    • 对于小形状非常重要
    • 详情请参阅内核文件中的launch_k_iterations

六、补充:FP8精度改进

本部分学习参考:https://zhuanlan.zhihu.com/p/26437292382

使用FP8框架进行训练的主要挑战在于精度与误差的处理,DeepSeek为其FP8低比特训练框架做了以下优化:

(1)细粒度量化
将数据分解成更小的组,每个组都使用特定乘数进行调整以保持高精度。这一方法类似于Tile-Wise或Block-Wise。对于激活,在1x128大小的基础上对计算数据进行分组和缩放;对于权重,以128x128大小对计算数据进行分组和缩放。该方法可以根据最大或最小数据调整缩放系数,来更好的适应计算中的异常值。
(2)在线量化
为了提高精度并简化框架,该框架在线计算每个1x128激活块或128x128权重块的最大绝对值,在线推算缩放因子,然后将激活或权重在线转化为FP8格式,而不是采用静态的历史数据。相对静态的量化方法,该方法可以获得更高的转换精度,减小误差的累积
(3)提高累加精度
FP8在大量累加时会累积出现随机误差。例如FP8 GEMM在英伟达H800 GPU上的累加精度保留14位左右,明显低于FP32累加精度。以K= 4096的两个随机矩阵的GEMM运算为例,Tensor Core中的有限累加精度可导致最大相对误差接近2%。
DeepSeek将中间结果储存计算升级为FP32(32位浮点),实行高精度累加,然后再转换回FP8,以降低大量微小误差累加带来的训练偏差。
(4)低精度/混合精度存储与通信
为了进一步减少MoE训练中的显存和通信开销,该框架基于FP8进行数据/参数缓存和处理激活,以节省显存与缓存空间并提升性能,并在BF16(16位浮点数)中存储低精度优化器状态。该框架中以下组件保持原始精度(例如BF16或FP32):嵌入模块、MoE门控模块、归一化算子和注意力算子,以确保模型的动态稳定训练。为保证数值稳定性,以高精度存储主要权重、权重梯度和优化器状态

在这里插入图片描述

七、小结

正如 Tim Dettmers 所说,“一切都是为了从硬件中榨干每一滴性能。” DeepGEMM 正是这样做的。它通过利用GPU的并行计算能力、优化内存访问模式以及采用先进的硬件特性,实现了矩阵乘法的高效计算。它不仅在性能上达到了极致,还通过简洁的代码结构和即时编译技术提供了高度的灵活性和适用性。
总结来说,DeepGEMM 是一个在矩阵乘法领域实现重大突破的库,为深度学习和高性能计算提供了一个强大而灵活的工具。

Logo

欢迎加入DeepSeek 技术社区。在这里,你可以找到志同道合的朋友,共同探索AI技术的奥秘。

更多推荐