OHEM的pytorch代码实现细节

OHEM的pytorch代码实现细节详细解读一下OHEM的实现代码:defohem_loss(batch_size,cls_pred,cls_target,loc_pred,loc_target,smooth_l1_sigma=1.0):”””Arguments:batch_size(int):numberofsampledroisforbboxhe…

大家好,又见面了,我是你们的朋友全栈君。

详细解读一下OHEM的实现代码:

def ohem_loss(
    batch_size, cls_pred, cls_target, loc_pred, loc_target, smooth_l1_sigma=1.0
):
    """
    Arguments:
        batch_size (int): number of sampled rois for bbox head training
        loc_pred (FloatTensor): [R, 4], location of positive rois
        loc_target (FloatTensor): [R, 4], location of positive rois
        pos_mask (FloatTensor): [R], binary mask for sampled positive rois
        cls_pred (FloatTensor): [R, C]
        cls_target (LongTensor): [R]

    Returns:
        cls_loss, loc_loss (FloatTensor)
    """
    ohem_cls_loss = F.cross_entropy(cls_pred, cls_target, reduction='none', ignore_index=-1)
    ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target, sigma=smooth_l1_sigma, reduce=False)
    #这里先暂存下正常的分类loss和回归loss
    loss = ohem_cls_loss + ohem_loc_loss
    #然后对分类和回归loss求和

  
    sorted_ohem_loss, idx = torch.sort(loss, descending=True)
    #再对loss进行降序排列
    keep_num = min(sorted_ohem_loss.size()[0], batch_size)
    #得到需要保留的loss数量
    if keep_num < sorted_ohem_loss.size()[0]:
    #这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留
        keep_idx_cuda = idx[:keep_num]
        #保留到需要keep的数目
        ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]
        ohem_loc_loss = ohem_loc_loss[keep_idx_cuda]
        #分类和回归保留相同的数目
    cls_loss = ohem_cls_loss.sum() / keep_num
    loc_loss = ohem_loc_loss.sum() / keep_num
    #然后分别对分类和回归loss求均值
    return cls_loss, loc_loss

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/139445.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • matlab中的imwrite_medfilt2函数

    matlab中的imwrite_medfilt2函数1.imwrite函数imwrite函数的作用是将图像写入图形文件。2.语法imwrite(A,filename)imwrite(A,map,filename)imwrite(___,fmt)imwrite(___,Name,Value)(1)imwrite(A,filename)将图像数据A写入filename指定的文件,并从扩展名推断出文件格式。imwrite在当前文件夹中创建新文件。输出图像的位深度取决于A的数据类型和文件格式。对于大多数格式来说: 如果……

  • html 鼠标形状箭头,CSS各种鼠标样式介绍

    html 鼠标形状箭头,CSS各种鼠标样式介绍大家否曾注意到有些网站的鼠标不是规则的斜向上箭头的形状,而是十字形,或者是向左的箭头,或者是个问号等等。当你想在网页的不同位置让鼠标显示不同形状,以体现不同的功能区;当你想让你的网站体现与众不同的风格时,考虑一下在鼠标样式上下功夫吧。其实鼠标样式的用途还是极为广泛的,那么怎样才能实现鼠标的不同样式呢?这就要用到css层叠样式表中的cursor属性了。cursor的属性:pointer:手型c…

  • html静态页面代码_静态网页设计代码

    html静态页面代码_静态网页设计代码这个例子我们做一个游戏静态页面,自动跳转到我们想要玩的游戏或者视频等网站大家也可以根绝我的代码,适当修改一些信息,但是套用我的这个模板请注释下来自我这,我也是初学者,辛辛苦苦写了几个小时,尊重下劳动成果先看效果图:我以张杰为背景图,里面是各种网站跳转,比如我点击:冰火人,他就会跳转到4399的冰火人游戏界面。ok,上代码,我觉得比较简单,就没注释,希望能看懂:<!DOCTYPEhtml><html><headlang=”en”><metacha

  • Centos下添加用户到用户组

    Centos下添加用户到用户组

    2021年10月23日
  • JAVA语言程序设计(一)04747

    JAVA语言程序设计(一)04747windows常用快捷键和常见命令省略100万行二进制=>0、1一个字节是八位。每个0或者每个1都叫做是bit二进制的计算,除2除到余数为一,一算到最后一位,结果需要倒过来。上述直接操作字节是计算机中最小的存储单元,计算机储存的任何数据都是以字节的形式存储的。1KB=1024Byte1MB=1024KB命令提示符常用的命令D:可以直接切换到d盘根路径…

  • HTTP和HTTPS有什么区别? 什么是SSL证书?使用ssl证书优势?

    HTTP和HTTPS有什么区别? 什么是SSL证书?使用ssl证书优势?

    2021年10月14日

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号