Jax and Backend

[1]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../../"))

DESC uses JAX for faster execution times with just-in-time (JIT) compilation, automatic differentiation, and other scientific computing tools. The purpose of backend.py is to determine whether DESC may take advantage of JAX and GPUs or default to standard numpy and CPUs. To run DESC on GPU, you should simply have the following code section before you import anything from DESC,

[2]:
# from desc import set_device
# set_device("gpu")

You can check if it is running on a CPU or GPU with print_backend_info(). This will print the DESC and JAX or NumPy versions, and the device information.

[3]:
from desc.backend import print_backend_info

print_backend_info()
WARNING:2025-01-31 16:53:19,421:jax._src.xla_bridge:969: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
DESC version=0.13.0+1313.g2d561ee71.dirty.
Using JAX backend: jax version=0.4.37, jaxlib version=0.4.36, dtype=float64.
Using device: CPU, with 5.18 GB available memory.

JAX provides a numpy style API for array operations. In many cases, to take advantage of JAX, one only needs to replace calls to numpy with calls to jax.numpy. A convenient way to do this is with the import statement import jax.numpy as jnp.

[4]:
from desc.backend import jnp
import numpy as np
[5]:
# give some JAX examples
zeros_jnp = jnp.zeros(4)
zeros_np = np.zeros(4)

print(zeros_jnp)
print(zeros_np)
[0. 0. 0. 0.]
[0. 0. 0. 0.]

Of course if such an import statement is used in DESC, and DESC is run on a machine where JAX is not installed, then a runtime error is thrown. We would prefer if DESC still works on machines where JAX is not installed. With that goal, in functions which can benefit from JAX, we use the following import statement: from desc.backend import jnp. desc.backend.jnp is an alias to jax.numpy if JAX is installed and numpy otherwise.

While jax.numpy attempts to serve as a drop in replacement for numpy, it imposes some constraints on how the code is written. For example, jax.numpy arrays are immutable. This means in-place updates to elements in arrays is not possible. To update elements in jax.numpy arrays, memory needs to be allocated to create a new array with the updated element. Similarly, JAX’s JIT compilation requires control flow structures such as loops and conditionals to be written in a specific way.

The utility functions in desc.backend provide a simple interface to perform these operations.

[6]:
zeros_jnp = jnp.zeros(4)
# this will give an error
# zeros_jnp[0] = 1
# we need to use the at[] method
zeros_jnp = zeros_jnp.at[0].set(1)
print(zeros_jnp)
[1. 0. 0. 0.]
[7]:
# or to make this compatible with numpy backend we can use the following
from desc.backend import put

zeros_jnp = put(zeros_jnp, 0, 2)
print(zeros_jnp)
[2. 0. 0. 0.]

Since JAX documentation does a really good job of explaining the similarities and the differences between jax.numpy and numpy, we won’t go too deep here but mention some of the major differences to get you started.

Technically, most of the operations can be written using numpy (as long as it is out of jax.jit), but for most of the cases, jax.numpy is faster and it can use both CPU and GPUs without any code change. jax.arrays can live in different devices and also take advantage of efficient implementations of a function depending on the hardware used.

It is still a good practice to test both versions to see which one is faster (for functions outside of jit). One important point to consider during profiling is to use block_until_ready() as explained here. If you want to specifically use numpy version, instead of using numpy backend for the whole code, just import numpy as usual. There are couple places in the code, we specifically use numpy functions. There are different reasons for these, for example, since jax.arrays are immutable, sometimes they behave unexpectedly in loops, or sometimes jax.numpy functions have overhead that makes them slower compared to their numpy counterpart for single use.

There is a plan to remove numpy backend since some portions of the code uses JAX or related functions which doesn’t have other equivalents, and code that relies on the numpy backend instead of JAX is not automatically tested for correctness by the GitHub CI. Depending on the backend, DESC automatically chooses which method of differentiation to use. If there is no JAX installation, it uses finite difference for derivatives.