CANN 混合精度训练:从 FP16 到 BF16

CANN 混合精度训练:从 FP16 到 BF16 一、为什么需要混合精度单精度 FP32 虽然准确但训练大模型时显存和速度都是瓶颈。半精度 FP16 大幅节省显存但动态范围窄、溢出风险高。混合精度取两者之长—— Forward 和 Backward 用 FP16 加速Optimizer State 和关键梯度用 FP32 保精度。FP32 (优化器状态) ← 保精度 ↓ FP16 (前反向计算) ← 加速 ↓ FP32 (权重更新) ← 稳定二、BF16 对比 FP16特性FP16BF16总位数1616阶码位数58尾数位数107动态范围约 10^-5 ~ 10^5约 10^-4 ~ 10^4精度较高更多尾数位较低更多阶码位溢出风险高动态范围小低动态范围大适用场景推理、小模型大模型训练BF16 的阶码位更多动态范围比 FP16 大约 8 倍训练稳定性更好更适合 LLM 等大模型场景。三、昇腾 AMF 混合精度3.1 自动混合精度AMFimporttorchimporttorch.npu# 8.1 及之前手动转 FP16modelmodel.npu()modelmodel.half()# 全部转 FP16风险高# 8.2AMF 自动混合精度fromtorch.npu.ampimportautocast,GradScaler scalerGradScaler()modelMyModel().npu()forbatchindataloader:withautocast(enabledTrue):outputmodel(data)lossloss_fn(output,target)# 梯度缩放防止下溢scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()3.2 梯度缩放原理FP16 动态范围小梯度值可能小于最小表示值Underflow导致更新失效。GradScaler 通过乘法放大梯度反向传播后再 scale 回来classGradScaler:def__init__(self,init_scale2.**16):self.scaleinit_scale self._scale_factorinit_scaledefscale(self,loss):returnloss*self.scaledefstep(self,optimizer):# 放大后的梯度scaled_grads[p.grad*self.scaleforpinmodel.parameters()]optimizer.step()defupdate(self,overflow_foundFalse):ifoverflow_found:self.scale/2# 缩放系数减半else:# 每 2000 步检查是否可加倍self.scale*1.01四、BF16 训练配置8.2 新增4.1 启用 BF16# 8.2 支持 BF16fromtorch.npu.ampimportautocast modelMyModel().npu()# 使用 BF16 混合精度forbatchindataloader:withautocast(enabledTrue,dtypetorch.bfloat16):outputmodel(data)lossloss_fn(output,target)loss.backward()optimizer.step()optimizer.zero_grad()4.2 FP32 vs BF16 vs FP16 对比配置显存占用训练速度精度FP32100%基准最高FP16~50%1.5-2x中等BF16~50%1.3-1.8x较好4.3 损失放大Loss ScalingclassDynamicLossScaler:def__init__(self,scale_factor2.**16,growth_factor2.0,backoff_factor0.5):self.scalescale_factor self.growth_factorgrowth_factor self.backoff_factorbackoff_factor self.growth_interval2000self.step_count0defupdate(self,overflow):self.step_count1ifoverflow:# 溢出缩小 scaleself.scale*self.backoff_factorelifself.step_count%self.growth_interval0:# 正常尝试放大 scaleself.scale*self.growth_factorreturnself.scale五、CANN 混合精度最佳实践5.1 推荐配置# 推荐BF16 动态缩放scalerDynamicLossScaler(scale_factor2.**16,growth_factor1.01,backoff_factor0.5)forepochinrange(num_epochs):forbatchindataloader:withautocast(enabledTrue,dtypetorch.bfloat16):outputmodel(data)lossloss_fn(output,target)scaler.scale(loss).backward()# 检查溢出overflowcheck_overflow(model)scaler.update(overflow)scaler.step(optimizer)scaler.update()5.2 不适合混合精度的算子算子原因建议Softmax输出端溢出风险高保持 FP32LayerNorm数值不稳定保持 FP32CrossEntropy指数运算保持 FP32位置编码小值范围大保持 FP32# 8.2 新增按模块指定精度classModel(nn.Module):defforward(self,x):# 这些层保持 FP32xself.layer_norm(x)# FP32# 其他层用 BF16withautocast(enabledTrue,dtypetorch.bfloat16):xself.attention(x)xself.feed_forward(x)returnx六、常见问题问题原因解决方案loss 变为 nan梯度溢出增大 scale 或改用 BF16精度下降明显关键算子未保留 FP32LayerNorm、Softmax 保持 FP32速度提升不明显batch size 太小增大 batch 或检查 AMP 配置多卡训练不稳定梯度 allreduce 溢出启用梯度缩放相关仓库torch_npu- 混合精度接口 https://gitee.com/ascend/torch_npuASCEND- 混合精度最佳实践 https://gitee.com/ascend/ascend