PyTorch 中的数据类型 torch.utils.data.DataLoader

PyTorch 中的数据类型 torch.utils.data.DataLoaderDataLoader是PyTorch中的一种数据类型。在PyTorch中训练模型经常要使用它,那么该数据结构长什么样子,如何生成这样的数据类型?下面就研究一下:先看看 dataloader.py脚本是怎么写的(VS中按F12跳转到该脚本) __init__(构造函数)中的几个重要的属性:1、dataset:(数据类型dataset)输入的数据类型。看名字感觉就像是数据库,…

大家好,又见面了,我是你们的朋友全栈君。

DataLoader是PyTorch中的一种数据类型。

在PyTorch中训练模型经常要使用它,那么该数据结构长什么样子,如何生成这样的数据类型?

下面就研究一下:

先看看 dataloader.py脚本是怎么写的(VS中按F12跳转到该脚本)

 __init__(构造函数)中的几个重要的属性:

1、dataset:(数据类型 dataset)

输入的数据类型。看名字感觉就像是数据库,C#里面也有dataset类,理论上应该还有下一级的datatable。这应当是原始数据的输入。PyTorch内也有这种数据结构。这里先不管,估计和C#的类似,这里只需要知道是输入数据类型是dataset就可以了。

2、batch_size:(数据类型 int)

每次输入数据的行数,默认为1。PyTorch训练模型时调用数据不是一行一行进行的(这样太没效率),而是一捆一捆来的。这里就是定义每次喂给神经网络多少行数据,如果设置成1,那就是一行一行进行(个人偏好,PyTorch默认设置是1)。

3、shuffle:(数据类型 bool)

洗牌。默认设置为False。在每次迭代训练时是否将数据洗牌,默认设置是False。将输入数据的顺序打乱,是为了使数据更有独立性,但如果数据是有序列特征的,就不要设置成True了。

4、collate_fn:(数据类型 callable,没见过的类型)

将一小段数据合并成数据列表,默认设置是False。如果设置成True,系统会在返回前会将张量数据(Tensors)复制到CUDA内存中。(不太明白作用是什么,就暂时默认False)

5、batch_sampler:(数据类型 Sampler)

批量采样,默认设置为None。但每次返回的是一批数据的索引(注意:不是数据)。其和batch_size、shuffle 、sampler and drop_last参数是不兼容的。我想,应该是每次输入网络的数据是随机采样模式,这样能使数据更具有独立性质。所以,它和一捆一捆按顺序输入,数据洗牌,数据采样,等模式是不兼容的。

6、sampler:(数据类型 Sampler)

采样,默认设置为None。根据定义的策略从数据集中采样输入。如果定义采样规则,则洗牌(shuffle)设置必须为False。

7、num_workers:(数据类型 Int)

工作者数量,默认是0。使用多少个子进程来导入数据。设置为0,就是使用主进程来导入数据。注意:这个数字必须是大于等于0的,负数估计会出错。

8、pin_memory:(数据类型 bool)

内存寄存,默认为False。在数据返回前,是否将数据复制到CUDA内存中。

9、drop_last:(数据类型 bool)

丢弃最后数据,默认为False。设置了 batch_size 的数目后,最后一批数据未必是设置的数目,有可能会小些。这时你是否需要丢弃这批数据。

10、timeout:(数据类型 numeric)

超时,默认为0。是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。 所以,数值必须大于等于0。

11、worker_init_fn(数据类型 callable,没见过的类型)

子进程导入模式,默认为Noun。在数据导入前和步长结束后,根据工作子进程的ID逐个按顺序导入数据。

 

从DataLoader类的属性定义中可以看出,这个类的作用就是实现数据以什么方式输入到什么网络中。

代码一般是这么写的:

# 定义学习集 DataLoader

train_data = torch.utils.data.DataLoader(各种设置...) 

# 将数据喂入神经网络进行训练

for i, (input, target) in enumerate(train_data): 
    循环代码行……

 

如果全部采用默认设置输入数据,数据就是一行一行按顺序输入到神经网络。如果对数据的输入有特殊要求。

比如:想打乱一下数据的排序,可以设置 shuffle(洗牌)为True;

比如:想数据是一捆的输入,可以设置 batch_size 的数目;

比如:想随机抽取的模式输入,可以设置 sampler 或 batch_sampler。如何定义抽样规则,可以看sampler.py脚本。这里不是重点;

比如:像多线程输入,可以设置 num_workers 的数目

其他的就不太懂了,以后实际应用时碰到特殊要求再研究吧。

DataLoader类中还有3个函数:

def __setattr__(self, attr, val):
        if self.__initialized and attr in (‘batch_size’, ‘sampler’, ‘drop_last’):
            raise ValueError(‘{} attribute should not be set after {} is ‘
                             ‘initialized’.format(attr, self.__class__.__name__))

        super(DataLoader, self).__setattr__(attr, val)

def __iter__(self):
        return _DataLoaderIter(self)

def __len__(self):
        return len(self.batch_sampler)

关键是第二个函数,

_DataLoaderIter 又是一个类,被一起写在DataLoader.py文件中。

主要是用来处理各种设置如何运作的,这里就不管那么多啦。

最后,如果要导入自己各种古灵精怪的数据,就要看看 DataSet 又是如何操作的。

torch.utils.data主要包括以下三个类: 
1. class torch.utils.data.Dataset

其他的数据集类必须是torch.utils.data.Dataset的子类,比如说torchvision.ImageFolder. 
2. class torch.utils.data.sampler.Sampler(data_source) 
参数: data_source (Dataset) – dataset to sample from 
作用: 创建一个采样器, class torch.utils.data.sampler.Sampler是所有的Sampler的基类, 其中,iter(self)函数来获取一个迭代器,对数据集中元素的索引进行迭代,len(self)方法返回迭代器中包含元素的长度. 
3. class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None) 
参数:

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/143863.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • mysql中的mvcc的使用和原理详解_mysql底层原理

    mysql中的mvcc的使用和原理详解_mysql底层原理什么是MVVCMVVC(Multi-VersionConcurrencyControl)(注:与MVCC相对的,是基于锁的并发控制,Lock-BasedConcurrencyControl)是一种基于多版本的并发控制协议,只有在InnoDB引擎下存在。MVCC是为了实现事务的隔离性,通过版本号,避免同一数据在不同事务间的竞争,你可以把它当成基于多版本号的一种乐观锁。当然,这种乐观锁只在…

    2022年10月31日
  • python激活码2021_在线激活

    (python激活码2021)JetBrains旗下有多款编译器工具(如:IntelliJ、WebStorm、PyCharm等)在各编程领域几乎都占据了垄断地位。建立在开源IntelliJ平台之上,过去15年以来,JetBrains一直在不断发展和完善这个平台。这个平台可以针对您的开发工作流进行微调并且能够提供…

  • pytorch之DataLoader

    pytorch之DataLoaderpytorch之DataLoader在训练神经网络时,最好是对一个batch的数据进行操作,同时还需要对数据进行shuffle和并行加速等。对此,PyTorch提供了DataLoader帮助实现这些功能。Dataset只负责数据的抽象,一次调用__getitem__只返回一个样本。DataLoader的函数定义如下:DataLoader(dataset,batch_size=1,shu…

  • hive建表并添加数据_hive和mysql的关系

    hive建表并添加数据_hive和mysql的关系在使用hive进行开发时,我们往往需要获得一个已存在hive表的建表语句(DDL),然而hive本身并没有提供这样一个工具。要想还原建表DDL就必须从元数据入手,我们知道,hive的元数据并不存放在hdfs上,而是存放在传统的RDBMS中,典型的如mysql,derby等,这里我们以mysql为元数据库,结合0.4.2版本的hive为例进行研究。连接上mysql后可以看到hive元数据对应的表约有…

    2022年10月31日
  • netstat命令参数和使用详解

    netstat命令参数和使用详解netstat-Printnetworkconnections,routingtables,interfacestatistics,masqueradeconnections,andmulticastmembershipsnetstat-打印网络连接、路由表、接口统计、伪装连接和多播成员关系参数usage:netstat[-…

  • whl文件安装方法

    whl文件安装方法   whl格式本质上是一个压缩包,里面包含了py文件,以及经过编译的pyd文件。使得可以在不具备编译环境的情况下,选择合适自己的python环境进行安装问题描述:whl下载了后不会安装解决方法:1.把下载的文件拖到桌面2.进入cmd命令行3.使用cd进入whl文件属性标识的目录)(红色框)4.使用“pipinstall文件名”安装下载的文件(绿色框)5.安装完成…

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号