深度学习 pytorch cifar10数据集训练「建议收藏」

深度学习 pytorch cifar10数据集训练「建议收藏」1.加载数据集,并对数据集进行增强,类型转换官网cifar10数据集附链接:https://www.cs.toronto.edu/~kriz/cifar.html读取数据过程中,可以改变batch_size和num_workers来加快训练速度transform=transforms.Compose([#图像增强transforms.Resize(120),transforms.RandomHorizontalFlip(),

大家好,又见面了,我是你们的朋友全栈君。

1.加载数据集,并对数据集进行增强,类型转换
官网cifar10数据集
附链接:https://www.cs.toronto.edu/~kriz/cifar.html
在这里插入图片描述
读取数据过程中,可以改变batch_size和num_workers来加快训练速度


    transform=transforms.Compose([
        #图像增强
        transforms.Resize(120),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(96),
        transforms.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5),
        #转变为tensor 正则化
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) #正则化
    ])

    trainset=tv.datasets.CIFAR10(
        root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
        train=True,
        download=True,
        transform=transform
    )

    trainloader=data.DataLoader(
        trainset,
        batch_size=8,
        shuffle=True, #乱序
        num_workers=4,
)

    testset=tv.datasets.CIFAR10(
        root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
        train=False,
        download=True,
        transform=transform
    )

    testloader=data.DataLoader(
        testset,
        batch_size=2,
        shuffle=False,
        num_workers=2
    )

net网络:

   class Net(nn.Module):
        def  __init__(self):
            super(Net, self).__init__()
            self.conv1=nn.Conv2d(in_channels=3,out_channels=6,kernel_size=5)
            self.conv2=nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)
            self.max=nn.MaxPool2d(2,2)
            self.q1=nn.Linear(16*441,120)
            self.q2=nn.Linear(120,84)
            self.q3=nn.Linear(84,10)
            self.relu=nn.ReLU()
        def forward(self,x):
            x1=self.max(F.relu(self.conv1(x)))
            x2=F.max_pool2d(self.relu(self.conv2(x1)),2)
            x3=x2.view(x2.size()[0],-1)
            x4=F.relu(self.q1(x3))
            x5=F.relu(self.q2(x4))
            x6=self.q3(x5)
            return x6

训练模型

    net=Net()
    #损失函数
    loss=nn.CrossEntropyLoss()
    opt=optim.SGD(net.parameters(),lr=0.001)

    for epoch in range(5):
        running_loss=0.0
        for i,data in enumerate(trainloader,0):
            inputs,labels=data
            inputs=inputs.cuda()
            labels=labels.cuda()

            inputs,labels=Variable(inputs),Variable(labels)

            opt.zero_grad()
            net.to(torch.device('cuda:0'))
            h=net(inputs)
            cost=loss(h,labels)
            cost.backward()
            opt.step()

            running_loss+=cost.item()

            if i%2000==1999:
                print('[%d,%5d] loss:%.3f' %(epoch+1,i+1,running_loss/2000))

                running_loss=0.0

                torch.save(net.state_dict(),r'net.pth')


                correct=0
                total=0
                for data in testloader:
                    images,labels=data
                    optputs=net(Variable(images.cuda()))
                    _,predicted=torch.max(optputs.cpu(),1)
                    total+=labels.size(0)
                    correct+=(predicted==labels).sum()


                print("准确率: %d %%" %(100*correct/total))

接下来可以直接进行训练
在这里插入图片描述
在运行过程中会出现虚拟内存不够的情况,可以调整虚拟内存大小,解决这一问题。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/152102.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • char c=168_char16_t

    char c=168_char16_taboutaarch64FocusonhighperformanceExceptionlevelsinsteadofdifferentmodesvirtualisationsupportbuilt-in32bitfixedlengthinstructionmoreregistersdivideinstructioncompare&jumpin

    2022年10月10日
  • FileSystemWatcher 基础用法[通俗易懂]

    FileSystemWatcher 基础用法[通俗易懂]转载自[http://blog.sina.com.cn/s/blog_532804fc0100dzuz.html]1.FileSystemWatcher基础在应用FileSystemWatcher对象之前,必须了解这个对象的一些基本属性和事件。毫无疑问,这个对象的最重要的

  • SpringMVC中使用Ajax POST请求以json格式传递参数服务端通过request.getParameter(“name”)无法获取参数值问题分析

    SpringMVC中使用Ajax POST请求以json格式传递参数服务端通过request.getParameter(“name”)无法获取参数值问题分析一:问题demo展示在开发新需求,调试代码的时候发现一个问题,就是HttpServletRequest 获取不到ajax post请求的json参数!下面是伪代码是整个请求的逻辑!1.前台JS请求代码(伪代码)

  • 预测算法用java实现吗_java 数据结构与算法

    预测算法用java实现吗_java 数据结构与算法常见的预测算法有1.简易平均法,包括几何平均法、算术平均法及加权平均法;2.移动平均法,包括简单移动平均法和加权移动平均法;3,指数平滑法,包括一次指数平滑法和二次指数平滑法,三次指数平滑法;4,线性回归法,包括一元线性回归和二元线性回归,下面我一一的简单介绍一下各种方法。 4P5?.C(B4j”^5_2h  一,简易平均法,是一种简便的时间序列法。是以一定观察期的数据求得

    2022年10月31日
  • mysql explain ref const_MySQL EXPLAIN 详解「建议收藏」

    mysql explain ref const_MySQL EXPLAIN 详解「建议收藏」一.介绍EXPLAIN命令用于SQL语句的查询执行计划。这条命令的输出结果能够让我们了解MySQL优化器是如何执行SQL语句的。这条命令并没有提供任何调整建议,但它能够提供重要的信息帮助你做出调优决策。先解析一条sql语句,你可以看出现什么内容EXPLAINSELECT*FROMperson,deptWHEREperson.dept_id=dept.didandper…

    2022年10月18日
  • 计算机如何修改任务管理器,win7如何更改任务管理器快捷键_win7更改任务管理器快捷键的教程…

    计算机如何修改任务管理器,win7如何更改任务管理器快捷键_win7更改任务管理器快捷键的教程…我们在打开任务管理器的时候,通常是CTRL+ALT+DEL就可以快速打开,不过有许多用户装完win7系统之后,发现任务管理器快捷键变成了Ctrl+Shift+Esc,这让用户们用着很不习惯,其实我们也可以自己手动更改快捷键,现在给大家带来win7更改任务管理器快捷键的教程。具体步骤如下:1、在“开始”菜单的搜索框输入指令gpedit.msc,回车打开Win7系统的组策略编辑器。2、在组策略编辑器里…

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号