Focal Loss 是何恺明设计的为了解决one-stage目标检测在训练阶段前景类和背景类极度不均衡（如1：1000）的场景的损失函数。它是由二分类交叉熵改造而来的。

α-平衡交叉熵：

Focal Loss定义

gamma是对损失函数的调节，当gamma=0是，Focal Loss与α-CE等价。以下是gamma

Focal Loss的Pytorch实现（蓝色字体）

``` import numpy as np
import torch
import torch.nn as nn

def calc_iou(a, b):
area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])

iw = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 0])
ih = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 1])

iw = torch.clamp(iw, min=0)
ih = torch.clamp(ih, min=0)

ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih

ua = torch.clamp(ua, min=1e-8)

intersection = iw * ih

IoU = intersection / ua

return IoU

class FocalLoss(nn.Module):
#def __init__(self):

def forward(self, classifications, regressions, anchors, annotations):
alpha = 0.25
gamma = 2.0
batch_size = classifications.shape[0]
classification_losses = []
regression_losses = []

anchor = anchors[0, :, :]

anchor_widths  = anchor[:, 2] - anchor[:, 0]
anchor_heights = anchor[:, 3] - anchor[:, 1]
anchor_ctr_x   = anchor[:, 0] + 0.5 * anchor_widths
anchor_ctr_y   = anchor[:, 1] + 0.5 * anchor_heights

for j in range(batch_size):

classification = classifications[j, :, :]
regression = regressions[j, :, :]

bbox_annotation = annotations[j, :, :]
bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]

if bbox_annotation.shape[0] == 0:
regression_losses.append(torch.tensor(0).float().cuda())
classification_losses.append(torch.tensor(0).float().cuda())

continue

classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)

IoU = calc_iou(anchors[0, :, :], bbox_annotation[:, :4]) # num_anchors x num_annotations

IoU_max, IoU_argmax = torch.max(IoU, dim=1) # num_anchors x 1

#import pdb
#pdb.set_trace()

# compute the loss for classification
targets = torch.ones(classification.shape) * -1
targets = targets.cuda()

targets[torch.lt(IoU_max, 0.4), :] = 0

positive_indices = torch.ge(IoU_max, 0.5)

num_positive_anchors = positive_indices.sum()

assigned_annotations = bbox_annotation[IoU_argmax, :]

targets[positive_indices, :] = 0
targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1

alpha_factor = torch.ones(targets.shape).cuda() * alpha

alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor)
82             focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification)
83             focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
84
85             bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification))
86
87             # cls_loss = focal_weight * torch.pow(bce, gamma)
88             cls_loss = focal_weight * bce
89
90             cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape).cuda())

classification_losses.append(cls_loss.sum()/torch.clamp(num_positive_anchors.float(), min=1.0))

# compute the loss for regression

if positive_indices.sum() > 0:
assigned_annotations = assigned_annotations[positive_indices, :]

anchor_widths_pi = anchor_widths[positive_indices]
anchor_heights_pi = anchor_heights[positive_indices]
anchor_ctr_x_pi = anchor_ctr_x[positive_indices]
anchor_ctr_y_pi = anchor_ctr_y[positive_indices]

gt_widths  = assigned_annotations[:, 2] - assigned_annotations[:, 0]
gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1]
gt_ctr_x   = assigned_annotations[:, 0] + 0.5 * gt_widths
gt_ctr_y   = assigned_annotations[:, 1] + 0.5 * gt_heights

# clip widths to 1
gt_widths  = torch.clamp(gt_widths, min=1)
gt_heights = torch.clamp(gt_heights, min=1)

targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi
targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi
targets_dw = torch.log(gt_widths / anchor_widths_pi)
targets_dh = torch.log(gt_heights / anchor_heights_pi)

targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh))
targets = targets.t()

targets = targets/torch.Tensor([[0.1, 0.1, 0.2, 0.2]]).cuda()

negative_indices = 1 - positive_indices

regression_diff = torch.abs(targets - regression[positive_indices, :])

regression_loss = torch.where(
torch.le(regression_diff, 1.0 / 9.0),
0.5 * 9.0 * torch.pow(regression_diff, 2),
regression_diff - 0.5 / 9.0
)
regression_losses.append(regression_loss.mean())
else:
regression_losses.append(torch.tensor(0).float().cuda())

## Focal Loss笔记的更多相关文章

1. 论文阅读笔记四十四：RetinaNet:Focal Loss for Dense Object Detection(ICCV2017）

论文原址:https://arxiv.org/abs/1708.02002 github代码:https://github.com/fizyr/keras-retinanet 摘要 目前,具有较高准确 ...

2. 深度学习笔记（八）Focal Loss

论文:Focal Loss for Dense Object Detection 论文链接:https://arxiv.org/abs/1708.02002 一. 提出背景 object detect ...

3. 目标检测 | RetinaNet：Focal Loss for Dense Object Detection

论文分析了one-stage网络训练存在的类别不平衡问题,提出能根据loss大小自动调节权重的focal loss,使得模型的训练更专注于困难样本.同时,基于FPN设计了RetinaNet,在精度和速 ...

4. Focal Loss理解

1. 总述 Focal loss主要是为了解决one-stage目标检测中正负样本比例严重失衡的问题.该损失函数降低了大量简单负样本在训练中所占的权重,也可理解为一种困难样本挖掘. 2. 损失函数形式 ...

5. Focal Loss

为了有效地同时解决样本类别不均衡和苦难样本的问题,何凯明和RGB以二分类交叉熵为例提出了一种新的Loss----Focal loss 原始的二分类交叉熵形式如下: Focal Loss形式如下: 上式 ...

6. Focal Loss(RetinaNet) 与 OHEM

Focal Loss for Dense Object Detection-RetinaNet YOLO和SSD可以算one-stage算法里的佼佼者,加上R-CNN系列算法,这几种算法可以说是目标检 ...

7. Focal Loss for Dense Object Detection 论文阅读

何凯明大佬 ICCV 2017 best student paper 作者提出focal loss的出发点也是希望one-stage detector可以达到two-stage detector的准确 ...

8. Focal Loss 的前向与后向公式推导

把Focal Loss的前向和后向进行数学化描述.本文的公式可能数学公式比较多.本文尽量采用分解的方式一步一步的推倒.达到能易懂的目的. Focal Loss 前向计算 其中 是输入的数据 是输入的标 ...

9. focal loss和ohem

公式推导:https://github.com/zimenglan-sysu-512/paper-note/blob/master/focal_loss.pdf 使用的代码:https://githu ...

## 随机推荐

1. 基于WinCE的JSON 类库 源码

基于WinCE的JSON 类库,可以将对象序列化成字符串和文件. 提示,其在反序列化时有一个BUG: 如果对象的某个字段值为 null,将其序列化成字符串,然后将该字符串反序列化成对象时会报异常. 这 ...

2. OOM解决方案

应用程序OOM异常永远都是值得关注的问题.通常这一块也是程序这中的重点之一 首先,OOM就是内存溢出,即Out Of Memory.也就是说内存占有量超过了VM所分配的最大. 怎么解决OOM,通常OO ...

3. linux 后台运行程序

有些时候,我们需要在终端启动一个程序,并使之运行--但是如果关闭终端,那么这个程序也就随着关闭了.那么有没有什么方法在关闭终端后,让已经从这个终端启动的程序继续运行呢? 前置知识: xterm,con ...

4. 跟我学机器视觉-HALCON学习例程中文详解-开关引脚测量

跟我学机器视觉-HALCON学习例程中文详解-开关引脚测量 This example program demonstrates the basic usage of a measure object. ...

5. Find Successor &amp; Predecessor in BST

First, we use recursive way. Successor public class Solution { public TreeNode inorderSuccessor(Tree ...

6. P2P中的NAT穿越方案简介

文章链接: http://www.shipin.it/Index/videolist/id/68.html

7. ACM第一天研究懂的AC代码——BFS问题解答——习题zoj2165

代码参考网址:http://blog.csdn.net/slience_perseverance/article/details/6706354 试题分析: 本题是研究red and black的一个 ...

8. ubuntu 默认 root 密码

安装完Ubuntu后忽然意识到没有设置root密码,不知道密码自然就无法进入根用户下.到网上搜了一下,原来是这麽回事.Ubuntu的默认root密码是随机的,即每次开机都有一个新的root密码.我们可 ...

9. HttpServletResponse ServletResponse 返回响应 设置响应头设置响应正文体 重定向 常用方法 如何重定向 响应编码 响应乱码

HttpServletResponse  和 ServletResponse  都是接口 具体的类型对象是由Servlet容器传递过来   ServletResponse对象的功能分为以下四种:   ...

10. [好文mark] 深度学习中的注意力机制

https://cloud.tencent.com/developer/article/1143127