社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  Python

PyTorch扩展自定义PyThon/C++(CUDA)算子的若干方法总结

小白学视觉 • 2 年前 • 283 次点击  

点击上方小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

作者丨奔腾的黑猫@知乎
来源丨https://zhuanlan.zhihu.com/p/158643792
本文仅用于学术分享,如有侵权,请联系后台作删文处理。

导读

 

关于PyTorch构建扩展的一些基础操作,官方往往已经出具了完整的教程。本文对这些官方教程的链接进行了整理,以供读者查阅。

在做毕设的时候,需要实现一个PyTorch原生代码中没有的并行算子,所以用到了这部分的知识,再不总结就要忘光了= =
本文内容主要是PyTorch的官方教程的各种传送门,这些官方教程写的都很好,以后就可以不用再浪费时间在百度上了。由于图神经网络计算框架PyG的代码实现也是采用了扩展的方法,因此也可以当成下面总结PyG源码文章的前导知识吧 。

第一种情况:使用PyThon扩展PyTorch

使用PyThon扩展PyTorch准确的来说是在PyTorch的Python前端实现自定义算子或者模型,不涉及底层C++的实现。这种扩展方式是所有扩展方式中最简单的,也是官方首先推荐的,这是因为PyTorch在NVIDIA cuDNN,Intel MKL或NNPACK之类的库的支持下已经对可能出现的CPU和GPU操作进行了高度优化,因此用Python扩展的代码通常足够快。
比如要扩展一个新的PyThon算子(torch.nn)只需要继承torch.nn.Module并实现其forward方法即可。详细的过程请参考官方教程传送门:
https://pytorch.org/docs/master/notes/extending.html

第二种情况:使用pybind11构建共享库形式的C++和CUDA扩展

但是如果我们想对代码进行进一步优化,比如对自己的算子添加并行的CUDA实现或者连接个OpenCV的库什么的,那么仅仅使用Python进行扩展就不能满足需求;其次如果我们想序列化模型,在一个没有Python环境的生产环境下部署,也需要我们使用C++重写算法;最后考虑到考虑到多线程执行和性能原因,一般Python代码也并不适合做部署。因此在对性能有要求或者需要序列化模型的场景下我们还是会用到C++扩展。
下面我先把官方教程传送门放在这里:
https://pytorch.org/tutorials/advanced/cpp_extension.html
对于一种典型的扩展情况,比如我们要设计一个全新的C++底层算子,其过程其实就三步:
第一步:使用C++编写算子的forward函数和backward函数
第二步:将该算子的forward函数和backward函数使用**pybind11**绑定到python上
第三步:使用setuptools/JIT/CMake编译打包C++工程为so文件
注意到在第一步中,我们不仅仅要实现forward函数也要实现backward函数,这是因为在C++端PyTorch目前不支持自动根据forward函数推导出backward函数,所以我们必须要对自己算子的反向传播过程完全清楚。一个需要注意的地方是,你可以选择直接在C++中继承torch::autograd类进行扩展;也可以像官方教程中那样在C++代码中实现forward和backward的核心过程,而在python端继承PyTorch的torch.autograd.Function类。
在C++端扩展forward函数和backward函数的需要注意以下规则:
(1)首先无论是forward函数还是backward函数都需要声明为静态函数
(2)forward函数可以接受任意多的参数并且应该返回一个 variable list或者variable;forward函数需要将[torch::autograd::AutogradContext](https://link.zhihu.com/?target=https%3A//pytorch.org/cppdocs/api/structtorch_1_1autograd_1_1_autograd_context.html%23structtorch_1_1autograd_1_1_autograd_context) 作为自己的第一个参数。Variables可以被使用ctx->save_for_backward保存,而其他数据类型可以使用ctx->saved_data以<:string>pairs的形式保存在一个map中。
(3)backward函数第一个参数同样需要为torch::autograd::AutogradContext,其余的参数是一个variable_list,包含的变量数量与forward输出的变量数量相等。它应该返回和forward输入一样多的变量。保存在forward中的Variable变量可以通过ctx->get_saved_variables而其他的数据类型可以通过ctx->saved_data获取。
请注意,backward的输入参数是自动微分系统反传回来的参数梯度值,其需要和forward函数的返回值位置一一对应的;而backward的返回值是对各参数根据自动微分规则求导后的梯度值,其需要和forward函数的输入参数位置一一对应,对于不需要求导的参数也需要使用空Variable占位。
// PyG的C++扩展就选择的是直接继承PyTorch的C++端的torch::autograd类进行扩展// 下面是PyG的一个ScatterSum算子的扩展示例// 不用纠结这个算子的具体内容,对扩展的算子的结构有一个大致了解即可class ScatterSum : public torch::autograd::Function {public:  // AutogradContext *ctx指针可以操作  static variable_list forward(AutogradContext *ctx, Variable src,                               Variable index, int64_t dim,                               torch::optional optional_out,                               torch::optional<int64_t> dim_size) {    dim = dim < 0 ? src.dim() + dim : dim;    ctx->saved_data["dim"] = dim;


    
    ctx->saved_data["src_shape"] = src.sizes();    index = broadcast(index, src, dim);    auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");    auto out = std::get<0>(result);    ctx->save_for_backward({index});    // 如果在扩展的C++代码中使用非Aten内建操作修改了tensor的值,需要对其进行脏标记    if (optional_out.has_value())      ctx->mark_dirty({optional_out.value()});      return {out};  } // grad_outs是out参数反传回来的梯度值  static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {    auto grad_out = grad_outs[0];    auto saved = ctx->get_saved_variables();    auto index = saved[0];    auto dim = ctx->saved_data["dim"].toInt();    auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());    auto grad_in = torch::gather(grad_out, dim, index, false);    // 不需要求导的参数需要空Variable占位    return {grad_in, Variable(), Variable(), Variable(), Variable()};  }};
由于涉及到在C++环境下操作张量和反向传播等操作,因此我们需要对PyTorch的C++后端的库有所了解,主要就是Torch和Aten这两个库,下面我简要介绍一下这两兄弟。
其中Torch是PyTorch的C++底层实现(PS:其实是先有的Torch后有的PyTorch,从名字也能看出来),FB在编码PyTorch的时候就有意将PyTorch的接口和Torch的接口设计的十分类似,因此如果你对PyTorch很熟悉的话那么你也会很快的对Torch上手。
Torch官方文档传送门:
https://pytorch.org/cppdocs/frontend.html
安装PyTorch的C++前端的官方教程:
https://pytorch.org/cppdocs/installing.html
而Aten是ATen从根本上讲是一个张量库,在PyTorch中几乎所有其他Python和C ++接口都在其上构建。它提供了一个核心Tensor类,在其上定义了数百种操作。这些操作大多数都具有CPU和GPU实现,Tensor该类将根据其类型向其动态调度。和Torch相比Aten更接近底层和核心逻辑。
Aten源代码传送门:
https://github.com/zdevito/ATen/tree/master/aten/srcgithub.com
使用Aten声明和操作张量的教程:
https://pytorch.org/cppdocs/notes/tensor_basics.html
由于Pyorch的C++后端文档比较少,因此要多参考官方的例子,尝试去模仿官方教程的代码,同时可以通过Python前端的接口猜测后端接口的功能,如果没有文档了就读一读源码,还是有不少注释的,还能理解实现的逻辑。

第三种情况:为TORCHSCRIPT添加C++和CUDA扩展

首先简单解释一下TorchScript是什么,如果用官方的定义来说:“TorchScript是一种从PyTorch代码创建可序列化和可优化模型的方法。任何TorchScript程序都可以从一个Python进程中保存并可以在一个没有Python环境的进程中被加载。”通俗来说TorchScript就是一个序列化模型(即Inference)的工具,它可以让你的PyTorch代码方便的在生产环境中部署,同时在将PyTorch代码转化TorchScript代码时还会对你的模型进行一些性能上的优化。使用TorchScript完成模型的部署要比我们之前提到的使用C++重写要简单的多,因为是自动生成的。
TorchScript包含两种序列化模型的方法:tracingscript,两种方法各有其适用场景,由于和本文关系不大就不详细展开了,具体的官方教程传送门在此:
https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html
但是,TorchScript只能自动化的构造PyTorch的原生代码,如果我们需要序列化自定义的C++扩展算子,则需要我们显式的将这些自定义算子注册到TorchScript中,所幸的是,这一过程其实非常简单,整个过程和第二小节中使用pybind11构建共享库的形式的C++和CUDA扩展十分类似。官方教程传送门如下:
https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html
而对于自定义的C++类,如果要注册到TorchScript要稍微复杂一些,官方教程传送门如下:
https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html?highlight=registeroperators
另外需要注意的是,如果想要编写能够被TorchScript编译器理解的代码,需要注意在C++自定义扩展算子参数中的数据类型,目前被TorchScript支持的参数数据类型有torch::Tensortorch::Scalar(标量类型),doubleint64_tstd::vector,而像float,int,short这些是不能作为自定义扩展算子的参数数据类型的。
目前就先总结这么多吧~

好消息!

小白学视觉知识星球

开始面向外开放啦👇👇👇



下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复: 扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲
小白学视觉公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲
小白学视觉公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群


欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~


Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/148447
 
283 次点击