jax-js note
feb 7
Current goals
"Tutorial mindset" — imagine making a workshop where I'm teaching people about array programming or functional programming in practice.
Try to just implement float32, grad() and arithmetic.
Build out a quick code editor / REPL in the browser and import the library using Vite. Then run code to experiment a bit.
After that, play it by ear. PyTrees -> JSON, …
Description (future)
NumPy and JAX for the browser, running on CPU or GPU.
Machine learning and numerical computing in JavaScript with the JAX/NumPy API. Define arrays, then run arbitrary differentiable code on CPU, WASM, WebGL, or WebGPU backends.
Examples: fluid simulation, neural networks, computer vision, robotics, statistics.
import { grad, numpy as np } from "jax-js";
const y = grad(x => x.mul(2))(np.array([1, 2, 3]))
console.log(y.js())Memory management
Refcount / ownership contract: all arguments to functions must be used or disposed.
ref() and dispose()
who is your target? scientists, artists, anyone who uses numerical computing. maybe eventually porting ML models (but that’s better suited for ONNX probably)
why?
story: you can’t figure out how to use pip. oh look, here’s a webgpu version that’s fast and “just works” — shaders can do a lot
story: XLA is hard and very complicated, what if you set up a minimal compiler toolchain to do the 80%+ of fusing operations directly in the browser?
story: you are an artist and want to write some numerical simulations, but you don’t want to invent a whole library like https://github.com/amandaghassaei/gpu-io yourself
