FlashMLA:把 KV Cache 压缩到原来的八分之一

FlashMLA:把 KV Cache 压缩到原来的八分之一 标准 MHA 的 KV Cache 是推理显存的第一大户。LLaMA-7B32 层每层 32 头HeadDim128SeqLen128K——KV Cache 吃 40GB。MLAMulti-head Latent Attention用低秩分解把 KV 映射到一个远小于 HeadDim 的潜在空间存压缩版。FlashMLA 是这套机制在昇腾NPU上的高性能实现。MLA 的压缩原理MLA 的核心是 KV Compression——不是存原始 K/V而是存它们的低秩表示。原始 K 和 V 分别通过下投影矩阵 W_DK 和 W_DV 映射到一个很小的潜在空间比如 16 维存下来的就是这个 16 维的压缩向量。推理时再通过上投影矩阵把压缩向量恢复到原始 HeadDim 做 Attention 计算。标准 MHA 的 K/V 存储 K X W_K → [B, H, S, 128] ← 每层每 Token 存 128 个 FP16 V X W_V → [B, H, S, 128] MLA 的 KV 存储DeepSeek-V2 参数 C_KV X W_DKV → [B, H, S, 16] ← 每层每 Token 只存 16 个 FP16 使用时解压 K C_KV W_UK → [B, H, S, 128] V C_KV W_UV → [B, H, S, 128] 压缩比128 / 16 8:1 128K 上下文 KV Cache40GB → 5GBFlashMLA 的计算流程FlashMLA 在 FlashAttention 的分块框架上加了两步解压和 RoPE 融合。不是一次性把所有 K/V 解压出来——而是每次算一个 Tile 的 Attention 时才解压这个 Tile 对应的 K/V。这样解压的中间结果不占显存。FlashMLA 的 Tile 级执行流程 for each Tile in Q: for each Tile in KV: 1. 解压K_tile C_KV_tile W_UK (16×128 矩阵乘在 Cube Unit 上算) 2. 解压V_tile C_KV_tile W_UV 3. AttentionS Q_tile K_tile^T (在 Cube Unit 上) 4. Online SoftmaxFlashAttention 的分块累积 5. 累积O Softmax(S) V_tile三个关键优化解压矩阵 W_UK/W_UV 很小16×128刚好塞进 Cube Unit 单次运算。RoPE 融进解压步骤——传统 Attention 里 RoPE 是独立 KernelFlashMLA 在解压的同时算 RoPE省一次 L1 写回。解压在 Cube Unit 上跑时Vector Unit 同时做下一个 Tile 的 Softmax 累积——双单元并行。昇腾NPU上的 FlashMLA 实现ops-transformer 仓库里 FlashMLA 的核心代码// FlashMLA Kernel——简化版核心循环__aicore__voidflash_mla_kernel(GlobalTensorfp16q,// [B, H, S_q, D]GlobalTensorint8kv_cache,// 压缩 KV [B, H, S_kv, 16]GlobalTensorfp16W_UK,// 解压矩阵 [H, 16, D]GlobalTensorfp16W_UV,// 解压矩阵 [H, 16, D]GlobalTensorfp16output// [B, H, S_q, D]){// L1 上分配 Tile BufferLocalTensorfp16q_tile,k_tile,v_tile,scores,attn,out_acc;for(intq_block0;q_blocknum_q_blocks;q_block){// 加载 Q 的一个 Tile 到 L1DataCopy(q_tile,q[q_block]);// Online Softmax 状态fp32 m_prev-INFINITY,l_prev0;for(intkv_block0;kv_blocknum_kv_blocks;kv_block){// 1. 解压 K/V——Cube Unit 做小矩阵乘MatMul(k_tile,kv_cache[kv_block],W_UK);// [T, 16] [16, D]MatMul(v_tile,kv_cache[kv_block],W_UV);// 2. RoPE 融合——直接在 k_tile 上原地做ApplyRoPEInPlace(k_tile,kv_block*block_size);// 3. Attention 计算MatMul(scores,q_tile,k_tile,/*transB*/true);// 4. Online Softmax 累积fp32 m_currrow_max(scores);fp32 m_newmax(m_prev,m_curr);fp32 l_currrow_sum(exp(scores-m_new));out_accout_acc*exp(m_prev-m_new);out_accexp(scores-m_new)*v_tile;m_prevm_new;l_prevexp(m_prev-m_new)*l_prevl_curr;}// 最终归一化DataCopy(output[q_block],out_acc/l_prev);}}显存收益DeepSeek-V2 配置下MLA 的 KV Cache 从每 Token 128KB标准 MHA128 头 × 128 HeadDim × 2降到约 16KB。128K 上下文从 16GB 降到 2GB。省下的 14GB 直接把 Batch Size 从 4 推到 32——在线推理吞吐涨 8 倍。Tensor缓存也跟着受益——压缩后的 KV 在 L1 上占的空间从 2MB/Tile 降到 256KB/Tile同一块 L1 能缓存的上下文 Tile 数翻 8 倍FlashAttention 的分块数减少整体延迟更低。参考仓库ops-transformer FlashAttention 实现ATB Transformer 加速库推理 RecipesCANN 学习中心