CANN/cann-bench MHA算子API描述

CANN/cann-bench MHA算子API描述 MHA 算子 API 描述【免费下载链接】cann-bench评测AI在处理CANN领域代码任务的能力涵盖算子生成、算子优化等领域支撑模型选型、训练效果评估统一量化评估标准识别Agent能力短板构建CANN领域评测平台推动AI能力在CANN领域的持续演进。项目地址: https://gitcode.com/cann/cann-bench1. 算子简介多头注意力 (Multi-Head Attention) 算子对已分头的 Q/K/V 执行缩放点积注意力计算广泛应用于 Transformer 架构。主要应用场景Transformer 编码器和解码器中的自注意力与交叉注意力大语言模型和视觉 Transformer 中的核心注意力模块多模态模型中的跨模态注意力融合算子特征难度等级L4FusedComposite多输入query, key, value单输出执行缩放点积注意力输入为已分头的张量不包含 QKV 投影和输出投影步骤支持可配置的缩放因子支持is_causal因果掩码仅计算 attention 矩阵 [S, S_kv] 中从右下角向左上方延伸 45° 对角线及其下方部分其余位置在 softmax 前置 -inf2. 算子定义数学公式$$ y \text{softmax}\left(Q \times K^T \times \text{scaleValue}\right) \times V $$其中$Q$、$K$、$V$ 为已分头的查询、键、值张量$\text{scaleValue}$ 为缩放因子0 时自动使用 $1/\sqrt{D}$$D$ 为每头维度softmax 沿最后一维计算具体子步骤缩放点积$\text{scores} Q \times K^T \times \text{scaleValue}$因果掩码可选当is_causalTrue时对scores[..., i, j]满足 $j i (S_{kv} - S)$ 的位置置为 $-\infty$即仅保留从右下角向左上方 45° 延伸的对角线及其下方部分$SS_{kv}$ 时退化为标准下三角掩码Softmax 归一化$\text{attn_weights} \text{softmax}(\text{scores}, \text{dim}-1)$加权求和$y \text{attn_weights} \times V$3. 接口规范算子原型cann_bench.mha(Tensor query, Tensor key, Tensor value, float scaleValue-1.0, bool is_causalFalse) - Tensor y输入参数说明参数类型默认值描述queryTensor必选查询张量已分头shape 为 [B, S, N, D]keyTensor必选键张量已分头shape 为 [B, S_kv, N, D]valueTensor必选值张量已分头shape 为 [B, S_kv, N, D]scaleValuefloat-1.0缩放因子0 时自动使用 1/sqrt(D)is_causalboolFalse是否启用因果掩码。False 时全计算True 时仅计算 [S, S_kv] attention 矩阵中从右下角向左上方 45° 延伸的对角线及其下方部分即满足 $j \le i (S_{kv} - S)$ 的位置上方部分在 softmax 前置 -inf输出参数Shapedtype描述y[B, S, N, D]与输入 query 相同多头注意力输出张量数据类型输入 dtype输出 dtypefloat16float16bfloat16bfloat16规则与约束所有输入 Tensorquery, key, value的 dtype 必须一致query的 shape 为 [B, S, N, D]key和value的 shape 为 [B, S_kv, N, D]N 为注意力头数D 为每头维度均从输入 shape 中推断scaleValue通常设置为 $1/\sqrt{D}$当 0 时自动使用该值is_causalTrue时要求 $S \le S_{kv}$否则 mask 会将部分 query 行全部屏蔽导致 softmax 出现 NaN支持范围输入 tensor 各维度与参数的支持范围维度 / 参数范围备注Bbatch1 ~ 128cases.csv 实测 1 ~ 128Squery 序列长度1 ~ 2048cases.csv 实测 1 ~ 1024decode 场景 S1 或 2prefill 场景 S 与 S_kv 同量级S_kvkey/value 序列长度1 ~ 4096cases.csv 实测 128 ~ 2048is_causalTrue时要求 S ≤ S_kvN注意力头数1 ~ 64cases.csv 实测 8 ~ 32D每头维度64 ~ 25664 对齐cases.csv 实测 64 / 128 / 256scaleValue任意 floatcases.csv 实测 -1.0auto 1/sqrt(D)和 0.08838显式 ≈ 1/sqrt(128)0 时回退到 1/sqrt(D)is_causal{False, True}cases.csv 实测两值均覆盖True 走右下角对齐因果掩码False 全计算输入 value range任意有限实数cases.csv 实测 [-1, 1]常态高斯采样和 [0, 0]全零退化输入输入 dtypefloat16, bfloat16Q/K/V 三个 tensor dtype 必须一致约束query.shape [B, S, N, D]key.shape value.shape [B, S_kv, N, D]四个共享维度 B/N/D 必须严格相等is_causalTrue时要求S ≤ S_kv否则部分 query 行会被全部屏蔽导致 softmax 出现 NaN。4. 精度要求采用生态算子精度标准进行验证。误差指标平均相对误差MERE采样点中相对误差平均值$$ \text{MERE} \text{avg}(\frac{\text{abs}(actual - golden)}{\text{abs}(golden)\text{1e-7}}) $$最大相对误差MARE采样点中相对误差最大值$$ \text{MARE} \max(\frac{\text{abs}(actual - golden)}{\text{abs}(golden)\text{1e-7}}) $$通过标准数据类型FLOAT16BFLOAT16FLOAT32HiFLOAT32FLOAT8 E4M3FLOAT8 E5M2通过阈值(Threshold)2^-102^-72^-132^-112^-32^-2当平均相对误差 MERE Threshold最大相对误差 MARE 10 * Threshold 时判定为通过。5. 标准 Golden 代码import torch MHA算子Torch Golden参考实现 多头注意力 (Multi-Head Attention)对已分头的 Q/K/V 执行缩放点积注意力 公式: y softmax(Q K^T * scaleValue) V def mha( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, scaleValue: float -1.0, is_causal: bool False, ) - torch.Tensor: 多头注意力 (Multi-Head Attention) Args: query: 查询张量 [B, S, N, D]已分头 key: 键张量 [B, S_kv, N, D]已分头 value: 值张量 [B, S_kv, N, D]已分头 scaleValue: 缩放因子0 时自动使用 1/sqrt(D) is_causal: 是否启用因果掩码右下角对齐True 时 scores[..., i, j] 满足 j i (S_kv - S) 的位置置 -inf Returns: 输出张量 [B, S, N, D] B, S, N, D query.shape S_kv key.shape[1] if scaleValue 0: scaleValue 1.0 / (D ** 0.5) # 转置为 [B, N, S, D] q query.transpose(1, 2) k key.transpose(1, 2) v value.transpose(1, 2) # 缩放点积注意力 scores torch.matmul(q, k.transpose(-2, -1)) * scaleValue if is_causal: i torch.arange(S, devicescores.device).unsqueeze(-1) j torch.arange(S_kv, devicescores.device).unsqueeze(0) causal_mask j (i (S_kv - S)) # 右下角对齐上三角置 -inf scores scores.masked_fill(causal_mask, float(-inf)) attn_weights torch.nn.functional.softmax(scores, dim-1) attn_output torch.matmul(attn_weights, v) # 转回 [B, S, N, D] return attn_output.transpose(1, 2)6. 额外信息算子调用示例import torch import cann_bench B, S, S_kv, N, D 2, 128, 128, 8, 64 query torch.randn(B, S, N, D, dtypetorch.float16, devicenpu) key torch.randn(B, S_kv, N, D, dtypetorch.float16, devicenpu) value torch.randn(B, S_kv, N, D, dtypetorch.float16, devicenpu) y cann_bench.mha(query, key, value, scaleValue-1.0, is_causalFalse)【免费下载链接】cann-bench评测AI在处理CANN领域代码任务的能力涵盖算子生成、算子优化等领域支撑模型选型、训练效果评估统一量化评估标准识别Agent能力短板构建CANN领域评测平台推动AI能力在CANN领域的持续演进。项目地址: https://gitcode.com/cann/cann-bench创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考