049、Focal Loss 核心思想:从 cross-entropy 到 α-balanced Focal Loss 的推导

049、Focal Loss 核心思想:从 cross-entropy 到 α-balanced Focal Loss 的推导 049、Focal Loss 核心思想从 cross-entropy 到 α-balanced Focal Loss 的推导一、一个让我失眠的调试问题去年做工业缺陷检测项目正样本缺陷占比不到0.1%负样本正常产品占99.9%。用标准交叉熵训练YOLOv5模型收敛后mAP0.5只有0.12——它学会了“永远输出负样本”。我试过过采样、欠采样、数据增强效果都像在泥潭里挣扎。直到我在RetinaNet论文里看到Focal Loss才意识到问题本质交叉熵对易分类样本的梯度贡献太大淹没了难分类样本的信号。这篇文章就带你从数学推导到代码实现彻底搞懂Focal Loss为什么能解决这个问题。二、交叉熵的“公平”其实是“不公平”先看二分类交叉熵的标准形式CE(p, y) -y * log(p) - (1-y) * log(1-p)其中y∈{0,1}是真实标签p∈[0,1]是模型预测为正类的概率。为了方便推导定义ptpt p if y1 else 1-p这样交叉熵可以简写为CE(p, y) -log(pt)问题出在哪里假设一个负样本y0模型预测p0.9即pt0.1这个样本被正确分类的概率只有10%属于“难分类样本”。但交叉熵给它的损失是 -log(0.1) ≈ 2.3。再看一个负样本模型预测p0.01pt0.99这是“易分类样本”损失是 -log(0.99) ≈ 0.01。关键点来了易分类样本的损失虽然小但它们的数量是难分类样本的成千上万倍。累加起来易分类样本的总损失占据了主导地位梯度更新方向被它们牵着鼻子走。这就是类别不平衡问题的本质——不是正负样本数量不平衡而是“易分类样本”和“难分类样本”的梯度贡献不平衡。三、Focal Loss的直觉让模型“聚焦”难样本Focal Loss的核心修改只有一行公式FL(pt) -(1-pt)^γ * log(pt)对比交叉熵多了一个调制因子 (1-pt)^γ。这个因子做了什么当pt接近1易分类样本(1-pt)γ接近0损失被大幅压低。当pt接近0难分类样本(1-pt)γ接近1损失几乎不变。γ是聚焦参数论文推荐γ2。我实际调参的经验是γ0退化为交叉熵γ1效果不明显γ2~3效果最好γ5会导致训练不稳定这里踩过坑梯度消失严重。举个例子γ2时易分类样本pt0.99的损失从0.01降为0.01*(0.01)21e-6几乎被忽略。难分类样本pt0.1的损失从2.3降为2.3*(0.9)2≈1.86只降低了20%。这样模型就会把注意力集中在那些“模棱两可”的样本上。四、α-balanced Focal Loss再加一层保险Focal Loss解决了“难易样本”问题但没解决“正负样本”问题。实际场景中正样本往往既是“少数”又是“难分类”的。如果只用Focal Loss模型可能过度关注负样本中的难分类样本比如背景中的噪声而忽略正样本。解决方案是引入α平衡因子FL(pt) -α_t * (1-pt)^γ * log(pt)其中α_t的定义和pt类似α_t α if y1 else 1-αα通常取0.25~0.75具体值取决于正负样本比例。我的经验公式α 1 / (1 正负样本比)。比如正负样本比1:1000α≈0.001。但别直接套用这个公式它只是初始值最终需要调参。注意α和γ不是独立参数。α控制正负样本的权重分配γ控制难易样本的权重分配。两者协同工作α让模型“看到”正样本γ让模型“聚焦”难样本。五、PyTorch实现别踩这些坑直接上代码注释里写满了我的血泪史importtorchimporttorch.nnasnnimporttorch.nn.functionalasFclassFocalLoss(nn.Module):def__init__(self,alpha0.25,gamma2.0,reductionmean):super().__init__()self.alphaalpha# 正样本权重别设成0.5那是交叉熵self.gammagamma# 聚焦参数推荐2.0self.reductionreduction# mean或sumdefforward(self,inputs,targets):# inputs: 模型输出logitsshape [N, C] 或 [N]# targets: 真实标签shape [N]值域{0,1,...,C-1}# 别这样写直接用F.binary_cross_entropy_with_logits# 因为Focal Loss需要手动计算pt# 先计算交叉熵的log部分ce_lossF.cross_entropy(inputs,targets,reductionnone)# ce_loss shape: [N]# 计算pt模型预测正确类别的概率# 这里踩过坑直接用softmax再gather但数值不稳定pttorch.exp(-ce_loss)# 因为ce_loss -log(pt)# pt shape: [N]# 计算Focal Lossfocal_loss(1-pt)**self.gamma*ce_loss# 加入alpha平衡因子# 需要根据targets构建alpha_tifself.alphaisnotNone:# 假设二分类alpha_t alpha if y1 else 1-alpha# 多分类时alpha可以是一个列表alpha_tself.alpha*targets(1-self.alpha)*(1-targets)# 别这样写直接乘alpha_t因为targets可能是long类型focal_lossalpha_t*focal_loss# 根据reduction聚合ifself.reductionmean:returnfocal_loss.mean()elifself.reductionsum:returnfocal_loss.sum()else:returnfocal_loss几个容易翻车的地方数值稳定性不要直接用softmax计算pt然后用log。上面用torch.exp(-ce_loss)是安全的因为ce_loss已经包含了log_softmax。alpha的维度如果做多分类alpha应该是一个长度为C的tensor每个类别一个权重。别偷懒用一个标量。reduction的选择YOLO系列通常用’sum’因为每个anchor的损失需要独立累加。分类任务用’mean’更稳定。六、在YOLO中集成Focal LossYOLOv5/v8的损失函数在loss.py里替换交叉熵部分# 原始YOLOv5的分类损失self.bcenn.BCEWithLogitsLoss(reductionmean)# 替换为Focal Lossself.flFocalLoss(alpha0.25,gamma2.0,reductionmean)注意YOLO的输出是multi-label分类每个类别独立二分类所以Focal Loss需要按二分类方式计算。上面的实现已经兼容了这种情况。调参建议先从γ2.0, α0.25开始论文默认值如果正样本极少0.1%尝试α0.1~0.15如果模型过拟合增大γ到3.0~4.0如果模型欠拟合减小γ到1.0~1.5七、个人经验什么时候用Focal Loss强烈推荐使用目标检测中的一阶段检测器YOLO、SSD、RetinaNet正负样本比例超过1:100的分类任务存在大量“简单负样本”的场景比如背景占90%以上的图像不建议使用二阶段检测器Faster R-CNN因为RPN已经做了样本筛选正负样本比例接近1:1的任务模型已经过拟合的情况Focal Loss会加剧过拟合一个容易被忽视的点Focal Loss会改变损失函数的尺度。如果从交叉熵切换到Focal Loss学习率可能需要调低1~2个数量级。我习惯先用γ0即交叉熵跑一个epoch观察损失量级再调整学习率后启用Focal Loss。最后说句实在话Focal Loss不是万能药。如果你的数据极度不平衡正样本比例0.01%建议先做数据增强和难例挖掘再用Focal Loss做精细调优。我那个缺陷检测项目最终方案是Online Hard Example Mining Focal Loss 数据增强才把mAP从0.12拉到0.87。工具是死的组合拳才是活的。