pytorch-DataLoader(数据迭代器)

pytorch-DataLoader(数据迭代器)目录1.1dataset1.1.1Map-styledatasets实现方法一(简单直白法)实现方法二(借助TensorDataset直接将数据包装成dataset类)实现方法三(地址读取法)1.1.1Iterable-styledatasets我们一般使用一个for循环(或多层的)来训练神经网络,每一次迭代,加载一个batch的数据,神经网络前向反向传播各一次并更新一次参数。而这个过程中加载一个batch的数据这一步需要使用一个torch.utils.data.DataLoader对象,并且

大家好,又见面了,我是你们的朋友全栈君。

本博客讲解了pytorch框架下DataLoader的多种用法,每一种方法都展示了实例,虽然有一点复杂,但是小伙伴静下心看一定能看懂哦 :)

个人建议,在1.1.1节介绍的三种方法中,推荐 方法二>方法一>方法三 (方法三实在是过于复杂不做推荐),另外,第三节中的处理示例使用了非DataLoader的方法进行数据集处理,也可以借鉴~

我们一般使用一个for循环(或多层的)来训练神经网络,每一次迭代,加载一个batch的数据,神经网络前向反向传播各一次并更新一次参数。
而这个过程中加载一个batch的数据这一步需要使用一个torch.utils.data.DataLoader对象,并且DataLoader是一个基于某个dataset的iterable,这个iterable每次从dataset中基于某种采样原则取出一个batch的数据。
也可以这样说:Torch中可以创建一个torch.utils.data.Dataset对象,并与torch.utils.data.DataLoader一起使用,在训练模型时不断为模型提供数据。

1 torch.utils.data.DataLoader

定义:Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.

我们先来看一看其构造函数的参数

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)

接下来我们具体分析下它的参数。
其中最重要的参数是dataset,是一个抽象类,包含两种类型:map-style datasets 和 iterable-style datasets.

dataset (Dataset) – dataset from which to load the data.

batch_size (int, optional) – how many samples per batch to load (default: 1).

shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).

sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with len implemented. If specified, shuffle must not be specified.

batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.

num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)

collate_fn (callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.

pin_memory (bool, optional) – If True, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.

drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)

timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)

worker_init_fn (callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers – 1]) as input, after seeding and before data loading. (default: None)

1.1 dataset

只支持两种类型的数据集:map-style datasets, iterable-style datasets.

1.1.1 Map-style datasets

是一个类,要求有 __getitem__()and__len__()这两个构造函数,代表一个从索引映射到数据样本。
(1)其中__getitem__函数的作用是根据索引index遍历数据
(2)__len__函数的作用是返回数据集的长度
(3)在创建的dataset类中可根据自己的需求对数据进行处理。可编写独立的数据处理函数,在__getitem__函数中进行调用;或者直接将数据处理方法写在__getitem__函数中或者__init__函数中,但__getitem__必须根据index返回响应的值,该值会通过index传到dataloader中进行后续的batch批处理。

即基本满足:

def __getitem__(self, index):
    return self.src[index], self.trg[index]
def __len__(self):
    return len(self.src)

看一下他的大概构造:

class Dataset(object):
    """An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~torch.utils.data.Sampler` implementations and the default options of :class:`~torch.utils.data.DataLoader`. .. note:: :class:`~torch.utils.data.DataLoader` by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided. """

    def __getitem__(self, index):
        raise NotImplementedError
        
    def __len__(self):
        raise NotImplementedError
        
    def __add__(self, other):
        return ConcatDataset([self, other])


    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

上述代码是pytorch中Datasets的源码,注意成员方法__getitem__和__len__都是未实现的。我们要实现自定义Datasets类来完成数据的读取,则只需要完成这两个成员方法的重写。

首先,getitem()方法用来从datasets中读取一条数据,这条数据包含训练图片(已CV距离)和标签,参数index表示图片和标签在总数据集中的Index。

len()方法返回数据集的总长度(训练集的总数)。

下面介绍两种简单实现MyDatasets类

实现方法一(简单直白法)

重点是把 x 和 label 都分别装入两个列表 self.src 和 self.trg ,然后通过 getitem(self, index)返回对应元素。

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
 
class My_dataset(Dataset):
    def __init__(self):
        super().__init__()
        # 使用sin函数返回10000个时间序列,如果不自己构造数据,就使用numpy,pandas等读取自己的数据为x即可。
        # 以下数据组织这块既可以放在init方法里,也可以放在getitem方法里
        self.x = torch.randn(1000,3)
        self.y = self.x.sum(axis=1)
        self.src,  self.trg = [], []
        for i in range(1000):
            self.src.append(self.x[i])
            self.trg.append(self.y[i])
           
    def __getitem__(self, index):
        return self.src[index], self.trg[index]

    def __len__(self):
        return len(self.src) 
        
 # 或者return len(self.trg), src和trg长度一样
 
data_train = My_dataset()
data_test = My_dataset()
data_loader_train = DataLoader(data_train, batch_size=5, shuffle=False)
data_loader_test = DataLoader(data_test, batch_size=5, shuffle=False)
# i_batch的多少根据batch size和def __len__(self)返回的长度确定
# batch_data返回的值根据def __getitem__(self, index)来确定
# 对训练集:(不太清楚enumerate返回什么的时候就多print试试)
for i_batch, batch_data in enumerate(data_loader_train):
    print(i_batch)  # 打印batch编号
    print(batch_data[0])  # 打印该batch里面src
    print(batch_data[1])  # 打印该batch里面trg
# 对测试集:(下面的语句也可以)
for i_batch, (src, trg) in enumerate(data_loader_test):
    print(i_batch)  # 打印batch编号
    print(src)  # 打印该batch里面src的尺寸
    print(trg)  # 打印该batch里面trg的尺寸 

多说几句:生成的data_train可以通过 data_train[xxx] 直接索引某个元素,或者通过next(iter(data_train))得到一条条的数据。

实现方法二(借助TensorDataset直接将数据包装成dataset类)

另一种方法是直接使用 TensorDataset 来将数据包装成Dataset类,再使用dataloader。

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
 
src = torch.sin(torch.arange(1, 1000, 0.1))
trg = torch.cos(torch.arange(1, 1000, 0.1))
 
data = TensorDataset(src, trg)
data_loader = DataLoader(data, batch_size=5, shuffle=False)
for i_batch, batch_data in enumerate(data_loader):
    print(i_batch)  # 打印batch编号
    print(batch_data[0].size())  # 打印该batch里面src
    print(batch_data[1].size())  # 打印该batch里面trg

output:

0
torch.Size([5])
torch.Size([5])
1
torch.Size([5])
torch.Size([5])
...

实现方法三(地址读取法)

适用于lfw这样的数据集,每一份数据都对应一个文件夹,或者说数据量过大,无法一次加载出来的数据集。并且要求这样的数据集,有一个txt文件可以进行索引!

import os

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import matplotlib.image as mpimg



# 对所有图片生成path-label map.txt 这个程序可根据实际需要适当修改
def generate_map(root_dir):
	#得到当前绝对路径
    current_path = os.path.abspath('.')
    #os.path.dirname()向前退一个路径
    father_path = os.path.abspath(os.path.dirname(current_path) + os.path.sep + ".")

    with open(root_dir + 'map.txt', 'w') as wfp:
        for idx in range(10):
            subdir = os.path.join(root_dir, '%d/' % idx)
            for file_name in os.listdir(subdir):
                abs_name = os.path.join(father_path, subdir, file_name)
                # linux_abs_name = abs_name.replace("\\", '/')
                wfp.write('{file_dir} {label}\n'.format(file_dir=linux_abs_name, label=idx))

# 实现MyDatasets类
class MyDatasets(Dataset):

    def __init__(self, dir):
        # 获取数据存放的dir
        # 例如d:/images/
        self.data_dir = dir
        # 用于存放(image,label) tuple的list,存放的数据例如(d:/image/1.png,4)
        self.image_target_list = []
        # 从dir--label的map文件中将所有的tuple对读取到image_target_list中
        # map.txt中全部存放的是d:/.../image_data/1/3.jpg 1 路径最好是绝对路径
        with open(os.path.join(dir, 'map.txt'), 'r') as fp:
            content = fp.readlines()
            #s.rstrip()删除字符串末尾指定字符(默认是字符)
            # 得到 [['d:/.../image_data/1/3.jpg', '1'], ...,]
            str_list = [s.rstrip().split() for s in content]
            # 将所有图片的dir--label对都放入列表,如果要执行多个epoch,可以在这里多复制几遍,然后统一shuffle比较好
            self.image_target_list = [(x[0], int(x[1])) for x in str_list]

    def __getitem__(self, index):
        image_label_pair = self.image_target_list[index]
        # 按path读取图片数据,并转换为图片格式例如[3,32,32]
        # 可以用别的代替
        img = mpimg.imread(image_label_pair[0])
        return img, image_label_pair[1]

    def __len__(self):
        return len(self.image_target_list)


if __name__ == '__main__':
    # 生成map.txt
    # generate_map('train/')

    train_loader = DataLoader(MyDatasets('train/'), batch_size=128, shuffle=True)

    for step in range(20000):
        for idx, (img, label) in enumerate(train_loader):
            print(img.shape)
            print(label.shape)

如果使用其他形式的数据,例如二进制文件,则需要字节读取文件,分割成每一张图片和label,然后从__getitem__中返回就可以了。例如cifar-10数据,我们只需要在__getitem__方法中,按index来读取对应位置的字节,然后转换为label和img,并返回。在__len__中返回cifar-10训练集的总样本数。DataLoader就可以根据我们提供的index,len以及batch_size,shuffle来返回相应的batch数据和label。

1.1.1 Iterable-style datasets

可迭代样式的数据集是IterableDataset的一个实例,该实例必须重写__iter__方法,该方法用于对数据集进行迭代。这种类型的数据集特别适合随机读取数据不太可能实现的情况,并且批处理大小batchsize取决于获取的数据。比如读取数据库,远程服务器或者实时日志等数据的时候,可使用该样式,一般时序数据不使用这种样式。

For example, such a dataset, when called iter(dataset), could return a stream of data reading from a database, a remote server, or even logs generated in real time.
。。。。。
这里就不详细讲了,太特么复杂了~

2 torchvision.datasets

这个包的作用是方便提供现成数据集。
torchvision.datasets中包含了以下数据集

  • MNIST
    -COCO(用于图像标注和目标检测)(Captioning and Detection)
    -LSUN Classification
    -ImageFolder
    -Imagenet-12
    -CIFAR10 and CIFAR100
    -STL10

Datasets 拥有以下API: __getitem____len__
具体用法看参考第四条(搭配torch.utils.data.DataLoader)

2.1 ImageFolder

这个和DatasetFolder一样,适合用于已经下载好的并且符合一定要求的数据集,ImageFolder要求数据呈这样分布:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

使用方法:

my_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
     
torchvision.datasets.ImageFolder(root="./my_dataset/", transform=my_transform)

3 处理示例

我们在1.1.1节已经讨论了三种加载数据集的方法,现在以Crime数据集另介绍一种数据集加载办法。这种方法和 DataLoader 没有任何关系,实现起来的复杂度一般。

import numpy as np
from matplotlib import pyplot as plt
import os
import torch


class CrimeDataset():
    def __init__(self, device):
        reader = open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data/communities.data'))

        attributes = []
        while True:
        	# 读取 用逗号作为分隔符的数据集文件
            line = reader.readline().split(',')
            if len(line) < 128:
                break
            # set the ? as -1 
            line = ['-1' if val == '?' else val for val in line]
            line = np.array(line[5:], dtype=np.float)
            attributes.append(line)
        reader.close()
        # attributes.shape=(1994, 123)
        attributes = np.stack(attributes, axis=0)
        
        
        # load the name of each column; total: 128
        reader = open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data/names'))
        names = []
        for i in range(128):
            # reader.readline().split() = ['@attribute', 'county', 'numeric'] and we choose 'county'
            line = reader.readline().split()[1]
            # exclude the first 5 columns. Thus the number of column names = 123, arroding with attributes.shape
            if i >= 5:
                names.append(line)
        names = np.array(names)
        
        # shuffle the attribute by axis0
        attributes = attributes[np.random.permutation(range(attributes.shape[0])), :]

        val_size = 500
        # the last column of attributes is the labels
        self.train_labels = attributes[val_size:, -1:]
        self.test_labels = attributes[:val_size:, -1:]
        
        # exclude the last column of attributes. Thus attributes.shape = (1994,122)
        attributes = attributes[:, :-1]
        
        # select the column whose minimum >= 0. selected has 99 features
        selected = np.argwhere(np.array([np.min(attributes[:, i]) for i in range(attributes.shape[1])]) >= 0).flatten()
        self.train_features = attributes[val_size:, selected]
        self.test_features = attributes[:val_size:, selected]
        self.names = names[selected]
        # self.train_ptr is the counter which counts the number of data records having been loaded
        self.train_ptr = 0
        self.test_ptr = 0
        self.x_dim = self.train_features.shape[1]
        # train_size = 1494; test_size = 500
        self.train_size = self.train_features.shape[0]
        self.test_size = self.test_features.shape[0]
        self.device = device

    def train_batch(self, batch_size=None):
        # if batch_size is None, then each iteration outputs all the training set
        if batch_size is None:
            batch_size = self.train_features.shape[0]
            self.train_ptr = 0
        # if all data has been outputed, reset the trailoader.
        if self.train_ptr + batch_size > self.train_features.shape[0]:
            self.train_ptr = 0
        bx, by = self.train_features[self.train_ptr:self.train_ptr+batch_size], \
                 self.train_labels[self.train_ptr:self.train_ptr+batch_size]
        self.train_ptr += batch_size
        if self.train_ptr == self.train_features.shape[0]:
            self.train_ptr = 0
        return torch.from_numpy(bx).float().to(self.device), torch.from_numpy(by).float().to(self.device)

    def test_batch(self, batch_size=None):
        if batch_size is None:
            batch_size = self.test_features.shape[0]
            self.train_ptr = 0
        if self.test_ptr + batch_size > self.test_features.shape[0]:
            self.test_ptr = 0
        bx, by = self.test_features[self.test_ptr:self.test_ptr+batch_size], \
                 self.test_labels[self.test_ptr:self.test_ptr+batch_size]
        self.test_ptr += batch_size
        if self.test_ptr == self.test_features.shape[0]:
            self.test_ptr = 0
        return torch.from_numpy(bx).float().to(self.device), torch.from_numpy(by).float().to(self.device)



if __name__ == '__main__':
    dataset = CrimeDataset("cpu")
    print(dataset.names)
    print(dataset.train_features.shape, dataset.train_labels.shape)

5 实用功能

5.1 分割dataloader

有时候从 torchvision 里下载下来的是一个完整的数据集,包装成 dataloader `以后我们想把该数据集进行进一步划分:

def split(dataloader, batch_size, split=0.2):
    """Splits the given dataset into training/validation. Args: dataset[torch dataloader]: Dataset which has to be split batch_size[int]: Batch size split[float]: Indicates ratio of validation samples Returns: train_set[list]: Training set val_set[list]: Validation set """

    index = 0
    length = len(dataloader)

    train_set = []
    val_set = []

    for data, target in dataloader:
        if index <= (length * split):
            train_set.append([data, target])
        else:
            val_set.append([data, target])

        index += 1

    return train_set, val_set

还有更好的分割方法见:pytorch数据集的分割

参考:
https://www.cnblogs.com/leokale-zz/p/11275800.html
https://www.lagou.com/lgeduarticle/74174.html
太感谢啦!https://blog.csdn.net/zuiyishihefang/article/details/105985760
torchvision-datasets

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/131858.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • 如何上传代码到github?

    如何上传代码到github?github是什么?github是Git远程仓库。github是一个基于git的代码托管平台Git是什么:Git是一个开源的分布式版本控制系统,用于敏捷高效地处理任何或小或大的项目。如何上传代码到github?参考:https://www.runoob.com/git/git-remote-repo.html在githbu上注册账号注册账号 创建一个项目 获得…

  • Android中文API(117)——WrapperListAdapter

    Android中文API(117)——WrapperListAdapter

  • Pytest(1)安装与入门「建议收藏」

    Pytest(1)安装与入门「建议收藏」pytest介绍pytest是python的一种单元测试框架,与python自带的unittest测试框架类似,但是比unittest框架使用起来更简洁,效率更高。根据pytest的官方网站介绍,它

  • cover letter 和response letter的写法

    cover letter 和response letter的写法http://emuch.net/bbs/viewthread.php?tid=988184&fpage=1投稿感受和体会bydingdang15fromemuch投稿感受和体会bydingdang15fromemuch几个月前认识了小木虫网站,从此就喜欢上了这里.每天有空都上这里,看一下虫友发表论文的经验,体会,怎么投稿,怎么回复审稿人的意见等,还有热心虫友提供的英文

  • 让旧Mac免费获得 iWork 套件的秘籍「建议收藏」

    让旧Mac免费获得 iWork 套件的秘籍「建议收藏」让旧Mac免费获得iWork套件的秘籍2013-10-2409:13iapps.im只要购买了苹果新设备就可以免费获得iWork和iLife套件。但是我们拥有旧Mac的人呢?昨夜大家是不是一夜无眠呀,数数手头有多少钱,银行卡可以刷多少,才能抱回几个心仪的设备呢!苹果对新Mac的政策也如当时对iPhone5s一样,只要购买了新设备就可以免费获得iW

    2022年10月31日
  • UpdatePanel概览

    UpdatePanel概览微软的asp.netajax为我们进入AJAX世界提供了方便的入口,让许多不熟悉js甚至不了解什么是ajax的人也能享受到ajax技术的好处.在asp.netajax1.0中,updatepan

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号