CANN算子开发Cast设计

CANN算子开发Cast设计 Cast 算子设计文档【免费下载链接】cann-ops-competitions本仓库用于 CANN 开源社区各类竞赛、开源课题、社区任务等课题发布、开发者作品提交和展示。项目地址: https://gitcode.com/cann/cann-ops-competitions一、需求背景1.1 需求来源参考昇腾版本内置 Cast 算子的 TBE 实现在昇腾 NPU 上基于 Ascend C 编程语言实现功能一致的算子并新增支持 BF16 数据类型输入。1.2 背景介绍1.2.1 Cast 算子概述Cast 算子执行数据类型转换将输入张量 self 的每个元素从源数据类型转换为目标数据类型输出相同形状的张量 out。其数学表达式为out[i] (DstType) self[i], for i 0, 1, ..., N-1 Cast 算子是深度学习模型中的基础算子广泛应用于精度切换、算子输入类型适配等场景。1.2.2 Cast 算子现状分析1.2.2.1 TBE 算子支持的数据类型和数据格式1.2.2.2 TBE 算子实现描述TBE Cast 算子核心逻辑路径/usr/local/Ascend/ascend-toolkit/latest/opp/built-in/op_impl/ai_core/tbe/impl/cast.pyCast 是逐元素单输入单输出类型转换算子实现流程如下参数校验校验输入、输出、kernel_name 合法性校验源数据类型和目标数据类型是否在白名单内特殊转换处理浮点数 → 整型输入中存在 NaN 则转换为 0INT32 → INT8数据在 (-2048, 1920) 范围内保证精度无误差FLOAT64 → UINT8输入为非负数保证精度无误差INT64 → FLOAT32数据在 (-2147483648, 2147483647) 范围内保证精度无误差核心计算调用 tbe.cast 指令完成类型转换编译生成自动生成调度编译输出可执行 kernelTBE 算子实现策略1.2.2.3 TBE 算子实现流程图-- 整体调度流程-- int8_uint8_process-- int32_process-- uint32_process-- float32_process-- float16_process-- bfloat16_process-- int64_process-- uint1_process-- int16_process-- uint16_process二、需求分析2.1 外部组件依赖不涉及外部组件依赖。2.2 内部适配模块适配 Aclnn 接口支持常规调用模式。2.3 需求模块设计2.3.1 AscendC 算子原型参数定义与任务书《参数说明》保持一致。参数名输入/输出/属性描述数据类型数据格式self输入待进行 cast 计算的入参。BOOL、FLOAT16、FLOAT、INT8、UINT8、INT16、INT32、INT64、BF16NDout输出cast 计算的出参。BOOL、FLOAT16、FLOAT、INT8、UINT8、INT16、INT32、INT64ND说明目标数据类型经 aclnn 接口的 dst_type 属性传入与 out 的 dtype 一致BF16 仅作为 self 输入支持输出不含 BF16self 与 out 的 shape 一致。2.3.2 AscendC 算子相关约束对外约束以任务书《约束说明》为唯一口径见 3.4 的 4 条精度约束。本节列出的是工程实现假设/支持范围不属于任务书新增的对外约束仅用于界定实现与测试覆盖输入输出 shape 相同Cast 不改变形状支持空 Tensor第一段接口短路、不下发 kernel。支持的张量维数范围与底层 ND 张量一致不额外收窄如需声明上限以底层框架为准。支持非连续 Tensor。BF16 仅作为源数据类型输入支持输出不扩展至 BF16与 TBE 版本一致已在参数表/3.4 体现。数据类型转换的对外精度约束与任务书《约束说明》一致详见 3.4在 kernel 侧保证三、需求详细设计3.1 使能方式3.2 需求总体设计3.2.1 host 侧设计3.2.1.0 目标硬件能力前提dav_m200能力结论对设计的影响硬件 vconv 直达支持 f16/f32/i32/i16/i8/u8/bool 间部分组合普通路径优先AscendC::CastBF16 vconvdav_m200 无 BF16 原生转换BF16 输入按 uint16 位模式展开为 FP32ShiftLeft/ShiftRight不支持BF16 展开不能依赖 16 位移位向量指令64bit 整数转换多数不直达int64 走低 32 位抽取 / 符号扩展软件路径DataCopyPadGM↔UB不支持CopyIn/CopyOut 必须 32B 对齐块3.2.1.0.1 交付 dtype 对矩阵源 → 目标交付范围对齐原 TBE Cast 注册的完整转换集参考ops-math/experimental/math/cast/op_host/cast_def.cpp即覆盖参考支持的全部源→目标组合仅排除任意类型 → BF16输出任务书 out 不含 BF16输出暂不扩展至 BF16。下表已据此补全不缩减原算子能力。源 \ 目标交付目标 dtypeFLOAT16FLOAT32, INT32, INT16, INT8, UINT8, BOOLFLOAT32FLOAT16, INT32, INT16, INT8, UINT8, BOOL, INT64INT32FLOAT32, FLOAT16, INT16, INT8, UINT8, BOOL, INT64INT8FLOAT16, FLOAT32, INT32, INT16, UINT8, BOOL, INT64UINT8FLOAT16, FLOAT32, INT32, INT8, INT16, INT64INT16FLOAT16, FLOAT32, INT32, INT8, UINT8, INT64BOOLFLOAT16, FLOAT32, INT32, INT8, UINT8, INT64INT64FLOAT16, FLOAT32, INT32, INT8, UINT8, INT16, BOOLBF16新增源FLOAT32, FLOAT16, INT32, INT16, INT8, UINT8, BOOL, INT64说明参考cast_def.cpp的 BF16 源仅注册 BF16→{F16, F32, I32, I8, U8, BOOL}本设计额外交付BF16→INT16、BF16→INT64两对经 BF16→FP32 展开后复用 FP32→目标链实现其精度沿用对应 FP32→INT16/INT64 路径的约束见 3.4不弱于参考的等价两跳。3.2.1.1 分核策略根据输出张量总元素数 totalElements 和 AI Core 数量进行均分数据总量不能被核数整除时前 remainder 个核心各多处理一个数据块获取平台核心数量通过 GetCoreNum() 获取 Atlas 300V Pro 的 AI Core 数量3.2.1.2 数据分块和内存优化策略UB 容量适配CopyIn 缓冲SrcT与 CopyOut 缓冲DstT在 UB 中同时存在且部分路径需要中间 buffer故按源 目标 中间之和预算而非取较大值BUFFER_NUM 2 bytesPerElem BUFFER_NUM × (sizeof(SrcT) sizeof(DstT)) middleBytes 单 Tile 处理元素数 ubFormer floor((UB_SIZE - reserve) / bytesPerElem) 向下按 256 元素对齐其中 middleBytes 由路径决定DIRECT 0TWO_CAST / DST_BOOL 需 f16/f32 中间 bufferBF16、int64 软件路径需更大的中间 buffer。早期版本用UB / (2 × max(sizeof(SrcT), sizeof(DstT)))估算单 Tile 大小是错误的源/目标缓冲同时占 UB用 max 会低估占用、导致 tile 过大溢出 UB。内存对齐因 dav_m200 无 DataCopyPadCopyIn/CopyOut 统一使用 32 字节对齐数据块tile 与尾块均按 32B/256 元素对齐设计避免越界读写。3.2.1.3 tilingKey 规划策略tilingKey 按 3.2.2.1 的路径类别取值0 DIRECT / 1 TWO_CAST / 2 DST_BOOL / 3 NARROW64 / 4 WIDEN64 / 5 BF16_TO_F32 / 6 BF16_TO_OTHER / 7 SAMEWIDTH_8BIT / 8 NARROW8由 host 依据源/目标 dtype 对查表确定。输入/输出 dtype 在 kernel 侧由编译期宏注入运行期 tiling 只描述分核与 UB 切分。host 侧维护三张[N][N]按 dtype 枚举索引真值表对齐参考cast_tiling.cpp的tilingKeyMap/ubDataNumMap/minDataTypeLengthMap工程模式避免运行期逐对 if 判断、也避免漏配pathClassMap[src][dst]→ 上述 0–8 路径类别即 tilingKeymiddleBytesMap[src][dst]→ 该对所需中间 buffer 字节数DIRECT/SAMEWIDTH_8BIT0TWO_CAST/DST_BOOL 需 1 块BF16/int64 软件路径需更大供 3.2.1.2 的 UB 预算精确取值roundModeMap[src][dst]→ 该对的 round mode规则见 3.2.2.1。将 tiling 参数totalElements、coreNum、ubFormer、tileNum、lastTileSize、roundMode封装到 CastTilingData 结构体中。3.2.2 kernel 侧设计3.2.2.1 kernel 侧实现描述Cast 算子 kernel 侧由编译期 dtype 宏注入输入/输出类型按源/目标 dtype 对选择以下路径tilingKey 即路径类别。所有路径共享 CopyIn → Compute → CopyOut 主框架因 dav_m200 无 DataCopyPadCopyIn/CopyOut 统一按 32 字节对齐块设计。tilingKey路径覆盖 dtype 对实现要点0DIRECTdav_m200 vconv 可直达组合f16/f32/i32/i16/i8/u8/bool 间的部分组合一跳AscendC::Cast浮点转整按对应 round mode见下round mode 逐对规则窄整型输出按 mod-256 回绕处理见 NARROW81TWO_CAST需中转的普通组合经 f16/f32 两段Cast串接中间加PipeBarrierPIPE_V2DST_BOOL目标 dtype 为 bool按ceil(min(abs(x),1))归一化语义实现Abs → Mins(1) → Cast(CAST_CEIL)3NARROW64int64 → 低位/窄类型抽取 int64 低 32 位 lane 后进入普通转换链4WIDEN64int32/float32/BF16 → int64先得 32 位结果再按符号生成高 32 位5BF16_TO_F32BF16 → FLOATBF16 halfword 放入 FP32 高 16 位、低 16 位补零保留 NaN/Inf 比特语义6BF16_TO_OTHERBF16 → 其它非 BF16 输出先展开为 FP32再复用 FP32→目标 dtype 路径7SAMEWIDTH_8BITbool ↔ int8 ↔ uint8同宽 8bit 互转位等价用TQueBindVECIN,VECOUT做 GM→UB→GM 零Cast直搬对齐参考 key8不走转换链是性能最优路径8NARROW8任意 → int8/uint8窄整型输出mod-256 回绕非真饱和And(x, 0xFF)取低 8 位得 uint8 语义int8 再Adds(128) → And(0xFF) → Adds(-128)把 [0,255] 映回 [-128,127]。这解释了 3.4 中 INT32→INT8 仅在 (-2048,1920) 无误差——超窗即回绕round mode 逐对规则对齐参考cast.h0TBuf 段浮点→整 CAST_TRUNCNaN 由硬件 vconv 截断为 0int64→float CAST_ROUND→bf16 CAST_RINT→bool CAST_CEIL其余等宽/扩宽 CAST_NONE。host 侧用roundModeMap[src][dst]真值表逐对落定避免运行期判断。Cast 计算核心DIRECT 路径示意template typename SrcT, typename DstT __aicore__ void CastKernel::Compute(int32_t tileIdx) { LocalTensorSrcT srcBuf srcQueue.DeQueSrcT(); LocalTensorDstT dstBuf dstQueue.AllocTensorDstT(); // round_mode: CAST_RINT / CAST_FLOOR / CAST_CEIL / CAST_ROUND / CAST_TRUNC按 TBE 对应模式 Cast(dstBuf, srcBuf, roundMode, curTileLen); dstQueue.EnQue(dstBuf); srcQueue.FreeTensor(srcBuf); }BF16 处理dav_m200 关键约束dav_m200无 BF16 原生 vconv且无 ShiftLeft/ShiftRight 向量指令因此不能用移位展开。BF16 输入按uint16_t位模式处理将每个 BF16 halfword 作为高 16 位、低 16 位补零拼成 FP32 的 bit layout保留 NaN/Inf 语义得到等值 FP32 后再进入目标 dtype 转换链。输出 dtype 不包含 BF16故不注册任意类型 → BF16路径。int64 处理dav_m200 多数 int64 相关 vconv 不直达采用低 32 位抽取NARROW64或符号扩展WIDEN64的软件路径仅在任务书约束范围内保证精度。特殊转换约束处理约束项说明浮点 → 整型 NaNNaN 转 0对齐 TBE / 硬件 vconv 行为8bit 整型输出mod-256 回绕非饱和uint8 取低 8 位And 0xFFint8 经128 → 0xFF → -128映回 [-128,127]。与 TBE 一致INT32 → INT8仅保证输入在 (-2048, 1920) 范围内精度无误差超窗即回绕故有此精度窗口INT64 → FLOAT32仅保证输入在 (-2147483648, 2147483647) 范围内精度无误差空 Tensor第一段接口短路不下发 kernel输出 BF16不属于任务书交付范围3.2.2.2 AscendC 实现流程图3.2.2.3 AscendC 实现流程图与 TBE 流程图存在的差异点和原因差异点TBE 实现AscendC 实现原因数据类型支持BOOL/F16/F32/I8/U8/I16/I32/I64新增 BF16 源数据类型输入任务要求dav_m200 无 BF16 原生 vconv按 uint16 位模式软件展开为 FP32计算指令tbe.castAscendC::Cast 高阶 APIAPI 差异数据格式NDND保持一致类型转换路径直接 cast 或中间类型按 dtype 对分 DIRECT/TWO_CAST/DST_BOOL/NARROW64/WIDEN64/BF16 路径适配 dav_m200 能力路径化分派尾块搬运依赖 pad 拷贝32B 对齐块 尾块对齐处理dav_m200 无 DataCopyPad3.3 支持硬件支持的芯片版本涉及勾选Atlas 300V Pro√3.4 算子约束限制以下对外约束与任务书《约束说明》一致针对数据类型从浮点数转换为整型的场景输入数据中存在 nan则将 nan 转换为 0。针对数据类型从 INT32 转换为 INT8 的场景只能保证输入数据在 (-2048, 1920) 范围内精度无误差。针对数据类型从 FLOAT64/COMPLEX64/COMPLEX128 转换为 UINT8 的场景只能保证输入数据为非负数精度无误差。针对数据类型从 INT64 转换为 FLOAT32 的场景只能保证输入数据在 (-2147483648, 2147483647) 范围内精度无误差。工程实现约束不改变上述对外语义self 与 out 的 shape 必须一致BF16 仅作为 self 输入支持输出不含 BF16。INT32 → INT8 仅保证输入在 (-2048, 1920) 范围内精度无误差INT64 → FLOAT32 仅保证输入在 (-2147483648, 2147483647) 范围内精度无误差所有核参与场景下性能不低于 TBE 算子的 95%BF16 相关转换不低于 FP32 输入性能的 90%四、特性交叉分析本算子为独立类型转换算子不涉及与其他算子的特性交叉。新增 BF16 源数据类型不影响已有类型组合的行为。五、可维可测分析5.1 精度标准/性能标准精度标准验收主口径任务书算子计算精度需满足 AscendOpTest 工具默认阈值。自测补充标准不替代主口径仅用于开发自验整型转换结果须与参考逐元素完全一致BF16 相关转换以 PyTorchtensor.to()为 CPU 标杆逐元素比对整体以 TBE 算子输出为基准对照。性能标准与任务书一致BF16 → 其他数据类型性能接近 FP32 输入性能相同 shape 对比≥ FP32 性能的 90%。其他格式数据转换性能与现有 TBE 实现不劣化≥ 现有实现的 95%。如小 shape 无法达标10μs 以下场景相差 3μs 以上提供性能仿真图和分析结论证明 Ascend C 实现与 TBE 完全一致或优于 TBE 实现。5.2 兼容性分析向后兼容完全兼容 TBE 算子已支持的数据类型和格式扩展支持新增 BF16 源数据类型输入调用模式支持常规计算接口与 TBE 版本一致5.3 测试用例设计维度覆盖用例dtype 对3.2.1.0.1 矩阵中全部交付源→目标对shape标量、小 shape、非 32B 对齐尾块、大 shape、多维 shape特殊值NaN、Inf、0、负数、int8/uint8 回绕边界±128/256 附近、int64 边界约束边界INT32→INT8 的 (-2048,1920)、INT64→FLOAT32 的 (±2147483648)BF16golden 先按 BF16 量化再转目标 dtype不能直接用 FP32 原值当 BF16 参考5.4 风险与评审确认项风险项说明处理方式BF16 性能dav_m200 无 BF16 原生 vconv需 halfword→FP32 软件展开且无移位指令与 FP32 直达有天然差异优先通过优化满足任务书 90% 要求确无法达标时按任务书规则提供性能仿真图与分析结论证明已接近 dav_m200 能力极限不写申请例外BF16 位拼接无移位无 ShiftLeft/ShiftRighthalfword→FP32 高半字拼接不能用移位按 4 字节 lane 视图清零低半字 字节重排 DataCopy 实现高 16 位写入列为 BF16 路径主要实现风险int64 软件路径缺多数组合直达转换NARROW64/WIDEN64 按 32 位 lane 抽取/符号扩展明确支持范围与数值约束测试覆盖边界DataCopy 对齐无 DataCopyPad尾块须避免越界host tiling 与 kernel CopyIn/CopyOut 统一 32B 对齐参考实现亦未用 DataCopyPad可借鉴其对齐 tiling5.5 算子接入模型验证项内容验证模型InternVL验证数据集flickr30k_entitieshttps://github.com/BryanPlummer/flickr30k_entities模型精度与 Atlas 800T A2 对比train loss 不超过 0.1EmbeddingDenseGrad 算子设计文档一、需求背景1.1 需求来源参考开源仓 embedding_dense_grad_v2 算子实现在昇腾 NPU 上基于 Ascend C 编程语言实现功能一致的算子完成算子设计、开发、测试全流程工作验收通过后将算子提交至昇腾算子开源仓。1.2 背景介绍1.2.1 EmbeddingDenseGrad 算子概述EmbeddingDenseGrad 是 Embedding 层的反向传播算子。前向按 token id 从权重表取行反向需把每个 token 对应的梯度累加回权重梯度表同一个 id 在一个 batch 中可能出现多次因此对应的多行 grad 需累加到同一个 out 行。grad 合轴后视为 [N, D]sort_indices 元素数为 N输出 out 形状为 [numWeights, D]。数学表达为对每个权重行 k ∈ [0, numWeights)out[k, :] Σ_{ i : sort_indices[i] k } grad[i, :]即按索引把 grad 中对应行累加到输出行scatter-add。若 scale_grad_by_freq True则按该索引出现频次 count[k] 缩放仅 count[k] ≥ 2 时缩放out[k, :] ( 1 / count[k] ) · Σ_{ i : sort_indices[i] k } grad[i, :]若 padding_idx ≥ 0则 out[padding_idx, :] 保持为 0跳过累加。sort_indices的有序性来自aclnn 调用链而非任务书口头约定aclnnEmbeddingDenseBackward在下发本算子前调用l0op::Sort对 indices 升序排序并返回排序后的sort_indices与原始行号posIdx参考ops-nn/index/embedding_dense_grad_v2/op_api/aclnn_embedding_dense_backward.cpp:277参考算子 kernel 据此用currentId ! lastIndices做段累加op_kernel/embedding_dense_grad_v2.h:259-279。故同一 index 的行在sort_indices中连续kernel 可在 UB 内对连续段先做段内累加再写回 GM显著降低对输出的离散写次数。该算子在推荐系统、自然语言处理等 Embedding 层反向传播中广泛使用。1.2.2 EmbeddingDenseGrad 算子现状分析1.2.2.1 参考算子支持的数据类型和数据格式1.2.2.2 参考算子实现描述参考算子实现路径https://gitcode.com/cann/ops-nn/blob/master/index/embedding_dense_grad_v2/README.md核心实现逻辑将 grad 和 indices 合轴reshape为一维形式grad_flat shape: [total_rows, dim]indices_flat shape: [total_rows]out_flat shape: [numWeights, dim]遍历 indices将 grad 对应行累加到 out 的对应行支持按频率缩放和 padding 行填充 01.2.2.3 tbe算子实现流程图二、需求分析2.1 外部组件依赖不涉及外部组件依赖。2.2 内部适配模块适配 Aclnn 接口支持常规调用模式。2.3 需求模块设计2.3.1 AscendC 算子原型参数定义与任务书《参数说明》保持一致。参数名输入/输出/属性描述数据类型数据格式grad输入表示数据的原始梯度。FLOATNDsort_indices输入表示 grad 输入对应的索引值。INT32NDout输出表示梯度求和的结果输出。FLOATNDnumWeights属性表示输出 tensor 的首轴大小。Int-padding_idx可选属性将输出 tensor 中第 paddingIdx 行填充成 0如果 paddingIdx 为负数则不进行处理。默认值为 -1。Int-scale_grad_by_freq可选属性根据单词出现的频率是否对梯度进行缩放。默认值为 false。Bool-说明grad 合轴后按 [N, D] 处理out 为 [numWeights, D]D grad.shape[-1]累加前清零sort_indices 已排序、元素数为 N。2.3.2 AscendC 算子相关约束grad 和 sort_indices 必须长度匹配grad.shape[0] indices.shape[0] 合轴后out 形状为 [numWeights, dim]其中 dim grad.shape[-1]grad 合轴成二维后第二维度需 32 字节对齐支持空 indicestotal_rows 0支持非连续 Tensor三、需求详细设计3.1 使能方式3.2 需求总体设计3.2.1 host 侧设计3.2.1.0 目标硬件能力前提dav_m200本设计基于 Atlas 300V Proascend310pdav_m200的 Ascend C 能力关键前提如下后续 host/kernel 设计据此展开能力结论对设计的影响SetAtomicAddfloat支持 fp32 GM 原子加同一 index 跨核累加正确性由原子加保证DataCopyPadGM↔UB不支持CopyIn/CopyOut 必须按 32 字节对齐块设计非对齐 D 走标量兜底裸 SyncAll()不支持scale 路径用 GM workspace 软同步KernelMode使用 MIX_MODE避免 dav_m200 上派发模式不匹配导致 kernel 未执行3.2.1.1 分核策略主路径按 grad 的行数 N 在多个 AI Core 间切分每核处理一段连续 row 区间gradRow Ngrad 合轴后的首轴 usedCoreNum min(platformCoreNum, gradRow) formerRowNum ceil(gradRow / usedCoreNum) // 前几个核各多处理一行 tailRowNum floor(gradRow / usedCoreNum)由于 sort_indices 已排序见 1.2.1 的依据相同 index 在核内连续可在核内做段内累加若同一 index 跨越两个核的边界最终由 GM 上的 SetAtomicAddfloat 归约保证数学语义正确。平台核数通过 GetCoreNum() 获取。grad 行寻址说明参考算子l0op::Sort同时返回排序后的sort_indices与原始行号posIdxgrad 本身不物理重排kernel 用posIdx[i] × D间接 gather 第 i 个排序位对应的 grad 行、用sort_indices[i]定位写回的 out 行参考embedding_dense_grad_v2.h:241。本设计同样按排序序 i遍历若实现选择接收posIdx做间接寻址则与参考一致若选择让调用方把 grad 随 indices 一起重排后顺序读取则需在接口约束中显式声明该前提。下文伪码以排序序的 grad 行为输入抽象不预设物理布局。3.2.1.2 数据分块和内存优化策略当 D 较大、单行无法一次装入 UB 时按列切分单次处理一段连续 embedding dim如 min(D, 4096)列切只改变单次 UB 处理宽度不改变累加语义。因 dav_m200 无 DataCopyPadCopyIn/CopyOut 统一按 32 字节对齐块设计D 非 32B 对齐时走标量兜底路径仅保证功能不作为高性能口径。本算子为 MTE3-bound 的 scatter 写操作禁用 double buffer双缓冲会额外占用 UB 并加剧 DRAM 总线带宽争抢反而降低吞吐。UB 大小通过 GetCoreMemSize(CoreMemType::UB) 获取。3.2.1.3 tilingKey 规划策略tilingKey触发条件说明0scale_grad_by_freq false默认高性能主路径段内累加 边界原子加1scale_grad_by_freq true累加后按频次除法需 counts workspace 与软同步任务书声明不保证高性能频次缩放不通过 tiling 传 freq 数组numWeights 可能很大不可行而是用 GM workspace 的 counts 缓冲在 kernel 内统计。tiling 参数dimSize、numWeights、paddingIdx、scaleGradByFreq、formerRowNum、tailRowNum、formerCoreNum、ubProcessNum 等封装到 EmbeddingDenseGradTilingData 结构体中。workspace 规划对齐参考 fp32 路径口径system 保留区16MB系统固定保留。counts 缓冲仅 scale 路径align(numWeights) × 4B每 index 一个 float 频次计数参考embedding_dense_grad_v2_tiling.cpp:244-245中 fp32 路径 workspace 系统区 counts无其它分量。软同步 slot极小常量区替代 dav_m200 缺失的裸SyncAll()参考用SyncAll本设计用 GM workspace 标志位软同步。参考实现还为 fp16/bf16 主路径分配outStage、为 small-dim 非 fp32 分配整张numWeights×D的outCasted中转区本设计仅交付 FP32、且不走 cast 中转/整表物化故无这两类 workspace也无参考 small-dim 的numWeights ≤ 16777216、D ≤ 512上限。3.2.2 kernel 侧设计3.2.2.1 kernel 侧实现描述EmbeddingDenseGrad 算子 kernel 侧采用模板化设计按 tilingKey 选择两种实现KernelEmbeddingDenseGrad默认版本tilingKey 0无频率缩放纯 scatter 累加KernelEmbeddingDenseGradWithFreq频率缩放版本tilingKey 1在默认累加基础上叠加 counts 统计与二阶段缩放两者共享 Init解析 tiling、按 blockIdx 计算本核 rowStart/rowCount、设置 GM 地址。输出 out 在累加前按 InitValue(0) 语义清零由 op_host/op_api 与 kernel 入口保持一致padding 行因被跳过而保持 0。默认路径tilingKey 0计算逻辑全程设置 SetAtomicAddfloat()把对 out 的写变为 GM 原子累加从而无需核间同步即可处理跨核同 index 冲突。利用 sort_indices 已排序的性质在 UB 内对连续同 index 的多行先做向量累加段内累加再整段原子写回降低 GM 原子写次数// 伪代码T float。按列块 cOff 遍历 D列块宽度 cLen 32B 对齐 SetAtomicAddfloat(); uint32_t row rowStart; while (row rowStart rowCount) { int32_t idx indicesGm.GetValue(row); uint32_t segEnd row; while (segEnd rowStart rowCount indicesGm.GetValue(segEnd) idx) segEnd; // 同 idx 连续段 if (idx ! paddingIdx) { // 段内累加把 grad[row..segEnd) 的对应列块累加进 UB 的 acc再原子写回 out[idx] Duplicate(acc, 0, cLen); for (uint32_t r row; r segEnd; r) { DataCopy(gradTile, gradGm[r * D cOff], cLen); // 32B 对齐块搬入 Add(acc, acc, gradTile, cLen); // 向量累加 } DataCopy(outGm[idx * D cOff], acc, cLen); // 原子加写回已 SetAtomicAdd } row segEnd; } SetAtomicNone();设计要点① grad 行连续读取降低 GM 访存离散度② 同 idx 连续段在 UB 内先累加减少 GM 原子写③ 跨核边界的同 idx 冲突由 fp32 原子加兜底④ padding_idx 行直接跳过写回⑤ 无 double bufferMTE3-bound。非对齐 D 兜底当 D % 8 ! 0fp32 非 32B 对齐时对齐块 DataCopy 无法安全覆盖尾列且 dav_m200 无 DataCopyPad改用标量 GetValue/SetValue 逐元素累加的功能兜底路径仅保证正确性不作为高性能验收口径。频率缩放路径tilingKey 1三阶段 —— Phase0 清零 counts workspace 并软同步Phase1 在默认累加的同时对每个有效 idx 用原子加把 count 累加到 counts workspacePhase2 软同步后按 out 行分核读取 count对 count ≥ 2 的行用 Muls 实现 out[k] * 1/count[k]以原子加形式写回。dav_m200 无裸 SyncAll()阶段间使用 GM workspace 的软同步保证 counts 累加完成后再进入缩放。3.2.2.2 AscendC 实现流程图注非对齐 DD % 8 ! 0走标量 GetValue/SetValue 兜底路径不在上图高性能主路径内。3.2.2.3 AscendC 实现流程图与参考算子流程图存在的差异点和原因差异点参考算子embedding_dense_grad_v2本 AscendC 实现原因交付硬件未覆盖 Atlas 300V Pro 形态面向 ascend310p / dav_m200任务要求新增 310P 交付数据搬运通用语义参考CopyIn/CopyOut 按 32B 对齐块非对齐 D 走标量兜底dav_m200 无 DataCopyPad同 index 累加按索引累加语义段内 UB 向量累加 边界 SetAtomicAddfloat利用已排序 sort_indices 降低 GM 原子写跨核同步通用同步scale 路径用 GM workspace 软同步dav_m200 无裸 SyncAll()派发模式—KernelMode::MIX_MODE避免 dav_m200 上 kernel 未实际执行缓冲策略—禁用 double bufferscatter 为 MTE3-bound双缓冲加剧 DRAM 带宽争抢3.3 支持硬件支持的芯片版本涉及勾选Atlas 300V Pro√3.4 算子约束限制以下约束与任务书《约束说明》一致在参数 shape 超过以下限制时输出无法保证高精度若开启了确定性计算也无法保证高性能grad 合轴成二维 shape 后第一个维度超过 INT32_MAX(2147483647)numWeights 超过 INT32_MAX(2147483647)。sort_indices 合轴后维度超过 INT32_INF(2139095040) 时无法保证高性能。grad 合轴成二维 shape 后第二个维度D需要 32 字节对齐否则无法保证高性能。scale_grad_by_freq 为 True 时对梯度进行缩放无法保证高性能。scale_grad_by_freq 为 True 时无法保证高性能实际耗时不得超过理论耗时输入 shape × 2 ÷ (204GB × 0.5)的 1.1 倍3.5 算子精度和性能要求性能要求理论耗时按访存口径计算theory_time input_bytes × 2 / (204GB/s × 0.5)其中 ×2 表示一读一写grad 读入 out 写出×0.5 为显存带宽利用系数有效带宽约 102GB/s。算子实际耗时不得超过理论耗时的1.1 倍。性能主判定聚焦D 为 32 字节对齐且 scale_grad_by_freq false的主路径非对齐 D 与 scale_grad_by_freq true 属任务书已声明不保证高性能的 explain-only 场景。如小 shape 无法达标100μs 以下场景超过理论耗时 30% 以上提供性能仿真图和瓶颈分析证明 Ascend C 实现已接近硬件极限。精度要求golden 参考以 numpy 标杆实现为基准按out[idx] grad[i]scale 时再除以 count等价于参考算子embedding_dense_grad_v2的反向语义。验收阈值算子计算精度需满足 AscendOpTest 工具 fp32 默认阈值。模型级验收接入 CLIP数据集 flickr30k_entities与 Atlas 800T A2 对比 train loss 偏差不超过 0.1。四、特性交叉分析本算子为 Embedding 层的反向传播算子与 Embedding 正向算子和优化器算子配合使用。不涉及与其他算子的直接特性交叉。五、可维可测分析5.1 精度标准/性能标准验收标准描述标准来源精度标准golden 采用 numpy 参考实现满足 AscendOpTest fp32 默认阈值任务书性能标准实际耗时 ≤ 理论耗时input_bytes×2 / (204GB/s×0.5)的 1.1 倍任务书模型级标准CLIP flickr30k_entities与 Atlas 800T A2 对比 train loss 偏差 ≤ 0.1任务书5.2 测试用例设计维度覆盖用例D 对齐D % 8 0 主路径D 非 32B 对齐的标量兜底路径索引分布单行多行不重复长段重复同 index跨核边界重复同 indexpadding_idxpadding_idx 0不处理padding_idx 0中间行 padding尾行 paddingscale_grad_by_freqfalse主路径true含 count 1 不缩放、count ≥ 2 缩放稀疏输出numWeights 远大于实际出现 index存在大量全 0 输出行维度2D / 3D / 更高维 grad 合轴场景5.3 风险与评审确认项风险项说明处理方式非对齐 D 性能dav_m200 无 DataCopyPad非 32B 对齐难以高效块搬运保留标量兜底性能按 explain-only 说明scale 路径性能需 counts、软同步与二阶段除法功能覆盖性能不作为主路径口径跨核重复 index同 index 可能跨核边界由 SetAtomicAddfloat 保证累加正确输出初始化反向累加依赖 out 初值为 0op_host/op_api 与 kernel 入口统一保持 InitValue(0) 语义确定性计算不实现范围声明参考有独立 determinist 路径倒序遍历 standIdice 固定跨核原子加求和序依赖 DataCopyPad 读单 indexdav_m200 无 DataCopyPad且任务未要求高性能确定性计算本实现不交付该路径确定性场景回退到原子加默认序small-dim 路径不实现范围声明参考有 small-dim 路径kernel 内Sort 批量段扫描并对非 fp32 物化整张numWeights×Dcast 中转本设计靠 host 端已排序的 sort_indices无需 kernel 内 Sort仅 fp32、不物化整表故不引入该路径及其D≤512 / numWeights≤16777216上限5.4 兼容性分析数据类型grad/out 为 FLOATsort_indices 为 INT32与参考算子语义一致。调用方式aclnn 两段式接口与现有调用模式兼容。硬件面向 Atlas 300V Proascend310p / dav_m200不依赖 910B/950 专属低阶能力。【免费下载链接】cann-ops-competitions本仓库用于 CANN 开源社区各类竞赛、开源课题、社区任务等课题发布、开发者作品提交和展示。项目地址: https://gitcode.com/cann/cann-ops-competitions创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考