简短的来说, PyTorch 2.0 新引入进来了一个 torch.compile
API, 通过一定的编译的方式,加速伸进网络的训练和推理(虽然好像效果一般吧)。
但是这个东西,也带来了一个问题,就是说 内存泄漏。
PyTorch 官方 GitHub 仓库中是有一个这样的 issue 的,已经关闭了。
里面提到的解决方式是安装最新的 Triton 就解决了。我测试了一下,能够解决这个问题。但是这里须要提到的一点是,出问题的版本是 2.0.0
,而解决问题的版本是 2.0.0.post1
。 这个要安装这个版本,可能须要执行 pip install triton==2.0.0.post1
这个命令。
除此之外,通过简单的测试,在更新前, torch.compile
里面如果将 backend
设置成 aot_ts_nvfuser
或者 nvprims_nvfuser
的话,没有出现过内存溢出的问题。
测试代码
我是用的网络是,通过 monkey-patch 魔改的 ResNet 网络:
def replace(m: nn.Module):
for n, mx in m.named_children():
if isinstance(mx, pool):
m.add_module(n, nn.Identity())
elif isinstance(mx, nn.Conv2d):
if mx.stride != (1, 1) or mx.stride != 1:
conv = nn.Conv2d(
mx.in_channels, mx.out_channels, mx.kernel_size,
stride=1, padding=mx.padding, dilation=mx.dilation,
groups=mx.groups, bias=mx.bias, padding_mode=mx.padding_mode,
device=mx.weight.device, dtype=mx.weight.dtype
)
m.add_module(n, conv)
else:
replace(mx)
class ResNet(M.resnet.ResNet):
def __init__(
self,
block: Type[Union[M.resnet.BasicBlock, M.resnet.Bottleneck]],
layers: List[int],
planes: List[int] = [64, 64, 64, 64],
zero_init_residual: bool = False,
groups: int = 1,
width_per_group: int = 64,
replace_stride_with_dilation: Optional[List[bool]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
num_classes: int = 1000,
):
super().__init__(
block, layers,
zero_init_residual=zero_init_residual,
groups=groups, width_per_group=width_per_group,
replace_stride_with_dilation=replace_stride_with_dilation,
norm_layer=norm_layer,
num_classes=num_classes,
)
if replace_stride_with_dilation is None:
replace_stride_with_dilation = [False, False, False]
self.inplanes = 64
self.add_module('layer1', self._make_layer(block, planes[0], layers[0]))
self.add_module('layer2', self._make_layer(block, planes[1], layers[1],
stride=1, dilate=replace_stride_with_dilation[0]))
self.add_module('layer3', self._make_layer(block, planes[2], layers[2],
stride=1, dilate=replace_stride_with_dilation[1]))
self.add_module(
'layer4', self._make_layer(block, planes[3], layers[3],
stride=1, dilate=replace_stride_with_dilation[2])
)
project_planes = self.layer4[-1].bn3.num_features \
if isinstance(self.layer4[-1], M.resnet.Bottleneck) \
else self.layer4[-1].bn2.num_features
replace(self)
然后测试的代码是
import torch
import model.resnet as R
net = R.resnet101(None, planes = [64,32,128,32]).cuda()
for i in range(100):
net(torch.rand(1,3,128,128).cuda())
for i in range(100):
net(torch.rand(1,3,128,128).cuda())
net = torch.compile(R.resnet101(None, planes = [64,32,128,32]).cuda())
for i in range(100):
net(torch.rand(1,3,128,128).cuda())
for i in range(100):
net(torch.rand(1,3,128,128).cuda())
net = torch.compile(R.resnet101(None, planes = [64,32,128,32]).cuda(), backend="aot_ts_nvfuser")
for i in range(100):
net(torch.rand(1,3,128,128).cuda())
for i in range(100):
net(torch.rand(1,3,128,128).cuda())
net = torch.compile(R.resnet101(None, planes = [64,32,128,32]).cuda(), backend="nvprims_nvfuser")
for i in range(100):
net(torch.rand(1,3,128,128).cuda())
for i in range(100):
net(torch.rand(1,3,128,128).cuda())
测试的时候,前项传播的过程是要执行两次的,因为如果只执行一次的话,不一定准确。
如何定为到 torch.compile 的问题的呢?
首先,一开始使用 objgraph 对内存使用状况进行监测。发现了内存虽然增长了,但是随着训练的时间增加,却没有检测出来任何增加的对象。
然后,就换了 memory_profiler。换之前,实际上通过各种方法,将可疑的的部份替换点,来看是否是有问题的,但是从来没有假设到 torch.compile
和 神经网络会有问题,所以就一直找不到问题。
然后更换了 memory_profiler 之后,就发现,模型执行的地方存在这个问题。 memory profiler 提示一直在分配内存。
一开始以为是我代码的实现上有问题,在模型或者损失函数的地方导致了内存泄漏。但是发现并不是,然后偶然将 torch.compile
关闭之后,就没有内存泄漏的问题了。
最后在网上搜索相关的信息后,发现问题并解决了。