{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Jax and Backend" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import sys\n", "import os\n", "\n", "sys.path.insert(0, os.path.abspath(\".\"))\n", "sys.path.append(os.path.abspath(\"../../../\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "DESC uses JAX for faster execution times with just-in-time (JIT) compilation, automatic differentiation, and other scientific computing tools.\n", "The purpose of ``backend.py`` is to determine whether DESC may take advantage of JAX and GPUs or default to standard ``numpy`` and CPUs. To run DESC on GPU, you should simply have the following code section before you import anything from DESC," ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# from desc import set_device\n", "# set_device(\"gpu\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can check if it is running on a CPU or GPU with `print_backend_info()`.\n", "This will print the DESC and JAX or NumPy versions, and the device information." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:2025-01-31 16:53:19,421:jax._src.xla_bridge:969: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "DESC version=0.13.0+1313.g2d561ee71.dirty.\n", "Using JAX backend: jax version=0.4.37, jaxlib version=0.4.36, dtype=float64.\n", "Using device: CPU, with 5.18 GB available memory.\n" ] } ], "source": [ "from desc.backend import print_backend_info\n", "\n", "print_backend_info()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "JAX provides a ``numpy`` style API for array operations.\n", "In many cases, to take advantage of JAX, one only needs to replace calls to ``numpy`` with calls to ``jax.numpy``.\n", "A convenient way to do this is with the import statement ``import jax.numpy as jnp``." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from desc.backend import jnp\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0. 0. 0. 0.]\n", "[0. 0. 0. 0.]\n" ] } ], "source": [ "# give some JAX examples\n", "zeros_jnp = jnp.zeros(4)\n", "zeros_np = np.zeros(4)\n", "\n", "print(zeros_jnp)\n", "print(zeros_np)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Of course if such an import statement is used in DESC, and DESC is run on a machine where JAX is not installed, then a runtime error is thrown.\n", "We would prefer if DESC still works on machines where JAX is not installed.\n", "With that goal, in functions which can benefit from JAX, we use the following import statement: ``from desc.backend import jnp``.\n", "``desc.backend.jnp`` is an alias to ``jax.numpy`` if JAX is installed and ``numpy`` otherwise.\n", "\n", "While ``jax.numpy`` attempts to serve as a drop in replacement for ``numpy``, it imposes some constraints on how the code is written.\n", "For example, ``jax.numpy`` arrays are immutable.\n", "This means in-place updates to elements in arrays is not possible.\n", "To update elements in ``jax.numpy`` arrays, memory needs to be allocated to create a new array with the updated element.\n", "Similarly, JAX's JIT compilation requires control flow structures such as loops and conditionals to be written in a specific way.\n", "\n", "The utility functions in ``desc.backend`` provide a simple interface to perform these operations." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1. 0. 0. 0.]\n" ] } ], "source": [ "zeros_jnp = jnp.zeros(4)\n", "# this will give an error\n", "# zeros_jnp[0] = 1\n", "# we need to use the at[] method\n", "zeros_jnp = zeros_jnp.at[0].set(1)\n", "print(zeros_jnp)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2. 0. 0. 0.]\n" ] } ], "source": [ "# or to make this compatible with numpy backend we can use the following\n", "from desc.backend import put\n", "\n", "zeros_jnp = put(zeros_jnp, 0, 2)\n", "print(zeros_jnp)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since `JAX` documentation does a really good job of explaining the similarities and the differences between `jax.numpy` and `numpy`, we won't go too deep here but mention some of the major differences to get you started.\n", "\n", "Technically, most of the operations can be written using `numpy` (as long as it is out of `jax.jit`), but for most of the cases, `jax.numpy` is faster and it can use both CPU and GPUs without any code change. `jax.array`s can live in different devices and also take advantage of efficient implementations of a function depending on the hardware used." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It is still a good practice to test both versions to see which one is faster (for functions outside of jit). One important point to consider during profiling is to use `block_until_ready()` as explained [here](https://jax.readthedocs.io/en/latest/async_dispatch.html). If you want to specifically use `numpy` version, instead of using `numpy` backend for the whole code, just import `numpy` as usual. There are couple places in the code, we specifically use `numpy` functions. There are different reasons for these, for example, since `jax.array`s are immutable, sometimes they behave unexpectedly in loops, or sometimes `jax.numpy` functions have overhead that makes them slower compared to their `numpy` counterpart for single use." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There is a plan to remove `numpy` backend since some portions of the code uses `JAX` or related functions which doesn't have other equivalents, and code that relies on the `numpy` backend instead of `JAX` is not automatically tested for correctness by the GitHub CI. Depending on the backend, DESC automatically chooses which method of differentiation to use. If there is no `JAX` installation, it uses finite difference for derivatives. " ] } ], "metadata": { "kernelspec": { "display_name": "desc", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 2 }