社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  Python

【Python】这个装饰器竟让 Python 提速了 30 倍!

机器学习初学者 • 2 年前 • 266 次点击  

Python是一种解释语言,其代码不是直接编译成机器码,而是由另一个叫做解释器的程序实时解释的(一般是 cpython )。因此,与其他编译语言相比,Python灵活性高(动态类型,兼容性高,...)。但这也造成了Python非常慢的缺点。

加速 Python的方法

实际上,有多种解决方案可以解决python的缓慢问题。

  • 使用 cython:一种编程语言,是python的超集。Cython是Python编程语言和扩展 Cython 编程语言(基于Pyrex)的优化静态编译器。它使得为 Python 编写 C 扩展就像 Python 本身一样容易。
  • 使用C/C++语言结合 ctypes pybind11 CFFI 来编写Python的绑定程序
  • 用C/C++扩展Python
  • 使用其他编译过的语言,如rust[1]

而所有这些方法,都需要使用除Python外的另一种语言,并编译代码使之与Python一起工作。尽管这些方法都很不错,但并不是最适合我们初学者的使python更快的方法,更别提他们通常比较难以设置了。

Numba & JIT 编译器

Numba[2]是一个Python包,在兼具Python的便利的同时,可以使你的代码更快。

numba使用Just-in-time (JIT)编译(即在Python代码执行过程中的实时编译的),使用起来非常方便,无需向其他工具一样,还需安装一个C/C++编译器,它仅需用 pip/conda 安装它即可。

pip install numba

接下来试一个例子:用蒙特卡洛模拟来计算π的估计值。

import random
from numba import njit
def monte_carlo_pi_without_numba(nsamples):
    acc = 0
    for i in range(nsamples):
        x = random.random()
        y = random.random()
        if (x ** 2 + y ** 2) 1.0:
            acc += 1
    return 4.0 * acc / nsamples

# 添加numba的装饰器,使该函数更快。
@njit
def monte_carlo_pi_with_numba(nsamples):
    acc = 0
    for i in range(nsamples):
        x = random.random()
        y = random.random()
        if (x ** 2 + y ** 2) 1.0:
            acc += 1
    return 4.0 * acc / nsamples

在使用该方法时,我们只需要导入numba的一个装饰器(njit),剩下的都由它自己完成即可,可以说是非常方便。

我们运行两个版本的代码,并进行计时对比。显示numba比普通python快30倍

%timeit monte_carlo_pi_with_numba(100_000)
# 1.24 ms ± 10.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit monte_carlo_pi_without_numba(100_000)
# 40.6 ms ± 814 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

一些注意事项

值得一提的是,numba确实有一些缺点:

  • 在首次运行numba装饰的函数时,有一定的时间开销。这是因为首次执行时,numba会试图找出参数的类型并编译函数,从而导致程序有一定的延迟。
  • 不是所有的Python代码都能用numba编译,例如,如果你对同一个变量或对列表元素使用混合数据类型,此种情况将会抛出异常。

加速 Pandas

Numba是专门为numpy设计的,对numpy数组非常友好。而 pandas 是建立在 numpy 之上的,这使得在使用用户定义的函数或甚至执行不同的Dataframe操作时,可以进行疯狂优化。

首先创建一个DataFrame数据集。

import numpy as np
import pandas as pd

n = 1_000_000

df = pd.DataFrame({
    'height'1 + 1.3 * np.random.random(n),
    'weight'40 + 260 * np.random.random(n),
    'hip_circumference'94 + 14 * np.random.random(n)
})

用户定义的函数

numba 的另一个重要的方法是 vectorize,使用该方法可以很容易的创建numpy通用函数(ufuncs[3]

通用函数(或简称ufunc)是以ndarrays逐个元素的方式运行的函数,支持数组广播、类型转换和其他几个标准特性。也就是说,ufunc 是一个函数的“矢量化”包装器,它接受固定数量的特定输入并产生固定数量的特定输出。

下面是计算数据集中列height的平方。

from numba import vectorize

def get_squared_height_without_numba(height):
  return  height ** 2

@vectorize
def get_squared_height_with_numba(height):
  return height ** 2

%timeit df['height'].apply(get_squared_height_without_numba)
# 279 ms ± 7.31 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


%timeit df['height'] ** 2
# 2.04 ms ± 229 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

# 我们首先将列转换为numpy数组,
# 因为numba与numpy兼容,与pandas并不兼容。
%timeit get_squared_height_with_numba(df['height'].to_numpy())
# 1.6 ms ± 51.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

基本操作

使用njit,并计算 BMI(身体质量指数)。

from numba import njit

@njit
def get_bmi(weight_col, height_col):
  n = len(weight_col)
  result = np.empty(n, dtype="float64")

  # 与python循环相比,Numba的循环非常快
  for i, (weight, height) in enumerate(zip(weight_col, height_col)):
    result[i] = weight / (height ** 2)
  return result

# 不要忘记将列转换为 numpy 
%timeit df['bmi'] = get_bmi(df['weight'].to_numpy(), 
                            df['height'].to_numpy())
# 6.77 ms ± 230 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit df['bmi'] = df['weight']  / (df['height'] ** 2)
# 8.63 ms ± 316 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

你可以看到,即使是基本的操作,numba仍然比原始 pandas 花费的时间更少(6.77ms vs 8.63ms)。

写在最后

numba 是一种开箱即用的方法,可以轻而易举地 让你的 Python 代码变得更快。当然,在成功编译代码之前可能需要多几次尝试,你可以试试使用它。如果本文对你有用,那就点个赞和在看支持下云朵君吧

拓展阅读

[1]

rust: https://github.com/PyO3/pyo3

[2]

Numba: https://numba.pydata.org/

[3]

ufuncs: https://numpy.org/doc/stable/reference/ufuncs.html


往期精彩回顾




Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/150319
 
266 次点击