大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。
Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺
H5文件读取:
import torch.utils.data as data
import torch
import h5py
class DatasetFromHdf5(data.Dataset):
def __init__(self, file_path):
super(DatasetFromHdf5, self).__init__()
hf = h5py.File(file_path)
self.data = hf.get('data')
self.target = hf.get('label')
def __getitem__(self, index):
return torch.from_numpy(self.data[index,:,:,:]).float(), torch.from_numpy(self.target[index,:,:,:]).float()
def __len__(self):
return self.data.shape[0]
调用的时候,先用DataLoader将数据装入 training_data_loader中
train_set = DatasetFromHdf5(r"D:\PycharmProjects\pytorch-vdsr-master\data\train.h5")
training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
在使用数据训练的时候写一个循环,iteration只是一个计数的,从1开始计数,表示已经取第iteration个批次了,batch就是每次取出一个批次的数值。
input和target是取出的输入和希望得到的输出,这里的返回顺序是在上边的DatasetFromHdf5中定义的。
def __getitem__(self, index):
return torch.from_numpy(self.data[index,:,:,:]).float(), torch.from_numpy(self.target[index,:,:,:]).float()
所以batch[0]表示input(也就是存储的data),batch[1]表示label(也就是label)。
index在这里应该是每次按第一个维度取出data中的数值。data[index,:,:,:],本来是维度是1000×1×41×41,每次取的是1×1×41×41。按照batch来,每次取出的就是batch×1×41×41
for iteration, batch in enumerate(training_data_loader, 1):
input, target = Variable(batch[0]), Variable(batch[1], requires_grad=False)
发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/195696.html原文链接:https://javaforall.cn
【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛
【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...