大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。
Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺
损失函数
文章目录
含义:
用于衡量在训练集上模型的输出与真实输出的差异
标准:
损失函数越小,模型输出与真实输出越相似,模型效果越好
常用的两种损失函数
均方误差损失函数
计算公式
M S E = 1 m ∑ i = 1 m ( y ^ ( i ) − y ( i ) ) 2 MSE=\frac{1}{m}\sum^m_{i=1}(\hat y^{(i)}-y^{(i)})^2 MSE=m1i=1∑m(y^(i)−y(i))2
含义解释
符号 | 含义 |
---|---|
m | 样本数量 |
y ^ ( i ) \hat y^{(i)} y^(i) | 第i个样本的模型预测输出的结果 |
y ( i ) y^{(i)} y(i) | 第i个样本的真实输出的结果 |
代码实现
''' MSE Loss '''
import torch
import torch.nn as nn
torch.manual_seed(1)
# create data
x = torch.linspace(0,10,10).reshape(2,5)
w = torch.randn((5,2))
bias = torch.randn((2,1))*0.1
y = x@w
y_ = y+bias
print(y)
print(y_)
# calulate the MSE loss between y and y_
MESLoss = torch.tensor([(y1-y2)**2 for y1,y2 in zip(y_.flatten(),y.flatten())]).mean()
print(MESLoss)
# MSELoss func
MSELoss_func = nn.MSELoss()
print(MSELoss_func(y_,y))
适用范围
-
回归问题,如线性回归
-
使用均方误差处理分类问题,公式
M S E c l a s s i f i c a t i o n = 1 m ∑ i = 1 m ∑ j = 1 c ( y ^ j ( i ) − y j ( i ) ) 2 MSE_{classification}=\frac{1}{m}\sum^m_{i=1}\sum^c_{j=1}(\hat y^{(i)}_{j}-y^{(i)}_{j})^2 MSEclassification=m1i=1∑mj=1∑c(y^j(i)−yj(i))2符号 含义 m 样本数量 y ^ j ( i ) \hat y^{(i)}_{j} y^j(i) 第i个样本的第j类上的模型预测输出的结果 y j ( i ) y^{(i)}_{j} yj(i) 第i个样本的第j类上的真实输出的结果
交叉熵损失函数
计算公式
C E = − 1 m ∑ i = 1 m ∑ j = 1 c y j ( i ) l o g ( y ^ j ( i ) ) CE=-\frac{1}{m}\sum^m_{i=1}\sum^c_{j=1}y^{(i)}_{j}log(\hat y^{(i)}_{j}) CE=−m1i=1∑mj=1∑cyj(i)log(y^j(i))
含义解释
符号 | 含义 |
---|---|
m | 样本数量 |
y ^ j ( i ) \hat y^{(i)}_{j} y^j(i) | 模型对第i个样本属于第j类上的预测结果 |
y j ( i ) y^{(i)}_{j} yj(i) | 第i个样本的第j类上的真实输出的结果,正确类别输出为1,其他输出0 |
交叉熵损失取决于模型对正确类别预测概率的对数值。
代码实现
''' CE Loss '''
import torch
def CrossEntropyLoss(input, target):
res = -input.gather(dim=1, index=target.view(-1, 1))
print(res.shape)
res += torch.log(torch.exp(input).sum(dim=1).view(-1, 1))
print(res.shape)
res = res.mean()
print(res.shape)
return res
input = torch.tensor([
[1, 2, 3],
[4, 5, 6]
], dtype=torch.float32)
target = torch.tensor(
[0, 1],
)
print(torch.nn.CrossEntropyLoss()(input, target))
print(CrossEntropyLoss(input, target))
适用范围
- 分类问题,又叫负对数似然损失
发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/192846.html原文链接:https://javaforall.cn
【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛
【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...