PyTorch 梯度累计 (Gradient Accumulation)

额,在内存不够的时候,须要通过梯度累计来凑。然后 PyTorch 进行梯度累加的话,网上搜索出来的教程可能有 bug。所以正确的操作应该是:

step = 8
for i, (x, y) in dataloader:
    pred = net(x)
    loss = criterion(pred, y) / step
    loss.backward()
    if (i + 1) % step == 0:
        optimizer.step()
        optimizer.zero_grad()

应该在 loss 这边 除去 step: loss = criterion(pred, y) / step

(测试了之后,还是有一些效果的。。,至少让随手搭建的 0.1B 的模型,在 MNIST 数据集上,用 3090 ,以 batch=6 的规模来训练,并在第一个 epoch loss 能将到 1 一下,然后 acc 为 94.6%。)