OHEM介绍

OHEM介绍在two-stage检测算法中,RPN阶段会生成大量的检测框,由于很多时候一张图片可能只会有少量几个目标,也就是说绝大部分框是没有目标的,为了减少计算就需要进行sample,一般来说fasterrcnn的sample机制是算框和label的IOU,大于0.7认为是正样本,小于0.3是负样本。但是单纯的random_sample选出来的框不一定是最容易错的框。那么ohem就是较好的一种正负样本策略

大家好,又见面了,我是你们的朋友全栈君。

目标检测之OHEM介绍

论文地址:https://arxiv.org/pdf/1604.03540.pdf

在two-stage检测算法中,RPN阶段会生成大量的检测框,由于很多时候一张图片可能只会有少量几个目标,也就是说绝大部分框是没有目标的,为了减少计算就需要进行sample,一般来说fasterrcnn的sample机制是算框和label的IOU,大于0.7认为是正样本,小于0.3是负样本。但是单纯的random_sample选出来的框不一定是最容易错的框。那么ohem就是这样的一种正负样本策略,通过根据框的loss得到最容易错的框。可以理解为错题集,我们只会把最容易错的题放到错题集。

首先是 negative,即负样本,其次是 hard,说明是困难样本,也可以说是容易将负样本看成正样本的那些样本,例如 RPN框里没有物体,全是背景,这时候分类器很容易正确分类成背景,这个就叫 easy negative;如果 框里有二分之一个物体,标签仍是负样本,这时候分类器就容易把他看成正样本,这时候就是 had negative。hard negative mining 就是多找一些 hard negative 加入负样本集,进行训练。
接下来我们来看看mmdection的ohem实现:

class OHEMSampler(BaseSampler):
r"""Online Hard Example Mining Sampler described in `Training Region-based
Object Detectors with Online Hard Example Mining
<https://arxiv.org/abs/1604.03540>`_.
"""
def __init__(self,
num,
pos_fraction,
context,
neg_pos_ub=-1,
add_gt_as_proposals=True,
**kwargs):
super(OHEMSampler, self).__init__(num, pos_fraction, neg_pos_ub,
add_gt_as_proposals)
self.context = context
if not hasattr(self.context, 'num_stages'):
self.bbox_head = self.context.bbox_head
else:
self.bbox_head = self.context.bbox_head[self.context.current_stage]
def hard_mining(self, inds, num_expected, bboxes, labels, feats):
with torch.no_grad():
rois = bbox2roi([bboxes])
if not hasattr(self.context, 'num_stages'):
bbox_results = self.context._bbox_forward(feats, rois)
else:
bbox_results = self.context._bbox_forward(
self.context.current_stage, feats, rois)
cls_score = bbox_results['cls_score']
loss = self.bbox_head.loss(
cls_score=cls_score,
bbox_pred=None,
rois=rois,
labels=labels,
label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None,
bbox_weights=None,
reduction_override='none')['loss_cls']
_, topk_loss_inds = loss.topk(num_expected)
return inds[topk_loss_inds]
def _sample_pos(self,
assign_result,
num_expected,
bboxes=None,
feats=None,
**kwargs):
"""Sample positive boxes.
Args:
assign_result (:obj:`AssignResult`): Assigned results
num_expected (int): Number of expected positive samples
bboxes (torch.Tensor, optional): Boxes. Defaults to None.
feats (list[torch.Tensor], optional): Multi-level features.
Defaults to None.
Returns:
torch.Tensor: Indices  of positive samples
"""
# Sample some hard positive samples
pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
if pos_inds.numel() != 0:
pos_inds = pos_inds.squeeze(1)
if pos_inds.numel() <= num_expected:
return pos_inds
else:
return self.hard_mining(pos_inds, num_expected, bboxes[pos_inds],
assign_result.labels[pos_inds], feats)
def _sample_neg(self,
assign_result,
num_expected,
bboxes=None,
feats=None,
**kwargs):
"""Sample negative boxes.
Args:
assign_result (:obj:`AssignResult`): Assigned results
num_expected (int): Number of expected negative samples
bboxes (torch.Tensor, optional): Boxes. Defaults to None.
feats (list[torch.Tensor], optional): Multi-level features.
Defaults to None.
Returns:
torch.Tensor: Indices  of negative samples
"""
# Sample some hard negative samples
neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
if neg_inds.numel() != 0:
neg_inds = neg_inds.squeeze(1)
if len(neg_inds) <= num_expected:
return neg_inds
else:
neg_labels = assign_result.labels.new_empty(
neg_inds.size(0)).fill_(self.bbox_head.num_classes)
return self.hard_mining(neg_inds, num_expected, bboxes[neg_inds],
neg_labels, feats)

上面代码就是整个ohem的sample过程,整个ohem分为三个函数分别是hard_mining,_sample_pos,_sample_neg,_sample_pos和_sample_neg是获得对应的困难正样本/困难负样本,由hard_mining完成整个sample过程:根据输入的box_list得到对应的bbox_loss的list取最大的256/512个,由于这一批box的loss最大,就可以认为是最难区分的box,这一批bbox就是所谓的
困难正样本/困难负样本。

至此ohem阶段完成,后面就是对候选框的分类和回归,因为ohem阶段得到了容易分错的样本框,所以在后续训练阶段模型会对这些容易分错的框重点关注,有利于困难样本的检测,提升了模型的效果。

实际上提升还是很明显的:
在这里插入图片描述

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/139162.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)
blank

相关推荐

  • 【18】进大厂必须掌握的面试题-15个Kafka面试

    点击上方“全栈程序员社区”,星标公众号 重磅干货,第一时间送达 1.什么是kafka? Apache Kafka是由Apache开发的一种发布订阅消息系统。 2.kafka的3个关…

  • pycharm怎么配置python环境anaconda_ug编程电脑配置要求

    pycharm怎么配置python环境anaconda_ug编程电脑配置要求1.介绍Python:一种解释型、面向对象、动态数据类型的高级程序设计语言。PyCharm:一款好用的集成开发环境。Conda:Python环境管理器,方便我们管理和切换编程环境。2.下载2.1Conda下载Miniconda下载链接Anaconda下载链接Miniconda是Anaconda的压缩版,Miniconda只包含conda的核心内容,Anaconda中包含了Spyder集成开发环境等扩充内容。Miniconda的功能足矣。根据计算机的实际情况选择下载安装包,上图中Py

  • 十六进制与十进制的互相换换计算

    十六进制与十进制的互相换换计算十六进制与十进制的互相换换计算

  • MySQL设计之三范式的理解

    MySQL设计之三范式的理解

    2021年11月10日
  • 【linux 】linux 命令:查看 Linux 服务器配置

    目录一、服务器型号二、操作系统三、CPU四、内存五、硬盘六、其他一、服务器型号dmidecode|grep”SystemInformation”-A9|egrep”Manufacturer|Product”二、操作系统(1)当前操作系统发行版信息:cat/etc/redhat-release(2)操作系统发行版详细信息:lsb_release-a三、CPU(1)CPU统计信息:lscpu(2)CPU型号

  • 渗透测试工具——SET「建议收藏」

    渗透测试工具——SET「建议收藏」社会工程学使用计谋、假情报或人际关系去获得利益和其他敏感信息。 攻击对象一-一人一-秘密信息的保存者,信息安全链中最薄弱的环节。 利用受害者的本能反应、好奇心、信任、贪婪等心理弱点进行欺骗、伤害。常见的社会工程学攻击方式环境渗透:对特定的环境进行渗透,是社会工程学为了获得所需的情报或敏感信息经常采用的手段之一。社会工程学攻击者通过观察目标对电子邮件的响应速度、重视程度以及可能提供的相关资料,比如一个人的姓名、生日、ID电话号码、管理员的IP地址、邮箱等,通过这些收集信息来判断目标的网

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号