大家好,又见面了,我是你们的朋友全栈君。
- 传送门:
- 相关OHEM的介绍:检测模型改进—OHEM与Focal-Loss算法总结
- 代码地址:OHEM
1. 前言
有关OHEM的介绍请参考上面给出的链接,这里主要就OHEM是怎么运行的做一些简单的分析,整个OHEM的代码也不是很多,这里将算法的步骤归纳为:
1)计算检测器的损失,这部分是使用和最后fc6、fc7预测头一样的共享参数,预测分类与边界框回归的结果,将预测的结果与GT进行比较得到分类和边界框回归的loss,这里的损失是将两种损失相加得到的;
2)使用阈值为0.7的NMS预先处理一遍检测框,去除一些无效的检测框;
3)NMS之后的检测框按照loss由大到小排列,选取一定数目(由两个数取最小决定)的边界框返回。
下面是OHEM在网络定义文件中的定义,方便后面查看相关代码的时候查找对应条目。
layer {
name: "hard_roi_mining"
type: "Python"
bottom: "cls_prob_readonly"
bottom: "bbox_pred_readonly"
bottom: "rois"
bottom: "labels"
bottom: "bbox_targets"
bottom: "bbox_inside_weights"
bottom: "bbox_outside_weights"
top: "rois_hard"
top: "labels_hard"
top: "bbox_targets_hard"
top: "bbox_inside_weights_hard"
top: "bbox_outside_weights_hard"
propagate_down: false
propagate_down: false
propagate_down: false
propagate_down: false
propagate_down: false
propagate_down: false
propagate_down: false
python_param {
module: "roi_data_layer.layer"
layer: "OHEMDataLayer"
param_str: "'num_classes': 6" #6
}
}
2. OHEM代码简单梳理
2.1 OHEMDataLayer
class OHEMDataLayer(caffe.Layer):
"""Online Hard-example Mining Layer."""
def setup(self, bottom, top):
"""Setup the OHEMDataLayer."""
# parse the layer parameter string, which must be valid YAML
layer_params = yaml.load(self.param_str_)
self._num_classes = layer_params['num_classes'] # 获取分类数目
self._name_to_bottom_map = {
# 将bottom的blob名称与index使用dict关联
'cls_prob_readonly': 0,
'bbox_pred_readonly': 1,
'rois': 2,
'labels': 3}
if cfg.TRAIN.BBOX_REG: # 有边界框回归
self._name_to_bottom_map['bbox_targets'] = 4
self._name_to_bottom_map['bbox_loss_weights'] = 5
self._name_to_top_map = {
} # 同理top的blob名称也要与index关联起来
……
# 前向传播函数
def forward(self, bottom, top):
"""Compute loss, select RoIs using OHEM. Use RoIs to get blobs and copy them into this layer's top blob vector."""
cls_prob = bottom[0].data # 获取对应bottom的数据
bbox_pred = bottom[1].data
rois = bottom[2].data
labels = bottom[3].data
if cfg.TRAIN.BBOX_REG:
bbox_target = bottom[4].data
bbox_inside_weights = bottom[5].data
bbox_outside_weights = bottom[6].data
else:
bbox_target = None
bbox_inside_weights = None
bbox_outside_weights = None
flt_min = np.finfo(float).eps
# 计算分类的损失
loss = [ -1 * np.log(max(x, flt_min)) \
for x in [cls_prob[i,label] for i, label in enumerate(labels)]]
# 计算边界框回归的损失,并且将两个损失加起来
if cfg.TRAIN.BBOX_REG:
# bounding-box regression loss
# d := w * (b0 - b1)
# smoothL1(x) = 0.5 * x^2 if |x| < 1
# |x| - 0.5 otherwise
def smoothL1(x):
if abs(x) < 1:
return 0.5 * x * x
else:
return abs(x) - 0.5
bbox_loss = np.zeros(labels.shape[0]) # 边界框损失
for i in np.where(labels > 0 )[0]:
indices = np.where(bbox_inside_weights[i,:] != 0)[0]
bbox_loss[i] = sum(bbox_outside_weights[i,indices] * [smoothL1(x) \
for x in bbox_inside_weights[i,indices] * (bbox_pred[i,indices] - bbox_target[i,indices])])
loss += bbox_loss # 两者损失相加
# 筛选出损失比较大的返回
blobs = get_ohem_minibatch(loss, rois, labels, bbox_target, \
bbox_inside_weights, bbox_outside_weights)
# 给top blob赋值
for blob_name, blob in blobs.iteritems():
top_ind = self._name_to_top_map[blob_name]
# Reshape net's input blobs
top[top_ind].reshape(*(blob.shape))
# Copy data into net's input blobs
top[top_ind].data[...] = blob.astype(np.float32, copy=False)
2.2 get_ohem_minibatch
# 获取OHEM训练的batch
def get_ohem_minibatch(loss, rois, labels, bbox_targets=None,
bbox_inside_weights=None, bbox_outside_weights=None):
"""Given rois and their loss, construct a minibatch using OHEM."""
loss = np.array(loss)
# 使用NMS过滤检测框
if cfg.TRAIN.OHEM_USE_NMS: # NMS thresh=0.7
# Do NMS using loss for de-dup and diversity
keep_inds = []
nms_thresh = cfg.TRAIN.OHEM_NMS_THRESH # 0.7
source_img_ids = [roi[0] for roi in rois] # 0或1(背景与前景)
for img_id in np.unique(source_img_ids):
for label in np.unique(labels):
sel_indx = np.where(np.logical_and(labels == label, \
source_img_ids == img_id))[0]
if not len(sel_indx):
continue
boxes = np.concatenate((rois[sel_indx, 1:],
loss[sel_indx][:,np.newaxis]), axis=1).astype(np.float32)
keep_inds.extend(sel_indx[nms(boxes, nms_thresh)])
hard_keep_inds = select_hard_examples(loss[keep_inds]) # 按照损失排序选择样本
hard_inds = np.array(keep_inds)[hard_keep_inds] # 最后保留下来的困难样本索引
else:
hard_inds = select_hard_examples(loss)
blobs = {
'rois_hard': rois[hard_inds, :].copy(),
'labels_hard': labels[hard_inds].copy()}
if bbox_targets is not None:
assert cfg.TRAIN.BBOX_REG
blobs['bbox_targets_hard'] = bbox_targets[hard_inds, :].copy()
blobs['bbox_inside_weights_hard'] = bbox_inside_weights[hard_inds, :].copy()
blobs['bbox_outside_weights_hard'] = bbox_outside_weights[hard_inds, :].copy()
return blobs
def select_hard_examples(loss):
"""Select hard rois."""
# Sort and select top hard examples.
sorted_indices = np.argsort(loss)[::-1]
hard_keep_inds = sorted_indices[0:np.minimum(len(loss), cfg.TRAIN.BATCH_SIZE)]
# (explore more ways of selecting examples in this function; e.g., sampling)
return hard_keep_inds
发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/139136.html原文链接:https://javaforall.cn
【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛
【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...