机器学习加速利器jax让numpy加速30倍
jax.numpy是CPU、GPU和TPU上的numpy,具有出色的自动差异化功能,可用于高性能机器学习研究。 我今天就来试一试到底多快。我在同一台bu带gpu的机器上进行试验 首先我们得安装jax pip install jax jaxlib 先试一下原生的numpy import numpy as np import time x = np.random.random([5000, 5000]).astype(np.float32) try: st=time.time() y=np.matmul(x, x) except Exception: print("e
用户评论