公式推导:https://github.com/zimenglan-sysu-512/paper-note/blob/master/focal_loss.pdf

使用的代码:https://github.com/zimenglan-sysu-512/Focal-Loss

在onestage的网络中,正负样本达到1:1000,这就会出现两个问题:1.样本不平衡   2.负样本主导loss。虽然负样本的loss小(因为大量的负样本是easy example,大量负样本是准确率很高的第0类),但个数众多,加起来的loss甚至大于了正样本的loss

focal loss先解决了样本不平衡的问题,即在CE上加权重,当class为1的时候,乘以权重alpha,当class为0的时候,乘以权重1-alpha,这是最基本的解决样本不平衡的方法,也就是在loss计算时乘以权重。注意下面的图:alpha下面有个坐标t,也就是说alpha针对不同类别,值并不一样

尝试了样本不平衡改变权重后,paper就给出了focal loss的公式,实际上就是在CE前加了一个(1-pt)的gama次方,也就是说,如果你的准确率越高,gama次方的值越小,整个loss的值也就越小。也就是说,gama次方就是用来衰减的,准确率越高的样本衰减越多,越低的衰减的越少,这样整个loss就是由准确率较低的样本主导了。对于onestage的网络,loss由负样本主导,但这些负样本大多是准确率很高的,经过focal loss后就变成了正负样本共同主导,或者说是概率低的主导。这一点和ohem很像,ohem是让loss大的进行训练。

最后总的公式是两者结合:

知乎上有人说:

OHEM是只取3:1的负样本去计算loss,之外的负样本权重置零,而focal loss取了所有负样本,根据难度给了不同的权重。
focal loss相比OHEM的提升点在于,3:1的比例比较粗暴,那些有些难度的负样本可能游离于3:1之外。之前实验中曾经调整过OHEM这个比例,发现是有好处的,现在可以试试focal loss了。

paper中单独做了一个实验,就是直接在CE上加权重,得到的结果是alpha=0.75的时候效果最好,也就是说,正样本的权重为0.75,负样本的权重为0.25,正样本的权重大于负样本,因为本身就是正样本个数远少于负样本。加了gama次方后,alpha取0.25的时候效果最好,也就是说,正样本的权重为0.25,负样本的权重为0.75,这个时候反而负样本的权重在增加,按道理来说,负样本个数这么多,应该占loss主导,这说明gama次方已经把负样本整体的loss衰减到需要加权重的地步。

paper中alpha取0.25,gama取2效果最好

代办项:知乎上很多人说这个很boosting很像,为什么?

上面自己做的总结是有问题的,自己认为ssd中正负样本严重失衡(1:1000),focal loss是在解决ssd的样本不平衡问题。但是实际上,ssd训练的时候通过hard mining选负样本,实现了正负样本1:3。focal loss主旨是:ssd按照ohem选出了loss较大的,但忽略了那些loss较小的easy的负样本,虽然这些easy负样本loss很小,但数量多,加起来的loss较大,对最终loss有一定贡献。作者想把这些loss较小的也融入到loss计算中。但如果直接计算所有的loss,loss会被那些easy的负样本主导,因为数量太多,加起来的loss就大了。也就是说,作者是想融入一些easy example,希望他们能有助于训练,但又不希望他们主导loss。这个时候就用了公式进行衰减那些easy example,让他们对loss做贡献,但又不至于主导loss,并且通过balanced crossentropy平衡类别。

对于two stage来说,rpn阶段是保持的1:1的正负样本比例,但rpn阶段也是有大量的负样本,这个阶段类似于ssd。按道理来说,实现应该是把所有的rpn的anchor都拿来训练使用focal-loss,但我做的时候还是用的512个正负样本,其实这个实验稍稍有点不正确,但性能上依旧能将ap值提升1.5个点。对于一般物体检测fast阶段,提取的roi实际上只有2000个,并且负样本一般不会太多,不会像ssd那样1:1000,所以王乃岩觉得这个阶段使用focal显然是没有太大意义的。但有个特殊的地方,小物体,我自己做小物体的时候,选512个roi(根据wk的来的,但是faster原版的代码应该是128),很多图片正样本只有几个,但负样本几百个,这个时候其实也可以考虑使用focal-loss.

随机推荐

  1. 利用if else 求已发奖金总数

    class Program    {        static void Main(string[] args)        {            while (true)           ...

  2. 【转】Android:控件Spinner实现下拉列表

    原文网址:http://www.cnblogs.com/tinyphp/p/3858920.html 在Web开发中,HTML提供了下拉列表的实现,就是使用<select>元素实现一个下拉 ...

  3. Linux高性能server规划——处理池和线程池

    进程池和线程池 池的概念 由于server的硬件资源"充裕".那么提高server性能的一个非常直接的方法就是以空间换时间.即"浪费"server的硬件资源.以 ...

  4. 数据库中间件 Sharding-JDBC 源码分析 —— SQL 解析(三)之查询SQL

  5. python数据挖掘orange

    http://blog.csdn.net/pipisorry/article/details/52845804 orange的安装 linux下的安装 先安装依赖pyqt4[PyQt教程 - pyth ...

  6. Linux跑脚本用sh和./有什么区别?(转)

    sh是一个shell.运行sh a.sh,表示我使用sh来解释这个脚本:如果我直接运行./a.sh,首先你会查找脚本第一行是否指定了解释器,如果没指定,那么就用当前系统默认的shell(大多数linu ...

  7. GWAS文献解读:The stability of educational achievement across school years is largely explained by genetic factors

    方法 从NPD(英国数据库,收集有关学生在学年中学业成绩的数据)和TEDS(英国国家课程指南报告成绩数据库,由国家教育研究基金会和资格与课程管理局制定标准化核心学术课程)数据库获得双胞胎的学业成绩数据 ...

  8. 读取Excel,单元格内容大于255个字符自动被截取的问题

    DataSet ds = new DataSet(); cl_initPage.v_DeBugLog("ExcelDataSource进入"); string strConn; s ...

  9. 华为交换机配置NTP服务端/客户端

    作者:邓聪聪 配置设备作为NTP服务器 单播客户端/服务器模式 # 配置NTP主时钟,层数为2. <HUAWEI> system-view [HUAWEI] ntp refclock-ma ...

  10. haproxy + keepalived 实现高可用负载均衡集群

    1. 首先准备两台tomcat机器,作为集群的单点server. 第一台: 1)tomcat,需要Java的支持,所以同样要安装Java环境. 安装非常简单. tar xf  jdk-7u65-lin ...