Pytorch(五)入门:DataLoader 和 Dataset

Pytorch(五)入门:DataLoader 和 DatasetDataLoader和Dataset构建模型的基本方法,我们了解了。接下来,我们就要弄明白怎么对数据进行预处理,然后加载数据,我们以前手动加载数据的方式,在数据量小的时候,并没有太大问题,但是到了大数据量,我们需要使用shuffle,分割成mini-batch等操作的时候,我们可以使用PyTorch的API快速地完成这些操作。Dataset是一个包装类,用来将数据包装为Datas…

大家好,又见面了,我是你们的朋友全栈君。

DataLoader 和 Dataset

构建模型的基本方法,我们了解了。
接下来,我们就要弄明白怎么对数据进行预处理,然后加载数据,我们以前手动加载数据的方式,在数据量小的时候,并没有太大问题,但是到了大数据量,我们需要使用 shuffle, 分割成mini-batch 等操作的时候,我们可以使用PyTorch的API快速地完成这些操作。

在这里插入图片描述

Dataset是一个包装类,用来将数据包装为Dataset类,然后传入DataLoader中,我们再使用DataLoader这个类来更加快捷的对数据进行操作。

DataLoader是一个比较重要的类,它为我们提供的常用操作有:batch_size(每个batch的大小), shuffle(是否进行shuffle操作), num_workers(加载数据的时候使用几个子进程)

现在,我们先展示直接使用 TensorDataset 来将数据包装成Dataset类

在这里插入图片描述

这里差个题外话,不知道为什么,出现这个错误,

 torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
TypeError: __init__() got an unexpected keyword argument 'data_tensor'

但是,改成deal_dataset = TensorDataset(x_data, y_data)这样就OK了。

在这里插入图片描述

接下来,我们来继承 Dataset类 ,写一个将数据处理成DataLoader的类。

当我们集成了一个 Dataset类之后,我们需要重写 len 方法,该方法提供了dataset的大小; getitem 方法, 该方法支持从 0 到 len(self)的索引

class DealDataset(Dataset):
    """
        下载数据、初始化数据,都可以在这里完成
    """
    def __init__(self):
        xy = np.loadtxt('../dataSet/diabetes.csv.gz', delimiter=',', dtype=np.float32) # 使用numpy读取数据
        self.x_data = torch.from_numpy(xy[:, 0:-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])
        self.len = xy.shape[0]
    
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len

# 实例化这个类,然后我们就得到了Dataset类型的数据,记下来就将这个类传给DataLoader,就可以了。    
dealDataset = DealDataset()

train_loader2 = DataLoader(dataset=dealDataset,
                          batch_size=32,
                          shuffle=True)


for epoch in range(2):
    for i, data in enumerate(train_loader2):
        # 将数据从 train_loader 中读出来,一次读取的样本数是32个
        inputs, labels = data

        # 将这些数据转换成Variable类型
        inputs, labels = Variable(inputs), Variable(labels)

        # 接下来就是跑模型的环节了,我们这里使用print来代替
        print("epoch:", epoch, "的第" , i, "个inputs", inputs.data.size(), "labels", labels.data.size())

在这里插入图片描述

torchvision 包的介绍

torchvision 是PyTorch中专门用来处理图像的库,PyTorch官网的安装教程,也会让你安装上这个包。

这个包中有四个大类。

torchvision.datasets

torchvision.models

torchvision.transforms

torchvision.utils

这里我们主要介绍前三个。

torchvision.datasets

torchvision.datasets 是用来进行数据加载的,PyTorch团队在这个包中帮我们提前处理好了很多很多图片数据集。

  • MNISTCOCO
  • Captions
  • Detection
  • LSUN
  • ImageFolder
  • Imagenet-12
  • CIFAR
  • STL10
  • SVHN
  • PhotoTour

我们可以直接使用,示例如下:
在这里插入图片描述

torchvision.models

torchvision.models 中为我们提供了已经训练好的模型,让我们可以加载之后,直接使用。

torchvision.models模块的 子模块中包含以下模型结构。

  • AlexNet
  • VGG
  • ResNet
  • SqueezeNet
  • DenseNet

我们可以直接使用如下代码来快速创建一个权重随机初始化的模型

import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
squeezenet = models.squeezenet1_0()
densenet = models.densenet_161()

也可以通过使用 pretrained=True 来加载一个别人预训练好的模型

import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
import torchvision.models as models
# 加载一个 resnet18 模型
resnet18 = models.resnet18()
print(resnet18)
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True) # 加载一个已经预训练好的模型, 需要下载一段时间...

在这里插入图片描述

# 我们这里还是对MNIST进行处理,初始的MNIST是 28 * 28,我们把它处理成 96 * 96 的torch.Tensor的格式
from torchvision import transforms as transforms
import torchvision
from torch.utils.data import DataLoader

# 图像预处理步骤
transform = transforms.Compose([
    transforms.Resize(96), # 缩放到 96 * 96 大小
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化
])

DOWNLOAD = True
BATCH_SIZE = 32

train_dataset = torchvision.datasets.MNIST(root='./data/', train=True, transform=transform, download=DOWNLOAD)


train_loader = DataLoader(dataset=train_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True)

print(len(train_dataset))
print(len(train_loader))

以上代码参考:https://github.com/LianHaiMiao/pytorch-lesson-zh/

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/130365.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • mysql整型转字符串_java中如何将字符串转换为字符数组

    mysql整型转字符串_java中如何将字符串转换为字符数组cast(字段asunsigned)例如1:把表结构中的name(字符串)字段转化成整型cast(nameasunsigned)应用:将表A记录按name字段从小到大排列select*fromAorderbycast(nameasunsigned); http://xuyemao.blog.163.com/blog/static/24454858…

  • 最新tracker服务器网站_服务器网速测试

    最新tracker服务器网站_服务器网速测试Tracker服务器是对于BT下载必须的,网上随便搜索一下就有很多Tracker服务器列表,一个服务器文件少则几十个,多则上百个,但”杂乱无章”,要不就是用不了,要不就是在中不可用,要不就是重复,真正能用的就比较少了。提供的服务器在中国都ping得通,每24小时自动更新,自动检测,从几个百服务器筛选出在中国可通的服务器列表。[下载中国可用Tracker服务器列表-每24小时更新]提供的服务…

  • pyqt退出窗口_win10电脑软件闪退

    pyqt退出窗口_win10电脑软件闪退1.使用qtdesigner创建窗口界面这个都很熟悉了,就不重复说明了。(自行百度)2.pyqt将.ui文件转成python代码cd到.ui文件的目录,使用指令即可完成。得到一个py文件(一个类)红色部分是我自己加上去的,只是为了更好看懂代码,调试代码。3.运行pyqt生成的python代码,生成界面这里,需要添加几行代码!直接在Ui_Dialog类的py文件尾部添加如下代码:if__name__==”__main__”:app=QApplication(

  • tomcat出现乱码怎么办_tomcat输出日志乱码

    tomcat出现乱码怎么办_tomcat输出日志乱码1.打开tomcat如下位置:找到logging-properties文件,选择用代码编辑器打开(我这里选择用idea)2.在25-47行中把五个红框起来的UTF-8改为GB2312此时点击bin,目录下的startup.bat(window用户)或startup.sh(mac用户)启动tomcat,控制台的乱码问题解决。如果此时还没有解决乱码问题,需要1.windows+R打开运行,在运行框中输入regedit,进入注册表编辑器中2.如果没有Tomcat或者CodePag(1)

  • flutter 存储_map根据key获取value

    flutter 存储_map根据key获取valueFlutter持久化存储之key-value存储

  • pycharm鼠标滚动控制字体大小_pycharm窗口放大

    pycharm鼠标滚动控制字体大小_pycharm窗口放大1、放大页面方法第一步:打开file里面的setting,然后打开Keymap,再搜索框中输入increase,点击increaseFontSize,双击AddMouseShortcut(先不用点OK)第二步:点击AddMouseShortcut弹出下面对话框,然后按住ctrl并向上滚动鼠标滑轮,将变成第二个对话框,点击OK;第三步:显示下面页面表示设置放大成功,点击OK即可。2、缩小页面方法与上面方法类似,将increase变成decrease输入即可;…

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号