Py学习  »  机器学习算法

Keras vs PyTorch,哪一个更适合做深度学习?

小白学视觉 • 1 年前 • 125 次点击  

点击上方小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

如何选择工具对深度学习初学者是个难题,本文作者以 Keras 和 Pytorch 库为例,提供了解决该问题的思路。


(转自机器之心)


当你决定学习深度学习时,有一个问题会一直存在——学习哪种工具?

深度学习有很多框架和库。这篇文章对两个流行库 Keras 和 Pytorch 进行了对比,因为二者都很容易上手,初学者能够轻松掌握。

那么到底应该选哪一个呢?本文给大家分享了一个解决思路。

做出合适选择的最佳方法是对每个框架的代码样式有一个概览。开发任何解决方案时首先也是最重要的事就是开发工具。你必须在开始一项工程之前设置好开发工具。一旦开始,就不能一直换工具了,否则会影响你的开发效率。

作为初学者,你应该多尝试不同的工具,找到最适合你的那一个。但是当你认真开发一个项目时,这些事应该提前计划好。

每天都会有新的框架和工具投入市场,而最好的工具能够在定制和抽象之间做好平衡。工具应该和你的思考方式和代码样式同步。因此要想找到适合自己的工具,首先你要多尝试不同的工具。

我们同时用 Keras 和 PyTorch 训练一个简单的模型。如果你是深度学习初学者,对有些概念无法完全理解,不要担心。


从现在开始,专注于这两个框架的代码样式,尽量去想象哪个最适合你,使用哪个工具你最舒服,也最容易适应。


这两个工具最大的区别在于:PyTorch 默认为 eager 模式,而 Keras 基于 TensorFlow 和其他框架运行(现在主要是 TensorFlow),其默认模式为图模式。最新版本的 TensorFlow 也提供类似 PyTorch 的 eager 模式,但是速度较慢。

如果你熟悉 NumPy,你可以将 PyTorch 视为有 GPU 支持的 NumPy。此外,现在有多个具备高级 API(如 Keras)且以 PyTorch 为后端框架的库,如 Fastai、Lightning、Ignite 等。如果你对它们感兴趣,那你选择 PyTorch 的理由就多了一个。

在不同的框架里有不同的模型实现方法。让我们看一下这两种框架里的简单实现。本文提供了 Google Colab 链接。打开链接,试验代码。这可以帮助你找到最适合自己的框架。

我不会给出太多细节,因为在此,我们的目标是看一下代码结构,简单熟悉一下框架的样式。

Keras 中的模型实现

以下示例是数字识别的实现,代码很容易理解。你需要打开 colab,试验代码,至少自己运行一遍。


Keras 自带一些样本数据集,如 MNIST 手写数字数据集。以上代码可以加载这些数据,数据集图像是 NumPy 数组格式。Keras 还做了一点图像预处理,使数据适用于模型。


以上代码展示了模型,在 Keras(TensorFlow)上,我们首先需要定义要使用的东西,然后立刻运行。在 Keras 中,我们无法随时随地进行试验,不过 PyTorch 可以。


以上的代码用于训练和评估模型。我们可以使用 save() 函数来保存模型,以便后续用 load_model() 函数加载模型。predict() 函数则用来获取模型在测试数据上的输出。

现在我们概览了 Keras 基本模型实现过程,现在来看 PyTorch。

PyTorch 中的模型实现

研究人员大多使用 PyTorch,因为它比较灵活,代码样式也是试验性的。你可以在 PyTorch 中调整任何事,并控制全部,但控制也伴随着责任。

在 PyTorch 里进行试验是很容易的。因为你不需要先定义好每一件事再运行。我们能够轻松测试每一步。因此,在 PyTorch 中 debug 要比在 Keras 中容易一些。

接下来,我们来看简单的数字识别模型实现。


以上代码导入了必需的库,并定义了一些变量。n_epochs、momentum 等变量都是必须设置的超参数。此处不讨论细节,我们的目的是理解代码的结构。


以上代码旨在声明用于加载训练所用批量数据的数据加载器。下载数据有很多种方式,不受框架限制。如果你刚开始学习深度学习,以上代码可能看起来比较复杂。


在此,我们定义了模型。这是一种创建网络的通用方法。我们扩展了 nn.Module,在前向传递中调用 forward() 函数。

PyTorch 的实现比较直接,且能够根据需要进行修改。


以上代码段定义了训练和测试函数。在 Keras 中,我们需要调用 fit() 函数把这些事自动做完。


但是在 PyTorch 中,我们必须手动执行这些步骤。像 Fastai 这样的高级 API 库会简化它,训练所需的代码也更少。



最后,保存和加载模型,以进行二次训练或预测。这部分没有太多差别。PyTorch 模型通常有 pt 或 pth 扩展。

关于框架选择的建议

学会一种模型并理解其概念后,再转向另一种模型,并不是件难事,只是需要一些时间。本文作者给出的建议是两个都学,但是不需要两个都深入地学。

你应该从一个开始,然后在该框架中实现模型,同时也应当掌握另一个框架的知识。这有助于你阅读别人用另一个框架写的代码。永远不要被框架限制住。

先从适合自己的框架开始,然后尝试学习另一个。如果你发现另一个用起来更合适,那么转换成另一个。因为 PyTorch 和 Keras 的大多数核心概念是类似的,二者之间的转换非常容易。

Colab 链接:

PyTorch:

https://colab.research.google.com/drive/1irYr0byhK6XZrImiY4nt9wX0fRp3c9mx?usp=sharing


Keras:

https://colab.research.google.com/drive/1QH6VOY_uOqZ6wjxP0K8anBAXmI0AwQCm?usp=sharing


原文链接:

https://medium.com/@karan_jakhar/keras-vs-pytorch-dilemma-dc434e5b5ae0


版权说明: 本文内容转自机器之心、版权归原作者所有,如有侵权请联系删除。

好消息!

小白学视觉知识星球

开始面向外开放啦👇👇👇




下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲
小白学视觉公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲
小白学视觉公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群


欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~


Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/147166
 
125 次点击