pytorch – ohem 代码实现

pytorch – ohem 代码实现如果考虑类别和坐标两种情况:importtorchimporttorch.nn.functionalasFimporttorch.nnasnnsmooth_l1_sigma=1.0smooth_l1_loss=nn.SmoothL1Loss(reduction=’none’)#reduce=Falsedefohem_loss(batch_size,…

大家好,又见面了,我是你们的朋友全栈君。

如果考虑类别和坐标两种情况:

import torch
import torch.nn.functional as F
import torch.nn as nn
smooth_l1_sigma = 1.0
smooth_l1_loss = nn.SmoothL1Loss(reduction='none')    # reduce=False


def ohem_loss(batch_size, cls_pred, cls_target, loc_pred, loc_target):   
    """    Arguments:
     batch_size (int): number of sampled rois for bbox head training      
     loc_pred (FloatTensor): [R, 4], location of positive rois      
     loc_target (FloatTensor): [R, 4], location of positive rois   
     pos_mask (FloatTensor): [R], binary mask for sampled positive rois   
     cls_pred (FloatTensor): [R, C]     
     cls_target (LongTensor): [R]  
     Returns:    
           cls_loss, loc_loss (FloatTensor)
    """

    ohem_cls_loss = F.cross_entropy(cls_pred, cls_target, reduction='none', ignore_index=-1)
    ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target).sum(dim=1)
    # 这里先暂存下正常的分类loss和回归loss
    print(ohem_cls_loss.shape, ohem_loc_loss.shape)
    loss = ohem_cls_loss + ohem_loc_loss
    # 然后对分类和回归loss求和
    
    sorted_ohem_loss, idx = torch.sort(loss, descending=True)   
    # 再对loss进行降序排列
    
    keep_num = min(sorted_ohem_loss.size()[0], batch_size)    
    # 得到需要保留的loss数量
    
    if keep_num < sorted_ohem_loss.size()[0]:    
        # 这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留
    
        keep_idx_cuda = idx[:keep_num]        # 保留到需要keep的数目
        ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]      
        ohem_loc_loss = ohem_loc_loss[keep_idx_cuda]        # 分类和回归保留相同的数目
        
    cls_loss = ohem_cls_loss.sum() / keep_num   
    loc_loss = ohem_loc_loss.sum() / keep_num    # 然后分别对分类和回归loss求均值
    return cls_loss, loc_loss


if __name__ == '__main__':
    batch_size = 4
    C = 6
    loc_pred = torch.randn(8, 4)
    loc_target = torch.randn(8, 4)
    cls_pred = torch.randn(8, C)
    cls_target = torch.Tensor([1, 1, 2, 3, 5, 3, 2, 1]).type(torch.long)
    cls_loss, loc_loss = ohem_loss(batch_size, cls_pred, cls_target, loc_pred, loc_target)
    print(cls_loss, '--', loc_loss)

如果只考虑坐标框的话,对以上代码略微调整如下:

import torch
import torch.nn.functional as F
import torch.nn as nn

smooth_l1_sigma = 1.0
smooth_l1_loss = nn.SmoothL1Loss(reduction='none')  # reduce=False


def ohem_loss(batch_size, loc_pred, loc_target):
    """    Arguments:
     batch_size (int): number of sampled rois for bbox head training
     loc_pred (FloatTensor): [R, 4], location of positive rois
     loc_target (FloatTensor): [R, 4], location of positive rois
     Returns:
           cls_loss, loc_loss (FloatTensor)
    """
    ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target).sum(dim=1)
    loss = ohem_loc_loss  # 对上面代码进行改动,不做简化了,感兴趣的自行替换

    # 再对loss进行降序排列
    sorted_ohem_loss, idx = torch.sort(loss, descending=True)

    # 得到需要保留的loss数量
    keep_num = min(sorted_ohem_loss.size()[0], batch_size)

    # 这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留(自己可以改动为保留的比例)
    if keep_num < sorted_ohem_loss.size()[0]:
        keep_idx_cuda = idx[:keep_num]  # 保留到需要keep的数目
        ohem_loc_loss = ohem_loc_loss[keep_idx_cuda]  # 回归保留相同的数目

    loc_loss = ohem_loc_loss.sum() / keep_num  # 然后对回归loss求均值
    return loc_loss


if __name__ == '__main__':
    batch_size = 4
    loc_pred = torch.randn(8, 4)
    loc_target = torch.randn(8, 4)
    loc_loss = ohem_loss(batch_size,loc_pred, loc_target)
    print(loc_loss)

以上代码,新建Python文件,右键运行即可

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/140952.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • 群晖linux怎么进入u盘,黑群晖菜鸟安装教程(一)制作U盘引导及软洗白!

    群晖linux怎么进入u盘,黑群晖菜鸟安装教程(一)制作U盘引导及软洗白!教程多都是参考网络上的一些大师们的教程做一些简化和把一些要点易出错的地方给大家指出,让大家能更快加入到群晖一起折腾。什么是黑群晖最简单的理解就是用普通的PC机安装了群晖NAS系统让普通的PC机可以体验白群晖的大多数功能。黑群晖对电脑的要求很低最是一般要求CPU为64位不然安装不了的。而且一般我们采用的PC机为低功率集成CPU的ITX主板。常用的主板有集成CPUD525E-240等低功率主板在正…

  • 网卡的流量控制flow control

    网卡的流量控制flow controlopenstack运维中发现有一个计算节点的存储网卡不停的down,摘取部分系统日志如下:Aug2906:25:59compute8kernel:ixgbe0000:86:00.0enp134s0f0:NICLinkisDownAug2906:26:00compute8kernel:ixgbe0000:86:00.0enp134s0f0:NICLink…

  • idea2021 激活码【最新永久激活】

    (idea2021 激活码)JetBrains旗下有多款编译器工具(如:IntelliJ、WebStorm、PyCharm等)在各编程领域几乎都占据了垄断地位。建立在开源IntelliJ平台之上,过去15年以来,JetBrains一直在不断发展和完善这个平台。这个平台可以针对您的开发工作流进行微调并且能够提供…

  • Oracle函数——TO_DATE[通俗易懂]

    Oracle函数——TO_DATE[通俗易懂]TO_DATE含义:将具有固定格式的字符串类型的数据转化为相对应的Date类型数据,官网解释如下图使用方法TO_DATE("需要转换的字符串","日期格式&quot

  • matlab三维图形的绘制[通俗易懂]

    matlab三维图形的绘制[通俗易懂]采用matlab进行三维图绘制mesh函数:网格图mesh(x,y,z)x是n维向量,y是m维向量,z是m*n维向量等高线,底座surf函数:曲面符号隐函数绘图

    2022年10月11日
  • ncodeURIComponent() 函数 vue内容

    ncodeURIComponent() 函数 vue内容

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号