jaxâjs: Bringing JAXâs Advanced MachineâLearning Engine to the Web with WebGPU and WebAssembly
The openâsource library jaxâjs reâimplements Google DeepMindâs JAX framework entirely in JavaScript, enabling highâperformance machineâlearning directly in the browser. By generating WebGPU shaders and WebAssembly kernels on the fly, it offers JAXâstyle automatic differentiation, justâinâtime compilation, and GPUâaccelerated linear algebra without a native runtime. The project demonstrates training neural networks on MNIST in seconds and performing realâtime semantic search on large text corporaâall with pure frontend code.
jaxâjs is an ambitious openâsource project that reâimplements the Google DeepMind JAX framework in pure JavaScript, allowing developers to write numerical and machineâlearning code that runs at nearânative speed in any modern browser.
## Why a JavaScript JAX?
Python dominates the ML ecosystem thanks to libraries like PyTorch and JAX, but JavaScript remains the lingua franca of the web. Traditionally, numerical kernels in the browser suffer from JavaScriptâs lack of typed arrays and the absence of a JIT tuned for tight loops. jaxâjs solves this by leveraging the emerging browser technologies WebAssembly (Wasm) and WebGPU, which expose a lowâlevel bytecode runtime and programmable GPU shaders, respectively. The result is a JAXâcompatible API that generates and executes kernels compiled to Wasm or GPU code on the fly, bypassing the JavaScript interpreter.
## Core Architecture
The library adopts the classic JAX frontâend and backâend split:
* **Frontâend** â The API mirrors JAXâs numpyâlike syntax (e.g., `np.array`, `np.add`, `np.mul`). Automatic differentiation, justâinâtime (JIT) compilation, vectorised mapped loops (`vmap`), and gradient transformations (`grad`, `jaxpr`) are all implemented in JavaScript. Internally, frontâend operations are represented as a graph of *Primitive* nodes which are then lowered to a lightweight, stackâbased intermediate representation called *Jaxpr*.
* **Backâend** â Kernels are generated as WebGPU shaders or Wasm bytecode. A `Backend` interface abstracts memory allocation and dispatch; two concrete implementations exist: a WebGPU backend that emits GLSLâlike WGSL shaders, and a small WebAssembly backend for browsers that lack GPU support.
The backâend uses a symbolic `AluExp` tree consisting of 28 mathematical primitives (add, mul, sin, sqrt, etc.). This tree is compiled into code that performs pointâwise operations and, for reductions such as matrix multiplication and convolution, emits tightly tiled loops with manual unrolling and SIMDâaware optimisations.
## Usage Example
```js
import { numpy as np, grad, jit, init, setDevice } from "@jax-js/jax";
await init("webgpu");
setDevice("webgpu");
// Create a simple array and perform an elementâwise operation
const a = np.array([1, 5, 6, 7]);
console.log(a.mul(10).js()); // â [10, 50, 60, 70]
// Automatic differentiation
const f = (x) => np.sqrt(np.square(x).sum());
const df = grad(f);
const x = np.array([1, 2, 3, 4]);
console.log(df(x).js());
// JITâcompiled matrix multiplication with fusion
const matmul = jit((x) => np.dot(np.square(x).add(2), Math.PI));
```
Arrays are moveâowned; methods like `.ref` increment reference counts, and `.js()` converts a `jax.Array` back to a plain JavaScript array.
## Performance Highlights
* **Matrix multiplication** â A 4096âŻĂâŻ4096 matrix multiplication on an Apple M4 Pro chip achieves >3âŻTFLOP/s through the WebGPU backend, outperforming TensorFlow.js and ONNX runtime in many scenarios.
* **Neuralânetwork training** â A vanilla MLP trained on MNIST reaches >99âŻ% accuracy in a few seconds, with all data loading, forward passes, and backâpropagation performed entirely in the browser.
* **Text embedding** â Running a CLIPâŻtransformer with batchâsizeâŻ16 processes a 180âŻkâword text collection in under 200âŻms per inference, yielding an estimated 485âŻGFLOP/s on an M1âŻPro.
Benchmarking indicates that the WebGPU implementation is already comparable to handâtuned C++ libraries thanks to the compilerâgenerated kernels and the underlying GPU acceleration.
## Future Work & Open Challenges
* **Backend optimisation** â The current WebAssembly backend is functional but offers up to 150Ă speedâups possible through loop tiling, SIMD usage, and multiâthreading. The WebGPU backend will benefit from experimental WebGPU features such as shared array buffers and compute group sizes.
* **Expanded JAX API compatibility** â While most core primitives are present, higherâorder features (e.g., `lax.while_loop`, advanced complex datatypes) and certain optimised kernels (e.g., tensorâcore fusion for 16âbit multiplyâaccumulate) are yet to be added.
* **WebGL fallback** â To support legacy browsers, a lightweight WebGL backend could be implemented â a codeâbase already exists in roughly 700 lines.
* **Community contributions** â The project encourages issues and pull requests, especially for missing primitives such as FFTs, optimisers (AdamW), and `np.split`.
## Getting Started
```sh
npm install @jax-js/jax
```
Browse the interactive REPL on the official website or clone the repository from GitHub to experiment locally.
## Conclusion
jaxâjs demonstrates that a full JAXâstyle machineâlearning stack can run efficiently on the web using only standard browser APIs. By generating WebGPU shaders and WebAssembly kernels at runtime, the library achieves competitive performance while staying pure JavaScript. The project opens up exciting possibilities for realâtime model training, onâdevice inference, and collaborative, browserâbased ML workflows.