Pytorch-BN层详细解读

Pytorch-BN层详细解读Pytorch-BN层BN解决了InternalCovariateShift问题机器学习领域有个很重要的假设:独立同分布假设,即假设训练数据和测试数据是满足相同分布的。我们知道:神经网络的训练实际上就是在拟合训练数据的分布。如果不满足独立同分布假设,那么训练得到的模型的泛化能力肯定不好。再来思考一个问题:为什么传统的神经网络要求将数据归一化(训练阶段将训练数据归一化并记录均值和方差,测试…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE稳定放心使用

Pytorch-BN层

BN解决了Internal Covariate Shift问题

机器学习领域有个很重要的假设:独立同分布假设,即假设训练数据和测试数据是满足相同分布的。我们知道:神经网络的训练实际上就是在拟合训练数据的分布。如果不满足独立同分布假设,那么训练得到的模型的泛化能力肯定不好。

再来思考一个问题:为什么传统的神经网络要求将数据归一化(训练阶段将训练数据归一化并记录均值和方差,测试阶段利用记录的均值和方差将测试数据也归一化)

首先:做了归一化之后,可以近似的认为训练数据和测试数据满足相同分布(即均值为0,方差为1的标准正态),这样一来模型的泛化能力会得到提高。其次:如果不做归一化,使用mini-batch梯度下降法训练的时候,每批训练数据的分布不相同,那么网络就要在每次迭代的时候去适应不同的分布,这样会大大降低网络的训练速度。综合以上两点,所以需要对数据做归一化预处理。PS:如果是mini-batch梯度下降法,每个batch都可以计算出一个均值和方差,最终记录的均值和方差是所有batches均值和方差的期望,当然也有其它更复杂的记录方式,如pytorch使用的滑动平均。

Internal Covariate Shift问题:在训练的过程中,即使对输入层做了归一化处理使其变成标准正态,随着网络的加深,函数变换越来越复杂,许多隐含层的分布还是会彻底放飞自我,变成各种奇奇怪怪的正态分布,并且整体分布逐渐往非线性函数(也就是激活函数)的取值区间的上下限两端靠近。对于sigmoid函数来说,就意味着输入值是大的负数或正数,这导致反向传播时底层神经网络的梯度消失,这是训练深层神经网络收敛越来越慢的本质原因。

为了解决上述问题,又想到网络的某个隐含层相对于之后的网络就相当于输入层,所以BN的基本思想就是:把网络的每个隐含层的分布都归一化到标准正态。其实就是把越来越偏的分布强制拉回到比较标准的分布,这样使得激活函数的输入值落在该激活函数对输入比较敏感的区域,这样一来输入的微小变化就会导致损失函数较大的变化。通过这样的方式可以使梯度变大,就避免了梯度消失的问题,而且梯度变大意味着收敛速度快,能大大加快训练速度。

简单说来就是:传统的神经网络只要求第一个输入层归一化,而带BN的神经网络则是把每个输入层(把隐含层也理解成输入层)都归一化。

BN的具体步骤

BN实际上包含两步操作。 x i x_i xi是BN的输入, y i y_i yi是BN的输出。

  • 归一化到标准正态, ϵ \epsilon ϵ是一个非常小的数字,是为了防止除以0

μ = 1 m ∑ i = 1 m x i \mu=\frac{1}{m}\sum\limits_{i=1}^{m}x_i μ=m1i=1mxi
σ 2 = 1 m ∑ i = 1 m ( x i − μ ) 2 \sigma^2=\frac{1}{m}\sum\limits_{i=1}^{m}(x_i-\mu)^2 σ2=m1i=1m(xiμ)2
x ^ i ← x i − μ σ 2 + ϵ \hat{x}_i\larr\frac{x_i-\mu}{\sqrt{\sigma^2+\epsilon}} x^iσ2+ϵ
xiμ

以sigmoid函数为例,可以将其近似的看成两个部分,中间区域的线性部分以及两侧的非线性部分。Internal Covariate Shift问题就是:隐含层的输出都落在了sigmoid函数的非线性区域,这部分区域对损失函数的影响极小,所以梯度也极小。归一化操作就是把非线性区域的值拉回到线性区域,这样一来虽然增大了梯度,但也降低了数据的非线性表示能力。所以还需要缩放操作来弥补归一化操作降低的非线性表达能力。归一化从形式上看来就是把输入值减去一个数字,再除以一个数字。它的逆操作就是先乘以一个数字,在加上一个数字,这就是缩放。

  • 缩放

y i ← γ x i ^ + β ≡ B N γ , β ( x i ) y_i\larr\gamma\hat{x_i}+\beta\equiv BN_{\gamma,\beta}(x_i) yiγxi^+βBNγ,β(xi)

Pytorch中的BN

Pytorch中的BN操作为nn.BatchNorm2d(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)

  • num_features,输入数据的通道数,归一化时需要的均值和方差是在每个通道中计算的
  • eps,用来防止归一化时除以0
  • momentum,滑动平均的参数,用来计算running_mean和running_var
  • affine,是否进行仿射变换,即缩放操作
  • track_running_stats,是否记录训练阶段的均值和方差,即running_mean和running_var

BN层的状态包含五个参数:

  • weight,缩放操作的 γ \gamma γ
  • bias,缩放操作的 β \beta β
  • running_mean,训练阶段统计的均值,测试阶段会用到。
  • running_var,训练阶段统计的方差,测试阶段会用到。
  • num_batches_tracked,训练阶段的batch的数目,如果没有指定momentum,则用它来计算running_mean和running_var。一般momentum默认值为0.1,所以这个属性暂时没用。

weight和bias这两个参数需要训练,而running_mean、running_val和num_batches_tracked不需要训练,它们只是训练阶段的统计值。

在训练阶段,假设输入是[4, 3, 2, 2]的张量,如下图所示。

在这里插入图片描述

对于这四个数据块,每次取其中一个通道的数据,然后对这个16个数据求均值 μ \mu μ和方差 σ \sigma σ,并用求得的均值和方差归一化并缩放数据,得到BN层的输出

接下来用滑动平均公式来更新running_mean和running_var,momentum默认值为0.1。

r u n n i n g _ m e a n = ( 1 − m o m e n t u m ) ∗ r u n n i n g _ m e a n + m o m e n t u m ∗ μ r u n n i n g _ v a r = ( 1 − m o m e n t u m ) ∗ r u n n i n g _ v a r + m o m e n t u m ∗ σ running\_mean = (1 – momentum) * running\_mean + momentum * \mu \\ running\_var = (1 – momentum) * running\_var + momentum * \sigma running_mean=(1momentum)running_mean+momentumμrunning_var=(1momentum)running_var+momentumσ

在测试阶段,归一化操作不用再计算均值方差,而是直接使用训练阶段统计的running_mean和running_var。

Note: track_running_stats和self.training有四种可能的组合。

  • training=True, track_running_stats=True, 这是常见的训练时期待的行为,running_mean和running_var会跟踪不同batch数据的均值和方差,但是仍然用每个batch的均值和方差做归一化。
  • training=True, track_running_stats=False, 这时候running_mean和running_var不跟踪各个batch的均值和方差了,但仍然用每个batch的均值和方差做归一化。
  • training=False, track_running_stats=True, 这是常见的测试时期待的行为,即使用训练阶段统计的running_mean和running_var做归一化。
  • training=False, track_running_stats=False, 使用每个batch的均值和方差做归一化。

Pytorch代码示例

bn = nn.BatchNorm2d(3)
x = torch.randn(4, 3, 2, 2)
y = bn(x)
a = (x[0, 0, :, :] + x[1, 0, :, :] + x[2, 0, :, :] + x[3, 0, :, :]).sum() / 16
b = (x[0, 1, :, :] + x[1, 1, :, :] + x[2, 1, :, :] + x[3, 1, :, :]).sum() / 16
c = (x[0, 2, :, :] + x[1, 2, :, :] + x[2, 2, :, :] + x[3, 2, :, :]).sum() / 16
print('The mean value of the first channel is %f' % a.data)
print('The mean value of the first channel is %f' % b.data)
print('The mean value of the first channel is %f' % c.data)
print('The output mean value of the BN layer is %f, %f, %f' % (bn.running_mean.data[0], bn.running_mean.data[1], bn.running_mean.data[2]))

细心的读者可能已经发现第一个通道的running_mean正好是真实的mean的0.1倍,这是为什么呢?因为在最开始的时候running_mean=0,然后用滑动平均公式去更新:
r u n n i n g _ m e a n = 0.9 ∗ r u n n i n g _ m e a n + 0.1 ∗ m e a n = 0.1 ∗ m e a n running\_mean = 0.9 * running\_mean + 0.1 * mean = 0.1 * mean running_mean=0.9running_mean+0.1mean=0.1mean

那么当我们前向传播两次之后,再来观察running_mean和mean的关系。

bn = nn.BatchNorm2d(3)
x = torch.randn(4, 3, 2, 2)
y = bn(x)
y = bn(x)  # 前向传播两次
a = (x[0, 0, :, :] + x[1, 0, :, :] + x[2, 0, :, :] + x[3, 0, :, :]).sum() / 16
b = (x[0, 1, :, :] + x[1, 1, :, :] + x[2, 1, :, :] + x[3, 1, :, :]).sum() / 16
c = (x[0, 2, :, :] + x[1, 2, :, :] + x[2, 2, :, :] + x[3, 2, :, :]).sum() / 16
print('The mean value of the first channel is %f' % a.data)
print('The mean value of the first channel is %f' % b.data)
print('The mean value of the first channel is %f' % c.data)
print('The output mean value of the BN layer is %f, %f, %f' % (bn.running_mean.data[0], bn.running_mean.data[1], bn.running_mean.data[2]))

我们会发现它们存在着这样一个关系:
r u n n i n g _ m e a n = 0.9 ∗ ( 0.9 ∗ 0 + 0.1 ∗ m e a n ) + 0.1 ∗ m e a n = 0.19 ∗ m e a n running\_mean = 0.9 * (0.9 * 0 + 0.1 * mean) + 0.1 * mean = 0.19 * mean running_mean=0.9(0.90+0.1mean)+0.1mean=0.19mean

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/182082.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)
blank

相关推荐

  • SQLserver基础语句大全[通俗易懂]

    SQLserver基础语句大全[通俗易懂]SQL基础结构化查询语言(StructuredQueryLanguage)简称SQL,是一种特殊目的的编程语言,是一种数据库查询和程序设计语言,用于存取数据以及查询、更新和管理关系数据库系统;同时也是数据库脚本文件的扩展名。SQLDML和DDL可以把SQL分为两个部分:数据操作语言(DML)和数据定义语言(DDL)。SQL(结构化查询语言)是用于执行…

  • 重定向和转发的区别及应用[通俗易懂]

    重定向和转发的区别及应用[通俗易懂]重定向重定向和转发有一个重要的不同:当使用转发时,JSP容器将使用一个内部的方法来调用目标页面,新的页面继续处理同一个请求,而浏览器将不会知道这个过程。与之相反,重定向方式的含义是第一个页面通知浏览器发送一个新的页面请求。因为,当你使用重定向时,浏览器中所显示的URL会变成新页面的URL,而当使用转发时,该URL会保持不变。在客户浏览器路径栏显示的是其重定向的路径,客户可以观察到地址的变化的。重定向行为是浏览器做了至少两次的访问请求的。重定向的速度比转发慢,因为浏览器还得发出一个新的请求。同时,由于重

  • webstorm激活码最新2021【2021最新】

    (webstorm激活码最新2021)本文适用于JetBrains家族所有ide,包括IntelliJidea,phpstorm,webstorm,pycharm,datagrip等。IntelliJ2021最新激活注册码,破解教程可免费永久激活,亲测有效,下面是详细链接哦~https://javaforall.cn/100143.html…

  • 2019 Python最新面试题及答案16道题

    在Python相关的工作岗位面试中,基础语法是必考的一个部分,它考察求职者对Python语言的最基本掌握情况;其次是数据库相关的问题,如查询,修改,插入数据等,数据库所占的比重也很大,不容小觑。

  • JAX-RS 2.0如何验证查询参数

    JAX-RS 2.0如何验证查询参数

  • iOS PerformSelector 遗漏问题

    iOS PerformSelector 遗漏问题一基础用法performSelecor响应了OC语言的动态性:延迟到运行时才绑定方法。当我们在使用以下方法时:[objperformSelector:@selector(play)];[objperformSelector:@selector(play:)withObject:@"李周"];[objperformSelector:@selector(play:with:)…

    2022年10月28日

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号