sageattention 据说比flash_atten_2还要快很多。 但是如何在deepseekvl2这训练这里把它用上呢?

        1.本质上sageattention是sdpa,SDPA的全称为Scaled Dot-Product Attention, 属于乘性注意力机制, 简单一句话来说就是,根据Query (Q)与Key之间的匹配度来对Value进行加权,而事实上不管是Query, Key还是Value都来自于输入,因此所谓的SDPA本质上是对输入信息信息进行重组。

       2. sageattention使用了Triton这个包,它可以把python的代码编译成目标机器码,大大加速这个运算的速度。

        3. 从官方例子可以看到,它本质的工作原理就是简单的替换torch.nn.functional.scaled_dot_product_attention这个函数,示例如下:

from sageattention import sageattn
import torch.nn.functional as F
。。。
    F.scaled_dot_product_attention = sageattn

        所以,要用上sageatten,其实只需要原来的模型支持sdpa的注意力机制即可。

        但很不幸,从deepseekvl2官方开源github可以看到DeepseekVLV2ForCausalLM和DeepseekV2ForCausalLM都是不支持sdpa的,这个两个类都没有声明_supports_sdpa = True

ATTENTION_CLASSES = {
    "eager": DeepseekV2Attention,
    "flash_attention_2": DeepseekV2FlashAttention2,

    "mla_eager": DeepseekV2Attention,
    "mla_flash_attention_2": DeepseekV2FlashAttention2,

    "mha_eager": LlamaAttention,
    "mha_flash_attention_2": LlamaFlashAttention2
}

        这个attention_class也没有sdpa的实现。

        因此,deepseekvl2无法直接简单使用sageattion,我们需要改一下deepseek的开源代码,才有可能用上sageattion.

        修改步骤如下:

1. 首先要让DeepseekVLV2ForCausalLM和DeepseekV2ForCausalLM先支持sdpa,添加_supports_sdpa,这样transformers/modeling_utils.py的_check_and_enable_sdpa才可以检查通过.

class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]
    _supports_sdpa = True
class DeepseekVLV2ForCausalLM(DeepseekVLV2PreTrainedModel):
    _supports_sdpa = True

2.然后在DeepSeek-VL2/deepseek_vl2/models/modeling_deepseek.py添加sdpa的attention的实现类,我们让它继承LlamaAttention,如下,其实这个直接抄的LlamaSdpaAttention,copy是为了修改方便,实现如下:


class DeepSeekSdpaAttention(LlamaAttention):
    """
    Deepseek attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
    `DeepseekV2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
    SDPA API.
    """

    # Adapted from LlamaAttention.forward
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if output_attentions:
            # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
            logger.warning_once(
                "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
                'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
            )
            return super().forward(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
            )

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
        query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            if isinstance(position_embeddings, torch.Tensor):
                cos, sin = self.rotary_emb(value_states, position_ids)
            else:
                cos, sin = position_embeddings


        query_states, key_states = apply_rotary_pos_emb2(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        causal_mask = attention_mask
        if attention_mask is not None:
            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
        # Reference: https://github.com/pytorch/pytorch/issues/112577.
        if query_states.device.type == "cuda" and causal_mask is not None:
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()

        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
        is_causal = True  #if causal_mask is None and q_len > 1 else False

        attn_output = sageattn(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=is_causal,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, -1)

        attn_output = self.o_proj(attn_output)

        return attn_output, None, past_key_value

3. 修改modeling_deepseek.py的ATTENTION_CLASSES,加上sdpa支持,如下:


ATTENTION_CLASSES = {
    "eager": DeepseekV2Attention,
    "flash_attention_2": DeepseekV2FlashAttention2,

    "mla_eager": DeepseekV2Attention,
    "mla_flash_attention_2": DeepseekV2FlashAttention2,

    "mha_eager": LlamaAttention,
    "mha_flash_attention_2": LlamaFlashAttention2,
    "mha_sdpa": DeepSeekSdpaAttention
}

4.使用的sageattion和Triton的版本如下:

Name: sageattention
Version: 1.0.6

Name: triton
Version: 3.2.0

5. 训练测试,

swift sft  --model "deepseek-ai/deepseek-vl2-tiny"  --dataset  ../TEST.json  --attn_impl sdp

。。。
using attn_implementation: mha_sdpa
[INFO:swift] model.hf_device_map: {'': device(type='cuda', index=0)}

。。。
[INFO:swift] End time of running main: 2025-03-28 10:43:29.304164

成功进行了训练。

Logo

欢迎加入DeepSeek 技术社区。在这里,你可以找到志同道合的朋友,共同探索AI技术的奥秘。

更多推荐