Developer(s) Logo Google 0.4.24[1]  / 6 February 2024; 26 days ago github.com/google/jax Python, C++ Linux, macOS, Windows Python, NumPy 9.0 MB Machine learning Apache 2.0 jax.readthedocs.io/en/latest/

Google JAX is a machine learning framework for transforming numerical functions, to be used in Python.[2][3][4] It is described as bringing together a modified version of autograd (automatic obtaining of the gradient function through differentiation of a function) and TensorFlow's XLA (Accelerated Linear Algebra). It is designed to follow the structure and workflow of NumPy as closely as possible and works with various existing frameworks such as TensorFlow and PyTorch.[5][6] The primary functions of JAX are:[2]

2. jit: compilation
3. vmap: auto-vectorization
4. pmap: SPMD programming

The code below demonstrates the `grad` function's automatic differentiation.

```# imports
import jax.numpy as jnp

# define the logistic function
def logistic(x):
return jnp.exp(x) / (jnp.exp(x) + 1)

# obtain the gradient function of the logistic function

# evaluate the gradient of the logistic function at x = 1
```

The final line should outputː

```0.19661194
```

## jit

The code below demonstrates the jit function's optimization through fusion.

```# imports
from jax import jit
import jax.numpy as jnp

# define the cube function
def cube(x):
return x * x * x

# generate data
x = jnp.ones((10000, 10000))

# create the jit version of the cube function
jit_cube = jit(cube)

# apply the cube and jit_cube functions to the same data for speed comparison
cube(x)
jit_cube(x)
```

The computation time for `jit_cube` (line no. 17) should be noticeably shorter than that for `cube` (line no. 16). Increasing the values on line no. 10, will increase the difference.

## vmap

The code below demonstrates the `vmap` function's vectorization.

```# imports
from functools import partial
from jax import vmap
import jax.numpy as jnp

# define function
```

The GIF on the right of this section illustrates the notion of vectorized addition.

## pmap

The code below demonstrates the `pmap` function's parallelization for matrix multiplication.

```# import pmap and random from JAX; import JAX NumPy
from jax import pmap, random
import jax.numpy as jnp

# generate 2 random matrices of dimensions 5000 x 6000, one per device
random_keys = random.split(random.PRNGKey(0), 2)
matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)

# without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPU
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)

# without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separately
means = pmap(jnp.mean)(outputs)
print(means)
```

The final line should print the valuesː

```[1.1566595 1.1805978]
```

## Libraries using JAX

Several python libraries use JAX as a backend, including:

Some R libraries use JAX as a backend as well, including:

• fastrerandomize, a library that uses the linear-algebra optimized compiler in JAX to speed up selection of balanced randomizations in a design of experiments procedure known as rerandomization.[16]