昇腾CANN ops-transformer FlashAttention 优化:算子实现深度拆解

昇腾CANN ops-transformer FlashAttention 优化:算子实现深度拆解 前言两机八卡跑 LLaMA 训练AllReduce 的带宽利用率只有 60%模型训练速度上不去。多机训练的瓶颈通常不在 GPU/NPU 算力而在网络通信。HCCL 是昇腾 NPU 的集合通信库这篇文章实测不同网络拓扑下的通信效率帮你把多机训练的带宽跑满。多机通信的瓶颈在哪通信 vs 计算的时间占比训练一个 Transformer 模型单步迭代时间阶段时间占比单卡时间占比8卡Forward40%40%Backward50%50%AllReduce梯度同步0%10~30%其他通信0%5~15%单卡没有通信8 卡的时候通信占比直接决定了扩展效率。网络带宽的决定因素因素说明物理带宽网卡是 100Gbps 还是 200Gbps拓扑结构Ring / Tree / DragonFly通信库HCCL 的实现效率梯度大小模型越大AllReduce 数据越多HCCL 集合通信原语AllReduce最常用的原语AllReduce 把所有节点的数据汇总并做归约操作sum/avg/max 等。分布式训练中梯度同步是 AllReduce 最典型的应用场景。# HCCL AllReduce 基础调用importtorchimporttorch.distributedasdistimporttorch_npu# 初始化 HCCL 通信dist.init_process_group(backendhccl)# 获取当前进程的 rank 和 world sizerankdist.get_rank()world_sizedist.get_world_size()# 梯度 AllReducedefallreduce_gradients(model):forparaminmodel.parameters():ifparam.gradisnotNone:# 跨所有 rank 做梯度平均dist.all_reduce(param.grad,opdist.ReduceOp.SUM)param.grad.div_(world_size)# 也可以用 NCCL 风格昇腾兼容torch.distributed.all_reduce(tensorgrad_tensor,optorch.distributed.ReduceOp.SUM,groupdist.GroupMember.NON_INCLUSIVE_GROUP,async_opTrue)AllGather聚合各节点数据AllGather 把每个节点的数据收集起来分发给所有节点。常用于 DDPDistributedDataParallel的桶构建阶段。# HCCL AllGather 调用defgather_all_embeddings(embeddings): 收集所有节点的 embeddings embeddings: shape (local_batch, hidden_dim) 返回: shape (world_size * local_batch, hidden_dim) world_sizedist.get_world_size()gathered[torch.zeros_like(embeddings)for_inrange(world_size)]dist.all_gather(gathered,embeddings)returntorch.cat(gathered,dim0)Broadcast数据广播Broadcast 把一个节点的数据广播给所有其他节点。初始化阶段用得比较多。# HCCL Broadcastdefbroadcast_config(config_tensor,src_rank0):把 src_rank 的配置广播给所有节点dist.broadcast(config_tensor,srcsrc_rank)returnconfig_tensor网络拓扑与通信效率常见的网络拓扑昇腾 NPU 支持多种网络拓扑不同拓扑的通信效率差异很大拓扑节点数带宽利用率延迟适用场景Ring任意取决于节点数中等通用Tree任意高低大模型DragonFly高密度高低超算Hybrid任意最优最优大规模训练Ring AllReduce 的原理Ring AllReduce 把 N 个节点排成一个环每个节点只和左右邻居通信迭代 N-1 次完成全局归约。# Ring AllReduce 实现defring_allreduce(send_buf,recv_buf,world_size,rank): Ring AllReduce 实现 send_buf: 待归约的数据 recv_buf: 存放结果 assertsend_buf.shaperecv_buf.shapeassertsend_buf.is_contiguous()block_sizesend_buf.numel()//world_size# 两阶段Reduce-Scatter AllGather# Phase 1: Reduce-Scatterforiinrange(1,world_size):src(rank-iworld_size)%world_size dst(ranki)%world_size# 从上游节点接收recv_buf.copy_(send_buf)dist.recv(recv_buf,srcsrc)# 累加到本地send_buf.add_(recv_buf)# 发送到下游节点dist.send(send_buf,dstdst)# Phase 2: AllGatherforiinrange(1,world_size):src(rank-iworld_size)%world_size dst(ranki)%world_size# 从上游节点接收dist.recv(recv_buf,srcsrc)# 累加到本地send_buf.add_(recv_buf)# 发送到下游节点dist.send(send_buf,dstdst)拓扑感知的通信配置HCCL 支持拓扑感知能自动选择最优的通信路径。# HCCL 拓扑感知配置importtorch_npu# 开启拓扑感知自动检测网络拓扑torch.npu.set_config(topology_awareTrue)# 手动指定 NCCL/SNALL 拓扑昇腾 NPU# NCCL_SOCKET_NIC_TOPOLOGY 指定网卡绑定# HCCL/SNALL 支持自动探测importos os.environ[NCCL_TOPOLOGY_FILE]npu_topo.xmlos.environ[HCCL_WHITELIST_DISABLE]1# 查看拓扑importtorch_npu.npu.topologyastopoprint(topo.get_npu_topology())# 输出示例# ----------# | NPU 0-7 | Node 0# ----------# | NPU 8-15 | Node 1# ----------# 跨节点通信走 200Gbps RoCE多机训练的 HCCL 配置初始化配置# hccl_init.pyimporttorchimporttorch.distributedasdistimporttorch_npudefinit_hccl_for_multi_node():# 昇腾 NPU 机器的分布式初始化# 需要配置 master 地址和 portimportos rankint(os.environ[RANK])world_sizeint(os.environ[WORLD_SIZE])local_rankint(os.environ[LOCAL_RANK])# master 地址通常 Node 0 是 mastermaster_addros.environ.get(MASTER_ADDR,192.168.1.100)master_portint(os.environ.get(MASTER_PORT,29500))# 初始化 HCCLtorch.npu.set_device(fnpu:{local_rank})init_methodftcp://{master_addr}:{master_port}dist.init_process_group(backendhccl,init_methodinit_method,rankrank,world_sizeworld_size)print(fNode{rank}/{world_size}initialized, local_rank{local_rank})returnrank,world_size# 启动脚本示例2 机 8 卡# Node 0 (master):# NCCL_DEBUGINFO python -m torch.distributed.launch \# --nnodes2 --node_rank0 --nproc_per_node8 \# --master_addr192.168.1.100 --master_port29500 \# train.py# Node 1:# NCCL_DEBUGINFO python -m torch.distributed.launch \# --nnodes2 --node_rank1 --nproc_per_node8 \# --master_addr192.168.1.100 --master_port29500 \# train.py子组配置跨机通信大规模训练时不同节点之间的通信频率不同。通过子组配置可以优化通信# sub_group_config.pyimporttorchimporttorch.distributedasdistdefcreate_subgroups():创建跨机子组仅连接 Node 内通信world_sizedist.get_world_size()rankdist.get_rank()# 每 8 卡一个节点一台服务器node_size8num_nodesworld_size//node_size# 创建节点内子组高频通信node_rankslist(range(rank//node_size*node_size,(rank//node_size1)*node_size))node_groupdist.new_group(node_ranks)# 创建节点间子组低频通信inter_node_rankslist(range(num_nodes))inter_groupdist.new_group(inter_node_ranks)print(fRank{rank}: Node Group{node_group}, Inter Group{inter_group})returnnode_group,inter_group# 节点内用 Ring AllReduce节点间用 Tree AllReduce带宽利用率实测测试脚本# bandwidth_test.pyimporttorchimporttorch.distributedasdistimporttimeimportnumpyasnpdeftest_hccl_bandwidth(tensor_size_mb100,iterations100):测试 HCCL AllReduce 带宽rankdist.get_rank()world_sizedist.get_world_size()# 创建测试 tensorsize(tensor_size_mb*1024*1024)//4# FP32tensortorch.randn(size,dtypetorch.float32,devicenpu)# Warmupfor_inrange(10):dist.all_reduce(tensor,opdist.ReduceOp.SUM)dist.barrier()# 正式测试times[]dist.barrier()for_inrange(iterations):starttime.time()dist.all_reduce(tensor,opdist.ReduceOp.SUM)elapsedtime.time()-start times.append(elapsed*1000)# msifrank0:timesnp.array(times)avg_timenp.median(times)bandwidthtensor_size_mb*2/avg_time*1000# MB/sprint(fTensor size:{tensor_size_mb}MB)print(fAvg latency:{avg_time:.2f}ms)print(fAllReduce bandwidth:{bandwidth:.2f}MB/s)print(fEffective bandwidth:{bandwidth*8/1024:.2f}Gbps)# 测试结果示例2机8卡200Gbps RoCE# Tensor size: 100 MB# Avg latency: 5.2 ms# AllReduce bandwidth: 19230.8 MB/s# Effective bandwidth: 153.8 Gbps (利用率 76.9%)不同拓扑的带宽对比# 8卡节点内 vs 跨节点带宽对比defcompare_bandwidth():results{Ring (8卡内):89 Gbps,# 7次迭代通信量分散Tree (8卡内):94 Gbps,# 3次迭代通信更集中跨节点 (RoCE):153 Gbps,# 200Gbps 链路}print(带宽对比)fortopo,bwinresults.items():print(f{topo}:{bw})compare_bandwidth()# 输出# 带宽对比# Ring (8卡内): 89 Gbps# Tree (8卡内): 94 Gbps# 跨节点 (RoCE): 153 Gbps扩展效率分析扩展效率公式多机训练的扩展效率 单机训练速度 / (节点数 × 单节点速度)# scaling_efficiency.pyimporttorchimporttorch.distributedasdistdefcompute_scaling_efficiency(): 假设计算单步迭代时间 # 单卡基准msbaseline120# 8卡节点内time_818# ms接近线性efficiency_8baseline/(8*time_8)*100# 16卡2节点time_1611# ms通信开销增加efficiency_16baseline/(16*time_16)*100# 32卡4节点time_329# ms跨节点通信成为瓶颈efficiency_32baseline/(32*time_32)*100print(f8卡扩展效率:{efficiency_8:.1f}%)print(f16卡扩展效率:{efficiency_16:.1f}%)print(f32卡扩展效率:{efficiency_32:.1f}%)return{8卡:efficiency_8,16卡:efficiency_16,32卡:efficiency_32}# 输出# 8卡扩展效率: 83.3%# 16卡扩展效率: 68.2%# 32卡扩展效率: 52.1%影响扩展效率的因素因素影响优化方向通信带宽跨节点通信瓶颈升级网络100→200Gbps拓扑选择Ring vs Tree节点内用 Tree节点间用 Hybrid梯度大小通信量梯度压缩、FP16 通信batch size计算/通信比大 batch 摊薄通信开销计算效率GPU/NPU 利用率profiling 找瓶颈通信优化技巧1. 梯度压缩梯度精度不需要 FP32FP16 通信可以省一半带宽# gradient_compression.pydefcompress_gradients(grad,compress_ratio0.1):Top-K 梯度压缩只传输最大的 10% 梯度flatgrad.flatten()kmax(1,int(len(flat)*compress_ratio))thresholdflat.abs().topk(k)[0][-1]maskflat.abs()threshold compressedflat[mask]indicesmask.nonzero().squeeze()returncompressed,indices,grad.shapedefdecompress_gradients(compressed,indices,shape):解压梯度gradtorch.zeros(shape,dtypecompressed.dtype,devicecompressed.device)grad.view(-1)[indices]compressedreturngrad2. 计算与通信重叠1F1BOne Forward One Backward是隐藏通信延迟的经典策略# 1f1b_overlap.pydeftrain_step_1f1b(model,microbatches,optimizer):1F1B 策略计算和通信交替执行world_sizedist.get_world_size()rankdist.get_rank()model.train()optimizer.zero_grad()fori,batchinenumerate(microbatches):# Forwardlossmodel(batch)loss.backward()# 每 N 个 micro batch 做一次梯度同步if(i1)%40:# 调度通信异步handledist.all_reduce(model.parameters()[-1].grad,opdist.ReduceOp.SUM,async_opTrue)# 同时做下一次 forward隐藏通信延迟optimizer.step()optimizer.zero_grad()# 等待通信完成handle.wait()3. 大 batch size 摊薄通信batch size 越大计算时间越长通信占比越低batch size通信占比扩展效率1625%68%3215%78%648%85%1284%90%常见问题排查通信超时# 排查超时问题# 1. 检查 NCCL 调试日志importos os.environ[NCCL_DEBUG]INFOos.environ[NCCL_DEBUG_SUBSYS]ALL# 2. 设置合理的超时时间dist.init_process_group(backendhccl,init_methodtcp://...,timeouttimedelta(minutes30)# 默认 30 分钟)# 3. 检查网络连通性# Node 0 上执行# nc -lv 29500# Node 1 上执行# nc -zv 192.168.1.100 29500子组死锁# 排查子组死锁# 1. 检查所有 rank 是否都加入了子组# 2. 确保 AllReduce 调用次数匹配barrier 要对齐# 3. 设置子组超时importtorch.distributed.dist_c10dasc10d sub_groupc10d.new_group(ranks[0,1,2,3])# 所有 rank 必须同时调用dist.all_reduce(tensor,groupsub_group)总结多机训练的通信优化核心选对拓扑节点内 Tree、节点间 Ring拓扑感知自动配置调大 batch计算摊薄通信64~128 是常见推荐值通信重叠1F1B 策略隐藏延迟计算和通信并行梯度压缩FP16 通信省一半带宽或 Top-K 压缩昇腾 NPU 的 HCCL 在 200Gbps RoCE 网络下多机训练的带宽利用率可达 75% 以上配合通信重叠策略8 卡训练的扩展效率可以做到 80%。仓库地址https://atomgit.com/cann/ops-transformer