OHEM代码梳理[通俗易懂]

OHEM代码梳理[通俗易懂]传送门:相关OHEM的介绍:检测模型改进—OHEM与Focal-Loss算法总结代码地址:OHEM1.前言有关OHEM的介绍请参考上面给出的链接,这里主要就OHEM是怎么运行的做一些简单的分析,整个OHEM的代码也不是很多,这里将算法的步骤归纳为:1)计算检测器的损失,这部分是使用和最后fc6、fc7预测头一样的共享参数,预测分类与边界框回归的结果,将预测的结果与GT进行比较得…

大家好,又见面了,我是你们的朋友全栈君。

  • 传送门
  1. 相关OHEM的介绍:检测模型改进—OHEM与Focal-Loss算法总结
  2. 代码地址:OHEM

1. 前言

有关OHEM的介绍请参考上面给出的链接,这里主要就OHEM是怎么运行的做一些简单的分析,整个OHEM的代码也不是很多,这里将算法的步骤归纳为:
1)计算检测器的损失,这部分是使用和最后fc6、fc7预测头一样的共享参数,预测分类与边界框回归的结果,将预测的结果与GT进行比较得到分类和边界框回归的loss,这里的损失是将两种损失相加得到的;
2)使用阈值为0.7的NMS预先处理一遍检测框,去除一些无效的检测框;
3)NMS之后的检测框按照loss由大到小排列,选取一定数目(由两个数取最小决定)的边界框返回。
下面是OHEM在网络定义文件中的定义,方便后面查看相关代码的时候查找对应条目。

layer {
  name: "hard_roi_mining"
  type: "Python"
  bottom: "cls_prob_readonly"
  bottom: "bbox_pred_readonly"
  bottom: "rois"
  bottom: "labels"
  bottom: "bbox_targets"
  bottom: "bbox_inside_weights"
  bottom: "bbox_outside_weights"
  top: "rois_hard"
  top: "labels_hard"
  top: "bbox_targets_hard"
  top: "bbox_inside_weights_hard"
  top: "bbox_outside_weights_hard"
  propagate_down: false
  propagate_down: false
  propagate_down: false
  propagate_down: false
  propagate_down: false
  propagate_down: false
  propagate_down: false
  python_param {
    module: "roi_data_layer.layer"
    layer: "OHEMDataLayer"
    param_str: "'num_classes': 6" #6
  }
}

2. OHEM代码简单梳理

2.1 OHEMDataLayer

class OHEMDataLayer(caffe.Layer):
    """Online Hard-example Mining Layer."""
    def setup(self, bottom, top):
        """Setup the OHEMDataLayer."""

        # parse the layer parameter string, which must be valid YAML
        layer_params = yaml.load(self.param_str_)

        self._num_classes = layer_params['num_classes']	 # 获取分类数目

        self._name_to_bottom_map = { 
     # 将bottom的blob名称与index使用dict关联
            'cls_prob_readonly': 0,
            'bbox_pred_readonly': 1,
            'rois': 2,
            'labels': 3}

        if cfg.TRAIN.BBOX_REG:  # 有边界框回归
            self._name_to_bottom_map['bbox_targets'] = 4
            self._name_to_bottom_map['bbox_loss_weights'] = 5

        self._name_to_top_map = { 
   }  # 同理top的blob名称也要与index关联起来
        ……
	
	# 前向传播函数
    def forward(self, bottom, top):
        """Compute loss, select RoIs using OHEM. Use RoIs to get blobs and copy them into this layer's top blob vector."""

        cls_prob = bottom[0].data  # 获取对应bottom的数据
        bbox_pred = bottom[1].data
        rois = bottom[2].data
        labels = bottom[3].data
        if cfg.TRAIN.BBOX_REG:
            bbox_target = bottom[4].data
            bbox_inside_weights = bottom[5].data
            bbox_outside_weights = bottom[6].data
        else:
            bbox_target = None
            bbox_inside_weights = None
            bbox_outside_weights = None

        flt_min = np.finfo(float).eps
        # 计算分类的损失
        loss = [ -1 * np.log(max(x, flt_min)) \
            for x in [cls_prob[i,label] for i, label in enumerate(labels)]]

        # 计算边界框回归的损失,并且将两个损失加起来
        if cfg.TRAIN.BBOX_REG:
            # bounding-box regression loss
            # d := w * (b0 - b1)
            # smoothL1(x) = 0.5 * x^2 if |x| < 1
            # |x| - 0.5 otherwise
            def smoothL1(x):
                if abs(x) < 1:
                    return 0.5 * x * x
                else:
                    return abs(x) - 0.5

            bbox_loss = np.zeros(labels.shape[0])  # 边界框损失
            for i in np.where(labels > 0 )[0]:
                indices = np.where(bbox_inside_weights[i,:] != 0)[0]
                bbox_loss[i] = sum(bbox_outside_weights[i,indices] * [smoothL1(x) \
                    for x in bbox_inside_weights[i,indices] * (bbox_pred[i,indices] - bbox_target[i,indices])])
            loss += bbox_loss  # 两者损失相加

        # 筛选出损失比较大的返回
        blobs = get_ohem_minibatch(loss, rois, labels, bbox_target, \
            bbox_inside_weights, bbox_outside_weights)

		# 给top blob赋值
        for blob_name, blob in blobs.iteritems():
            top_ind = self._name_to_top_map[blob_name]
            # Reshape net's input blobs
            top[top_ind].reshape(*(blob.shape))
            # Copy data into net's input blobs
            top[top_ind].data[...] = blob.astype(np.float32, copy=False)

2.2 get_ohem_minibatch

# 获取OHEM训练的batch
def get_ohem_minibatch(loss, rois, labels, bbox_targets=None,
                       bbox_inside_weights=None, bbox_outside_weights=None):
    """Given rois and their loss, construct a minibatch using OHEM."""
    loss = np.array(loss)
	
	# 使用NMS过滤检测框
    if cfg.TRAIN.OHEM_USE_NMS:	# NMS thresh=0.7
        # Do NMS using loss for de-dup and diversity
        keep_inds = []
        nms_thresh = cfg.TRAIN.OHEM_NMS_THRESH  # 0.7
        source_img_ids = [roi[0] for roi in rois] # 0或1(背景与前景)
        for img_id in np.unique(source_img_ids):
            for label in np.unique(labels):
                sel_indx = np.where(np.logical_and(labels == label, \
                                    source_img_ids == img_id))[0]
                if not len(sel_indx):
                    continue
                boxes = np.concatenate((rois[sel_indx, 1:],
                        loss[sel_indx][:,np.newaxis]), axis=1).astype(np.float32)
                keep_inds.extend(sel_indx[nms(boxes, nms_thresh)])

        hard_keep_inds = select_hard_examples(loss[keep_inds])  # 按照损失排序选择样本
        hard_inds = np.array(keep_inds)[hard_keep_inds]  # 最后保留下来的困难样本索引
    else:
        hard_inds = select_hard_examples(loss)

    blobs = { 
   'rois_hard': rois[hard_inds, :].copy(),
             'labels_hard': labels[hard_inds].copy()}
    if bbox_targets is not None:
        assert cfg.TRAIN.BBOX_REG
        blobs['bbox_targets_hard'] = bbox_targets[hard_inds, :].copy()
        blobs['bbox_inside_weights_hard'] = bbox_inside_weights[hard_inds, :].copy()
        blobs['bbox_outside_weights_hard'] = bbox_outside_weights[hard_inds, :].copy()

    return blobs

def select_hard_examples(loss):
    """Select hard rois."""
    # Sort and select top hard examples.
    sorted_indices = np.argsort(loss)[::-1]
    hard_keep_inds = sorted_indices[0:np.minimum(len(loss), cfg.TRAIN.BATCH_SIZE)]
    # (explore more ways of selecting examples in this function; e.g., sampling)

    return hard_keep_inds
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/139136.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • minicom指令_minicom 串口通信设置

    minicom指令_minicom 串口通信设置L文件捕获开关。打开时,所有到屏幕的输出也将被捕获到文件中。M发送modem初始化串。若你online,且DCD线设为on,则modem被初始化前将要求你进行确认。O配置minicom。转到配置菜单。P通信参数。允许你改变bps速率,奇偶校验和位数。Q不复位modem就退出minicom。如果改变了macros,而且未存盘,会提供你一个save的机会。R接收文件。从各种协议(外部)中进行选择。若f…

  • 简单window.open()使用方法和按钮关闭window.open页面

    简单window.open()使用方法和按钮关闭window.open页面简单window.open()使用方法和按钮关闭window.open页面

  • 高通搜网流程_搜艺贝流程

    高通搜网流程_搜艺贝流程这里写自定义目录标题欢迎使用Markdown编辑器新的改变功能快捷键合理的创建标题,有助于目录的生成如何改变文本的样式插入链接与图片如何插入一段漂亮的代码片生成一个适合你的列表创建一个表格设定内容居中、居左、居右SmartyPants创建一个自定义列表如何创建一个注脚注释也是必不可少的KaTeX数学公式新的甘特图功能,丰富你的文章UML图表FLowchart流程图导出与导入导出导入https:…

  • golang []byte和string相互转换

    golang []byte和string相互转换测试例子:packagemainimport(“fmt”)funcmain(){str2:=”hello”data2:=[]byte(str2)fmt.Println(data2)str2=string(data2[:])fmt.Println(str2)}

  • QMovie的使用

    QMovie的使用QMovie是一个可以存放动态视频的类今天第一次使用,记录一下一般是配合QLabel使用的,可以用来存放GIF动态图 m_background=newQLabel(this);m_background->setGeometry(0,0,MENU_WINDOW_WIDTH,MENU_WINDOW_HEIGHT);QMovie*backgroundMovie=newQMovie(“:/images/menu/MenuBackground.gif”,QByteArra

  • sql多表联合查询详解_sql多表查询例子

    sql多表联合查询详解_sql多表查询例子sql语句会用到许多查询语句,如果牵扯到多张表的时候一般会需要复杂查询方式:    嵌套查询:   select*frombi_BillItemwhereBillIDin(selectBillIDfrombi_BillwhereIsArchived=’0’andIsCheckOuting=’2′)groupbymenuId,MenuPri…

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号