1. 首页
  2. 数据库
  3. 其它
  4. sympy2jax:将SymPy表达式转换为JAX函数 源码

sympy2jax:将SymPy表达式转换为JAX函数 源码

上传者: 2021-04-07 01:45:03上传 ZIP文件 10.32KB 热度 13次
sympy2jax 将SymPy表达式转换为参数化,可微分,可矢量化的JAX函数。 所有SymPy浮点数都将成为可训练的输入参数。 SymPy符号成为传递矩阵的列。 安装 pip install git+https://github.com/MilesCranmer/sympy2jax.git 例子 import sympy from sympy import symbols import jax import jax . numpy as jnp from jax import random from sympy2jax import sympy2jax 让我们在SymPy中创建一个表达式: x , y = symbols ( 'x y' ) expression = 1.0 * sympy . cos ( x ) + 3.2 * y 让我们获取JAX版本。 我们传递方程式和所
下载地址
用户评论