减少遍历优化【免费下载链接】cannbot-skillsCANNBot 是面向 CANN 开发的用于提升开发效率的系列智能体本仓库为其提供可复用的 Skills 模块。项目地址: https://gitcode.com/cann/cannbot-skills概述减少数据遍历次数是 Triton 性能优化的核心手段。有两种主要方法方法适用场景核心思路Pass 合并多个统计量计算同时计算多个统计量减少遍历次数循环消除数据量小 (N ≤ BLOCK_SIZE)一次加载所有数据消除循环方法一Pass 合并问题描述问题多次遍历数据每次独立计算统计量导致重复内存访问。# 问题代码3-pass BatchNorm # Pass 1: 计算 mean for ...: data tl.load(...) mean tl.sum(data) # Pass 2: 计算 variance再次遍历 for ...: data tl.load(...) # 重复加载 var tl.sum((data - mean) ** 2) # Pass 3: 归一化第三次遍历 for ...: data tl.load(...) # 第三次加载 tl.store(...)优化原理利用数学公式在单次遍历中同时计算多个统计量mean sum(x) / count var sum(x²) / count - mean² 证明 var E[(x - mean)²] E[x²] - 2·mean·E[x] mean² E[x²] - mean² sum(x²)/count - mean²优化代码# Pass 1: 同时计算 sum 和 sum_sq sum_val 0.0 sum_sq 0.0 for ...: data tl.load(...) sum_val tl.sum(data) sum_sq tl.sum(data * data) # 同时累加 mean sum_val / count var sum_sq / count - mean * mean # Pass 2: 归一化 for ...: ...案例BatchNorm2dtriton.jit def batchnorm_2pass( input_ptr, output_ptr, gamma_ptr, beta_ptr, N, C, H, W, stride_n, stride_c, stride_h, stride_w, eps: tl.constexpr, BLOCK_SIZE_HW: tl.constexpr, ): c tl.program_id(0) gamma tl.load(gamma_ptr c) beta tl.load(beta_ptr c) count N * H * W # Pass 1: 同时计算 sum 和 sum_sq sum_val 0.0 sum_sq 0.0 for n in range(N): for h in range(H): for w_start in range(0, W, BLOCK_SIZE_HW): data tl.load(...) sum_val tl.sum(data) sum_sq tl.sum(data * data) mean sum_val / count var sum_sq / count - mean * mean inv_std 1.0 / tl.sqrt(var eps) # Pass 2: normalize for n in range(N): for h in range(H): for w_start in range(0, W, BLOCK_SIZE_HW): data tl.load(...) output (data - mean) * inv_std * gamma beta tl.store(...)性能3-pass → 2-pass延迟从 73.67ms → ~52ms加速 1.42x方法二循环消除问题描述问题Triton 对 Python for 循环优化有限大量循环是性能杀手。# 问题代码循环多次加载 for col_offset in range(0, n_cols, BLOCK_SIZE): vals tl.load(...) # 循环内加载 max_val update(max_val, vals) for col_offset in range(0, n_cols, BLOCK_SIZE): vals tl.load(...) # 再次加载 sum_exp update(sum_exp, vals) for col_offset in range(0, n_cols, BLOCK_SIZE): vals tl.load(...) # 第三次加载 tl.store(...)为什么 Triton 循环是性能杀手循环次数延迟原因1~1 ms无循环开销10~10 ms线性增长100~100 ms线性增长原因分析:循环展开有限: 编译器不会激进展开所有循环无法向量化: 循环体被视为串行操作每次迭代独立编译: 增加编译开销无法流水线: 循环边界动态检查优化原理当N BLOCK_SIZE时可以一次加载所有数据消除循环# 优化一次加载多次使用 BLOCK_SIZE triton.next_power_of_2(N) # 确保 N BLOCK_SIZE vals tl.load(...) # 一次加载 # 直接在加载的数据上操作 max_val tl.max(vals, axis1, keep_dimsTrue) sum_val tl.sum(vals, axis1, keep_dimsTrue) output compute(vals, max_val, sum_val) tl.store(...)案例Log Softmax原始实现triton.jit def log_softmax_original(...): row_idx tl.program_id(0) # Phase 1: 循环计算 max max_val -float(inf) for col_offset in range(0, n_cols, BLOCK_SIZE): vals tl.load(...) max_val tl.max(tl.maximum(vals, max_val)) # Phase 2: 循环计算 sum sum_exp 0.0 for col_offset in range(0, n_cols, BLOCK_SIZE): vals tl.load(...) # 第二次加载 exp_vals tl.exp(vals - max_val) sum_exp tl.sum(exp_vals) # Phase 3: 循环存储 for col_offset in range(0, n_cols, BLOCK_SIZE): vals tl.load(...) # 第三次加载 output vals - max_val - tl.log(sum_exp) tl.store(...)问题分析3 次循环3 次数据加载每次 load 需要单独的内存访问Grid (M,)kernel launch 开销大优化实现triton.jit def log_softmax_optimized( input_ptr, output_ptr, stride_in, stride_out, n_rows, n_cols, BLOCK: tl.constexpr, ROWS_PER_BLOCK: tl.constexpr, ): pid tl.program_id(0) row_offs pid * ROWS_PER_BLOCK tl.arange(0, ROWS_PER_BLOCK) row_mask row_offs n_rows col_offs tl.arange(0, BLOCK) col_mask col_offs n_cols mask_2d row_mask[:, None] col_mask[None, :] # 一次加载所有数据 x tl.load( input_ptr row_offs[:, None] * stride_in col_offs[None, :], maskmask_2d, other-float(inf) ) # 在同一数据上完成所有计算 row_max tl.max(x, axis1, keep_dimsTrue) x_shifted x - row_max exp_x tl.exp(x_shifted) row_sum tl.sum(tl.where(mask_2d, exp_x, 0.0), axis1, keep_dimsTrue) output x_shifted - tl.log(row_sum) tl.store(output_ptr row_offs[:, None] * stride_out col_offs[None, :], output, maskmask_2d)性能延迟从 82.32μs → 7.97μs加速 10.3x性能收益场景原始优化后收益Softmax (3 phase)3 次加载1 次加载3xLayerNorm (2 phase)2 次加载1 次加载2xLog Softmax82.32 μs7.97 μs10.3x两种方法对比对比项Pass 合并循环消除适用条件任意数据量N ≤ BLOCK_SIZE优化对象多个统计量计算循环结构核心操作同时计算 sumsum_sq 等一次加载多次使用收益来源减少遍历次数减少加载次数 减少 kernel launch可组合性可与维度合并组合可与多行并行组合适用条件Pass 合并✅ 多个统计量可同时计算meanvar, sumsum_sq❌ 统计量之间有依赖关系Softmax 的 sum 依赖 max循环消除✅ N ≤ BLOCK_SIZE可一次性加载❌ N BLOCK_SIZE需保留循环累积大数据量处理当数据量超过 BLOCK_SIZE 时需要保留循环但优化为累积式# 累积式循环模板 triton.jit def kernel_large_n( input_ptr, output_ptr, stride_in, stride_out, n_rows, n_cols, BLOCK: tl.constexpr, ROWS_PER_BLOCK: tl.constexpr, ): pid tl.program_id(0) row_offs pid * ROWS_PER_BLOCK tl.arange(0, ROWS_PER_BLOCK) row_mask row_offs n_rows # 初始化累加器 max_acc -float(inf) # shape: [ROWS_PER_BLOCK, 1] sum_acc 0.0 # 循环处理列 for col_start in range(0, n_cols, BLOCK): col_offs col_start tl.arange(0, BLOCK) col_mask col_offs n_cols mask_2d row_mask[:, None] col_mask[None, :] # 加载当前块 x tl.load( input_ptr row_offs[:, None] * stride_in col_offs[None, :], maskmask_2d, other-float(inf) ) # 累积统计量 block_max tl.max(x, axis1, keep_dimsTrue) # ... 使用 Welford 算法或其他累积方法 # 最终处理...关键点循环内不要重复加载相同数据而是在循环内累积结果。其他常见应用LayerNorm2-pass → 1-pass# 原始2-pass for i in range(N): mean x[i] mean / N for i in range(N): var (x[i] - mean) ** 2 var / N # 优化1-pass for i in range(N): sum_val x[i] sum_sq x[i] ** 2 mean sum_val / N var sum_sq / N - mean ** 2Softmax3-pass → 2-pass# 原始3-pass (max sum normalize) for i in range(N): max_val max(max_val, x[i]) for i in range(N): sum_exp exp(x[i] - max_val) for i in range(N): output[i] exp(x[i] - max_val) / sum_exp # 注意无法进一步合并因为 sum 依赖 max常见错误错误 1忘记调整 BLOCK_SIZE# ❌ 错误BLOCK_SIZE 不够大 BLOCK_SIZE 256 # 如果 N 512会丢失数据 vals tl.load(...) # ✅ 正确确保 N BLOCK_SIZE BLOCK_SIZE triton.next_power_of_2(N)错误 2忘记更新除数# ❌ 错误有 mask 时仍用固定 count sum_val tl.sum(data) mean sum_val / (N * H * W) # 实际元素数可能更少 # ✅ 正确跟踪实际元素数 valid_count tl.sum(mask) mean sum_val / valid_count错误 3循环内重复加载# ❌ 错误循环内多次加载同一数据 for col_start in range(0, n_cols, BLOCK): x tl.load(...) max_val update(max_val, x) for col_start in range(0, n_cols, BLOCK): x tl.load(...) # 重复加载 sum_val update(sum_val, x) # ✅ 正确循环内累积减少加载 for col_start in range(0, n_cols, BLOCK): x tl.load(...) max_val update(max_val, x) sum_val update(sum_val, f(x)) # 同时处理总结方法适用场景核心操作收益Pass 合并多统计量计算同时计算 sumsum_sq减少遍历次数循环消除N ≤ BLOCK_SIZE一次加载多次使用减少加载 减少 launch选择依据数据量小优先循环消除数据量大 多统计量优先 Pass 合并两者可组合使用核心原则减少数据加载次数寻找可同时计算的统计量循环内累积而非重复加载【免费下载链接】cannbot-skillsCANNBot 是面向 CANN 开发的用于提升开发效率的系列智能体本仓库为其提供可复用的 Skills 模块。项目地址: https://gitcode.com/cann/cannbot-skills创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考
减少遍历优化 - CANN/cannbot-skills
减少遍历优化【免费下载链接】cannbot-skillsCANNBot 是面向 CANN 开发的用于提升开发效率的系列智能体本仓库为其提供可复用的 Skills 模块。项目地址: https://gitcode.com/cann/cannbot-skills概述减少数据遍历次数是 Triton 性能优化的核心手段。有两种主要方法方法适用场景核心思路Pass 合并多个统计量计算同时计算多个统计量减少遍历次数循环消除数据量小 (N ≤ BLOCK_SIZE)一次加载所有数据消除循环方法一Pass 合并问题描述问题多次遍历数据每次独立计算统计量导致重复内存访问。# 问题代码3-pass BatchNorm # Pass 1: 计算 mean for ...: data tl.load(...) mean tl.sum(data) # Pass 2: 计算 variance再次遍历 for ...: data tl.load(...) # 重复加载 var tl.sum((data - mean) ** 2) # Pass 3: 归一化第三次遍历 for ...: data tl.load(...) # 第三次加载 tl.store(...)优化原理利用数学公式在单次遍历中同时计算多个统计量mean sum(x) / count var sum(x²) / count - mean² 证明 var E[(x - mean)²] E[x²] - 2·mean·E[x] mean² E[x²] - mean² sum(x²)/count - mean²优化代码# Pass 1: 同时计算 sum 和 sum_sq sum_val 0.0 sum_sq 0.0 for ...: data tl.load(...) sum_val tl.sum(data) sum_sq tl.sum(data * data) # 同时累加 mean sum_val / count var sum_sq / count - mean * mean # Pass 2: 归一化 for ...: ...案例BatchNorm2dtriton.jit def batchnorm_2pass( input_ptr, output_ptr, gamma_ptr, beta_ptr, N, C, H, W, stride_n, stride_c, stride_h, stride_w, eps: tl.constexpr, BLOCK_SIZE_HW: tl.constexpr, ): c tl.program_id(0) gamma tl.load(gamma_ptr c) beta tl.load(beta_ptr c) count N * H * W # Pass 1: 同时计算 sum 和 sum_sq sum_val 0.0 sum_sq 0.0 for n in range(N): for h in range(H): for w_start in range(0, W, BLOCK_SIZE_HW): data tl.load(...) sum_val tl.sum(data) sum_sq tl.sum(data * data) mean sum_val / count var sum_sq / count - mean * mean inv_std 1.0 / tl.sqrt(var eps) # Pass 2: normalize for n in range(N): for h in range(H): for w_start in range(0, W, BLOCK_SIZE_HW): data tl.load(...) output (data - mean) * inv_std * gamma beta tl.store(...)性能3-pass → 2-pass延迟从 73.67ms → ~52ms加速 1.42x方法二循环消除问题描述问题Triton 对 Python for 循环优化有限大量循环是性能杀手。# 问题代码循环多次加载 for col_offset in range(0, n_cols, BLOCK_SIZE): vals tl.load(...) # 循环内加载 max_val update(max_val, vals) for col_offset in range(0, n_cols, BLOCK_SIZE): vals tl.load(...) # 再次加载 sum_exp update(sum_exp, vals) for col_offset in range(0, n_cols, BLOCK_SIZE): vals tl.load(...) # 第三次加载 tl.store(...)为什么 Triton 循环是性能杀手循环次数延迟原因1~1 ms无循环开销10~10 ms线性增长100~100 ms线性增长原因分析:循环展开有限: 编译器不会激进展开所有循环无法向量化: 循环体被视为串行操作每次迭代独立编译: 增加编译开销无法流水线: 循环边界动态检查优化原理当N BLOCK_SIZE时可以一次加载所有数据消除循环# 优化一次加载多次使用 BLOCK_SIZE triton.next_power_of_2(N) # 确保 N BLOCK_SIZE vals tl.load(...) # 一次加载 # 直接在加载的数据上操作 max_val tl.max(vals, axis1, keep_dimsTrue) sum_val tl.sum(vals, axis1, keep_dimsTrue) output compute(vals, max_val, sum_val) tl.store(...)案例Log Softmax原始实现triton.jit def log_softmax_original(...): row_idx tl.program_id(0) # Phase 1: 循环计算 max max_val -float(inf) for col_offset in range(0, n_cols, BLOCK_SIZE): vals tl.load(...) max_val tl.max(tl.maximum(vals, max_val)) # Phase 2: 循环计算 sum sum_exp 0.0 for col_offset in range(0, n_cols, BLOCK_SIZE): vals tl.load(...) # 第二次加载 exp_vals tl.exp(vals - max_val) sum_exp tl.sum(exp_vals) # Phase 3: 循环存储 for col_offset in range(0, n_cols, BLOCK_SIZE): vals tl.load(...) # 第三次加载 output vals - max_val - tl.log(sum_exp) tl.store(...)问题分析3 次循环3 次数据加载每次 load 需要单独的内存访问Grid (M,)kernel launch 开销大优化实现triton.jit def log_softmax_optimized( input_ptr, output_ptr, stride_in, stride_out, n_rows, n_cols, BLOCK: tl.constexpr, ROWS_PER_BLOCK: tl.constexpr, ): pid tl.program_id(0) row_offs pid * ROWS_PER_BLOCK tl.arange(0, ROWS_PER_BLOCK) row_mask row_offs n_rows col_offs tl.arange(0, BLOCK) col_mask col_offs n_cols mask_2d row_mask[:, None] col_mask[None, :] # 一次加载所有数据 x tl.load( input_ptr row_offs[:, None] * stride_in col_offs[None, :], maskmask_2d, other-float(inf) ) # 在同一数据上完成所有计算 row_max tl.max(x, axis1, keep_dimsTrue) x_shifted x - row_max exp_x tl.exp(x_shifted) row_sum tl.sum(tl.where(mask_2d, exp_x, 0.0), axis1, keep_dimsTrue) output x_shifted - tl.log(row_sum) tl.store(output_ptr row_offs[:, None] * stride_out col_offs[None, :], output, maskmask_2d)性能延迟从 82.32μs → 7.97μs加速 10.3x性能收益场景原始优化后收益Softmax (3 phase)3 次加载1 次加载3xLayerNorm (2 phase)2 次加载1 次加载2xLog Softmax82.32 μs7.97 μs10.3x两种方法对比对比项Pass 合并循环消除适用条件任意数据量N ≤ BLOCK_SIZE优化对象多个统计量计算循环结构核心操作同时计算 sumsum_sq 等一次加载多次使用收益来源减少遍历次数减少加载次数 减少 kernel launch可组合性可与维度合并组合可与多行并行组合适用条件Pass 合并✅ 多个统计量可同时计算meanvar, sumsum_sq❌ 统计量之间有依赖关系Softmax 的 sum 依赖 max循环消除✅ N ≤ BLOCK_SIZE可一次性加载❌ N BLOCK_SIZE需保留循环累积大数据量处理当数据量超过 BLOCK_SIZE 时需要保留循环但优化为累积式# 累积式循环模板 triton.jit def kernel_large_n( input_ptr, output_ptr, stride_in, stride_out, n_rows, n_cols, BLOCK: tl.constexpr, ROWS_PER_BLOCK: tl.constexpr, ): pid tl.program_id(0) row_offs pid * ROWS_PER_BLOCK tl.arange(0, ROWS_PER_BLOCK) row_mask row_offs n_rows # 初始化累加器 max_acc -float(inf) # shape: [ROWS_PER_BLOCK, 1] sum_acc 0.0 # 循环处理列 for col_start in range(0, n_cols, BLOCK): col_offs col_start tl.arange(0, BLOCK) col_mask col_offs n_cols mask_2d row_mask[:, None] col_mask[None, :] # 加载当前块 x tl.load( input_ptr row_offs[:, None] * stride_in col_offs[None, :], maskmask_2d, other-float(inf) ) # 累积统计量 block_max tl.max(x, axis1, keep_dimsTrue) # ... 使用 Welford 算法或其他累积方法 # 最终处理...关键点循环内不要重复加载相同数据而是在循环内累积结果。其他常见应用LayerNorm2-pass → 1-pass# 原始2-pass for i in range(N): mean x[i] mean / N for i in range(N): var (x[i] - mean) ** 2 var / N # 优化1-pass for i in range(N): sum_val x[i] sum_sq x[i] ** 2 mean sum_val / N var sum_sq / N - mean ** 2Softmax3-pass → 2-pass# 原始3-pass (max sum normalize) for i in range(N): max_val max(max_val, x[i]) for i in range(N): sum_exp exp(x[i] - max_val) for i in range(N): output[i] exp(x[i] - max_val) / sum_exp # 注意无法进一步合并因为 sum 依赖 max常见错误错误 1忘记调整 BLOCK_SIZE# ❌ 错误BLOCK_SIZE 不够大 BLOCK_SIZE 256 # 如果 N 512会丢失数据 vals tl.load(...) # ✅ 正确确保 N BLOCK_SIZE BLOCK_SIZE triton.next_power_of_2(N)错误 2忘记更新除数# ❌ 错误有 mask 时仍用固定 count sum_val tl.sum(data) mean sum_val / (N * H * W) # 实际元素数可能更少 # ✅ 正确跟踪实际元素数 valid_count tl.sum(mask) mean sum_val / valid_count错误 3循环内重复加载# ❌ 错误循环内多次加载同一数据 for col_start in range(0, n_cols, BLOCK): x tl.load(...) max_val update(max_val, x) for col_start in range(0, n_cols, BLOCK): x tl.load(...) # 重复加载 sum_val update(sum_val, x) # ✅ 正确循环内累积减少加载 for col_start in range(0, n_cols, BLOCK): x tl.load(...) max_val update(max_val, x) sum_val update(sum_val, f(x)) # 同时处理总结方法适用场景核心操作收益Pass 合并多统计量计算同时计算 sumsum_sq减少遍历次数循环消除N ≤ BLOCK_SIZE一次加载多次使用减少加载 减少 launch选择依据数据量小优先循环消除数据量大 多统计量优先 Pass 合并两者可组合使用核心原则减少数据加载次数寻找可同时计算的统计量循环内累积而非重复加载【免费下载链接】cannbot-skillsCANNBot 是面向 CANN 开发的用于提升开发效率的系列智能体本仓库为其提供可复用的 Skills 模块。项目地址: https://gitcode.com/cann/cannbot-skills创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考