深度学习 pytorch cifar10数据集训练「建议收藏」

深度学习 pytorch cifar10数据集训练「建议收藏」1.加载数据集,并对数据集进行增强,类型转换官网cifar10数据集附链接:https://www.cs.toronto.edu/~kriz/cifar.html读取数据过程中,可以改变batch_size和num_workers来加快训练速度transform=transforms.Compose([#图像增强transforms.Resize(120),transforms.RandomHorizontalFlip(),

大家好,又见面了,我是你们的朋友全栈君。

1.加载数据集,并对数据集进行增强,类型转换
官网cifar10数据集
附链接:https://www.cs.toronto.edu/~kriz/cifar.html
在这里插入图片描述
读取数据过程中,可以改变batch_size和num_workers来加快训练速度


    transform=transforms.Compose([
        #图像增强
        transforms.Resize(120),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(96),
        transforms.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5),
        #转变为tensor 正则化
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) #正则化
    ])

    trainset=tv.datasets.CIFAR10(
        root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
        train=True,
        download=True,
        transform=transform
    )

    trainloader=data.DataLoader(
        trainset,
        batch_size=8,
        shuffle=True, #乱序
        num_workers=4,
)

    testset=tv.datasets.CIFAR10(
        root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
        train=False,
        download=True,
        transform=transform
    )

    testloader=data.DataLoader(
        testset,
        batch_size=2,
        shuffle=False,
        num_workers=2
    )

net网络:

   class Net(nn.Module):
        def  __init__(self):
            super(Net, self).__init__()
            self.conv1=nn.Conv2d(in_channels=3,out_channels=6,kernel_size=5)
            self.conv2=nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)
            self.max=nn.MaxPool2d(2,2)
            self.q1=nn.Linear(16*441,120)
            self.q2=nn.Linear(120,84)
            self.q3=nn.Linear(84,10)
            self.relu=nn.ReLU()
        def forward(self,x):
            x1=self.max(F.relu(self.conv1(x)))
            x2=F.max_pool2d(self.relu(self.conv2(x1)),2)
            x3=x2.view(x2.size()[0],-1)
            x4=F.relu(self.q1(x3))
            x5=F.relu(self.q2(x4))
            x6=self.q3(x5)
            return x6

训练模型

    net=Net()
    #损失函数
    loss=nn.CrossEntropyLoss()
    opt=optim.SGD(net.parameters(),lr=0.001)

    for epoch in range(5):
        running_loss=0.0
        for i,data in enumerate(trainloader,0):
            inputs,labels=data
            inputs=inputs.cuda()
            labels=labels.cuda()

            inputs,labels=Variable(inputs),Variable(labels)

            opt.zero_grad()
            net.to(torch.device('cuda:0'))
            h=net(inputs)
            cost=loss(h,labels)
            cost.backward()
            opt.step()

            running_loss+=cost.item()

            if i%2000==1999:
                print('[%d,%5d] loss:%.3f' %(epoch+1,i+1,running_loss/2000))

                running_loss=0.0

                torch.save(net.state_dict(),r'net.pth')


                correct=0
                total=0
                for data in testloader:
                    images,labels=data
                    optputs=net(Variable(images.cuda()))
                    _,predicted=torch.max(optputs.cpu(),1)
                    total+=labels.size(0)
                    correct+=(predicted==labels).sum()


                print("准确率: %d %%" %(100*correct/total))

接下来可以直接进行训练
在这里插入图片描述
在运行过程中会出现虚拟内存不够的情况,可以调整虚拟内存大小,解决这一问题。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/152102.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)
blank

相关推荐

  • 数据结构哈希表例题_数据结构哈希算法

    数据结构哈希表例题_数据结构哈希算法各类介绍:各类实战代码如下:(包括五种,自己可以逐个测试)#include “pch.h”#include <iostream>using namespace std;//折半查找int BinarySearchFunc(int key, int a[], int n){ int low, mid, high; //查找标记 int count …

  • 闭关六个月整理出来的微机原理知识点(特别适用河北专接本)

    闭关六个月整理出来的微机原理知识点(特别适用河北专接本)笔者准备过程中的总结,是通过填空题,简答题等等总结出来的如有不足,还望大佬们指教A14运算器和控制器又称为中央处理器(CPU)。计算机由运算器控制器存储器输入设备输出设备五大部分组成。根据传送的信息类型,系统总线可以分为三类:数据总线地址总线控制总线8086CPU由总线接口部件BIU执行部件EU组成。半导体存储器按存取方式不同,分为读写存储器RAM只读存储器ROM。读写存储器RAM指可以随机地、个别地对任意一个存储单元进行读写的存.

  • 高通msm8937的BLSP学习

    高通msm8937的BLSP学习1.基础概念(1)  BusAccessModule(BAM),总线访问模块BAMisusedtomovedatato/fromtheperipheralbuffers.(2)  BAMLow-SpeedPeripheral(BLSP),低速接口的总线访问模块(3)  QUP:QualcommUniversalPeripheral,高通统一的…

    2022年10月19日
  • python 矩阵转置 transpose

    python 矩阵转置 transpose*forin嵌套列表deftranspose1(matrix):cols=len(matrix[0])return[[row[i]forrowinmatrix]foriinrange(0,cols)]deftranspose2(matrix):transposed=[]foriinrange(len(ma…

  • 御用导航提示提醒_AR实景导航,让你安全驾驶,不再“绕弯”

    御用导航提示提醒_AR实景导航,让你安全驾驶,不再“绕弯”虽然现在手机、车机的导航能力越来越强,但是当我们遇到不熟悉的路况,特别是在立交桥和高速匝道口时还是会出拐错弯或错过路口的情况,而往往错过了一个出口,就意味着你要多跑几公里甚至更远!!基于当前复杂的行车环境,EASYOWN联合高德地图,推出了AR系列行车记录仪,在应对相关行车痛点问题上拥有完美的解决方案。EASYOWN-E3AR行车记录仪通过连接高德地图,在真实的路况信息中,加入3D…

  • linux下 VSCode快捷键

    linux下 VSCode快捷键文章目录一、常用二、全部1、常规2、基本编辑3、richlanguagesediting4、多光标和选择5、显示6、搜索与替换7、导航8、编辑页面管理9、文件管理10、终端一、常用命令作用Ctrl+,用户设置Alt+↑/↓将当前行上移或下移Ctrl+Shift+K删除行Ctrl+Shift+\跳至相匹配的括号处Ctrl+Shift+[/Ctrl+Shift+]折叠/展开当前代码块Ctrl+KCtrl+0折叠所有代码块Ctrl+

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号