Pytorch 构建数据集&数据加载

Pytorch 专栏收录该内容
17 篇文章 1 订阅

在这里插入图片描述

为了能够使用 DataLoader类,首先需要构造关于单个数据的 torch.untils.data.Dataset 类


1、构建 My_Dataset类

对于数据集的处理,Pytorch 提供了 torch.utils.data.Dataset 这个抽象类,在使用时只需要继承该类,并重写 __len__() & __getitem__() 函数,便可以方便的进行数据集的迭代。

Dataset中主要有3个方法

  • __init__:在这个方法主要就是初始化信息

  • __getitem__:在这个方法里根据传入的下标返回label和transform之后的图片tensor

  • __len__:返回dataset的长度

from torch.utils.data import Dataset

class my_dataset(Dataset):
	def __init__(self, image_path ,annotation_path ,transform=None):
		# 初始化,读取数据集
	
	def __len__(self):
		# 获取数据集的总大小
	
	def __getitem__(self, id):
		# 对指定的 id ,读取该数据并返回


# 对上述类进行实例化,即可进行迭代

dataset = my_dataset("your image path" , "your annotation path") 			# 实例化该类

for data in dataset:
	print(data)


1、映射类型的数据集

对于这个类型,每个数据有一个对应的索引,通过输入具体的索引,就能得到对应的数据,其构造方法如下:


构造方法:

class Dataset(object):
	def __getitem__(self,index):
		# index:数据索引(整数,范围从0开始)
		...
		...
		# 返回数据张量

	def __len__(self):
		# 返回数据的数目
		...
		...

这个类主要重写了两个方法:

  • __gtitem__:该方法是 Python内置的操作符方法,对应的操作符是索引操作符 [ ],通过输入整数数据索引,其大小在 0 至 N-1之间,,返回具体的某一条数据记录,这就是该方法需要完成的任务;而具体的内部逻辑需要根据数据集的类型来决定;
  • __len__:该方法返回数据的总数;

2、可迭代类型的数据集

相比映射类型的数据集,这个数据集并不需要实现__getitem__ 方法 或者 __len__ 方法,它本身更像一个 Python迭代器。

不同于映射类型,因为索引之间相互独立,在使用多进程载入数据的情况下(DataLoader 中的参数 num_works>1),多个进程可以独立分配索引,迭代器在使用过程中,因为索引之间有前后顺序关系,需要考虑如何分割数据,使得不同的进程可以得到不同的数据。


构造方法:

from torch.utils.data import 

class MyIterableDataset(IterableDataset):
     def __init__(self, start, end):
         super(MyIterableDataset).__init__()
         assert end > start, "this example code only works with end >= start"
         self.start = start
         self.end = end

     def __iter__(self):
         worker_info = torch.utils.data.get_worker_info()
         if worker_info is None:  		# single-process data loading, return the full iterator
             iter_start = self.start
             iter_end = self.end
         else:  						# in a worker process ,split workload
             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
             worker_id = worker_info.id
             iter_start = self.start + worker_id * per_worker
             iter_end = min(iter_start + per_worker, self.end)
         return iter(range(iter_start, iter_end))

根据不同工作进程的序号 worker_id ,设定了不同进程数据迭代器取值的范围,这样就能保证不同的进程获取不同的迭代器,而且迭代器返回的数据各不相同。


4、数据集增强

虽然将数据打包好,但在实际应用是,数据集中的图片有可能存在大小不一的情况,并且原始图片像素RGB值(0~255)比较大,不利于模型的训练收敛,因此需要进行一些图像变换工作。

Pytorch 为此提供了 torchvision.transform 工具包,可以方便地进行图像缩放、裁剪、随机翻转、填充及张量的归一化等操作,操作对象是 PIL 的 Image or Tensor。

如果需要进行多个变换功能,可以利用 transforms.Compose 将多个变换整合起来,并且在实际使用时,通常会将变换操作集成到 Dataset 的继承类中。

from torchvision import transfroms

# 将 transforms 集成到 Dataset类中,使用 Compose将多个变换整合到一起
dataset = my_data( "your image path" , "your annotation path"
					transforms = transforms.Compose([ transforms.Resize(256) ,
													  transforms.RandomHorizontalFilp() ,
													  transforms.ToTensor() ,
													  transforms.Normalize([0.5 ,0.5 , 0.5] , [0.5 ,0.5 , 0.5])] ) )


5、加载数据集 DataLoader

在使用 Pytorch 构建 & 训练模型的过程中,经常需要把原始的数据转化为张量的格式。为了能够方便地批量处理图片数据, Pytorch 引入了一系列工具来对这个过程进行包装。

通常来说,Pytorch 数据的载入使用的是 torch.utils.data.DataLoader 类。
DataLoader :将自定义的 Dataset 根据batch size大小、是否shuffle等封装成一个BatchSize大小的 Tensor,用于后面的训练。

DataLoader( dataset  , batch_size=1 , shuffle=False    , sampler=None 	, 
		    batch_sampler=None      , num_workers=0   , collate_fn=None	,
		    pin_memory=False        , drop_last=False , timeout=0    	, 
		    worker_init_fn=None )



  dataset 				   	# 加载的数据集(Dataset对象);
  batch_size=1			   	# 将数据集分成 mini-batch数量;
  shuffle=True            	# 要不要打乱数据 (打乱比较好);
  sampler=None				# 自定义的采样器(shuffle=true时会构建默认的采样器,如果想设置自定义采样方法;
  							# 可以构造一个 torch.utils.dada.Sample 的实例来进行采样,并设置shuffle=False),
  							# 采样器是一个Python迭代器,每次迭代的时候会返回一个数据的下标索引;
  							
  batch_sampler=None		# 类似于 sampler,不过返回的是一个迷你批次的数据索引,而sampler返回的仅仅是一个下标索引;
  num_workers=0		   		# 使用多进程加载的进程数,0 代表不使用多进程;
  collate_fn=None		   	# 将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可;
  pin_memory=False		   	# 将数据保存在pin_memory区,数据转到GPU会快一些;
  drop_last=False         	# drop_last为True会将最后一批 mini-batch 多出来不足一个 mini-batch的数据丢弃;
  timeout=0					# 如果timeout>0,就会决定在多进程情况下对数据的等待时间;
  worker_init_fn=None)    	# 决定了每个数据载入的子进程开始时运行的函数,这个函数运行在随机种子设置以后、数据载入之前;
  

  • 1
    点赞
  • 0
    评论
  • 1
    收藏
  • 一键三连
    一键三连
  • 扫一扫,分享海报

相关推荐
©️2020 CSDN 皮肤主题: 岁月 设计师:pinMode 返回首页
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值