067、混合精度训练 autocast 源码:前向 FP16到Loss Scale到反向 FP32 的完整机制

067、混合精度训练 autocast 源码:前向 FP16到Loss Scale到反向 FP32 的完整机制 067、混合精度训练 autocast 源码前向 FP16到Loss Scale到反向 FP32 的完整机制一、从一次显存爆炸说起去年有个项目训练YOLOv8-Lbatch size设到32RTX 4090 24G显存直接爆了。同事说“降batch size呗”我说“试试混合精度”。结果改了四行代码batch size拉到48显存占用反而降了30%训练速度还快了近一倍。但第二天发现loss曲线在某个epoch后突然变成NaN模型直接废了。这就是混合精度训练的典型坑FP16精度不够梯度下溢模型崩了。后来我翻autocast源码才彻底搞明白这东西不是简单地把float32转成float16就完事了背后有一套完整的动态Loss Scale机制在兜底。二、autocast到底干了什么很多人以为with torch.cuda.amp.autocast():就是自动把模型输入转成FP16。错。它做的是选择性精度转换——哪些操作用FP16算哪些必须用FP32保精度autocast内部有个白名单。看PyTorch 2.0的源码autocast核心逻辑在torch/cuda/amp/autocast.py。它维护了一个_cast字典里面记录了每个op的精度偏好# 源码简化版实际在torch/amp/autocast_mode.py_CASTS{torch.addmm:[fp16,fp16,fp16],# 三个参数都转fp16torch.mm:[fp16,fp16],torch.bmm:[fp16,fp16],torch.conv2d:[fp16,fp16,fp32],# 权重和输入转fp16bias保持fp32torch.softmax:[fp32],# softmax必须fp32否则梯度爆炸torch.layer_norm:[fp32],# layer norm也是}这里踩过坑softmax和layer norm在FP16下精度损失极大尤其是YOLO的检测头里用了softmax做分类如果autocast没把它保护成FP32训练到一半loss直接飞掉。PyTorch官方已经把这些op写死了但如果你自己写了个自定义op记得手动加装饰器torch.cuda.amp.custom_fwd。三、前向传播FP16的“偷懒”哲学前向时autocast的工作流程是这样的拦截op调用每个torch操作被_cast函数包裹检查输入类型精度转换如果op在白名单里把float32 tensor转成float16半精度执行计算用FP16做矩阵乘法、卷积等计算速度翻倍结果缓存输出保持FP16但autocast会记录哪些tensor是“关键节点”别这样写x x.half()手动转。autocast会自动处理你手动转反而可能破坏它的精度选择逻辑。我见过有人把输入手动转成FP16结果softmax也变成FP16算loss直接NaN。关键点autocast只在with块内生效块外的操作不受影响。所以训练循环里前向和loss计算要包在autocast里反向传播和优化器更新在外面。四、Loss Scale那个救命的缩放因子FP16的数值范围是[-65504, 65504]但梯度通常很小比如1e-5在FP16下直接变成0这就是下溢。Loss Scale就是把这个梯度放大算完再缩回去。PyTorch的GradScaler源码在torch/cuda/amp/grad_scaler.py核心逻辑classGradScaler:def__init__(self,init_scale2.**16,growth_factor2.0,backoff_factor0.5,growth_interval2000):self._scaletorch.tensor(init_scale,dtypetorch.float32)self._growth_factorgrowth_factor self._backoff_factorbackoff_factor self._growth_intervalgrowth_interval self._growth_tracker0# 记录连续无溢出步数defscale(self,loss):# 把loss放大避免梯度下溢returnloss*self._scaledefunscale_(self,optimizer):# 反向传播后把梯度缩回来forgroupinoptimizer.param_groups:forpingroup[params]:ifp.gradisnotNone:p.grad.data.div_(self._scale)defstep(self,optimizer):# 检查是否有梯度溢出NaN或Infifself._has_inf_or_nan(optimizer):self._scale*self._backoff_factor# 发现溢出缩小scaleself._growth_tracker0optimizer.zero_grad()# 跳过这步更新else:self._scale*self._growth_factor# 连续无溢出放大scaleself._growth_tracker1ifself._growth_trackerself._growth_interval:self._scale*self._growth_factor self._growth_tracker0这里踩过坑unscale_必须在step之前调用。如果你先调了optimizer.step()再unscale梯度已经被更新了scale白做了。PyTorch官方推荐写法scaler.scale(loss).backward()# 前向反向scaler.unscale_(optimizer)# 先缩梯度scaler.step(optimizer)# 再更新参数scaler.update()# 最后调整scale别这样写scaler.step(optimizer)之后才调unscale_。我debug过一整天发现loss曲线震荡就是因为顺序搞反了。五、反向传播FP32的“救场”机制反向传播时autocast会自动把梯度转回FP32。为什么因为梯度计算需要高精度。FP16的梯度更新参数相当于用一把尺子量头发丝误差太大。看反向传播的源码逻辑# 在autocast模式下每个op的反向函数被包装classAutocastFunction(torch.autograd.Function):staticmethoddefforward(ctx,input,weight,biasNone):# 前向用FP16input_fp16input.half()weight_fp16weight.half()outputtorch.conv2d(input_fp16,weight_fp16,bias)ctx.save_for_backward(input_fp16,weight_fp16,bias)returnoutputstaticmethoddefbackward(ctx,grad_output):# 反向时梯度自动转回FP32input_fp16,weight_fp16,biasctx.saved_tensors grad_inputtorch.conv2d_backward(grad_output.float(),# 这里转成FP32input_fp16.float(),weight_fp16.float(),bias)returngrad_input关键点梯度计算用FP32但参数更新时又转回FP16。所以模型权重在内存里是FP16但梯度计算时临时转成FP32算完再转回去。这解释了为什么混合精度能省显存——权重存FP16但计算时临时用FP32用完就释放。六、YOLO实战中的坑与优化在YOLOv8的训练代码里混合精度配置是这样的scalertorch.cuda.amp.GradScaler(enabledargs.amp)forbatchindataloader:withtorch.cuda.amp.autocast(enabledargs.amp):predsmodel(images)losscompute_loss(preds,targets)scaler.scale(loss).backward()scaler.unscale_(optimizer)torch.nn.utils.clip_grad_norm_(model.parameters(),max_norm10.0)scaler.step(optimizer)scaler.update()optimizer.zero_grad()这里有个隐藏坑clip_grad_norm_必须在unscale_之后。因为梯度被scale放大了直接clip会剪掉正常梯度。我见过有人把clip放在unscale前面结果梯度被剪成0模型不收敛。另一个坑YOLO的检测头里用了nn.SiLU激活函数它在FP16下表现不稳定。解决方案是在autocast块外手动把检测头的输入转成FP32classDetect(nn.Module):defforward(self,x):withtorch.cuda.amp.autocast(enabledFalse):# 关闭autocastxx.float()# 强制FP32# 检测头计算...别这样写整个模型都包在autocast里然后指望它自动处理所有层。YOLO的检测头对精度敏感必须手动干预。七、个人经验什么时候该用什么时候别用该用的情况显存不够batch size上不去模型大YOLOv8-L以上训练速度慢梯度值在1e-3到1e-5之间不会下溢别用的情况模型很小YOLOv5-n显存充足用FP32更快任务对精度极度敏感比如医学图像检测FP16的误差不可接受自定义op太多且没有注册精度偏好调试技巧第一次训练先用FP32跑10个epoch记录loss范围如果loss在1e-4以下说明梯度太小需要调高init_scale比如2^20如果频繁出现inf或nan检查是否有op没被autocast保护用torch.cuda.amp.autocast(enabledTrue, dtypetorch.bfloat16)试试BF16数值范围更大不容易溢出最后说一句混合精度不是银弹。我见过有人为了省显存强行用FP16训练YOLOv8-x结果精度掉了2个点还不如降batch size用FP32。工具是死的人是活的根据你的硬件和任务灵活选择。