社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  Git

1行代码消除PyTorch的CUDA内存溢出报错,这个GitHub项目刚发布就揽星600+

程序员遇见GitHub • 3 年前 • 373 次点击  

公众号关注 “程序员遇见GitHub

设为“星标”,重磅干货,第一时间送达

报道 | 量子位

CUDA error: out of memory.

多少人用PyTorch“炼丹”时都会被这个bug困扰。

一般情况下,你得找出当下占显存的没用的程序,然后kill掉。

如果不行,还需手动调整batch size到合适的大小……

有点麻烦。

现在,有人写了一个PyTorch wrapper,用一行代码就能“无痛”消除这个bug。

有多厉害?

相关项目在GitHub才发布没几天就收获了600+星。

一行代码解决内存溢出错误

软件包名叫koila,已经上传PyPI,先安装一下:

pip install koila

现在,假如你面对这样一个PyTorch项目:构建一个神经网络来对FashionMNIST数据集中的图像进行分类。

先定义input、label和model:

# A batch of MNIST image
input = torch.randn(82828)

# A batch of labels
label = torch.randn(010, [8])

class NeuralNetwork(Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = Flatten()
        self.linear_relu_stack = Sequential(
            Linear(28 * 28512),
            ReLU(),
            Linear(512512),
            ReLU(),
            Linear(51210),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

然后定义loss函数、计算输出和losses。

loss_fn = CrossEntropyLoss()

# Calculate losses
out = nn(t)
loss = loss_fn(out, label)

# Backward pass
nn.zero_grad()
loss.backward()

好了,如何使用koila来防止内存溢出?

超级简单!

只需在第一行代码,也就是把输入用lazy张量wrap起来,并指定bacth维度——

koila就能自动帮你计算剩余的GPU内存并使用正确的batch size了。

在本例中,batch=0,则修改如下:

input = lazy(torch.randn(82828), batch=0)

完事儿!就这样和PyTorch“炼丹”时的OOM报错说拜拜。

灵感来自TensorFlow的静态/懒惰评估

下面就来说说koila背后的工作原理。

“CUDA error: out of memory”这个报错通常发生在前向传递(forward pass)中,因为这时需要保存很多临时变量。

koila的灵感来自TensorFlow的静态/懒惰评估(static/lazy evaluation)

它通过构建图,并仅在必要时运行访问所有相关信息,来确定模型真正需要多少资源。

而只需计算临时变量的shape就能计算各变量的内存使用情况;而知道了在前向传递中使用了多少内存,koila也就能自动选择最佳batch size了。

又是算shape又是算内存的,koila听起来就很慢?

NO。

即使是像GPT-3这种具有96层的巨大模型,其计算图中也只有几百个节点。

而Koila的算法是在线性时间内运行,任何现代计算机都能够立即处理这样的图计算;再加上大部分计算都是单个张量,所以,koila运行起来一点也不慢。

你又会问了,PyTorch Lightning的batch size搜索功能不是也可以解决这个问题吗?

是的,它也可以。

但作者表示,该功能已深度集成在自己那一套生态系统中,你必须得用它的DataLoader,从他们的模型中继承子类,才能训练自己的模型,太麻烦了。

koila灵活又轻量,只需一行代码就能解决问题,非常“大快人心”有没有。

不过目前,koila还不适用于分布式数据的并行训练方法(DDP),未来才会支持多GPU

以及现在只适用于常见的nn.Module类。

ps. koila作者是一位叫做RenChu Wang的小哥。

项目地址:
https://github.com/rentruewang/koila


参考链接:
https://www.reddit.com/r/MachineLearning/comments/r4zaut/p_eliminate_pytorchs_cuda_error_out_of_memory/



推荐阅读:

我教你如何读博!

牛逼!轻松高效处理文本数据神器

B站强化学习大结局!

如此神器,得之可得顶会!

兄弟们!神经网络画图,有它不愁啊

太赞了!东北大学朱靖波,肖桐团队开源《机器翻译:统计建模与深度学习方法》

当年毕业答辩!遗憾没有它...

已开源!所有李航老师《统计学习方法》代码实现

这个男人,惊为天人!手推PRML!

它来了!《深度学习》(花书) 数学推导、原理剖析与代码实现

你们心心念念的MIT教授Gilbert Strang线性代数彩板笔记!强烈推荐!

GitHub超过9800star!学习Pytorch,有这一份资源就够了!强推!

你真的懂神经网络?强推一个揭秘神经网络的工具,ANN Visualizer

诸位!看我如何白嫖2020 icassp!

这个时代研究情感分析,是最好也是最坏!

BERT雄霸天下!

玩转Pytorch,搞懂这个教程就可以了,从GAN到词嵌入都有实例

是他,是他,就是他!宝藏博主让你秒懂Transformer、BERT、GPT!

fitlog!复旦邱锡鹏老师组内部调参工具!一个可以节省一篇论文的调参利器

Github开源!查阅arXiv论文新神器,一行代码比较版本差别,我爱了!

开源!数据结构与算法必备的 50 个代码实现

他来了!吴恩达带着2018机器学习入门高清视频,还有习题解答和课程拓展来了!

太赞了!复旦邱锡鹏老师NLP实战code解读开源!

这块酷炫的Python神器!我真的爱了,帮助你深刻理解语言本质!实名推荐!

论文神器!易搜搭

不瞒你说!这可能是世界上最好的线性代数教程

Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/124773
 
373 次点击