ResNet18复现「建议收藏」

ResNet18复现「建议收藏」ResNet18的网络架构图首先将网络分为四层(layers),每层有两个模块组成,除了第一层是两个普通的残差块组成,其它三层有一个普通的残差块和下采样的卷积块组成。输入图像为3x224x224格式,经过卷积池化后为64x112x112格式进入主网络架构。代码如下:importtorchfromtorchimportnnfromtorch.nnimportfunctionalasFclassBasicBlock(nn.Module):def__ini

大家好,又见面了,我是你们的朋友全栈君。

ResNet18的网络架构图

ResNet18复现「建议收藏」

首先将网络分为四层(layers),每层有两个模块组成,除了第一层是两个普通的残差块组成,其它三层有一个普通的残差块和下采样的卷积块组成。输入图像为3x224x224格式,经过卷积池化后为64x112x112格式进入主网络架构。

代码如下:

import torch
from torch import nn
from torch.nn import functional as F

class BasicBlock(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size,stride):
        super(BasicBlock,self).__init__()
        self.conv1=nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding=1)
        self.bn1=nn.BatchNorm2d(out_channels)
        self.conv2=nn.Conv2d(out_channels,out_channels,kernel_size,stride,padding=1)
        self.bn2=nn.BatchNorm2d(out_channels)
        
    def forward(self,x):
        output=self.bn1(self.conv1(x))
        output=self.bn2(self.conv2(output))
        return F.relu(x+output)
    

class BasicDownBlock(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size,stride):
        super(BasicDownBlock,self).__init__()     
        self.conv1=nn.Conv2d(in_channels,out_channels,kernel_size[0],stride[0],padding=1)
        self.bn1=nn.BatchNorm2d(out_channels)
        self.conv2=nn.Conv2d(out_channels,out_channels,kernel_size[0],stride[1],padding=1)
        self.bn2=nn.BatchNorm2d(out_channels)
        self.conv3=nn.Conv2d(in_channels,out_channels,kernel_size[1],stride[0])
        self.bn3=nn.BatchNorm2d(out_channels)
        
    def forward(self,x):
        output=self.bn1(self.conv1(x))
        output=self.bn2(self.conv2(output))
        output1=self.bn3(self.conv3(x))
        return F.relu(output1+output)

class ResNet18(nn.Module):
    def __init__(self):
        super().__init__()
        # 3x224x224-->64x112x112
        self.conv1=nn.Conv2d(in_channels=3,out_channels=64,kernel_size=7,stride=2,padding=3)
        self.bn1=nn.BatchNorm2d(64)
        # 64x112x112-->64x56x56
        self.pool1=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        
        # 64x56x56-->64x56x56
        self.layer1=nn.Sequential(
            BasicBlock(64,64,3,1),
            BasicBlock(64,64,3,1)
        )
        # 64x56x56-->128*28*28
        self.layer2=nn.Sequential(
            BasicDownBlock(64,128,[3,1],[2,1]),
            BasicBlock(128,128,3,1)
        )
        # 128*28*28-->256*14*14
        self.layer3=nn.Sequential(
            BasicDownBlock(128,256,[3,1],[2,1]),
            BasicBlock(256,256,3,1)
        )
        # 256*14*14-->512x7x7
        self.layer4=nn.Sequential(
            BasicDownBlock(256,512,[7,1],[2,1]),
            BasicBlock(512,512,3,1)
        )
        # 512x7x7-->512x1x1
        self.avgpool=nn.AdaptiveMaxPool2d(output_size=(1,1))
        self.flat=nn.Flatten()
        self.linear=nn.Linear(512,10)
        
    def forward(self,x):
        output=self.pool1(F.relu(self.bn1(self.conv1(x))))
        output=self.layer1(output)
        output=self.layer2(output)
        output=self.layer3(output)
        output=self.layer4(output)
        output=self.avgpool(output)
        output=self.flat(output)
        output=self.linear(output)
        return output
    

net=ResNet18()
x=torch.randn(32,3,224,224)
print(x.shape)
y=net(x)
print(y.shape)

代码中BasicBlock为普通的残差块,注意步长和卷积核的大小,BasicDownBlock为下采样的残差块,然后将四层的网络表示出来,最后进行验证x.shape为torch.Size([32, 3, 224, 224]),y.shape为torch.Size([32, 10])。 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/141545.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)
blank

相关推荐

  • goland激活码最新【2021.7最新】

    (goland激活码最新)JetBrains旗下有多款编译器工具(如:IntelliJ、WebStorm、PyCharm等)在各编程领域几乎都占据了垄断地位。建立在开源IntelliJ平台之上,过去15年以来,JetBrains一直在不断发展和完善这个平台。这个平台可以针对您的开发工作流进行微调并且能够提供…

  • strip 命令的使用方法

    strip 命令的使用方法

  • linux cat /etc/passwd 说明

    linux cat /etc/passwd 说明

    2021年10月27日
  • java实体entity转map对象[通俗易懂]

    java实体entity转map对象[通俗易懂]实体转对象方法一,一句搞定,直接返回map对象:importorg.springframework.cglib.beans.BeanMap;BeanMap.create(entityObj);方法二:利用反射——详见原文

  • 图析,Pycharm 上如何设置QT环境[通俗易懂]

    图析,Pycharm 上如何设置QT环境[通俗易懂]一、参数设置文件–设置–外部工具–“+”–“ExternalTools”下两个设置1.QtDesigner和2.PyUIC1.QtDesigner参数设置:(1.)名称框:QtDesigner(2.)工具设置–程序框:填写Qta安装的路径Designer.exe(例:C:\ProgramData\Anaconda3\Library\bin\designer.exe注:Anaconda3目录下有designer.exe软件,无需下载.

  • 大数据开发步骤和流程「建议收藏」

    大数据项目开发步骤:第一步:需求:数据的输入和数据的产出;第二步:数据量、处理效率、可靠性、可维护性、简洁性;第三步:数据建模;第四步:架构设计:数据怎么进来,输出怎么展示,最最重要的是处理流出数据的架构;第五步:再次思考大数据系统和企业IT系统的交互;第六步:最终确定选择、规范等;第七步:基于数据建模写基础服务代码;第八步:正式编写第一个模块;第九步:实现其它…

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号