大家好,又见面了,我是你们的朋友全栈君。
这两天把DataLoader的源代码的主要内容进行了一些分析,基于版本0.4.1。当然,因为内容比较多,没有全部展开,这里的主要内容是DataLoader关于数据加载以及分析PyTorch是如何通过Python本身的multiprocessing和Threading等库来保证batch是顺序取出的。额外的内容都会给出链接,在这里不会详细展开。
一点推荐
作为CSDN的忠实用户,最近发现CSDN学院上了一些对新手比较友好的课程。以我的切身体会来看,对于想要了解机器学习算法或者python编程语言的同学,非常有帮助。还记得我最开始学习python的时候,看的是一本写给小孩子的书《趣学Python——教孩子学编程》。
虽然这本书不错,但是确实有些过于简单了,而CSDN提供的课程有两门对现在的我来讲还是有相当大的帮助,老师讲课水平高,配合丰富的例子,容易让人掌握知识点,下面推荐两门课程:
人工智能在网络领域的应用与实践:
https://edu.csdn.net/course/play/10319?utm_source=sooner
ps: 如果想要系统学习python的朋友,下面这门课是涵盖了python基础语法、web开发、数据挖掘以及机器学习,是CSDN强力推荐的课程,有需要的朋友可以看看哈:
Python全栈工程师:
https://edu.csdn.net/topic/python115?utm_source=sooner
0.前言(楔子)
本篇关于DataLoader
源码的分析是继PyTorch学习笔记(5)——论一个torch.Tensor是如何构建完成的?之后的第2篇源码分析,相比前一篇的内容。本篇内容完全基于Python语言范畴内,因为会比较直接一些,容易阅读。
输入数据PipeLine
pytorch 的数据加载到模型的操作顺序是这样的:
① 创建一个 Dataset
对象
② 创建一个 DataLoader
对象
③ 循环这个 DataLoader
对象,将img, label加载到模型中进行训练
dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
for img, label in dataloader:
....
所以,作为直接对数据进入模型中的关键一步, DataLoader
非常重要。
首先简单介绍一下DataLoader
,它是PyTorch中数据读取的一个重要接口,该接口定义在dataloader.py中,只要是用PyTorch来训练模型基本都会用到该接口(除非用户重写…),该接口的目的:将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。
官方对DataLoader
的说明是:
“数据加载由数据集和采样器组成,基于python的单、多进程的iterators来处理数据。”
关于iterator和iterable的区别和概念请自行查阅,在实现中的差别就是iterators有
__iter__
和__next__
方法,而iterable只有__iter__
方法。
1.DataLoader
先介绍一下DataLoader(object)
的参数:
-
dataset(Dataset): 传入的数据集
-
batch_size(int, optional): 每个batch有多少个样本
-
shuffle(bool, optional): 在每个epoch开始的时候,对数据进行重新排序
-
sampler(Sampler, optional): 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False
-
batch_sampler(Sampler, optional): 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)
-
num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)
-
collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数
-
pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.
-
drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了…
如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。 -
timeout(numeric, optional): 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0
-
worker_init_fn (callable, optional): 每个worker初始化函数 If not None, this will be called on each
worker subprocess with the worker id (an int in[0, num_workers - 1]
) as
input, after seeding and before data loading. (default: None)
显然,根据上面参数的解释,DataLoader
这个类就是进行数据的初始化的操作,
class DataLoader(object):
__initialized = False
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
...
if sampler is not None and shuffle:
raise ValueError('sampler option is mutually exclusive with "shuffle"')
...
if batch_sampler is None:
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
self.__initialized = True
...
def __iter__(self):
return _DataLoaderIter(self)
...
这里我们主要看__init__()
和__iter__()
:
① 数据的shuffle和batch处理
- RandomSampler(dataset)
- SequentialSampler(dataset)
- BatchSampler(sampler, batch_size, drop_last)
② 因为DataLoader只有__iter__()
而没有实现__next__()
所以DataLoader是一个iterable而不是iterator。
这个iterator的实现在_DataLoaderIter
中
1.1 DataLoader之RandomSampler(dataset)、 SequentialSampler(dataset)
这两个类的实现是在dataloader.py
的同级目录下的torch/utils/data/sampler.py
sampler.py中实现了一个父类Sampler
,以及SequentialSampler
,RandomSampler
和BatchSampler
等五个继承Sampler
的子类
这里面的Sampler
的实现是用C/C++实现的,这里的细节暂且不表。
我们这里需要知道的是:对每个采样器,都需要提供__iter__方法,这个方法用以表示数据遍历的方式和__len__方法,用以返回数据的长度
class Sampler(object):
r"""Base class for all Samplers. Every Sampler subclass has to provide an __iter__ method, providing a way to iterate over indices of dataset elements, and a __len__ method that returns the length of the returned iterators. """
def __init__(self, data_source):
pass
def __iter__(self):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
class SequentialSampler(Sampler):
r"""Samples elements sequentially, always in the same order. Arguments: data_source (Dataset): dataset to sample from """
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
class RandomSampler(Sampler):
r"""Samples elements randomly, without replacement. Arguments: data_source (Dataset): dataset to sample from """
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(torch.randperm(len(self.data_source)).tolist())
def __len__(self):
return len(self.data_source)
if __name__ == "__main__":
print(list(RandomSampler(range(10))))
#[2, 8, 3, 5, 9, 4, 6, 0, 1, 7]
print(list(SequentialSampler(range(10))))
#[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
可以看出RandomSampler
等方法返回的就是DataSet
中的索引位置(indices),其中,在子类中的__iter__
方法中,需要返回的是iter(xxx)
(即iterator)的形式:
#### 以下两个代码是等价的
for data in dataloader:
...
#### 等价与
iters = iter(dataloader)
while 1:
try:
next(iters)
except StopIteration:
break
此外,torch.randperm()
的用法如下:
1.2 DataLoader之BatchSampler(Sampler)
BatchSampler
是wrap一个sampler,并生成mini-batch的索引(indices)的方式
这里主要看__iter__
方法,可以看到,代码的思路很清楚明白的展示了batch indices的是如何取出的。
class BatchSampler(Sampler):
r"""Wraps another sampler to yield a mini-batch of indices. Args: sampler (Sampler): Base sampler. batch_size (int): Size of mini-batch. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size`` Example: >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """
def __init__(self, sampler, batch_size, drop_last):
if not isinstance(sampler, Sampler):
raise ValueError("sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}"
.format(sampler))
if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
batch_size <= 0:
raise ValueError("batch_size should be a positive integeral value, "
"but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
# 一旦达到batch_size的长度,说明batch被填满,就可以yield出去了
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
# 比如epoch有100个样本,batch_size选择为64,那么drop_last的结果为1,不drop_last的结果为2
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
if __name__ == "__main__":
print(list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)))
# [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
print(list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)))
# [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
2._DataLoaderIter
这个_DataLoaderIter
其实就是DataLoader
类的__iter__()
方法的返回值:
注意,这个_DataLoaderIter
中*init(self, loader)*中的loader就是对应的DataLoader类的实例。
class _DataLoaderIter(object):
r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
def __init__(self, loader):
self.dataset = loader.dataset
# 将一个list的sample组成一个mini-batch的函数
...
# 监听事件完成与否——https://www.cnblogs.com/lcchuguo/p/4687348.html
self.done_event = threading.Event()
# self.sample_iter是iterator:迭代器
self.sample_iter = iter(self.batch_sampler)
# 随机种子,用于worker_init_fn的初始化
base_seed = torch.LongTensor(1).random_().item()
if self.num_workers > 0:
# worker_init_fn是worker初始化函数
self.worker_init_fn = loader.worker_init_fn
# index_queue 索引队列 每个worker进程对应一个:
self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)]
# worker 队列索引
self.worker_queue_idx = 0
# worker_result_queue 进程间通信
# multiprocessing.SimpleQueue是multiprocessing.Queue([maxsize])的简化,只有三个方法------empty(), get(), put()
self.worker_result_queue = multiprocessing.SimpleQueue()
# batches_outstanding
# 当前已经准备好的 batch 的数量(可能有些正在准备中)
# 当为 0 时, 说明, dataset 中已经没有剩余数据了。
# 初始值为 0, 在 self._put_indices() 中 +1,在 self.__next__ 中-1
self.batches_outstanding = 0
self.worker_pids_set = False
# shutdown为True是关闭worker
self.shutdown = False
# send_idx, rcvd_idx——发送索引,接收索引
# send_idx 用来记录 这次要放 index_queue 中 batch 的 idx
self.send_idx = 0
# rcvd_idx 用来记录 这次要从 data_queue 中取出 的 batch 的 idx
self.rcvd_idx = 0
# 因为多进程,可能会导致 data_queue 中的batch乱序
# 用这个来保证 batch 的返回是按照send_idx升序出去的。
self.reorder_dict = {
}
# 创建num_workers个worker进程来处理
self.workers = [
multiprocessing.Process(
target=_worker_loop,
args=(self.dataset, self.index_queues[i],
self.worker_result_queue, self.collate_fn, base_seed + i,
self.worker_init_fn, i))
for i in range(self.num_workers)]
# 这里暂不分析CUDA或者timeout的情况
if self.pin_memory or self.timeout > 0:
...
else:
# data_queue就是self.worker_result_queue(MultiProcessing.SimpleQueue()类型)
# 这个唯一的队列
self.data_queue = self.worker_result_queue
# 设置守护进程
for w in self.workers:
w.daemon = True # ensure that the worker exits on process exit
w.start()
...
# prime the prefetch loop
# 初始化的时候,就将 2*num_workers 个 (batch_idx, sampler_indices) 放到 index_queue 中
for _ in range(2 * self.num_workers):
self._put_indices()
在_DataLoaderIter
中,首先来看self.workers
,这个成员变量对应是开个num_workers个进程来处理数据,对应的函数是_worker_loop
2.1 _worker_loop
这部分多进程执行的代码的目的:从index_queue
中取索引,然后通过collate_fn
处理数据,然后再将处理好的 batch 数据放到 data_queue
中。(发送到队列中的idx是self.send_idx
)
传入的参数:
args=(self.dataset, self.index_queues[i],self.worker_result_queue,
self.collate_fn, base_seed + i, self.worker_init_fn,
i)
- 1.dataset
- 2.index_queue中的其中之一(
multiprocessing.Queue()
) - 3.进程共享的data_queue(
multiprocessing.SimpleQueue()
) - 4.collate_fn
- 5.id(是pid?)
- 6.worker初始化函数
- 7.第i个worker
显然,可以看出,对应**_worker_loop**,数据队列是共享的SimpleQueue()
,而索引队列是每个worker独有的Queue()
。
def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
global _use_shared_memory
_use_shared_memory = True
...
torch.set_num_threads(1)
random.seed(seed)
# 保证每个worker的随机种子相同
torch.manual_seed(seed)
# 初始化worker
if init_fn is not None:
init_fn(worker_id)
# 以Linux为例,
#class ManagerWatchdog(object):
# def __init__(self):
# self.manager_pid = os.getppid()
#
# def is_alive(self):
# os.getppid--->获得父进程的id
# return os.getppid() == self.manager_pid
watchdog = ManagerWatchdog()
# 处理代码
while True:
try:
# MANAGER_STATUS_CHECK_INTERVAL = 5.0
# r = 从索引队列里取索引
r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
except queue.Empty:
if watchdog.is_alive():
continue
else:
break
if r is None:
break
idx, batch_indices = r
try:
# 传到 collate_fn 的数据是 list of dataset[i] (i in batch_indices)
samples = collate_fn([dataset[i] for i in batch_indices])
except Exception:
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
# 将从索引队列取出的数据放进data_queue中,并将samples删除
data_queue.put((idx, samples))
del samples
2.2 self._put_indices(self)
根据2.1,我们知道了_DataLoaderIter
是如何从不同的index_queue中消费数据并将数据转换为data放入同一个data_queue中。
但是在_DataLoaderIter
的构造函数中,index_queue还都是空队列,没法进行”消费”。所以,在构造函数的最后,有如下代码:
# prime the prefetch loop
# 初始化的时候,就将 2*num_workers 个 (batch_idx, sampler_indices) 放到 index_queue 中
for _ in range(2 * self.num_workers):
self._put_indices()
它其实就是初始化,这是因为之前的num_workers个index_queue都是空的,所以务必要初始化一下!
那么这个核心的内容self._put_indices()
,其代码不多,如下:
def _put_indices(self):
assert self.batches_outstanding < 2 * self.num_workers
indices = next(self.sample_iter, None)
if indices is None:
return
self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
# 保证worker_queue_idx在[0, self.num_workers)之间。
self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
# batches_outstanding表示index_queue队列里有几个batch可供"消费"
self.batches_outstanding += 1
# send_idx 发送索引,和rcvd_idx需要对应,后面会提到
self.send_idx += 1
① self.batches_outstanding
的内容在构造函数中说明,初始值为0,在_put_indices()
中会加1
② 从self.sample_iter
这个iterator中返回一个batch对应的索引,具体内容在之前的BatchSampler(Sampler)
提到
③ 向对应的self.index_queues[i]
中放入(send_idx, indices)内容,其中i = worker_queue_idx通过
self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
总是保证在**[0, self.num_workers)**中(左闭右开区间)
④ batches_outstanding
+=1 表明batches加1
⑤ send_idx
+= 1 记录从sample_iter中发送索引到index_queue的次数
疑问
当我看到这里的时候,有一个疑问,因为在_DataLoaderIter
的构造函数中,num_workers个_worker_loop
进程已经开始从不同的index_queue
取数据,制作后放入data_queue
了。
但是以num_workers = 2为例,如果epoch有很多样本,比如10000个,但是batch的size不大,比如为32,那么所有的2个index_queue所得到的数据只有2项,即64个索引,并没有将数据全部制作成indices放入到index_queue里啊。
答疑
需要注意,_DataLoaderIter
是一个迭代器,接收的参数就是DataLoader
的一个实例,而_DataLoaderIter
的__next__
方法用yield
的方式(生成器)是很节省内存的,即数据不是一次性加载到内存中再一点点挤牙膏挤出来,而是需要的时候再取出,很安全且便捷。
所以说,对于迭代器,我们不需要一次性把数据全load进所有的index_queue
中,而是根据需要load就好,这样也避免了队列过大可能带来的额外开销。
2.3 self.__next__(self)
第一部分,就是如果num_workers = 0的话,
就用一个普通的iterator加collate_fn数据处理,没什么特殊。
def __next__(self):
if self.num_workers == 0: # same-process loading
indices = next(self.sample_iter) # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.pin_memory:
batch = pin_memory_batch(batch)
return batch
下面才是重点内容!!
# check if the next sample has already been generated
① if self.rcvd_idx in self.reorder_dict:
batch = self.reorder_dict.pop(self.rcvd_idx)
return self._process_next_batch(batch)
② if self.batches_outstanding == 0:
self._shutdown_workers()
raise StopIteration
③ while True:
assert (not self.shutdown and self.batches_outstanding > 0)
idx, batch = self._get_batch()
self.batches_outstanding -= 1
if idx != self.rcvd_idx:
# store out-of-order samples
self.reorder_dict[idx] = batch
continue
return self._process_next_batch(batch)
next = __next__ # Python 2 compatibility
将上面的核心代码分成①,②,③三部分,
我们分析的顺序是③ ① ②
③ While True:
因为这里我们还不知道self.rcvd_idx
和self.reorder_dict
的用法,所以先关注第③部分最后的while True内容:
在构造函数中,我们有:
self.shutdown = False
self._put_indices
使得self.batches_outstanding = 2 * num_workers
下面进入函数self._get_batch()
,如下所示,就是从data queue里面取数据,**idx是_put_indices()
中的self.send_idx **
def _get_batch(self):
if self.timeout > 0:
try:
return self.data_queue.get(timeout=self.timeout)
except queue.Empty:
raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
else:
return self.data_queue.get()
接着,对self.batches_outstanding
减1(也就是预备好的batch个数需要减1)。
因为**idx是_put_indices()
中的self.send_idx **,而self.rcvd_idx是接收到的idx,判断它们是否一致。
if idx != self.rcvd_idx:
# store out-of-order samples
self.reorder_dict[idx] = batch
continue
需要注意,self.rcvd_idx初始值为0,它只在_process_next_batch
中产生变化(+1)
def _process_next_batch(self, batch):
self.rcvd_idx += 1
self._put_indices()
...
return batch
# 调用_process_next_batch的时候,处理了接收索引(rcvd_idx),并且通过调用`self._put_indices()`,
# 向index_queue中扔数据,并使得发送索引数加1, 在data_queue中可以被处理的batch数量加1
# 而实际上batch本身不变
这里说一下为什么是在data_queue中可以被处理的batch数量加1:因为有num_workers个守护子进程是对index_queue中的数据进行处理的,当index_queue中有新的内容时,若这些守护子进程有空闲,则会对其从index_queue中取出,并进行处理,将batch size个索引经过处理放入data_queue中。
需要额外注意的是:当index_queue没有内容的时候,执行self._put_indices()是不会使得self.send_idx
和self.batches_outstanding
的值发生变化的,这也就是我们在_DataLoaderIter
的构造函数最后可以对其进行一个初始化的原因。
其实说到这里,可能还是很迷糊,下面在__next__()
的一些关键位置加注了信息输出
我们以num_workers = 2,为例
def __next__(self):
if self.num_workers == 0: # same-process loading
indices = next(self.sample_iter) # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.pin_memory:
batch = pin_memory_batch(batch)
return batch
# check if the next sample has already been generated
if self.rcvd_idx in self.reorder_dict:
print('从不定序dict中获取对应的batch:', self.rcvd_idx)
batch = self.reorder_dict.pop(self.rcvd_idx)
return self._process_next_batch(batch)
if self.batches_outstanding == 0:
self._shutdown_workers()
raise StopIteration
while True:
assert (not self.shutdown and self.batches_outstanding > 0)
idx, batch = self._get_batch()
# initial batches_outstanding = 4
self.batches_outstanding -= 1
print("batches outstanding:", self.batches_outstanding)
if idx != self.rcvd_idx:
# store out-of-order samples
print("send_idx != rcvd_idx:", idx, self.rcvd_idx)
self.reorder_dict[idx] = batch
continue
print("send_idx = rcvd_idx:", idx)
print('-' * 20)
return self._process_next_batch(batch)
自定义了一个DataLoader,并对其进行遍历,结果如下:
#### 第1个next
# 经过self._get_batch()之后,可以处理的batch数据-1,从4变为3
batches outstanding: 3
# 发送的idx(send_idx) = 1, 而第一次next的时候rcvd_idx = 0,此时用self.reorder_dict这个字典
# 把idx = 1对于的batch记录下来
send_idx != rcvd_idx: 1 0
# 这里self.reorder_dict = {1: correspond_batch}, 因为不满足idx == self.rcvd_idx,
# 所以继续执行循环语句。
# 经过self._get_batch()之后,可以处理的batch数据-1,从3变为2
batches outstanding: 2
# 这下子idx和rcvd_idx相等了!执行self._process_next_batch(batch)
send_idx = rcvd_idx: 0
#执行self._process_next_batch(batch),使rcvd_idx += 1, _put_indices()
# --->也就是send_idx += 1和batches_outstanding += 1(如果self.sample_iter不为空)
--------------------
#### 第2个next
# 对于`__next__()`中的代码段①
从不定序dict中获取对应的batch: 1
**执行self._process_next_batch(batch),使rcvd_idx += 1, _put_indices()--->也就是send_idx += 1和outstanding += 1**
#### 第3个next
batches outstanding: 3
send_idx != rcvd_idx: 3 2
batches outstanding: 2
send_idx = rcvd_idx: 2
--------------------
从不定序dict中获取对应的batch: 3
batches outstanding: 3
send_idx != rcvd_idx: 5 4
batches outstanding: 2
send_idx = rcvd_idx: 4
① 检查样本是否已经生成:
由上面的例子可以看出,因为rvcd_idx = 1对于的send_idx = 1样本已经存在且放置于self.reorder_dict
中,
所以self.reorder_dict
的目的是保证batch size数目的样本在每次next输出的时候是根据rcvd_idx进行升序输出的。
# check if the next sample has already been generated
if self.rcvd_idx in self.reorder_dict:
print('从不定序dict中获取对应的batch:', self.rcvd_idx)
batch = self.reorder_dict.pop(self.rcvd_idx)
return self._process_next_batch(batch)
② 检查是否还有剩余样本:
如果batch都被处理完了,那么就关闭所有的处理_worker_loop
进程。
if self.batches_outstanding == 0:
self._shutdown_workers()
raise StopIteration
2.4 default_collate(batch)
default_collate
是DataLoader
的默认collate_fn
,并传给了_DataLoaderIter
作为_worker_loop
处理数据的基本函数,这里我们只需要看torch.stack就好了,它的目的:将batch size个样本合成为一个batch(加了一个维度)
def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
# elem_type = type(batch[0])
# if isinstance(batch[0], torch.Tensor):
# print(isinstance(batch[0], torch.Tensor))
if elem_type == torch.Tensor:
out = None
if _use_shared_memory:
...
return torch.stack(batch, 0, out=out)
...
我们暂时需要关注一个torch.stack的用法即可:
3. 总结
① DataLoader本质上就是一个iterable(跟python的内置类型list等一样),并利用多进程来加速batch data的处理,使用yield来使用有限的内存
② Queue的特点
当队列里面没有数据时: queue.get() 会阻塞, 阻塞的时候,其它进程/线程如果有queue.put() 操作,本线程/进程会被通知,然后就可以 get 成功。
当数据满了: queue.put() 会阻塞
③ DataLoader是一个高效,简洁,直观的网络输入数据结构,便于使用和扩展
4. 参考资料
发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/132498.html原文链接:https://javaforall.cn
【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛
【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...