resnet源码pytorch_pytorch conv1d

resnet源码pytorch_pytorch conv1d#Pytorch 0.4.0 ResNet34实现cifar10分类.#@Time:2018/6/17#@Author:xfLiimporttorchvisionastvimporttorchastimporttorchvision.transformsastransformsfromtorchimportnnfromtorch.utils.da…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE稳定放心使用


# Pytorch 0.4.0 ResNet34实现cifar10分类.
# @Time: 2018/6/17
# @Author: xfLi

import torchvision as tv
import torch as t
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
t.set_num_threads(8)


class ResidualBloak(nn.Module):
    #残差块
    def __init__(self, inchannel, outchannel, stride=1, shortcut=None):
        super(ResidualBloak, self).__init__()
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, 3, stride, 1, bias=False),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, 3, 1, 1, bias=False),
            nn.BatchNorm2d(outchannel))
        self.right = shortcut

    def forward(self, x):
        out = self.left(x)
        residual = x if self.right is None else self.right(x)
        out += residual
        return F.relu(out)

class ResNet34(nn.Module):
    #  实现主module:ResNet34  
    #  ResNet34 包含多个layer,每个layer又包含多个residual block  
    #  用子module来实现residual block,用_make_layer函数来实现layer 
    def __init__(self, num_classes):
        super(ResNet34, self).__init__()
        #前几层图像转换
        self.pre = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2, 1))

        # 重复的layer,分别有3,4,6,3个residual block
        self.layer1 = self._make_layer(16, 16, 3, stride=1)
        self.layer2 = self._make_layer(16, 32, 4, stride=1)
        self.layer3 = self._make_layer(32, 64, 6, stride=1)
        self.layer4 = self._make_layer(64, 64, 3, stride=1)
        #分类用的全连接
        self.fc = nn.Linear(256, num_classes)

    def _make_layer(self, inchannel, outchannel, block_num, stride=1):
        #构建layer,包含多个residual block
        shortcut = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, 1, stride, bias=False),
            nn.BatchNorm2d(outchannel))
        layer = []
        layer.append(ResidualBloak(inchannel, outchannel, stride, shortcut))
        for i in range(1, block_num):
            layer.append(ResidualBloak(outchannel, outchannel))
        return nn.Sequential(*layer)

    def forward(self, x):
        x = self.pre(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = F.avg_pool2d(x, 7)
        x = x.view(x.size(0), -1)
        return self.fc(x)

def getData(): # 定义对数据的预处理  
    transform = transforms.Compose([
        transforms.Resize(40),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32),
        transforms.ToTensor()])
    #训练集
    trainset = tv.datasets.CIFAR10(root='/data/', train=True, transform=transform, download=True)
    trainset_loader = DataLoader(trainset, batch_size=4, shuffle=True)
    #测试集
    testset = tv.datasets.CIFAR10(root='/data/', train=False, transform=transform, download=True)
    testset_loader = DataLoader(testset, batch_size=4, shuffle=False)
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    return trainset_loader, testset_loader, classes

def train(): #训练
    trainset_loader, testset_loader, _ = getData() #获取数据
    net = ResNet34(10)
    print(net)
    criterion = nn.CrossEntropyLoss()
    optimizer = t.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) #优化器

    for epoch in range(1):
        for step, (inputs,labels) in enumerate(trainset_loader):
            optimizer.zero_grad() #梯度清零
            output = net(inputs)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            if step % 10 ==9:
                acc = test(net, testset_loader)
                print('Epoch', epoch, '|step ', step, 'loss: %.4f' %loss.item(), 'test accuracy:%.4f' %acc)
    print('Finished Training')
    return net

def test(net, testdata): #测试集
    correct, total = .0, .0
    for inputs, label in testdata:
        net.eval()
        output = net(inputs)
        _, predicted = t.max(output, 1) #分类结果
        total += label.size(0)
        correct += (predicted == label).sum()
    return float(correct) / total

if __name__ == '__main__':
    net = train()








版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/185357.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • StoredProcedure “存储过程名” 的TextHeader 中存在语法错误

    StoredProcedure “存储过程名” 的TextHeader 中存在语法错误修改存储过程的时候出现StoredProcedure“存储过程名”的TextHeader中存在语法错误出现这样的问题的解决方法(本人修改已成功)在创建存储过程的时候加了注释,把注释删掉就没有问题了(或者把注释放到其他地方)错误代码如下:CREATEPROCEDURE[dbo].[tableToTxtExport]@dbTabNamenvarchar(4000),@dbBoo…

  • pycharm 滚轮字体大小设置_pycharm中文字体设置

    pycharm 滚轮字体大小设置_pycharm中文字体设置pycharm用鼠标滚轮控制字体大小一、file–>settings或者工具栏中点击二、搜索increase三、双击increaseFrontsize进入AddMouseShortcut四、摁住ctrl向上滚动鼠标滑轮。点击ok,即可实现ctrl+向上滚轮增加字体 大小。五、实现减小字体大小搜索decrease同上类似操作即可完成…

    2022年10月24日
  • cisco光纤交换机配置「建议收藏」

    cisco光纤交换机配置「建议收藏」1.初始化信息首次设置,必须通过console进行连接(需要U口转DB9针的接口线,专门卖接口线的有卖大约30元),然后进行初始化设计,以后设定IP后可通过LAN进行登陆具体步骤:(红色字体部分着重注意,需要进行设置,大部分按照默认设置即可,而且设置的部分进入管理工具软件可以更改)—-SystemAdminAccountSetup—-Enterthep…

  • qt 当前窗口句柄_QT获取窗口句柄

    qt 当前窗口句柄_QT获取窗口句柄mac安装paramiko$brewinstallopenssl$/usr/local/opt/openssl/bin/c_rehash$exportARCHFLAGS=”-archx86_64&q…用Canvas制作剪纸效果在做剪纸效果之前,先介绍剪纸效果运用到的一些知识:1.阴影:在Canvas之中进行绘制时,可以通过修改绘图环境中的如下4个属性值来指定阴影…

  • 静态路由(静态汇总路由,静态默认路由,负载均衡,浮动静态路由)介绍

    静态路由(静态汇总路由,静态默认路由,负载均衡,浮动静态路由)介绍网络上通过硬件设备传递数据。最常见的就是路由器和交换机。本篇介绍路由器如何使用静态路由条目来转发数据。一个数据包从源IP地址到目标IP地址间可能穿过多个路由器,也可能有多条路径通往目标IP地址。那路由器收到数据后,如何知道哪个端口能通往目标地址呢?如果多个端口都可通往目标地址,又如何选择用哪个端口转发才是最优路径呢?依据的就是路由表。路由表就是路由器的灵魂

  • 阿里云服务器开放端口设置_阿里云服务器开启全部端口

    阿里云服务器开放端口设置_阿里云服务器开启全部端口一、问题未开放端口号,如何开放端口号呢?咱们下边以redis为例二、操作1、阿里云部分先把服务器上的实例配置打开进入安全组规则选择添加或者手动编辑,我这里已经有了redis,所以随意添加一个为例这样就添加成功了!2、在linux系统中检查端口号是否存在#查看是否开启了6379端口号firewall-cmd–list-ports发现报如下错误:表示没有开启防火墙,下面我们先开启防火墙#开启防火墙systemctls…

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号