Dataloader and Dataset
引入
在使用PyTorch进行模型训练时,第一步就是要进行数据处理,然后将数据送入网络中。在梯度下降的文章中说到过,使用SGD会使得程序无法进行并行,而使用梯度下降会容易使程序陷入局部最优,所以一般会采用mini-batch进行训练。PyTorch为我们提供了Dataset和Dataloader两个类来对数据进行方便的构建。
一些概念
我们在进行训练循环时,框架如下所示,其中会涉及到3个概念。
1 2 3
| for epoch in range(training_epochs): for i in range(total_batch):
|
-
**Epoch:**全部样本都进行了一次训练(前向传播,反向传播,梯度更新)
-
**Batch-Size:**一次训练中所使用的样本数量
-
**Iteration:**一共有多少个Batch(内层迭代的次数)
Dataset
在torch中使用Dataset需要import抽象类:torch.utils.data.Dataset
抽象类不能够进行实例化,使用自己的数据集时,需要继承Dataset类,并实现下面三个函数
1 2 3 4 5 6 7 8 9
| from torch.utils.data import Dataset Class textDataset(Dataset): def __init__(self): pass def __getitem__(self, index): pass def __len__(self): pass t_dataset = textDataset()
|
我们使用糖尿病数据集构建Dataset
1 2 3 4 5 6 7 8 9 10 11
| class DiabetesDataset(Dataset): def __init__(self,filepath): xy = np.loadtxt(filepath, delimiter ==',', dtype =np.float32) self.len = xy.shape[0] self.x_data = torch.from_numpy(xy[:,:-1]) self.y_data = torch.from_numpy(xy[:,[-1]]) def __getitem__(self, index): return self.x_data[index], self.y_data[index] def __len__(self): return self.len
|
Dataloader
需要使用Dataloader从dataset中加载数据,可以获取每个mini-batch,进行shuffle等。
需要初始化以下的参数:batch-size ,shuffle , process number
下面的代码是常用的训练模型的框架:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| t_dataset = DiabetesDataset() train_loader = DataLoader(dataset = t_dataset,batch_size=32,shuffle=True,num_workers=2) for epoch in range(10): for index, data in enumerate(train_loader): x_data, y_data = data y_pred = model(x_data) loss = criterion(y_pred, labels) print(epoch, i, loss.item()) optimizer.zero_grad() loss.backward() optimizer.step()
|
参考