大家好,又见面了,我是你们的朋友全栈君。
原文连接: http://chenhao.space/post/d313d236.html
pytorch-DataLoader的使用
import torch
import torch.utils.data as Data
# [1, 1, 1]相当于一句话的word embedding,这个tensor中含有三句话
x = torch.tensor([[1, 1, 1], [2, 2, 2,], [3, 3, 3], [4, 4, 4], [5, 5, 5,], [6, 6, 6],[7, 7, 7], [8, 8, 8,], [9, 9, 9], [10, 10, 10]])
# [1, 2, 3]分别是这三句话的标签
y = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
torch_dataset = Data.TensorDataset(x, y)
# dataset:Dataset类型,从其中加载数据
# batch_size:int,可选。每个batch加载多少样本
# shuffle:bool,可选。为True时表示每个epoch都对数据进行洗牌
# sampler:Sampler,可选。从数据集中采样样本的方法。
# num_workers:int,可选。加载数据时使用多少子进程。默认值为0,表示在主进程中加载数据。
# collate_fn:callable,可选。
# pin_memory:bool,可选
# drop_last:bool,可选。True表示如果最后剩下不完全的batch,丢弃。False表示不丢弃。
loader = Data.DataLoader(torch_dataset, batch_size=3, shuffle=True, num_workers=0)
data = iter(loader)
n = len(y)//3 if len(y)%3 == 0 else len(y)//3 + 1 # batch的数量
for i in range(n):
print(next(data))
[tensor([[5, 5, 5],
[9, 9, 9],
[8, 8, 8]]), tensor([5, 9, 8])]
[tensor([[10, 10, 10],
[ 2, 2, 2],
[ 7, 7, 7]]), tensor([10, 2, 7])]
[tensor([[6, 6, 6],
[1, 1, 1],
[3, 3, 3]]), tensor([6, 1, 3])]
[tensor([[4, 4, 4]]), tensor([4])]
for epoch in range(5): # 训练所有数据5次
i = 0
for sentence, label in loader:
i += 1
print('Epoch:{} | num:{} | sentence:{} | label:{}'.format(epoch,i,sentence,label))
Epoch:0 | num:1 | sentence:tensor([[10, 10, 10],
[ 2, 2, 2],
[ 8, 8, 8]]) | label:tensor([10, 2, 8])
Epoch:0 | num:2 | sentence:tensor([[7, 7, 7],
[9, 9, 9],
[5, 5, 5]]) | label:tensor([7, 9, 5])
Epoch:0 | num:3 | sentence:tensor([[6, 6, 6],
[4, 4, 4],
[1, 1, 1]]) | label:tensor([6, 4, 1])
Epoch:0 | num:4 | sentence:tensor([[3, 3, 3]]) | label:tensor([3])
Epoch:1 | num:1 | sentence:tensor([[9, 9, 9],
[3, 3, 3],
[4, 4, 4]]) | label:tensor([9, 3, 4])
Epoch:1 | num:2 | sentence:tensor([[8, 8, 8],
[6, 6, 6],
[5, 5, 5]]) | label:tensor([8, 6, 5])
Epoch:1 | num:3 | sentence:tensor([[ 1, 1, 1],
[10, 10, 10],
[ 2, 2, 2]]) | label:tensor([ 1, 10, 2])
Epoch:1 | num:4 | sentence:tensor([[7, 7, 7]]) | label:tensor([7])
Epoch:2 | num:1 | sentence:tensor([[4, 4, 4],
[6, 6, 6],
[7, 7, 7]]) | label:tensor([4, 6, 7])
Epoch:2 | num:2 | sentence:tensor([[10, 10, 10],
[ 8, 8, 8],
[ 5, 5, 5]]) | label:tensor([10, 8, 5])
Epoch:2 | num:3 | sentence:tensor([[3, 3, 3],
[2, 2, 2],
[9, 9, 9]]) | label:tensor([3, 2, 9])
Epoch:2 | num:4 | sentence:tensor([[1, 1, 1]]) | label:tensor([1])
Epoch:3 | num:1 | sentence:tensor([[7, 7, 7],
[5, 5, 5],
[3, 3, 3]]) | label:tensor([7, 5, 3])
Epoch:3 | num:2 | sentence:tensor([[10, 10, 10],
[ 1, 1, 1],
[ 6, 6, 6]]) | label:tensor([10, 1, 6])
Epoch:3 | num:3 | sentence:tensor([[9, 9, 9],
[8, 8, 8],
[4, 4, 4]]) | label:tensor([9, 8, 4])
Epoch:3 | num:4 | sentence:tensor([[2, 2, 2]]) | label:tensor([2])
Epoch:4 | num:1 | sentence:tensor([[ 5, 5, 5],
[ 7, 7, 7],
[10, 10, 10]]) | label:tensor([ 5, 7, 10])
Epoch:4 | num:2 | sentence:tensor([[9, 9, 9],
[3, 3, 3],
[4, 4, 4]]) | label:tensor([9, 3, 4])
Epoch:4 | num:3 | sentence:tensor([[2, 2, 2],
[8, 8, 8],
[1, 1, 1]]) | label:tensor([2, 8, 1])
Epoch:4 | num:4 | sentence:tensor([[6, 6, 6]]) | label:tensor([6])
发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/146258.html原文链接:https://javaforall.cn
【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛
【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...