# PyTorch 大批量数据 如何训练?

Pytorch 专栏收录该内容
17 篇文章 1 订阅

1. 当 GPU 的内存小于 Batch Size 的训练样本，或者甚至连一个样本都塞不下的时候，怎么用单个或多个 GPU 进行训练？
2. 怎么尽量高效地利用多 GPU？

## 1、单个或多个 GPU 进行大批量训练

PyTorch 的开发人员都出来了，估计一脸黑线：兄弟，这不是 bug，是你内存不够…

predictions = model(inputs)               			# 前向计算
loss = loss_function(predictions, labels) 			# 计算损失函数
loss.backward()                           			# 后向计算梯度
optimizer.step()                          			# 优化器更新梯度
predictions = model(inputs)               			# 用更新过的参数值进行下一次前向计算


model.zero_grad()                                   # 重置保存梯度值的张量
for i, (inputs, labels) in enumerate(training_set):
predictions = model(inputs)                     # 前向计算
loss = loss_function(predictions, labels)       # 计算损失函数
loss = loss / accumulation_steps                # 对损失正则化 (如果需要平均所有损失)
loss.backward()                                 # 计算梯度
if (i 1) % accumulation_steps == 0:             # 重复多次前面的过程
optimizer.step()                            # 更新梯度


## 2、如果连一个样本都不放下怎么办？

#### TORCH.UTILS.CHECKPOINT

NOTE:

Checkpointing is implemented by rerunning a forward-pass segment for each checkpointed segment during backward. This can cause persistent states like the RNG state to be advanced than they would without checkpointing. By default, checkpointing includes logic to juggle the RNG state such that checkpointed passes making use of RNG (through dropout for example) have deterministic output as compared to non-checkpointed passes. The logic to stash and restore RNG states can incur a moderate performance hit depending on the runtime of checkpointed operations. If deterministic output compared to non-checkpointed passes is not required, supply preserve_rng_state=False to checkpoint or checkpoint_sequential to omit stashing and restoring the RNG state during each checkpoint.

The stashing logic saves and restores the RNG state for the current device and the device of all cuda Tensor arguments to the run_fn. However, the logic has no way to anticipate if the user will move Tensors to a new device within the run_fn itself. Therefore, if you move Tensors to a new device (“new” meaning not belonging to the set of [current device + devices of Tensor arguments]) within run_fn, deterministic output compared to non-checkpointed passes is never guaranteed.

• torch.utils.checkpoint.checkpoint(function, *args, **kwargs)

Checkpoint a model or part of the model

Checkpointing works by trading compute for memory. Rather than storing all intermediate activations of the entire computation graph for computing backward, the checkpointed part does not save intermediate activations, and instead recomputes them in backward pass. It can be applied on any part of a model.

Specifically, in the forward pass, function will run in torch.no_grad() manner, i.e., not storing the intermediate activations. Instead, the forward pass saves the inputs tuple and the function parameter. In the backwards pass, the saved inputs and function is retrieved, and the forward pass is computed on function again, now tracking the intermediate activations, and then the gradients are calculated using these activation values.

WARNING:

WARNING:
If function invocation during backward does anything different than the one during forward, e.g., due to some global variable, the checkpointed version won’t be equivalent, and unfortunately it can’t be detected.

Parameters:

• function – describes what to run in the forward pass of the model or part of the model. It should also know how to handle the inputs passed as the tuple. For example, in LSTM, if user passes (activation, hidden), function should correctly use the first input as activation and the second input as hidden
• preserve_rng_state (bool, optional, default=True) – Omit stashing and restoring the RNG state during each checkpoint.
• args – tuple containing inputs to the function

Returns: Output of running function on *args

• torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, **kwargs)

A helper function for checkpointing sequential models.

Sequential models execute a list of modules/functions in order (sequentially). Therefore, we can divide such a model in various segments and checkpoint each segment. All segments except the last will run in torch.no_grad() manner, i.e., not storing the intermediate activations. The inputs of each checkpointed segment will be saved for re-running the segment in the backward pass.
See checkpoint() on how checkpointing works.

WARNING:

Parameters:

• functions – A torch.nn.Sequential or the list of modules or functions (comprising the model) to run sequentially.
• egments – Number of chunks to create in the model
• input – A Tensor that is input to functions
• preserve_rng_state (bool, optional, default=True) – Omit stashing and restoring the RNG state during each checkpoint.

Returns: Output of running functions sequentially on *inputs

Example:

>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)


## 3、多 GPU 训练方法

torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)


• module：需要多GPU训练的网络模型
• device_ids： GPU的编号（默认全部GPU）
• output_device：（默认是device_ids[0])
• dim：tensors被分散的维度，默认是0
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


model = Model()
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model，device_ids=[0,1,2])

model.to(device)


parallel_model = torch.nn.DataParallel(model) # 就是这里！

predictions = parallel_model(inputs)          # 前向计算
loss = loss_function(predictions, labels)     # 计算损失函数
loss.mean().backward()                        # 计算多个GPU的损失函数平均值，计算梯度
optimizer.step()                              # 反向传播
predictions = parallel_model(inputs)


from parallel import DataParallelModel, DataParallelCriterion

parallel_model = DataParallelModel(model)               # 并行化model
parallel_loss  = DataParallelCriterion(loss_function)   # 并行化损失函数

predictions = parallel_model(inputs)      				# 并行前向计算
# "predictions"是多个gpu的结果的元组
loss = parallel_loss(predictions, labels) 				# 并行计算损失函数
loss.backward()                           				# 计算梯度
optimizer.step()                          				# 反向传播
predictions = parallel_model(inputs)


output_1, output_2 = zip(*predictions)


gathered_predictions = parallel.gather(predictions)


• 5
点赞
• 0
评论
• 7
收藏
• 一键三连
• 扫一扫，分享海报

09-15 2640
12-26 268
05-29 8727
05-20 4096
06-23 3440
03-06 1458
07-12 941
04-10 3034
09-06 5588
04-21 5053