(深度学习)Pytorch之dropout训练

(深度学习)Pytorch之dropout训练(深度学习)Pytorch学习笔记之dropout训练Dropout训练实现快速通道:点我直接看代码实现Dropout训练简介在深度学习中,dropout训练时我们常常会用到的一个方法——通过使用它,我们可以可以避免过拟合,并增强模型的泛化能力。通过下图可以看出,dropout训练训练阶段所有模型共享参数,测试阶段直接组装成一个整体的大网络:那么,我们在深度学习的有力工具——Pytor…

大家好,又见面了,我是你们的朋友全栈君。

(深度学习)Pytorch学习笔记之dropout训练

Dropout训练实现快速通道:点我直接看代码实现

Dropout训练简介

在深度学习中,dropout训练时我们常常会用到的一个方法——通过使用它,我们可以可以避免过拟合,并增强模型的泛化能力。

通过下图可以看出,dropout训练训练阶段所有模型共享参数,测试阶段直接组装成一个整体的大网络:
在这里插入图片描述

那么,我们在深度学习的有力工具——Pytorch中如何实现dropout训练呢?

易错大坑

网上查找到的很多实现都是这种形式的:

        out = F.dropout(out, p=0.5)

这种形式的代码非常容易误导初学者,给人带来很大的困扰:

  • 首先,这里的F.dropout实际上是torch.nn.functional.dropout的简写(很多文章都没说清这一点,就直接给个代码),我尝试了一下我的Pytorch貌似无法使用,可能是因为版本的原因。
  • 其次,torch.nn.functional.dropout()还有个大坑:F.dropout()相当于引用的一个外部函数,模型整体的training状态变化也不会引起F.dropout这个函数的training状态发生变化。因此,上面的代码实质上就相当于out = out

因此,如果你非要使用torch.nn.functional.dropout的话,推荐的正确方法如下(这里默认你已经import torch.nn as nn了):

       out = nn.functional.dropout(out, p=0.5, training=self.training)

推荐代码实现方法

这里更推荐的方法是:nn.Dropout(p),其中p是采样概率。nn.Dropout实际上是对torch.nn.functional.dropout的一个包装, 也将self.training传入了其中,可以有效避免前面所说的大坑。

下面给出一个三层神经网络的例子:

import torch.nn as nn


input_size = 28 * 28   
hidden_size = 500   
num_classes = 10    


# 三层神经网络
class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)  # 输入层到影藏层
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)  # 影藏层到输出层
        self.dropout = nn.Dropout(p=0.5)  # dropout训练

    def forward(self, x):
        out = self.fc1(x)
        out = self.dropout(out)
        out = self.relu(out)
        out = self.fc2(out)
        return out
   

model = NeuralNet(input_size, hidden_size, num_classes)
model.train()
model.eval()

另外还有一点需要说明的是,训练阶段随机采样时需要用model.train(),而测试阶段直接组装成一个整体的大网络时需要使用model.eval():

  • 如果你二者都没使用的话,默认情况下实际上是相当于使用了model.train(),也就是开启dropout随机采样了——这样你如果你是在测试的话,准确率可能会受到影响。
  • 如果你不希望开启dropout训练,想直接以一个整体的大网络来训练,不需要重写一个网络结果,而只需要在训练阶段开启model.eval()即可。
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/133466.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)
blank

相关推荐

  • 语音信号处理入门系列(1)—— 语音信号处理概念「建议收藏」

    语音信号处理入门系列(1)—— 语音信号处理概念「建议收藏」文章目录1.语音交互2.复杂的声学环境2.1声学回声消除2.2解混响2.3语音分离2.4波束形成2.5噪声抑制2.6幅度控制2.7前端信号处理的技术路线3.参考4.推荐开源项目原博客地址:https://www.cnblogs.com/LXP-Never/p/13620804.html1.语音交互你知道苹果手机有几个麦克风吗?语音交互(VUI)是指人与人/设备通过自然语音进行信息传递的过程。语音交互的优势:输入效率高。语音输入的速度是传统键盘输入方式的3倍以上。例如:语

  • 数据挖掘应用实例分析

    数据挖掘应用实例分析数据挖掘应用实例分析——个性化推荐系统​ 数据挖掘技术,一门基于计算机技术与大数据时代信息处理需求的技术产物,从世纪之交的火热发展以来,不知不觉间,早已应用到我们生活的方方面面:电子邮箱中的垃圾邮件分类、电影院的票房预测、网页上的广告推荐、语音识别、电网语义精确搜索等。还有人工智能、自然语言处理、数据修正等。我们认为,数据挖掘技术将成为互联网时代应用最广泛的技术之一,它有可能为人类社会带来一个新的时代。​ 但是由于笔者才疏学浅,今天我们暂不谈得那么高深,只分析的一个常见的应用实例——个性化推荐系统。

  • spinner:获取选中值的三种方法

    spinner:获取选中值的三种方法

  • Shell if else语句「建议收藏」

    Shell if else语句「建议收藏」Shellifelse语句if语句通过关系运算符判断表达式的真假来决定执行哪个分支。Shell有三种if…else语句:if…fi语句; if…else…fi语句; if…elif…else…fi语句。1)if…else语句if…else语句的语法:if[expression…

  • java cloneable 接口_Cloneable 接口 记号接口(标记接口)「建议收藏」

    java cloneable 接口_Cloneable 接口 记号接口(标记接口)「建议收藏」Cloneable接口指示了一个类提供了一个安全的clone方法。首先了解Object.clone()方法:clone是Object超类的一个protected方法,用户代码不能直接调用这个方法。Object的子类只能调用Object超类中受保护的clone方法来克隆它自己的对象,必须重新定义clone为public才能允许所有方法调用这个类的实例的clone方法克隆对象。clone方法的作用:…

  • html表单提交_html表单标签有哪些

    html表单提交_html表单标签有哪些1.表单属性设置<form>标签表示表单标签,定义整体的表单区域action属性设置表单数据提交地址method属性设置表单提交的方式,一般有“GET”方式和“POST”方式,不区分大小写2.表单元素属性设置name属性设置表单元素的名称,该名称是提交数据时的参数名value属性设置表单元素的值,该值是提交数据时参数名所对应的值3.示例代码<formaction=”https://www.baidu.com”method=”GET”>

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号