PyTorch 梯度累计 (Gradient Accumulation)
Johann Li | February 25, 2024
额,在内存不够的时候,须要通过梯度累计来凑。然后 PyTorch 进行梯度累加的话,网上搜索出来的教程可能有 bug。所以正确的操作应该是:
step = 8
for i, (x, y) in dataloader:
pred = net(x)
loss = criterion(pred, y) / step
loss.backward()
if (i + 1) % step == 0:
optimizer.step()
optimizer.zero_grad()
应该在 loss 这边 除去 step
: loss = criterion(pred, y) / step
。
(测试了之后,还是有一些效果的。。,至少让随手搭建的 0.1B 的模型,在 MNIST 数据集上,用 3090 ,以 batch=6
的规模来训练,并在第一个 epoch loss 能将到 1 一下,然后 acc 为 94.6%。)