python复现softmax损失函数详细版

python复现softmax损失函数详细版fromtorchimportnnimporttorchdefloss_func(output,target):one_hot=torch.zeros_like(output)foriinrange(target.size(0)):one_hot[i,target[i]]=1softmax_out=torch.exp(output)/(torch.unsque…

大家好,又见面了,我是你们的朋友全栈君。

主要内容

  • softmax和交叉熵公式
  • 单个样本求解损失
  • 多个样本求解损失

softmax和交叉熵公式

  • softmax

先来看公式,softmax的作用是将样本对应的输出向量,转换成对应的类别概率值。这里使用以e为底的指数函数,将向量值归一化为0-1的概率值;
Alt
使用numpy的代码实现也很简单,但是当数值过大时会发生溢出,此时会将向量中的其他值减去最大值,数值平移到0附近。会避免溢出现象。ps:这里暂时不考虑这种情况
在这里插入图片描述

  • softmax交叉熵
    交叉熵是用来衡量分布p和q之间的相似度,越相似交叉熵越小。其中 p ( x ) p(x) p(x)是真实标签的one_hot编码, q ( x ) q(x) q(x)是预测值。需要注意的是这里的 q ( x ) q(x) q(x)必须是经过softmax的概率值。
    Alt

单个样本求解损失

#conding=utf-8

from torch import nn
import torch
import numpy as np

def MySoftmax(vector):
    return np.exp(vector)/np.exp(vector).sum()

def LossFunc(target,output):
    output = MySoftmax(output)
    one_hot = np.zeros_like(output)
    one_hot[:,target] = 1
    # print(one_hot)
    loss = (-np.log(output)*one_hot).sum()
    return loss
target = np.array([1])
output = np.array([[8,-3.,10]])
softmax_out = MySoftmax(output)
np.set_printoptions(suppress=True)
print(softmax_out)

# torch自带的softmax实现
print(nn.Softmax()(torch.Tensor(output)))

print(LossFunc(target,output))
print(nn.CrossEntropyLoss(reduction="sum")(torch.Tensor(output),torch.Tensor(target).long()))

需要注意的是现有的框架中基本都会在损失函数内部进行softmax转换。我这里设置的loss值没有求平均,所以reduction=“sum”

多个样本求解损失

#conding=utf-8

from torch import nn
import torch
import numpy as np

# def MySoftmax(vector):
# return np.exp(vector)/np.exp(vector).sum()
#
# def LossFunc(target,output):
# output = MySoftmax(output)
# one_hot = np.zeros_like(output)
# one_hot[:,target] = 1
# # print(one_hot)
# loss = (-np.log(output)*one_hot).sum()
# return loss
# target = np.array([1])
# output = np.array([[8,-3.,10]])
# softmax_out = MySoftmax(output)
# np.set_printoptions(suppress=True)
# print(softmax_out)
#
# # torch自带的softmax实现
# print(nn.Softmax()(torch.Tensor(output)))
#
# print(LossFunc(target,output))
# print(nn.CrossEntropyLoss(reduction="sum")(torch.Tensor(output),torch.Tensor(target).long()))

def loss_func(output,target):
    one_hot = torch.zeros_like(output)
    for i in range(target.size(0)):
        one_hot[i,target[i]]=1

    softmax_out = torch.exp(output)/( torch.unsqueeze(torch.exp(output).sum(dim=1),dim=1))
    # 确保每一个样本维度的概率之和为1
    print(softmax_out.sum(dim=1))
    loss = (-torch.log(softmax_out) * one_hot).sum()
    return loss

target = torch.Tensor([1,1,1]).long()
output = torch.Tensor([[10.,-5,5],[5,2,-1],[4,-9,5]])
softmax = nn.Softmax(dim=1)


criterion = nn.CrossEntropyLoss(reduction="sum")
print(criterion(output,target))
print(loss_func(output,target))

我这里使用的是torch的计算,主要原因是想使用label smoothing技巧,torch版在项目中应用更方便。
只是将numpy换成torch的形式,基本的公式都没有改变的。需要注意的是在多个样本求解softmax值是在样本的维度求概率。

喜欢的童鞋点个赞哦!大家有什么要了解的请留言,老汤尽量满足

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/153145.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)
blank

相关推荐

  • 不再年轻,尽管如此一遍

    不再年轻,尽管如此一遍

  • 网络编程socket原理_socket的基本概念和原理

    网络编程socket原理_socket的基本概念和原理一、客户机/服务器模式在TCP/IP网络中两个进程间的相互作用的主机模式是客户机/服务器模式(Client/Servermodel)。该模式的建立基于以下两点:1、非对等作用;2、通信完全是异步的。客户机/服务器模式在操作过程中采取的是主动请示方式:首先服务器方要先启动,并根据请示提供相应服务:(过程如下)1、打开一通信通道并告知本地主机,它愿意在某一个公认地址上接收客户请求。2、等待客户请求到

    2022年10月10日
  • J2EE是什么,主要包括哪些技术[通俗易懂]

    J2EE是什么,主要包括哪些技术[通俗易懂]https://blog.csdn.net/Ashes18/article/details/73614571最近最为深刻的认识:被面试老师问到了的知识,哪些是在自己心中模棱两可,似是而非的概念都一目了然。而后,只有在顿悟中不断总结才能进步。今天,我总结的部分是J2EE到底是什么东西,它包括了哪些技术。一、J2EE是什么?从整体上讲,J2EE是使用Java技术开发企业级应用的工业标准,它…

    2022年10月11日
  • html嵌入python代码(python做人脸识别)

    最近闲来无事,研究研究在安卓上跑Python。想起以前玩过的kivy技术,kivy[1]是一个跨平台的UI框架。当然对我们最有用的是,kivy可以把python代码打包成安卓App。但是由于安卓打包的工具链很长,包括androidsdk打包java代码、ndk编译python、编译各种python依赖包,经常花一整天从入门到放弃。这次使出认真研究的心态,终于找到一个解决方案,于是有了这篇文章:…

  • 计算机网络期末考试题库(超级多的那种)「建议收藏」

    计算机网络期末考试题库(超级多的那种)「建议收藏」废话不多说,不管是应对期末考试还是考研基础复习,刷题是必不可少的!!!大家冲就完了!!!!记得给罡罡同学点关注哦!后期还会更新其他题库的呢!!!点关注!!!点关注!!!点关注!!!谢谢另外还有4套模拟题哦!!!计算机网络试题库——选择题及答案(共500题)1、Internet中发送邮件协议是(B)。A、FTPB、SMTP C、HTTP D、POP2、在OSI模型中,第N层和其上的N+l层的关系是(A

  • oracle与mysql分页的区别_分段存储和分页存储的区别

    oracle与mysql分页的区别_分段存储和分页存储的区别oracle与MySQL分页区别(1)MySql的Limitm,n语句Limit后的两个参数中,参数m是起始下标,它从0开始;参数n是返回的记录数。(2)Oracle数据库的rownum在Oracle数据库中,分页方式没有MySql这样简单,它需要依靠rownum来实现。Rownum表示一条记录的行号,值得注意的是它在获取每一行后才赋予。因此,想指定rownum的区间来取得分页数据在一层查询语句中是无法做到的,要分页还要进行一次查询。两种sql写法:SELECT*FROM(SEL

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号