Pytorch-BN层详细解读

Pytorch-BN层详细解读Pytorch-BN层BN解决了InternalCovariateShift问题机器学习领域有个很重要的假设:独立同分布假设,即假设训练数据和测试数据是满足相同分布的。我们知道:神经网络的训练实际上就是在拟合训练数据的分布。如果不满足独立同分布假设,那么训练得到的模型的泛化能力肯定不好。再来思考一个问题:为什么传统的神经网络要求将数据归一化(训练阶段将训练数据归一化并记录均值和方差,测试…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE稳定放心使用

Pytorch-BN层

BN解决了Internal Covariate Shift问题

机器学习领域有个很重要的假设:独立同分布假设,即假设训练数据和测试数据是满足相同分布的。我们知道:神经网络的训练实际上就是在拟合训练数据的分布。如果不满足独立同分布假设,那么训练得到的模型的泛化能力肯定不好。

再来思考一个问题:为什么传统的神经网络要求将数据归一化(训练阶段将训练数据归一化并记录均值和方差,测试阶段利用记录的均值和方差将测试数据也归一化)

首先:做了归一化之后,可以近似的认为训练数据和测试数据满足相同分布(即均值为0,方差为1的标准正态),这样一来模型的泛化能力会得到提高。其次:如果不做归一化,使用mini-batch梯度下降法训练的时候,每批训练数据的分布不相同,那么网络就要在每次迭代的时候去适应不同的分布,这样会大大降低网络的训练速度。综合以上两点,所以需要对数据做归一化预处理。PS:如果是mini-batch梯度下降法,每个batch都可以计算出一个均值和方差,最终记录的均值和方差是所有batches均值和方差的期望,当然也有其它更复杂的记录方式,如pytorch使用的滑动平均。

Internal Covariate Shift问题:在训练的过程中,即使对输入层做了归一化处理使其变成标准正态,随着网络的加深,函数变换越来越复杂,许多隐含层的分布还是会彻底放飞自我,变成各种奇奇怪怪的正态分布,并且整体分布逐渐往非线性函数(也就是激活函数)的取值区间的上下限两端靠近。对于sigmoid函数来说,就意味着输入值是大的负数或正数,这导致反向传播时底层神经网络的梯度消失,这是训练深层神经网络收敛越来越慢的本质原因。

为了解决上述问题,又想到网络的某个隐含层相对于之后的网络就相当于输入层,所以BN的基本思想就是:把网络的每个隐含层的分布都归一化到标准正态。其实就是把越来越偏的分布强制拉回到比较标准的分布,这样使得激活函数的输入值落在该激活函数对输入比较敏感的区域,这样一来输入的微小变化就会导致损失函数较大的变化。通过这样的方式可以使梯度变大,就避免了梯度消失的问题,而且梯度变大意味着收敛速度快,能大大加快训练速度。

简单说来就是:传统的神经网络只要求第一个输入层归一化,而带BN的神经网络则是把每个输入层(把隐含层也理解成输入层)都归一化。

BN的具体步骤

BN实际上包含两步操作。 x i x_i xi是BN的输入, y i y_i yi是BN的输出。

  • 归一化到标准正态, ϵ \epsilon ϵ是一个非常小的数字,是为了防止除以0

μ = 1 m ∑ i = 1 m x i \mu=\frac{1}{m}\sum\limits_{i=1}^{m}x_i μ=m1i=1mxi
σ 2 = 1 m ∑ i = 1 m ( x i − μ ) 2 \sigma^2=\frac{1}{m}\sum\limits_{i=1}^{m}(x_i-\mu)^2 σ2=m1i=1m(xiμ)2
x ^ i ← x i − μ σ 2 + ϵ \hat{x}_i\larr\frac{x_i-\mu}{\sqrt{\sigma^2+\epsilon}} x^iσ2+ϵ
xiμ

以sigmoid函数为例,可以将其近似的看成两个部分,中间区域的线性部分以及两侧的非线性部分。Internal Covariate Shift问题就是:隐含层的输出都落在了sigmoid函数的非线性区域,这部分区域对损失函数的影响极小,所以梯度也极小。归一化操作就是把非线性区域的值拉回到线性区域,这样一来虽然增大了梯度,但也降低了数据的非线性表示能力。所以还需要缩放操作来弥补归一化操作降低的非线性表达能力。归一化从形式上看来就是把输入值减去一个数字,再除以一个数字。它的逆操作就是先乘以一个数字,在加上一个数字,这就是缩放。

  • 缩放

y i ← γ x i ^ + β ≡ B N γ , β ( x i ) y_i\larr\gamma\hat{x_i}+\beta\equiv BN_{\gamma,\beta}(x_i) yiγxi^+βBNγ,β(xi)

Pytorch中的BN

Pytorch中的BN操作为nn.BatchNorm2d(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)

  • num_features,输入数据的通道数,归一化时需要的均值和方差是在每个通道中计算的
  • eps,用来防止归一化时除以0
  • momentum,滑动平均的参数,用来计算running_mean和running_var
  • affine,是否进行仿射变换,即缩放操作
  • track_running_stats,是否记录训练阶段的均值和方差,即running_mean和running_var

BN层的状态包含五个参数:

  • weight,缩放操作的 γ \gamma γ
  • bias,缩放操作的 β \beta β
  • running_mean,训练阶段统计的均值,测试阶段会用到。
  • running_var,训练阶段统计的方差,测试阶段会用到。
  • num_batches_tracked,训练阶段的batch的数目,如果没有指定momentum,则用它来计算running_mean和running_var。一般momentum默认值为0.1,所以这个属性暂时没用。

weight和bias这两个参数需要训练,而running_mean、running_val和num_batches_tracked不需要训练,它们只是训练阶段的统计值。

在训练阶段,假设输入是[4, 3, 2, 2]的张量,如下图所示。

在这里插入图片描述

对于这四个数据块,每次取其中一个通道的数据,然后对这个16个数据求均值 μ \mu μ和方差 σ \sigma σ,并用求得的均值和方差归一化并缩放数据,得到BN层的输出

接下来用滑动平均公式来更新running_mean和running_var,momentum默认值为0.1。

r u n n i n g _ m e a n = ( 1 − m o m e n t u m ) ∗ r u n n i n g _ m e a n + m o m e n t u m ∗ μ r u n n i n g _ v a r = ( 1 − m o m e n t u m ) ∗ r u n n i n g _ v a r + m o m e n t u m ∗ σ running\_mean = (1 – momentum) * running\_mean + momentum * \mu \\ running\_var = (1 – momentum) * running\_var + momentum * \sigma running_mean=(1momentum)running_mean+momentumμrunning_var=(1momentum)running_var+momentumσ

在测试阶段,归一化操作不用再计算均值方差,而是直接使用训练阶段统计的running_mean和running_var。

Note: track_running_stats和self.training有四种可能的组合。

  • training=True, track_running_stats=True, 这是常见的训练时期待的行为,running_mean和running_var会跟踪不同batch数据的均值和方差,但是仍然用每个batch的均值和方差做归一化。
  • training=True, track_running_stats=False, 这时候running_mean和running_var不跟踪各个batch的均值和方差了,但仍然用每个batch的均值和方差做归一化。
  • training=False, track_running_stats=True, 这是常见的测试时期待的行为,即使用训练阶段统计的running_mean和running_var做归一化。
  • training=False, track_running_stats=False, 使用每个batch的均值和方差做归一化。

Pytorch代码示例

bn = nn.BatchNorm2d(3)
x = torch.randn(4, 3, 2, 2)
y = bn(x)
a = (x[0, 0, :, :] + x[1, 0, :, :] + x[2, 0, :, :] + x[3, 0, :, :]).sum() / 16
b = (x[0, 1, :, :] + x[1, 1, :, :] + x[2, 1, :, :] + x[3, 1, :, :]).sum() / 16
c = (x[0, 2, :, :] + x[1, 2, :, :] + x[2, 2, :, :] + x[3, 2, :, :]).sum() / 16
print('The mean value of the first channel is %f' % a.data)
print('The mean value of the first channel is %f' % b.data)
print('The mean value of the first channel is %f' % c.data)
print('The output mean value of the BN layer is %f, %f, %f' % (bn.running_mean.data[0], bn.running_mean.data[1], bn.running_mean.data[2]))

细心的读者可能已经发现第一个通道的running_mean正好是真实的mean的0.1倍,这是为什么呢?因为在最开始的时候running_mean=0,然后用滑动平均公式去更新:
r u n n i n g _ m e a n = 0.9 ∗ r u n n i n g _ m e a n + 0.1 ∗ m e a n = 0.1 ∗ m e a n running\_mean = 0.9 * running\_mean + 0.1 * mean = 0.1 * mean running_mean=0.9running_mean+0.1mean=0.1mean

那么当我们前向传播两次之后,再来观察running_mean和mean的关系。

bn = nn.BatchNorm2d(3)
x = torch.randn(4, 3, 2, 2)
y = bn(x)
y = bn(x)  # 前向传播两次
a = (x[0, 0, :, :] + x[1, 0, :, :] + x[2, 0, :, :] + x[3, 0, :, :]).sum() / 16
b = (x[0, 1, :, :] + x[1, 1, :, :] + x[2, 1, :, :] + x[3, 1, :, :]).sum() / 16
c = (x[0, 2, :, :] + x[1, 2, :, :] + x[2, 2, :, :] + x[3, 2, :, :]).sum() / 16
print('The mean value of the first channel is %f' % a.data)
print('The mean value of the first channel is %f' % b.data)
print('The mean value of the first channel is %f' % c.data)
print('The output mean value of the BN layer is %f, %f, %f' % (bn.running_mean.data[0], bn.running_mean.data[1], bn.running_mean.data[2]))

我们会发现它们存在着这样一个关系:
r u n n i n g _ m e a n = 0.9 ∗ ( 0.9 ∗ 0 + 0.1 ∗ m e a n ) + 0.1 ∗ m e a n = 0.19 ∗ m e a n running\_mean = 0.9 * (0.9 * 0 + 0.1 * mean) + 0.1 * mean = 0.19 * mean running_mean=0.9(0.90+0.1mean)+0.1mean=0.19mean

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/182082.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • WinRAR3.71注册激活成功教程方法

    WinRAR3.71注册激活成功教程方法步骤:1.下载WinRAR3.71正式版(网上资源多);如:WinRAR3.71简体中文版2.将下面的数据其中一个复制到“记事本”中,另存为“rarreg.key”,(注意key是后缀名,需要

  • 深入理解Java自定义注解(二)-使用自定义注解

    深入理解Java自定义注解(二)-使用自定义注解

  • 腾讯云视频直播sdk开发攻略

    腾讯云视频直播sdk开发攻略视频直播这一两年在移动互联网上可以说是非常的火,各种视频直播软件层出不穷。有的通过自己的推广宣传确实火了起来,比如:映客。我之前也是在一家专门做视频直播的公司打酱油,当时对这个概念还是很模糊,后来才慢慢的了解清楚视频直播的这个概念。后来离开这家公司,到现在的公司,做了一段时间,又有需要做视频直播的需求。由于公司各方面原因,只能引用第三方的sdk,最后选择的腾讯云。所以我下面给大家讲一下开发过程

  • html滑动解锁,js实现滑动解锁效能(PC+Moblie)

    html滑动解锁,js实现滑动解锁效能(PC+Moblie)js实现滑动解锁功能(PC+Moblie)实现效果:css样式代码略。html代码:页面上导入了jquery.mobile、jquerySlidetoconfirmIamhuman!js代码:window.onload=function(){varslider1=newSlider();slider1.Init();///屏幕大小发生改变时触发$(window).res…

  • elementuitable样式更改_elementui下拉框

    elementuitable样式更改_elementui下拉框表格样式修改(表头高、表头边框、表格内边框、表格行高)//控制表头高度.el-table/deep/.el-table__headerth{padding:0;height:40px;line-height:40px;//表头边框设置border:solid#cccccc;border-width:1px0px0px1px;}//添加表格行边框.el-table/deep/td{border:solid#cccccc;border-width:1px0

  • java线程通信的三种方式「建议收藏」

    java线程通信的三种方式「建议收藏」1、传统的线程通信。在synchronized修饰的同步方法或者修饰的同步代码块中使用Object类提供的wait(),notify()和notifyAll()3个方法进行线程通信。关于这3个方法的解释:wait():导致当前线程等待,直到其他线程调用该同步监视器的notify()方法或notifyAll()方法来唤醒该线程。notify():唤醒在此…

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号