batchnorm pytorch_Pytorch中的BatchNorm

batchnorm pytorch_Pytorch中的BatchNorm前言:本文主要介绍在pytorch中的BatchNormalization的使用以及在其中容易出现的各种小问题,本来此文应该归属于[1]中的,但是考虑到此文的篇幅可能会比较大,因此独立成篇,希望能够帮助到各位读者。如有谬误,请联系指出,如需转载,请注明出处,谢谢。∇∇\nabla∇联系方式:e-mail:FesianXu@163.comQQ:973926198github:https:/…

大家好,又见面了,我是你们的朋友全栈君。

前言:

本文主要介绍在pytorch中的Batch Normalization的使用以及在其中容易出现的各种小问题,本来此文应该归属于[1]中的,但是考虑到此文的篇幅可能会比较大,因此独立成篇,希望能够帮助到各位读者。如有谬误,请联系指出,如需转载,请注明出处,谢谢。

∇∇\nabla∇ 联系方式:

e-mail: FesianXu@163.com

QQ: 973926198

github: https://github.com/FesianXu

Batch Normalization,批规范化

Batch Normalization(简称为BN)[2],中文翻译成批规范化,是在深度学习中普遍使用的一种技术,通常用于解决多层神经网络中间层的协方差偏移(Internal Covariate Shift)问题,类似于网络输入进行零均值化和方差归一化的操作,不过是在中间层的输入中操作而已,具体原理不累述了,见[2-4]的描述即可。

在BN操作中,最重要的无非是这四个式子:

Unexpected text node: ‘ ’Unexpected text node: ‘ ’Input:Output:更新过程:μB​σB2​x^i​yi​​B={x1​,⋯,xm​},为m个样本组成的一个batch数据。需要学习到的是γ和β,在框架中一般表述成weight和bias。←m1​i=1∑m​xi​//得到batch中的统计特性之一:均值←m1​i=1∑m​(xi​−μB​)2//得到batch中的另一个统计特性:方差←σB2​+ϵ​xi​−μB​​//规范化,其中ϵ是一个很小的数,防止计算出现数值问题。←γx^i​+β≡BNγ,β​(xi​)//这一步是输出尺寸伸缩和偏移。​

注意到这里的最后一步也称之为仿射(affine),引入这一步的目的主要是设计一个通道,使得输出output至少能够回到输入input的状态(当γ=1,β=0γ=1,β=0\gamma=1,\beta=0γ=1,β=0时)使得BN的引入至少不至于降低模型的表现,这是深度网络设计的一个套路。

整个过程见流程图,BN在输入后插入,BN的输出作为规范后的结果输入的后层网络中。

forwardbackwardforwardbackwardinput batchBatch_NormOutput batch

好了,这里我们记住了,在BN中,一共有这四个参数我们要考虑的:

γ,βγ,β\gamma, \betaγ,β:分别是仿射中的weightweight\mathrm{weight}weight和biasbias\mathrm{bias}bias,在pytorch中用weight和bias表示。

μℬμB\mu_{\mathcal{B}}μB​和σ2ℬσB2\sigma_{\mathcal{B}}^2σB2​:和上面的参数不同,这两个是根据输入的batch的统计特性计算的,严格来说不算是“学习”到的参数,不过对于整个计算是很重要的。在pytorch中,用running_mean和running_var表示[5]

在Pytorch中使用

Pytorch中的BatchNorm的API主要有:

torch.nn.BatchNorm1d(num_features,

eps=1e-05,

momentum=0.1,

affine=True,

track_running_stats=True)1

2

3

4

5

一般来说pytorch中的模型都是继承nn.Module类的,都有一个属性trainning指定是否是训练状态,训练状态与否将会影响到某些层的参数是否是固定的,比如BN层或者Dropout层。通常用model.train()指定当前模型model为训练状态,model.eval()指定当前模型为测试状态。

同时,BN的API中有几个参数需要比较关心的,一个是affine指定是否需要仿射,还有个是track_running_stats指定是否跟踪当前batch的统计特性。容易出现问题也正好是这三个参数:trainning,affine,track_running_stats。

其中的affine指定是否需要仿射,也就是是否需要上面算式的第四个,如果affine=False则γ=1,β=0γ=1,β=0\gamma=1,\beta=0γ=1,β=0,并且不能学习被更新。一般都会设置成affine=True[10]

trainning和track_running_stats,track_running_stats=True表示跟踪整个训练过程中的batch的统计特性,得到方差和均值,而不只是仅仅依赖与当前输入的batch的统计特性。相反的,如果track_running_stats=False那么就只是计算当前输入的batch的统计特性中的均值和方差了。当在推理阶段的时候,如果track_running_stats=False,此时如果batch_size比较小,那么其统计特性就会和全局统计特性有着较大偏差,可能导致糟糕的效果。

一般来说,trainning和track_running_stats有四种组合[7]

trainning=True, track_running_stats=True。这个是期望中的训练阶段的设置,此时BN将会跟踪整个训练过程中batch的统计特性。

trainning=True, track_running_stats=False。此时BN只会计算当前输入的训练batch的统计特性,可能没法很好地描述全局的数据统计特性。

trainning=False, track_running_stats=True。这个是期望中的测试阶段的设置,此时BN会用之前训练好的模型中的(假设已经保存下了)running_mean和running_var并且不会对其进行更新。一般来说,只需要设置model.eval()其中model中含有BN层,即可实现这个功能。[6,8]

trainning=False, track_running_stats=False 效果同(2),只不过是位于测试状态,这个一般不采用,这个只是用测试输入的batch的统计特性,容易造成统计特性的偏移,导致糟糕效果。

同时,我们要注意到,BN层中的running_mean和running_var的更新是在forward()操作中进行的,而不是optimizer.step()中进行的,因此如果处于训练状态,就算你不进行手动step(),BN的统计特性也会变化的。如

model.train() # 处于训练状态

for data, label in self.dataloader:

pred = model(data)

# 在这里就会更新model中的BN的统计特性参数,running_mean, running_var

loss = self.loss(pred, label)

# 就算不要下列三行代码,BN的统计特性参数也会变化

opt.zero_grad()

loss.backward()

opt.step()1

2

3

4

5

6

7

8

9

10

这个时候要将model.eval()转到测试阶段,才能固定住running_mean和running_var。有时候如果是先预训练模型然后加载模型,重新跑测试的时候结果不同,有一点性能上的损失,这个时候十有八九是trainning和track_running_stats设置的不对,这里需要多注意。 [8]

假设一个场景,如下图所示:

inputmodel_Amodel_Boutput

此时为了收敛容易控制,先预训练好模型model_A,并且model_A内含有若干BN层,后续需要将model_A作为一个inference推理模型和model_B联合训练,此时就希望model_A中的BN的统计特性值running_mean和running_var不会乱变化,因此就必须将model_A.eval()设置到测试模式,否则在trainning模式下,就算是不去更新该模型的参数,其BN都会改变的,这个将会导致和预期不同的结果。

Reference

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/137457.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • 马来西亚最大的电商平台_东南亚最受欢迎的跨境电商平台

    马来西亚最大的电商平台_东南亚最受欢迎的跨境电商平台一直以来,马来西亚电商市场几乎被Shopee和Lazada两大电商平台所统治,国际巨头占据主要市场。马来西亚电商平台TOP10中,Shopee和Lazada两大电商平台共占据了83.58%的网站流量,是马来电商入驻首选平台。然而直到2020年,Shopee超过了Lazada,拉开了距离,Shopee月均流量已达到Lazada的两倍以上。与此同时,马来西亚本土电商PGMall也在2020年的竞争中战胜Zalora与Lelong,稳固了他在马来西亚前三甲的地位。目前,无需注册马来西亚本地公司即可直接在

  • redis配置文件_redis怎么连接

    redis配置文件_redis怎么连接dd#redis配置开始#Redis数据库索引(默认为0)spring.redis.database=0#Redis服务器地址#redis.host=192.168.59.43redis.host1=192.168.58.11redis.host2=192.168.58.12redis.host3=192.168.58.13#Redis服务器连接端口redis.port=6379redis.master.port=6379redis.slave.port=6380#Re.

  • matlab插值实验目的,matlab插值实验报告数学实验.doc

    matlab插值实验目的,matlab插值实验报告数学实验.docmatlab插值实验报告数学实验.doc新乡学院数学与信息科学系实验报告实验项目名称插值实验所属课程名称数学实验实验类型综合性实验实验日期班级学号姓名成绩一、实验概述【实验目的】掌握用MATLAB插值的方法,了解拉格朗日插值、线性插值、样条插值的基本思想,了解三种网格节点数据的插值方法的基本思想,了解掌握用MATLAB计算一维差值和二维插值的方法。【实验原理】拉格朗日LAGRANGE插值。已知函…

  • excel多列合并关联数据[通俗易懂]

    excel多列合并关联数据[通俗易懂]假设现在有三张表第一张第二张第三张姓名与操作id相对应,现在想弄出这样的一个表,将多列数据整合起来那怎么做呢?需要用到函数vlookup这个查找值是合并时不变的那列,在这个案例下,就是指日期+姓名+操作id这三列,但是这里是不能写这么多的,只能是一列的第一个值,作为查找值,应该是像主键一样具有唯一的id。第一步,将三列合并为一列,需要用到函数concatenate公式

  • hashmap和hashtable的区别,说法错误的是_javamap的用法

    hashmap和hashtable的区别,说法错误的是_javamap的用法HashMap和Hashtable的区别一、HashMap简介HashMap是在JDK1.2中引入的Map的实现类。1.HashMap是基于哈希表实现的,每一个元素是一个key-value对,其内部通过单链表解决冲突问题,容量不足(超过了阀值)时,同样会自动增长。2.HashMap是非线程安全的,只是用于单线程环境下,多线程环境下可以采用concurrent并发包下的concurren…

  • Python for循环的使用

    Python for循环的使用Pythonfor循环的使用(一)for循环的使用场景1.如果我们想要某件事情重复执行具体次数的时候可以使用for循环。2.for循环主要用来遍历、循环、序列、集合、字典,文件、甚至是自定义类或函数。(二)for循环操作列表实例演示使用for循环对列表进行遍历元素、修改元素、删除元素、统计列表中元素的个数。1.for循环用来遍历整个列表#for循环主

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号