pytorch中resnet_通过Pytorch实现ResNet18

pytorch中resnet_通过Pytorch实现ResNet18对于像我这样刚刚入门深度学习的同学来说,可能接触学习了一个开发工具,却没有通过运用来熟练的掌握它。而ResNet是深度学习里面一个非常重要的backbone,并且ResNet18实现起来又足够简单,所以非常适合拿来练手。我们这里的开发环境是:python3.6.10pytorch1.5.0torchvision0.6.0cudatoolkit10.2.89cudnn7.6.5首先,我们需…

大家好,又见面了,我是你们的朋友全栈君。

对于像我这样刚刚入门深度学习的同学来说,可能接触学习了一个开发工具,却没有通过运用来熟练的掌握它。而ResNet是深度学习里面一个非常重要的backbone,并且ResNet18实现起来又足够简单,所以非常适合拿来练手。

我们这里的开发环境是:

python 3.6.10

pytorch 1.5.0

torchvision 0.6.0

cudatoolkit 10.2.89

cudnn 7.6.5

首先,我们需要明确ResNet18的网络结构。在我自己学习的一开始,我对于ResNet的ShortCut机制的实现不是很清楚,当你知道怎么实现这个机制之后,那么剩下的部分也就没有什么挑战了。

论文中,ResNet各种层数的结构如下:pytorch中resnet_通过Pytorch实现ResNet18

我们观察,实际可以将ResNet18分成6个部分:

1. Conv1:也就是第一层卷积,没有shortcut机制。

2. Conv2:第一个残差块,一共有2个。

3. Conv3:第二个残差块,一共有2个。

4. Conv4:第三个残差块,一共有2个。

5. Conv5:第四个残差块,一共有2个。

6. fc:全连阶层。pytorch中resnet_通过Pytorch实现ResNet18

明确这些部分之后,我们就可以开始着手实现啦!

首先,咱们实现残差块:

import torch

import torch.nn as nn

import torch.nn.functionl as F

#定义残差块ResBlock

class ResBlock(nn.Module):

def __init__(self, inchannel, outchannel, stride=1):

super(ResBlock, self).__init__()

#这里定义了残差块内连续的2个卷积层

self.left = nn.Sequential(

nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),

nn.BatchNorm2d(outchannel),

nn.ReLU(inplace=True),

nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),

nn.BatchNorm2d(outchannel)

)

self.shortcut = nn.Sequential()

if stride != 1 or inchannel != outchannel:

#shortcut,这里为了跟2个卷积层的结果结构一致,要做处理

self.shortcut = nn.Sequential(

nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),

nn.BatchNorm2d(outchannel)

)

def forward(self, x):

out = self.left(x)

#将2个卷积层的输出跟处理过的x相加,实现ResNet的基本结构

out = out + self.shortcut(x)

out = F.relu(out)

return out

接着,我们实现ResNet18模型:

class ResNet(nn.Module):

def __init__(self, ResBlock, num_classes=10):

super(ResNet, self).__init__()

self.inchannel = 64

self.conv1 = nn.Sequential(

nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),

nn.BatchNorm2d(64),

nn.ReLU()

)

self.layer1 = self.make_layer(ResBlock, 64, 2, stride=1)

self.layer2 = self.make_layer(ResBlock, 128, 2, stride=2)

self.layer3 = self.make_layer(ResBlock, 256, 2, stride=2)

self.layer4 = self.make_layer(ResBlock, 512, 2, stride=2)

self.fc = nn.Linear(512, num_classes)

#这个函数主要是用来,重复同一个残差块

def make_layer(self, block, channels, num_blocks, stride):

strides = [stride] + [1] * (num_blocks – 1)

layers = []

for stride in strides:

layers.append(block(self.inchannel, channels, stride))

self.inchannel = channels

return nn.Sequential(*layers)

def forward(self, x):

#在这里,整个ResNet18的结构就很清晰了

out = self.conv1(x)

out = self.layer1(out)

out = self.layer2(out)

out = self.layer3(out)

out = self.layer4(out)

out = F.avg_pool2d(out, 4)

out = out.view(out.size(0), -1)

out = self.fc(out)

return out

到此,一个ResNet18网络就搭建完成了,不过,仅仅是搭建完成还是远远不够的,让我们拿它来练练手。笔者在jupyter notebook上使用CIFAR10数据集来测试我们的ResNet18模。

from resnet import ResNet18

#Use the ResNet18 on Cifar-10

import torch.optim as optim

import torchvision

import torchvision.transforms as transforms

#check gpu

device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)

#set hyperparameter

EPOCH = 10

pre_epoch = 0

BATCH_SIZE = 128

LR = 0.01

#prepare dataset and preprocessing

transform_train = transforms.Compose([

transforms.RandomCrop(32, padding=4),

transforms.RandomHorizontalFlip(),

transforms.ToTensor(),

transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

])

transform_test = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

])

trainset = torchvision.datasets.CIFAR10(root=’../data’, train=True, download=True, transform=transform_train)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root=’../data’, train=False, download=True, transform=transform_test)

testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

#labels in CIFAR10

classes = (‘plane’, ‘car’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’)

#define ResNet18

net = ResNet18().to(device)

#define loss funtion & optimizer

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)

然后开始跑模型:

#train

for epoch in range(pre_epoch, EPOCH):

print(‘\nEpoch:%d’ % (epoch + 1))

net.train()

sum_loss = 0.0

correct = 0.0

total = 0.0

for i, data in enumerate(trainloader, 0):

#prepare dataset

length = len(trainloader)

inputs, labels = data

inputs, labels = inputs.to(device), labels.to(device)

optimizer.zero_grad()

#forward & backward

outputs = net(inputs)

loss = criterion(outputs, labels)

loss.backward()

optimizer.step()

#print ac & loss in each batch

sum_loss += loss.item()

_, predicted = torch.max(outputs.data, 1)

total += labels.size(0)

correct += predicted.eq(labels.data).cpu().sum()

print(‘[epoch:%d, iter:%d] Loss:%.03f| Acc:%.3f%%’

% (epoch + 1, (i + 1 + epoch * length), sum_loss / (i + 1), 100. * correct / total))

#get the ac with testdataset in each epoch

print(‘Waiting Test…’)

with torch.no_grad():

correct = 0

total = 0

for data in testloader:

net.eval()

images, labels = data

images, labels = images.to(device), labels.to(device)

outputs = net(images)

_, predicted = torch.max(outputs.data, 1)

total += labels.size(0)

correct += (predicted == labels).sum()

print(‘Test\’s ac is:%.3f%%’ % (100 * correct / total))

print(‘Train has finished, total epoch is%d’ % EPOCH)

如果不出意外,这个模型就已经跑起来了,到这里,咱们就已经完成的实现了一个ResNet18网络,这个模型的jupyter notebook源码我已经放到了github上,如果这片文章对你有帮助,那就给我star一下吧:samcw/ResNet18-Pytorch​github.compytorch中resnet_通过Pytorch实现ResNet18

参考:

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/141371.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • jquery checkbox 设置选中和不选中

    jquery checkbox 设置选中和不选中1.设置选中:$(“#hasApply”).prop(“checked”,true);设置不选中:$(“#hasApply”).prop(“checked”,false);或如下方法://$(“#ck”).attr(“checked”,true)//选中//$(“#ck”).attr(“checked”,false)//未选中2.获取选中的状态:varstatus…

  • MySQL中多表删除方法

    MySQL中多表删除方法

  • mysql != 索引_Mysql语法

    mysql != 索引_Mysql语法转:https://www.cnblogs.com/huanzi-qch/p/15238604.html介绍通常情况下,全文检索引擎我们一般会用ES组件(传送门:SpringBoot系列——ElasticSearch),但不是所有业务都有那么大的数据量、那么大的并发要求,MySQL5.7之后内置了ngram分词器,支持中文分词,使用全文索引,即可实现对中文语义分词检索MySQL支持全文索引和搜索:  MySQL中的全文索引是FULLTEXT类型的索引。  全文索引只能用于InnoDB或My

  • 【MyBatis学习13】MyBatis中的二级缓存[通俗易懂]

    【MyBatis学习13】MyBatis中的二级缓存[通俗易懂]1.二级缓存的原理  前面介绍了,mybatis中的二级缓存是mapper级别的缓存,值得注意的是,不同的mapper都有一个二级缓存,也就是说,不同的mapper之间的二级缓存是互不影响的。为了更加清楚的描述二级缓存,先来看一个示意图:  从图中可以看出:sqlSession1去查询用户id为1的用户信息,查询到用户信息会将查询数据存储到该UserMapper的二级缓存中。

  • mysql中phpmyadmin安装教程_安装phpMyAdmin图文教程

    mysql中phpmyadmin安装教程_安装phpMyAdmin图文教程phpmyadmin的安装配置已经是老生常谈的话题了,网络上到处都可以找到相关的配置教程。但是,那些大多都是手动配置的,稍不留神,容易出错。因此站长今天在这里介绍的是,被很多phpmyadmin用户所忽略的phpmyadmin自带的安装程序,下面我们就开始一步一步来安装phpmyadmin。1、首先下载phpmyadmin3.4.11,这是目前最稳定无bug的版本,点击下载2、在你的web根目录新…

  • Struts2–自定义拦截器三种方式(实现Interceptor接口、继承抽象类AbstractInterceptor、继承MethodFilterInterceptor)「建议收藏」

    Struts2–自定义拦截器三种方式(实现Interceptor接口、继承抽象类AbstractInterceptor、继承MethodFilterInterceptor)「建议收藏」实现自定义拦截器在实际的项目开发中,虽然Struts2的内建拦截器可以完成大部分的拦截任务,但是,一些与系统逻辑相关的通用功能(如权限的控制和用户登录控制等),则需要通过自定义拦截器实现。本节将详细讲解如何自定义拦截器。1.实现Interceptor接口在Struts2框架中,通常开发人员所编写的自定义拦截器类都会直接或间接地实现com.opensymphony.xwork2.in…

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号