公众号关注 “程序员遇见GitHub”
设为“星标”,重磅干货,第一时间送达

报道 | 量子位
CUDA error: out of memory.
多少人用PyTorch“炼丹”时都会被这个bug困扰。

一般情况下,你得找出当下占显存的没用的程序,然后kill掉。
如果不行,还需手动调整batch size到合适的大小……
有点麻烦。
现在,有人写了一个PyTorch wrapper,用一行代码就能“无痛”消除这个bug。

有多厉害?
相关项目在GitHub才发布没几天就收获了600+星。

一行代码解决内存溢出错误
软件包名叫koila,已经上传PyPI,先安装一下:
pip install koila
现在,假如你面对这样一个PyTorch项目:构建一个神经网络来对FashionMNIST数据集中的图像进行分类。
先定义input、label和model:
# A batch of MNIST image
input = torch.randn(8, 28, 28)
# A batch of labels
label = torch.randn(0, 10, [8])
class NeuralNetwork(Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = Flatten()
self.linear_relu_stack = Sequential(
Linear(28 * 28, 512),
ReLU(),
Linear(512, 512),
ReLU(),
Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
然后定义loss函数、计算输出和losses。
loss_fn = CrossEntropyLoss()
# Calculate losses
out = nn(t)
loss = loss_fn(out, label)
# Backward pass
nn.zero_grad()
loss.backward()
好了,如何使用koila来防止内存溢出?
超级简单!
只需在第一行代码,也就是把输入用lazy张量wrap起来,并指定bacth维度——
koila就能自动帮你计算剩余的GPU内存并使用正确的batch size了。
在本例中,batch=0,则修改如下:
input = lazy(torch.randn(8, 28, 28), batch=0)
完事儿!就这样和PyTorch“炼丹”时的OOM报错说拜拜。
灵感来自TensorFlow的静态/懒惰评估
下面就来说说koila背后的工作原理。
“CUDA error: out of memory”这个报错通常发生在前向传递(forward pass)中,因为这时需要保存很多临时变量。
koila的灵感来自TensorFlow的静态/懒惰评估(static/lazy evaluation)。
它通过构建图,并仅在必要时运行访问所有相关信息,来确定模型真正需要多少资源。
而只需计算临时变量的shape就能计算各变量的内存使用情况;而知道了在前向传递中使用了多少内存,koila也就能自动选择最佳batch size了。
又是算shape又是算内存的,koila听起来就很慢?

NO。
即使是像GPT-3这种具有96层的巨大模型,其计算图中也只有几百个节点。
而Koila的算法是在线性时间内运行,任何现代计算机都能够立即处理这样的图计算;再加上大部分计算都是单个张量,所以,koila运行起来一点也不慢。
你又会问了,PyTorch Lightning的batch size搜索功能不是也可以解决这个问题吗?
是的,它也可以。
但作者表示,该功能已深度集成在自己那一套生态系统中,你必须得用它的DataLoader,从他们的模型中继承子类,才能训练自己的模型,太麻烦了。
而koila灵活又轻量,只需一行代码就能解决问题,非常“大快人心”有没有。
不过目前,koila还不适用于分布式数据的并行训练方法(DDP),未来才会支持多GPU。

以及现在只适用于常见的nn.Module类。

ps. koila作者是一位叫做RenChu Wang的小哥。

项目地址:
https://github.com/rentruewang/koila
参考链接:
https://www.reddit.com/r/MachineLearning/comments/r4zaut/p_eliminate_pytorchs_cuda_error_out_of_memory/
推荐阅读:
我教你如何读博!
牛逼!轻松高效处理文本数据神器
B站强化学习大结局!
如此神器,得之可得顶会!
兄弟们!神经网络画图,有它不愁啊
太赞了!东北大学朱靖波,肖桐团队开源《机器翻译:统计建模与深度学习方法》
当年毕业答辩!遗憾没有它...
已开源!所有李航老师《统计学习方法》代码实现
这个男人,惊为天人!手推PRML!
它来了!《深度学习》(花书) 数学推导、原理剖析与代码实现
你们心心念念的MIT教授Gilbert Strang线性代数彩板笔记!强烈推荐!
GitHub超过9800star!学习Pytorch,有这一份资源就够了!强推!
你真的懂神经网络?强推一个揭秘神经网络的工具,ANN Visualizer
诸位!看我如何白嫖2020 icassp!
这个时代研究情感分析,是最好也是最坏!
BERT雄霸天下!
玩转Pytorch,搞懂这个教程就可以了,从GAN到词嵌入都有实例
是他,是他,就是他!宝藏博主让你秒懂Transformer、BERT、GPT!
fitlog!复旦邱锡鹏老师组内部调参工具!一个可以节省一篇论文的调参利器
Github开源!查阅arXiv论文新神器,一行代码比较版本差别,我爱了!
开源!数据结构与算法必备的 50 个代码实现
他来了!吴恩达带着2018机器学习入门高清视频,还有习题解答和课程拓展来了!
太赞了!复旦邱锡鹏老师NLP实战code解读开源!
这块酷炫的Python神器!我真的爱了,帮助你深刻理解语言本质!实名推荐!
论文神器!易搜搭
不瞒你说!这可能是世界上最好的线性代数教程