pytorch – ohem 代码实现

pytorch – ohem 代码实现如果考虑类别和坐标两种情况:importtorchimporttorch.nn.functionalasFimporttorch.nnasnnsmooth_l1_sigma=1.0smooth_l1_loss=nn.SmoothL1Loss(reduction=’none’)#reduce=Falsedefohem_loss(batch_size,…

大家好,又见面了,我是你们的朋友全栈君。

如果考虑类别和坐标两种情况:

import torch
import torch.nn.functional as F
import torch.nn as nn
smooth_l1_sigma = 1.0
smooth_l1_loss = nn.SmoothL1Loss(reduction='none')    # reduce=False


def ohem_loss(batch_size, cls_pred, cls_target, loc_pred, loc_target):   
    """    Arguments:
     batch_size (int): number of sampled rois for bbox head training      
     loc_pred (FloatTensor): [R, 4], location of positive rois      
     loc_target (FloatTensor): [R, 4], location of positive rois   
     pos_mask (FloatTensor): [R], binary mask for sampled positive rois   
     cls_pred (FloatTensor): [R, C]     
     cls_target (LongTensor): [R]  
     Returns:    
           cls_loss, loc_loss (FloatTensor)
    """

    ohem_cls_loss = F.cross_entropy(cls_pred, cls_target, reduction='none', ignore_index=-1)
    ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target).sum(dim=1)
    # 这里先暂存下正常的分类loss和回归loss
    print(ohem_cls_loss.shape, ohem_loc_loss.shape)
    loss = ohem_cls_loss + ohem_loc_loss
    # 然后对分类和回归loss求和
    
    sorted_ohem_loss, idx = torch.sort(loss, descending=True)   
    # 再对loss进行降序排列
    
    keep_num = min(sorted_ohem_loss.size()[0], batch_size)    
    # 得到需要保留的loss数量
    
    if keep_num < sorted_ohem_loss.size()[0]:    
        # 这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留
    
        keep_idx_cuda = idx[:keep_num]        # 保留到需要keep的数目
        ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]      
        ohem_loc_loss = ohem_loc_loss[keep_idx_cuda]        # 分类和回归保留相同的数目
        
    cls_loss = ohem_cls_loss.sum() / keep_num   
    loc_loss = ohem_loc_loss.sum() / keep_num    # 然后分别对分类和回归loss求均值
    return cls_loss, loc_loss


if __name__ == '__main__':
    batch_size = 4
    C = 6
    loc_pred = torch.randn(8, 4)
    loc_target = torch.randn(8, 4)
    cls_pred = torch.randn(8, C)
    cls_target = torch.Tensor([1, 1, 2, 3, 5, 3, 2, 1]).type(torch.long)
    cls_loss, loc_loss = ohem_loss(batch_size, cls_pred, cls_target, loc_pred, loc_target)
    print(cls_loss, '--', loc_loss)

如果只考虑坐标框的话,对以上代码略微调整如下:

import torch
import torch.nn.functional as F
import torch.nn as nn

smooth_l1_sigma = 1.0
smooth_l1_loss = nn.SmoothL1Loss(reduction='none')  # reduce=False


def ohem_loss(batch_size, loc_pred, loc_target):
    """    Arguments:
     batch_size (int): number of sampled rois for bbox head training
     loc_pred (FloatTensor): [R, 4], location of positive rois
     loc_target (FloatTensor): [R, 4], location of positive rois
     Returns:
           cls_loss, loc_loss (FloatTensor)
    """
    ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target).sum(dim=1)
    loss = ohem_loc_loss  # 对上面代码进行改动,不做简化了,感兴趣的自行替换

    # 再对loss进行降序排列
    sorted_ohem_loss, idx = torch.sort(loss, descending=True)

    # 得到需要保留的loss数量
    keep_num = min(sorted_ohem_loss.size()[0], batch_size)

    # 这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留(自己可以改动为保留的比例)
    if keep_num < sorted_ohem_loss.size()[0]:
        keep_idx_cuda = idx[:keep_num]  # 保留到需要keep的数目
        ohem_loc_loss = ohem_loc_loss[keep_idx_cuda]  # 回归保留相同的数目

    loc_loss = ohem_loc_loss.sum() / keep_num  # 然后对回归loss求均值
    return loc_loss


if __name__ == '__main__':
    batch_size = 4
    loc_pred = torch.randn(8, 4)
    loc_target = torch.randn(8, 4)
    loc_loss = ohem_loss(batch_size,loc_pred, loc_target)
    print(loc_loss)

以上代码,新建Python文件,右键运行即可

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/140952.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • UnityShader-BilateralFilter(双边滤波,磨皮滤镜)「建议收藏」

    UnityShader-BilateralFilter(双边滤波,磨皮滤镜)「建议收藏」双边滤波(BilateralFilter),可能没有高斯滤波那样著名,但是如果说磨皮滤镜,那肯定是无人不知无人不晓了,用双边滤波就可以实现很好的皮肤滤镜效果,不管脸上有多少麻子,用完双边滤波,瞬间变身白富美。

  • Mac配置Android开发环境

    Mac配置Android开发环境1、下载jdk和AndroidStudio下载地址如下:jdk:https://www.oracle.com/java/technologies/javase-downloads.htmlAS:https://developer.android.google.cn/studio2、安装jdk安装,一直下一步,安装完成后打开“终端”,执行命令:java-version即可查看…

  • pandas fillna详解

    pandas fillna详解pandas中补全nan具体的参数Series.fillna(self,value=None,method=None,axis=None,inplace=False,limit=None,downcast=None,**kwargs)[source]参数: value:scalar,dict,Series,orDataFrameValuetouset…

  • js也能写3D游戏?

    js也能写3D游戏?看完这本书《3DGameProgramingforKids》之后,发现3D游戏的学习,也可以使用javascript来写的。先要上这个网站https://threejs.org,然后下载它的three.js源码放到一个目录,比如js。然后放入这段代码: Myfirstthree.jsapp body{margin:0;} canvas{w

  • LoRaWAN地区参数更新至版本B,新增印度865频段「建议收藏」

    LoRaWAN地区参数更新至版本B,新增印度865频段「建议收藏」LoRaWAN地区参数更新至版本B,新增了印度865频段。这为塔塔通讯近期宣布的20万传感器和基站节点建设计划进行了规范铺路。

    2022年10月21日
  • pycharm安装教程中文_java将对象转为json

    pycharm安装教程中文_java将对象转为json#1.下载安装包下载地址(http://www.jetbrains.com/pycharm/download/#section=windows)#2.安装#3.激活选择Activationcode在http://idea.lanyus.com/获取注册码修改hosts文件,加入以下字段0.0.0.0account.jetbrains.com#4.中文界面下载语言包https://github…

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号