1. 首页
  2. 数据库
  3. 其它
  4. 机器学习加速利器jax让numpy加速30倍

机器学习加速利器jax让numpy加速30倍

上传者: 2020-12-23 07:14:07上传 PDF文件 47.97KB 热度 12次
jax.numpy是CPU、GPU和TPU上的numpy,具有出色的自动差异化功能,可用于高性能机器学习研究。 我今天就来试一试到底多快。我在同一台bu带gpu的机器上进行试验 首先我们得安装jax pip install jax jaxlib 先试一下原生的numpy import numpy as np import time x = np.random.random([5000, 5000]).astype(np.float32) try: st=time.time() y=np.matmul(x, x) except Exception: print("e
用户评论