Py学习  »  Python

从头开始实现一个线性代数库:Python模块篇

一个普普通通简简单单平平凡凡的神 • 5 年前 • 320 次点击  

这两天用C/C++实现了一下线性代数库的Python模块,大部分操作已经封装完成,剩下的慢慢补坑吧= =

关于线性代数库的一些算法实现,可以参考我的前一篇文章从头开始实现一个线性代数库:算法实现篇。现在主要总结一下如何用C/C++编写Python模块,代码地址 github.com/netcan/LinA…

先来看看效果:

  1. In [1]: ls
  2. linalg.cpython-36m-darwin.so*
  3. In [2]: import linalg
  4. In [3]: m = linalg.Matrix([
  5. ...: [43, 63, 57, 35],
  6. ...: [63, 26, 32, 35],
  7. ...: [57, 32, 78, 76],
  8. ...: [35, 35, 76, 25],
  9. ...: ])
  10. In [4]: m.Jacobi()
  11. Out[4]: C(-19.248723 33.225382 197.775875 -39.752533)
  12. In [5]: m.det()
  13. Out[5]: 5028170.999999999

编写模块最麻烦的部分应该是引用计数了,如果能够妥善处理,将会事半功倍,我因为这个问题debug好久了。紧接着就是异常处理了。主要参考资料还是官方文档:

定义模块

首先定义一个名为linalg的模块,其下有两个类:linalg.Vectorlinalg.Matrix

  1. static PyModuleDef linalgmodule = {
  2. PyModuleDef_HEAD_INIT,
  3. "linalg", /* name of module */
  4. NULL, /* module documentation, may be NULL */
  5. -1, /* size of per-interpreter state of the module,
  6. or -1 if the module keeps state in global variables. */
  7. };

如果需要定义一些模块级方法,需要在定义linalgmodule的时候指定m_methods,这里不需要这类方法,就省略了。

接下来初始化模块,定义名为PyInit_linalg的函数:

  1. PyMODINIT_FUNC
  2. PyInit_linalg(void) {
  3. PyObject *m;
  4. m = PyModule_Create(&linalgmodule);
  5. if (m == NULL) return NULL;
  6. // 添加Vector/Matrix类
  7. if (PyType_Ready(&PyVectorType) < 0) return NULL;
  8. if (PyType_Ready(&PyMatrixType) < 0) return NULL;
  9. Py_INCREF(&PyVectorType);
  10. Py_INCREF(&PyMatrixType);
  11. PyModule_AddObject(m, "Vector", (PyObject*)&PyVectorType);
  12. PyModule_AddObject(m, "Matrix", (PyObject*)&PyMatrixType);
  13. return m;
  14. }

可以看到在初始化模块的时候,添加了linalg.Vectorlinalg.Matrix两个类。接下来定义这两个类,以及类方法。

定义linalg.Vector

由于linalg.Matrix类和linalg.Vector类定义的方法基本相同,这里只总结一下linalg.Vector类的定义以及实现。

首先定义VectorObject对象类:

  1. typedef struct {
  2. PyObject_HEAD;
  3. Vector ob_vector; // 需要封装的成员
  4. } PyVectorObject;

接着是VectorType类的定义:

  1. static PyTypeObject PyVectorType = {
  2. PyObject_HEAD_INIT(NULL)
  3. .tp_name = "linalg.Vector", // 模块类名
  4. .tp_doc = "Vector objects", // doc描述
  5. .tp_basicsize = sizeof(PyVectorObject), // 对象大小
  6. .tp_itemsize = 0,
  7. .tp_dealloc = (destructor) PyLinAlg_dealloc, // 析构函数
  8. .tp_new = PyType_GenericNew, // 构造函数
  9. .tp_init = (initproc) PyVector_init, // 初始化函数
  10. .tp_members = PyVector_members, // 类成员
  11. .tp_methods = PyVector_methods, // 类方法
  12. .tp_flags = Py_TPFLAGS_DEFAULT,
  13. .tp_str = (reprfunc)PyVector_str, // str(obj), print用
  14. .tp_repr = (reprfunc)PyVector_str,
  15. .tp_as_sequence = &PyVectorSeq_methods, // 一些序列类方法,例如vec[i]
  16. .tp_as_number = &PyVectorNum_methods, // 基本运算方法,例如vecA + vecB
  17. };

先来看看linalg.Vector的析构函数,若不是单例模式,所有类都需要这个函数在引用计数为0的时候来释放对象:

  1. static void
  2. PyLinAlg_dealloc(PyObject *self) {
  3. Py_TYPE(self)->tp_free((PyObject *)self);
  4. }

构造函数使用默认的PyType_GenericNew就好,初始化一个Vector类,在C++接口中是这样的:

  1. Vector v({16,22,32,44});
  2. v.show();

现在用Python模块方式进行初始化,为了简单起见,这里只支持列表类型初始化,即:

  1. In [6]: v = linalg.Vector([1,2,3,4])
  2. In [7]: v
  3. Out[7]: C(1.000000 2.000000 3.000000 4.000000)

初始化函数如下:

  1. static int
  2. PyVector_init(PyVectorObject *self, PyObject *args, PyObject *kwds) {
  3. PyObject *pList, *pItem;
  4. if (!PyArg_ParseTuple(args, "O!", &PyList_Type, &pList)) { // 初始化参数必须为列表
  5. PyErr_SetString(PyExc_TypeError, "parameter must be a list.");
  6. return -1;
  7. }
  8. Py_ssize_t n = PyList_Size(pList); // 获取列表长度
  9. if(n <= 0) {
  10. PyErr_SetString(PyExc_TypeError, "size of list must be greater than 0");
  11. return -1;
  12. }
  13. vector<double> data;
  14. for(Py_ssize_t i = 0; i < n; ++i) {
  15. pItem = PyList_GetItem(pList, i);
  16. if(! isNumber(pItem)) return -1; // 列表内的元素必须为数字
  17. else data.push_back(getNumber(pItem));
  18. }
  19. self->ob_vector = Vector(data); // 存入封装Vector的VectorObject类型中
  20. return 0;
  21. }

而对于类支持的方法,需要定义一个数组PyVector_methods说明,比如linalg.Vector类支持T(), copy()方法,那么:

  1. static PyMethodDef PyVector_methods[] = {
  2. {"T", (PyCFunction)PyVector_T, METH_NOARGS, "change vector type"},
  3. {"copy", (PyCFunction)PyVector_Copy, METH_NOARGS, "deep copy of vector"},
  4. {NULL} /* Sentinel */
  5. };

然后实现PyVector_T()PyVector_Copy即可。

而对于一些magic方法,例如__add__, __getitem__等等,需要在PyTypeObject.tp_as_number.tp_as_sequence数组中指明。

引用计数导致的BUG

之前因为返回对象的引用计数导致了非常麻烦的BUG,这里贴出来一下:

  1. static PyObject *
  2. PyVector_imul(PyVectorObject *self, PyObject *arg) {
  3. if(!arg || ! isNumber(arg)) return NULL;
  4. Py_XINCREF(self); // 不加这个会崩溃...因为返回的对象没有被接收会被释放掉
  5. self->ob_vector *= getNumber(arg);
  6. return (PyObject*)self;
  7. }

这里定义了一下*=类方法,当执行v *= 5等语句,可能导致崩溃,需要增加引用计数来获得所有权,因为操作的对象是borrowed references的,那么返回的时候由于没有被其他对象引用,将被释放,而后来要是有对象申请空间也用到这块内存的话,就会出现异常。

官方文档是这么说明的,对于PyNumber_InPlaceAdd这类方法应该返回的是新引用,而不是borrowed references

PyObject PyNumber_InPlaceAdd(PyObject o1, PyObject *o2)
Return value: New reference.
Returns the result of adding o1 and o2, or NULL on failure. The operation is done in-place when o1 supports it. This is the equivalent of the Python statement o1 += o2.

在StackOverflow上有这么一个回答:stackoverflow.com/questions/1…

According to the documentation, the inplace add operation for an object returns a new reference.
By returning self directly without calling Py_INCREF on it, your object will be freed while it is still referenced. If some other object is allocated the same piece of memory, those references would now give you the new object.

编译、链接模块

掌握了如何定义模块、类,剩下的就简单了,最后是通过编写setup.py来生成模块,如下:

  1. from distutils.core import setup, Extension
  2. linalgmodule = Extension('linalg',
  3. extra_compile_args = ['-std=c++14'],
  4. sources = ['src/linalgmodule.cpp'],
  5. include_dirs = ['include'],
  6. libraries = ['linalg'],
  7. )
  8. setup (name = 'linalg',
  9. version = '0.1',
  10. author = 'netcan',
  11. author_email = 'netcan1996@gmail.com',
  12. description = 'A Linear Algebra library for studying by netcan',
  13. ext_modules = [linalgmodule],
  14. )

编译的时候,通过如下编译,

  1. python setup.py build

可能会报错,例如找不到std::move,这时候需要指定标准头文件的位置了:

  1. CFLAGS='-I/Library/Developer/CommandLineTools/usr/include/c++/v1/' python setup.py build
文章目录
  1. 1. 定义模块
  2. 2. 定义linalg.Vector类
    1. 2.1. 引用计数导致的BUG
  3. 3. 编译、链接模块

今天看啥 - 高品质阅读平台
本文地址:http://www.jintiankansha.me/t/9NzsneaNH3
Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/12988
 
320 次点击