← BackJan 6, 2026

jax‑js: Bringing JAX’s Advanced Machine‑Learning Engine to the Web with WebGPU and WebAssembly

The open‑source library jax‑js re‑implements Google DeepMind’s JAX framework entirely in JavaScript, enabling high‑performance machine‑learning directly in the browser. By generating WebGPU shaders and WebAssembly kernels on the fly, it offers JAX‑style automatic differentiation, just‑in‑time compilation, and GPU‑accelerated linear algebra without a native runtime. The project demonstrates training neural networks on MNIST in seconds and performing real‑time semantic search on large text corpora—all with pure frontend code.

jax‑js is an ambitious open‑source project that re‑implements the Google DeepMind JAX framework in pure JavaScript, allowing developers to write numerical and machine‑learning code that runs at near‑native speed in any modern browser. ## Why a JavaScript JAX? Python dominates the ML ecosystem thanks to libraries like PyTorch and JAX, but JavaScript remains the lingua franca of the web. Traditionally, numerical kernels in the browser suffer from JavaScript’s lack of typed arrays and the absence of a JIT tuned for tight loops. jax‑js solves this by leveraging the emerging browser technologies WebAssembly (Wasm) and WebGPU, which expose a low‑level bytecode runtime and programmable GPU shaders, respectively. The result is a JAX‑compatible API that generates and executes kernels compiled to Wasm or GPU code on the fly, bypassing the JavaScript interpreter. ## Core Architecture The library adopts the classic JAX front‑end and back‑end split: * **Front‑end** – The API mirrors JAX’s numpy‑like syntax (e.g., `np.array`, `np.add`, `np.mul`). Automatic differentiation, just‑in‑time (JIT) compilation, vectorised mapped loops (`vmap`), and gradient transformations (`grad`, `jaxpr`) are all implemented in JavaScript. Internally, front‑end operations are represented as a graph of *Primitive* nodes which are then lowered to a lightweight, stack‑based intermediate representation called *Jaxpr*. * **Back‑end** – Kernels are generated as WebGPU shaders or Wasm bytecode. A `Backend` interface abstracts memory allocation and dispatch; two concrete implementations exist: a WebGPU backend that emits GLSL‑like WGSL shaders, and a small WebAssembly backend for browsers that lack GPU support. The back‑end uses a symbolic `AluExp` tree consisting of 28 mathematical primitives (add, mul, sin, sqrt, etc.). This tree is compiled into code that performs point‑wise operations and, for reductions such as matrix multiplication and convolution, emits tightly tiled loops with manual unrolling and SIMD‑aware optimisations. ## Usage Example ```js import { numpy as np, grad, jit, init, setDevice } from "@jax-js/jax"; await init("webgpu"); setDevice("webgpu"); // Create a simple array and perform an element‑wise operation const a = np.array([1, 5, 6, 7]); console.log(a.mul(10).js()); // → [10, 50, 60, 70] // Automatic differentiation const f = (x) => np.sqrt(np.square(x).sum()); const df = grad(f); const x = np.array([1, 2, 3, 4]); console.log(df(x).js()); // JIT‑compiled matrix multiplication with fusion const matmul = jit((x) => np.dot(np.square(x).add(2), Math.PI)); ``` Arrays are move‑owned; methods like `.ref` increment reference counts, and `.js()` converts a `jax.Array` back to a plain JavaScript array. ## Performance Highlights * **Matrix multiplication** – A 4096 × 4096 matrix multiplication on an Apple M4 Pro chip achieves >3 TFLOP/s through the WebGPU backend, outperforming TensorFlow.js and ONNX runtime in many scenarios. * **Neural‑network training** – A vanilla MLP trained on MNIST reaches >99 % accuracy in a few seconds, with all data loading, forward passes, and back‑propagation performed entirely in the browser. * **Text embedding** – Running a CLIP transformer with batch‑size 16 processes a 180 k‑word text collection in under 200 ms per inference, yielding an estimated 485 GFLOP/s on an M1 Pro. Benchmarking indicates that the WebGPU implementation is already comparable to hand‑tuned C++ libraries thanks to the compiler‑generated kernels and the underlying GPU acceleration. ## Future Work & Open Challenges * **Backend optimisation** – The current WebAssembly backend is functional but offers up to 150× speed‑ups possible through loop tiling, SIMD usage, and multi‑threading. The WebGPU backend will benefit from experimental WebGPU features such as shared array buffers and compute group sizes. * **Expanded JAX API compatibility** – While most core primitives are present, higher‑order features (e.g., `lax.while_loop`, advanced complex datatypes) and certain optimised kernels (e.g., tensor‑core fusion for 16‑bit multiply‑accumulate) are yet to be added. * **WebGL fallback** – To support legacy browsers, a lightweight WebGL backend could be implemented – a code‑base already exists in roughly 700 lines. * **Community contributions** – The project encourages issues and pull requests, especially for missing primitives such as FFTs, optimisers (AdamW), and `np.split`. ## Getting Started ```sh npm install @jax-js/jax ``` Browse the interactive REPL on the official website or clone the repository from GitHub to experiment locally. ## Conclusion jax‑js demonstrates that a full JAX‑style machine‑learning stack can run efficiently on the web using only standard browser APIs. By generating WebGPU shaders and WebAssembly kernels at runtime, the library achieves competitive performance while staying pure JavaScript. The project opens up exciting possibilities for real‑time model training, on‑device inference, and collaborative, browser‑based ML workflows.