JAX,是用于变换数值函数的Python机器学习框架,它由Google开发并具有来自Nvidia的一些贡献[4][5][6]。它结合了修改版本的Autograd(自动通过函数的微分获得其梯度函数)[7],和OpenXLA的XLA(英语:Accelerated Linear Algebra)(加速线性代数)[8]。它被设计为尽可能的遵从NumPy的结构和工作流程,并协同工作于各种现存的框架如TensorFlow和PyTorch[9][10]。
JAX的主要功能是[4]:
下面的代码演示grad
函数的自动微分。
# 导入库fromjaximportgradimportjax.numpyasjnp# 定义logistic函数deflogistic(x):returnjnp.exp(x)/(jnp.exp(x)+1)# 获得logistic函数的梯度函数grad_logistic=grad(logistic)# 求值logistic函数在x = 1处的梯度grad_log_out=grad_logistic(1.0)print(grad_log_out)
最终的输出为:
下面的代码演示jit
函数的优化。
# 导入库fromjaximportjitimportjax.numpyasjnp# 定义cube函数defcube(x):returnx*x*x# 生成数据x=jnp.ones((10000,10000))# 创建cube函数的jit版本jit_cube=jit(cube)# 应用cube函数和jit_cube函数于相同数据来比较其速度cube(x)jit_cube(x)
可见jit_cube
的运行时间显著的短于cube
。
下面的代码展示vmap
函数的通过SIMD的向量化。
# 导入库fromfunctoolsimportpartialfromjaximportvmapimportjax.numpyasjnp# 定义函数defgrads(self,inputs):in_grad_partial=partial(self._net_grads,self._net_params)grad_vmap=vmap(in_grad_partial)rich_grads=grad_vmap(inputs)flat_grads=np.asarray(self._flatten_batch(rich_grads))assertflat_grads.ndim==2andflat_grads.shape[0]==inputs.shape[0]returnflat_grads
下面的代码展示pmap
函数的对矩阵乘法的并行化。
# 从JAX导入pmap和random;导入JAX NumPyfromjaximportpmap,randomimportjax.numpyasjnp# 生成2个维度为5000 x 6000的随机数矩阵,每设备一个random_keys=random.split(random.PRNGKey(0),2)matrices=pmap(lambdakey:random.normal(key,(5000,6000)))(random_keys)# 没有数据传输,并行的在每个CPU/GPU上进行局部矩阵乘法outputs=pmap(lambdax:jnp.dot(x,x.T))(matrices)# 没有数据传输,并行的在每个CPU/GPU上分别求取这两个矩阵的均值means=pmap(jnp.mean)(outputs)print(means)
最终的输出为:
一些Python库使用JAX作为后端,这包括:
- ^jax/AUTHORS at main · jax-ml/jax.GitHub. [December 21, 2024].
- ^jax-v0.1.49.
- ^https://github.com/google/jax/releases/tag/jax-v0.4.24.
- ^4.04.1Bradbury, James; Frostig, Roy; Hawkins, Peter; Johnson, Matthew James; Leary, Chris; MacLaurin, Dougal; Necula, George; Paszke, Adam; Vanderplas, Jake; Wanderman-Milne, Skye; Zhang, Qiao,JAX: Autograd and XLA, Astrophysics Source Code Library (Google), 2022-06-18 [2022-06-18],Bibcode:2021ascl.soft11002B, (原始内容存档于2022-06-18)
- ^Frostig, Roy; Johnson, Matthew James; Leary, Chris.Compiling machine learning programs via high-level tracing(PDF). MLsys. 2018-02-02: 1–3. (原始内容存档(PDF)于2022-06-21). 引文格式1维护:日期与年 (link)
- ^Using JAX to accelerate our research. www.deepmind.com. [2022-06-18]. (原始内容存档于2022-06-18)(英语).
- ^autograd. [2023-09-23]. (原始内容存档于2022-07-18).
- ^XLA. [2023-09-23]. (原始内容存档于2022-09-01).
- ^Lynley, Matthew.Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta. Business Insider. [2022-06-21]. (原始内容存档于2022-06-21)(美国英语).
- ^Why is Google's JAX so popular?. Analytics India Magazine. 2022-04-25 [2022-06-18]. (原始内容存档于2022-06-18)(美国英语).
- ^Flax: A neural network library and ecosystem for JAX designed for flexibility, Google, 2022-07-29 [2022-07-29], (原始内容存档于2022-09-03)
- ^Kidger, Patrick,Equinox, 2022-07-29 [2022-07-29], (原始内容存档于2023-09-19)
- ^Kidger, Patrick,Diffrax, 2023-08-05 [2023-08-08], (原始内容存档于2023-08-10)
- ^Optax, DeepMind, 2022-07-28 [2022-07-29], (原始内容存档于2023-06-07)
- ^Lineax, Google, 2023-08-08 [2023-08-08], (原始内容存档于2023-08-10)
- ^RLax, DeepMind, 2022-07-29 [2022-07-29], (原始内容存档于2023-04-26)
- ^Jraph - A library for graph neural networks in jax., DeepMind, 2023-08-08 [2023-08-08], (原始内容存档于2022-11-23)
- ^jaxtyping, Google, 2023-08-08 [2023-08-08], (原始内容存档于2023-08-10)
- ^NumPyro - Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU. [2022-08-31]. (原始内容存档于2022-08-31).
- ^Brax - Massively parallel rigidbody physics simulation on accelerator hardware. [2022-08-31]. (原始内容存档于2022-08-31).
可微分计算 |
---|
概论 | |
---|
概念 | |
---|
应用 | |
---|
硬件 | |
---|
软件库 | |
---|
架构 | |
---|
主题 分类
|