PyTorch 检查显存使用情况

如果你有遇到代码疯狂占用显存,但是,不知道显存被用在了哪里的话,可以尝试看看这篇博文。博文的源头来自于 Understanding CUDA Memory Usage。 是我为了找到显存占用的代码,在网上游荡的时候,找见的这个代码。下面我简单说一下这个玩意是如何帮助大家找到

探针

就是要知道程序在运行的时候,内存的申请状况啥的,就需要有一个探针来对内存的分配情况进行记录,然后我们才能知道当前内存的使用或者相关的信息。 因此就需要在所执行的代码的前后添加代码:

import torch
...

torch.cuda.memory._record_memory_history()
train_or_eval()
torch.cuda.memory._dump_snapshot("snapshot.pickle")

然后就会得到一个 snapshot.pickle 这个文件里面就是相关的内存的信息。

可视化

上面获取的信息,还是须要可视化才能用的。(应该没有人能做到裸眼看二进制数据,然后就知道如何优化代码了吧?)

可视化的网站是 pytorch.org/memory_viz。将产生的文件(snapshot.pickle 或者其他什么的)上传到这个网页上(鼠标拖动上传)。 然后网站就会给出可视化结果(鼠标移动到统计图上还会显示分配内存对应的调用栈):