这么牛逼是不是很难用呢?No,No,No,So easy,你不需要替换 Python 解释器,不需要单独编译,甚至不需要安装 C / C ++ 编译器。只需将 Numba 提供的装饰器放在 Python 函数上面就行,剩下的就交给 Numba 完成。举个简单的例子:
from numba import jit import random
@jit(nopython=True) defmonte_carlo_pi(nsamples): acc = 0 for i in range(nsamples): x = random.random() y = random.random() if (x ** 2 + y ** 2) 1.0: acc += 1 return4.0 * acc / nsamples
@numba.jit(nopython=True, parallel=True) deflogistic_regression(Y, X, w, iterations): for i in range(iterations): w -= np.dot(((1.0 / (1.0 + np.exp(-Y * np.dot(X, w))) - 1.0) * Y), X) return w
现在我们来看看,同样的代码,使用 Numba 前后与 C++ 的性能对比。比如说我们要找出 1000 万以内所有的素数,代码的算法逻辑是相同的:
import math import
time
defis_prime(num): if num == 2: returnTrue if num <= 1ornot num % 2: returnFalse for div in range(3, int(math.sqrt(num) + 1), 2): ifnot num % div: returnFalse returnTrue
defrun_program(N): total = 0 for i in range(N): if is_prime(i): total += 1 return total
if __name__ == "__main__": N = 10000000 start = time.time() total = run_program(N) end = time.time() print(f"total prime num is {total}") print(f"cost {end - start}s")
执行耗时:
total prime num is 664579 cost 47.386465072631836s
C++ 代码如下:
#include #include #include using namespace std;
bool isPrime(int num) {
if (num == 2) return true; if (num <= 1 || num % 2
== 0) return false; double sqrt_num = sqrt(double(num)); for (int div = 3; div <= sqrt_num; div +=2){ if (num % div == 0) return false; } return true; }
int run_program(int N){
int total = 0; for (int i; i if(isPrime(i)) total ++; } return total; }
int main() { int N = 10000000; clock_t start,end; start = clock(); int total = run_program(N); end = clock(); cout return0; } $ g++ isPrime.cpp -o isPrime $ ./isPrime total prime num is664579 cost 2.36221s
# @njit 相当于 @jit(nopython=True) @njit defis_prime(num): if num == 2: returnTrue if num <= 1ornot num % 2: returnFalse for div in range(3, int(math.sqrt(num) + 1), 2): ifnot num % div: returnFalse returnTrue
@njit defrun_program(N): total = 0 for i in range(N): if is_prime(i): total += 1 return total
if __name__ == "__main__": N = 10000000 start = time.time() total = run_program(N) end = time.time() print(f"total prime num is {total}") print(f"cost {end - start}s")
运行一下,可以看出时间已经从 47.39 秒降低到 3 秒。
total prime numis664579 cost 3.0948808193206787s
相比 C++ 的 2.3 秒还是有一点慢,你可能会说 Python 还是不行啊。等一等,我们还有优化的空间,就是 Python 的 for 循环,那可是 1000 万的循环,对此,Numba 提供了 prange 参数来并行计算,从而并发处理循环语句,只需要将 range 修改为 prange,装饰器传个参数:parallel = True,其他不变,代码改动如下:
import math import time from numba import njit, prange
@njit defis_prime(num): if num == 2: returnTrue if num <= 1ornot num % 2: returnFalse for div in range(3, int(math.sqrt(num) + 1), 2): ifnot num % div: returnFalse returnTrue
@njit(parallel = True) defrun_program(N): total = 0 for i in prange(N): if is_prime(i): total += 1 return total
if __name__ == "__main__": N = 10000000 start = time.time() total = run_program(N) end = time.time() print(f"total prime num is {total}") print(f"cost {end - start}s")
现在运行一下:
$ python isPrime.py total prime numis664579 cost 1.4398791790008545s
才 1.43 秒,比 C++ 还快,Numba 真的牛逼!我又运行了两次,确认自己没看错,平均就是 1.4 秒:
看到这里,Numba 又让我燃起了对 Python 的激情,我不转 C++ 了,Python 够用了。
Numba 如何做到的呢?官方文档这样介绍:它读取装饰函数的 Python 字节码,并将其与有关函数输入参数类型的信息结合起来,分析和优化代码,最后使用编译器库(LLVM)针对你的 CPU 生成量身定制的机器代码。每次调用函数时,都会使用此编译版本,你说牛逼不?