OHEM代码梳理[通俗易懂]

OHEM代码梳理[通俗易懂]传送门:相关OHEM的介绍:检测模型改进—OHEM与Focal-Loss算法总结代码地址:OHEM1.前言有关OHEM的介绍请参考上面给出的链接,这里主要就OHEM是怎么运行的做一些简单的分析,整个OHEM的代码也不是很多,这里将算法的步骤归纳为:1)计算检测器的损失,这部分是使用和最后fc6、fc7预测头一样的共享参数,预测分类与边界框回归的结果,将预测的结果与GT进行比较得…

大家好,又见面了,我是你们的朋友全栈君。

  • 传送门
  1. 相关OHEM的介绍:检测模型改进—OHEM与Focal-Loss算法总结
  2. 代码地址:OHEM

1. 前言

有关OHEM的介绍请参考上面给出的链接,这里主要就OHEM是怎么运行的做一些简单的分析,整个OHEM的代码也不是很多,这里将算法的步骤归纳为:
1)计算检测器的损失,这部分是使用和最后fc6、fc7预测头一样的共享参数,预测分类与边界框回归的结果,将预测的结果与GT进行比较得到分类和边界框回归的loss,这里的损失是将两种损失相加得到的;
2)使用阈值为0.7的NMS预先处理一遍检测框,去除一些无效的检测框;
3)NMS之后的检测框按照loss由大到小排列,选取一定数目(由两个数取最小决定)的边界框返回。
下面是OHEM在网络定义文件中的定义,方便后面查看相关代码的时候查找对应条目。

layer {
  name: "hard_roi_mining"
  type: "Python"
  bottom: "cls_prob_readonly"
  bottom: "bbox_pred_readonly"
  bottom: "rois"
  bottom: "labels"
  bottom: "bbox_targets"
  bottom: "bbox_inside_weights"
  bottom: "bbox_outside_weights"
  top: "rois_hard"
  top: "labels_hard"
  top: "bbox_targets_hard"
  top: "bbox_inside_weights_hard"
  top: "bbox_outside_weights_hard"
  propagate_down: false
  propagate_down: false
  propagate_down: false
  propagate_down: false
  propagate_down: false
  propagate_down: false
  propagate_down: false
  python_param {
    module: "roi_data_layer.layer"
    layer: "OHEMDataLayer"
    param_str: "'num_classes': 6" #6
  }
}

2. OHEM代码简单梳理

2.1 OHEMDataLayer

class OHEMDataLayer(caffe.Layer):
"""Online Hard-example Mining Layer."""
def setup(self, bottom, top):
"""Setup the OHEMDataLayer."""
# parse the layer parameter string, which must be valid YAML
layer_params = yaml.load(self.param_str_)
self._num_classes = layer_params['num_classes']	 # 获取分类数目
self._name_to_bottom_map = { 
  # 将bottom的blob名称与index使用dict关联
'cls_prob_readonly': 0,
'bbox_pred_readonly': 1,
'rois': 2,
'labels': 3}
if cfg.TRAIN.BBOX_REG:  # 有边界框回归
self._name_to_bottom_map['bbox_targets'] = 4
self._name_to_bottom_map['bbox_loss_weights'] = 5
self._name_to_top_map = { 
}  # 同理top的blob名称也要与index关联起来
……
# 前向传播函数
def forward(self, bottom, top):
"""Compute loss, select RoIs using OHEM. Use RoIs to get blobs and copy them into this layer's top blob vector."""
cls_prob = bottom[0].data  # 获取对应bottom的数据
bbox_pred = bottom[1].data
rois = bottom[2].data
labels = bottom[3].data
if cfg.TRAIN.BBOX_REG:
bbox_target = bottom[4].data
bbox_inside_weights = bottom[5].data
bbox_outside_weights = bottom[6].data
else:
bbox_target = None
bbox_inside_weights = None
bbox_outside_weights = None
flt_min = np.finfo(float).eps
# 计算分类的损失
loss = [ -1 * np.log(max(x, flt_min)) \
for x in [cls_prob[i,label] for i, label in enumerate(labels)]]
# 计算边界框回归的损失,并且将两个损失加起来
if cfg.TRAIN.BBOX_REG:
# bounding-box regression loss
# d := w * (b0 - b1)
# smoothL1(x) = 0.5 * x^2 if |x| < 1
# |x| - 0.5 otherwise
def smoothL1(x):
if abs(x) < 1:
return 0.5 * x * x
else:
return abs(x) - 0.5
bbox_loss = np.zeros(labels.shape[0])  # 边界框损失
for i in np.where(labels > 0 )[0]:
indices = np.where(bbox_inside_weights[i,:] != 0)[0]
bbox_loss[i] = sum(bbox_outside_weights[i,indices] * [smoothL1(x) \
for x in bbox_inside_weights[i,indices] * (bbox_pred[i,indices] - bbox_target[i,indices])])
loss += bbox_loss  # 两者损失相加
# 筛选出损失比较大的返回
blobs = get_ohem_minibatch(loss, rois, labels, bbox_target, \
bbox_inside_weights, bbox_outside_weights)
# 给top blob赋值
for blob_name, blob in blobs.iteritems():
top_ind = self._name_to_top_map[blob_name]
# Reshape net's input blobs
top[top_ind].reshape(*(blob.shape))
# Copy data into net's input blobs
top[top_ind].data[...] = blob.astype(np.float32, copy=False)

2.2 get_ohem_minibatch

# 获取OHEM训练的batch
def get_ohem_minibatch(loss, rois, labels, bbox_targets=None,
bbox_inside_weights=None, bbox_outside_weights=None):
"""Given rois and their loss, construct a minibatch using OHEM."""
loss = np.array(loss)
# 使用NMS过滤检测框
if cfg.TRAIN.OHEM_USE_NMS:	# NMS thresh=0.7
# Do NMS using loss for de-dup and diversity
keep_inds = []
nms_thresh = cfg.TRAIN.OHEM_NMS_THRESH  # 0.7
source_img_ids = [roi[0] for roi in rois] # 0或1(背景与前景)
for img_id in np.unique(source_img_ids):
for label in np.unique(labels):
sel_indx = np.where(np.logical_and(labels == label, \
source_img_ids == img_id))[0]
if not len(sel_indx):
continue
boxes = np.concatenate((rois[sel_indx, 1:],
loss[sel_indx][:,np.newaxis]), axis=1).astype(np.float32)
keep_inds.extend(sel_indx[nms(boxes, nms_thresh)])
hard_keep_inds = select_hard_examples(loss[keep_inds])  # 按照损失排序选择样本
hard_inds = np.array(keep_inds)[hard_keep_inds]  # 最后保留下来的困难样本索引
else:
hard_inds = select_hard_examples(loss)
blobs = { 
'rois_hard': rois[hard_inds, :].copy(),
'labels_hard': labels[hard_inds].copy()}
if bbox_targets is not None:
assert cfg.TRAIN.BBOX_REG
blobs['bbox_targets_hard'] = bbox_targets[hard_inds, :].copy()
blobs['bbox_inside_weights_hard'] = bbox_inside_weights[hard_inds, :].copy()
blobs['bbox_outside_weights_hard'] = bbox_outside_weights[hard_inds, :].copy()
return blobs
def select_hard_examples(loss):
"""Select hard rois."""
# Sort and select top hard examples.
sorted_indices = np.argsort(loss)[::-1]
hard_keep_inds = sorted_indices[0:np.minimum(len(loss), cfg.TRAIN.BATCH_SIZE)]
# (explore more ways of selecting examples in this function; e.g., sampling)
return hard_keep_inds
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/139136.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • 在计算机中1 KB等于多少字节,字节、kb、MB、GB 等单位怎么换算的?1M等于多少kb,1g等于多少kb?…[通俗易懂]

    在计算机中1 KB等于多少字节,字节、kb、MB、GB 等单位怎么换算的?1M等于多少kb,1g等于多少kb?…[通俗易懂]字节、kb、MB、GB等单位怎么换算的?1M等于多少kb,1g等于多少kb?我们查看文件属性时可以看到很多文件和大小是以kb来显示的,很多朋友都知道电脑中文件大小、容量等采用的是字节、kb、MB、GB等单位,那么你知道它们之间怎么换算的吗,如1M等于多少kb,1g等于多少kb,下面小编就和大家一起来分享下相关知识。1M等于多少kb?1MB=1024KB=1048576字节1G等于多少KB?1G=…

  • webstorm 永久激活方法【2021免费激活】

    (webstorm 永久激活方法)2021最新分享一个能用的的激活码出来,希望能帮到需要激活的朋友。目前这个是能用的,但是用的人多了之后也会失效,会不定时更新的,大家持续关注此网站~https://javaforall.cn/100143.htmlIntelliJ2021最新激活注册码,破解教程可免费永久激活,亲测有效,上面是详细链接哦~3S…

  • APUE学习笔记——10.15 sigsetjmp和siglongjmp[通俗易懂]

    APUE学习笔记——10.15 sigsetjmp和siglongjmp

  • Source Insight 4.0 序列号 license文件

    Source Insight 4.0 序列号 license文件安装程序下载在官网上下载SourceInsight4.0的安装程序.目前版本4.00.0098可用30天的试用安装首次启动选择授权方式,这里选择第二个选项,30天试用。点击下一步,输入名称、公司或组织名称、邮箱信息,申请30天的试用。输入完成后,点击下一步,直到安装完成。修改sourceinsight4.exe用16进制编辑器(sublimetext)打开s…

  • java velocity 语法_Velocity语法

    java velocity 语法_Velocity语法1.变量(1)变量的定义:#set($name=”hello”)说明:velocity中变量是弱类型的。当使用#set指令时,括在双引号中的字面字符串将解析和重新解释,如下所示:#set($directoryRoot=”www”)#set($templateName=”index.vm”)#set($template=”$directoryRoot/$templateName”…

  • Nginx sendfile原理详解[通俗易懂]

    Nginx sendfile原理详解[通俗易懂]配置语法语法:sendfileon|off;默认值:sendfileoff;上下文:http,server,location,ifinlocation说明sendfile值为on,指定使用sendfile系统调用来传输文件。sendfile系统调用在两个文件描述符之间直接传递数据(完全在内核中操作),从而避免了数据在内核缓冲区和用户缓冲区之间的拷贝,操作效率很高,被…

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号