OHEM(Online Hard Example Mining)在线难例挖掘(在线困难样例挖掘) & HNM

OHEM(Online Hard Example Mining)在线难例挖掘(在线困难样例挖掘) & HNMHardNegatieMining与OnlineHardExampleMining(OHEM)都属于难例挖掘,它是解决目标检测老大难问题的常用办法,运用于R-CNN,fastR-CNN,fasterrcnn等two-stage模型与SSD等(有anchor的)one-stage模型训练时的训练方法。OHEM和难负例挖掘名字上的不同。HardNegativeMining只注意难负例 OHEM则注意所有难例,不论正负(Loss大的例子)难例挖掘的思想…

大家好,又见面了,我是你们的朋友全栈君。

      Hard Negatie Mining与Online Hard Example Mining(OHEM)都属于难例挖掘,它是解决目标检测老大难问题的常用办法,运用于R-CNN,fast R-CNN,faster rcnn等two-stage模型与SSD等(有anchor的)one-stage模型训练时的训练方法。

OHEM和难负例挖掘名字上的不同。

  • Hard Negative Mining只注意难负例
  • OHEM 则注意所有难例,不论正负(Loss大的例子)

      难例挖掘的思想可以解决很多样本不平衡/简单样本过多的问题,比如说分类网络,将hard sample 补充到数据集里,重新丢进网络当中,就好像给网络准备一个错题集,哪里不会点哪里。

      难例挖掘与非极大值抑制 NMS 一样,都是为了解决目标检测老大难问题(样本不平衡+低召回率)及其带来的副作用。

      根据每个RoIs的loss的大小来决定哪些是难样例, 哪些是简单样例, 通过这种方法, 可以更高效的训练网络, 并且可以使得网络获得更小的训练loss

Pytorch实现

def ohem_loss(
    batch_size, cls_pred, cls_target, loc_pred, loc_target, smooth_l1_sigma=1.0
):
    """
    Arguments:
        batch_size (int): number of sampled rois for bbox head training
        loc_pred (FloatTensor): [R, 4], location of positive rois
        loc_target (FloatTensor): [R, 4], location of positive rois
        pos_mask (FloatTensor): [R], binary mask for sampled positive rois
        cls_pred (FloatTensor): [R, C]
        cls_target (LongTensor): [R]
    Returns:
        cls_loss, loc_loss (FloatTensor)
    """
    ohem_cls_loss = F.cross_entropy(cls_pred, cls_target, reduction='none', ignore_index=-1)
    ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target, sigma=smooth_l1_sigma, reduce=False)
    #这里先暂存下正常的分类loss和回归loss
    loss = ohem_cls_loss + ohem_loc_loss
    #然后对分类和回归loss求和
 
  
    sorted_ohem_loss, idx = torch.sort(loss, descending=True)
    #再对loss进行降序排列
    keep_num = min(sorted_ohem_loss.size()[0], batch_size)
    #得到需要保留的loss数量
    if keep_num < sorted_ohem_loss.size()[0]:
    #这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留
        keep_idx_cuda = idx[:keep_num]
        #保留到需要keep的数目
        ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]
        ohem_loc_loss = ohem_loc_loss[keep_idx_cuda]
        #分类和回归保留相同的数目
    cls_loss = ohem_cls_loss.sum() / keep_num
    loc_loss = ohem_loc_loss.sum() / keep_num
    #然后分别对分类和回归loss求均值
    return cls_loss, loc_loss

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/139390.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • Java面试之字符串常量池「建议收藏」

    Java面试之字符串常量池「建议收藏」来源:https://segmentfault.com/a/1190000009888357作为最基础的引用数据类型,Java设计者为String提供了字符串常量池以提高其性能,那么字符串常量池的具体原理是什么,我们带着以下三个问题,去理解字符串常量池:字符串常量池的设计意图是什么?字符串常量池在哪里?如何操作字符串常量池?字符串常量池的设计思想a.字符串的分…

  • 自己定义对象的监听方式

    自己定义对象的监听方式

    2021年11月16日
  • 二叉树 二叉搜索树_二叉树和二叉搜索树

    二叉树 二叉搜索树_二叉树和二叉搜索树原题链接一棵二叉搜索树可被递归地定义为具有下列性质的二叉树:对于任一结点,其左子树中所有结点的键值小于该结点的键值;其右子树中所有结点的键值大于等于该结点的键值;其左右子树都是二叉搜索树。所谓二叉搜索树的“镜像”,即将所有结点的左右子树对换位置后所得到的树。给定一个整数键值序列,现请你编写程序,判断这是否是对一棵二叉搜索树或其镜像进行前序遍历的结果。输入格式:输入的第一行给出正整数 N(≤1000)。随后一行给出 N 个整数键值,其间以空格分隔。输出格式:如果输入序列是对一棵二叉搜索树或

  • serialized objects

    serialized objectsThisstartedagain…athreadfrom*****:WhatdoyouconsiderabestpracticeforserialVersionUID?T______________________________________________From:*******Sent:Thursday,Nov…

  • linux下连接mysql数据库命令,linux连接mysql命令[通俗易懂]

    linux下连接mysql数据库命令,linux连接mysql命令[通俗易懂]linux连接mysql是最基本的操作之一,对于初学者来说我们可以通过命令来连接mysql,下面由学习啦小编为大家整理了linux下连接mysql命令的相关知识,希望对大家有所帮助!linux连接MYSQL命令格式:mysql-h主机地址-u用户名-p用户密码linux连接mysql命令实例1、连接到本机上的MYSQL找到mysql的安装目录,一般可以直接键入命令mysql-uroot…

  • 线程通信机制—共享内存:消息传递

    线程通信机制—共享内存:消息传递在并发编程中,我们必须考虑的问题时如何在两个线程间进行通讯。这里的通讯指的是不同的线程之间如何交换信息。目前有两种方式:1、共享内存2、消息传递(actor模型) 共享内存共享内存这种方式比较常见,我们经常会设置一个共享变量。然后多个线程去操作同一个共享变量。从而达到线程通讯的目的。例如,我们使用多个线程去执行页面抓取任务,我们可以使用一个共享变量count来记录任务

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号