Catlass TileMuls标量乘模板

Catlass TileMuls标量乘模板 TileMuls【免费下载链接】catlass本项目是CANN的算子模板库提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass代码位置[TOC]功能说明TileMuls模板完成了 Vector 引擎的标量乘法AscendC::Muls对 UB 上某个 tensor 的全部元素乘以一个标量scalar结果写回 UB。适用的场景在 Epilogue 阶段对 L0C→UB 后累积结果进行 scale 缩放。该模板使用AscendC::SetVectorMaskAscendC::Muls的组合通过 COUNTER 掩码模式精确控制计算长度避免越界访问。模板原型template class ArchTag_, // 架构标签 class ComputeType_, // 计算类型Gemm::GemmTypeElement, Layout uint32_t COMPUTE_LENGTH_ // 单次计算长度 struct TileMuls { using Element typename ComputeType_::Element; static constexpr uint32_t COMPUTE_LENGTH COMPUTE_LENGTH_; };COMPUTE_LENGTH_为模板参数传入的常量不通过运行时参数控制用于常量折叠优化。调用接口void operator()( AscendC::LocalTensorElement dstTensor, // UB 目标 tensor AscendC::LocalTensorElement srcTensor, // UB 源 tensor Element scalar, // 标量 uint32_t len // 实际计算长度 );执行流程SetMaskCount()→SetVectorMaskElement, COUNTER(len)设置掩码MulsElement, false(dst, src, scalar, MASK_PLACEHOLDER, 1, {} )SetMaskNorm()→ResetMask()恢复掩码调用示例#include catlass/gemm/tile/tile_muls.hpp using namespace Catlass::Gemm; using ComputeType Gemm::GemmTypehalf, layout::RowMajor; constexpr uint32_t COMPUTE_LENGTH 256; using MulsOp Tile::TileMulsArch::AtlasA2, ComputeType, COMPUTE_LENGTH; half scalar 0.5_hf; uint32_t len 256; AscendC::LocalTensorhalf srcUB; AscendC::LocalTensorhalf dstUB; MulsOp mulsOp; mulsOp(dstUB, srcUB, scalar, len); // 等效于dstUB[i] srcUB[i] * 0.5 for i in [0, len)【免费下载链接】catlass本项目是CANN的算子模板库提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考