<?xml version="1.0" encoding="UTF-8"?><rss xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:content="http://purl.org/rss/1.0/modules/content/" xmlns:atom="http://www.w3.org/2005/Atom" version="2.0" xmlns:itunes="http://www.itunes.com/dtds/podcast-1.0.dtd" xmlns:googleplay="http://www.google.com/schemas/play-podcasts/1.0"><channel><title><![CDATA[eric makes software]]></title><description><![CDATA[this is where I write about making software outside of work! systems/pl, interaction design, open-source]]></description><link>https://ss.ekzhang.com</link><image><url>https://ss.ekzhang.com/img/substack.png</url><title>eric makes software</title><link>https://ss.ekzhang.com</link></image><generator>Substack</generator><lastBuildDate>Fri, 08 May 2026 11:19:51 GMT</lastBuildDate><atom:link href="https://ss.ekzhang.com/feed" rel="self" type="application/rss+xml"/><copyright><![CDATA[Eric]]></copyright><language><![CDATA[en]]></language><webMaster><![CDATA[ekzhang@substack.com]]></webMaster><itunes:owner><itunes:email><![CDATA[ekzhang@substack.com]]></itunes:email><itunes:name><![CDATA[Eric]]></itunes:name></itunes:owner><itunes:author><![CDATA[Eric]]></itunes:author><googleplay:owner><![CDATA[ekzhang@substack.com]]></googleplay:owner><googleplay:email><![CDATA[ekzhang@substack.com]]></googleplay:email><googleplay:author><![CDATA[Eric]]></googleplay:author><itunes:block><![CDATA[Yes]]></itunes:block><item><title><![CDATA[jax-js: an ML library for the web]]></title><description><![CDATA[JAX in pure JavaScript, as a flexible machine learning library and compiler.]]></description><link>https://ss.ekzhang.com/p/jax-js-an-ml-library-for-the-web</link><guid isPermaLink="false">https://ss.ekzhang.com/p/jax-js-an-ml-library-for-the-web</guid><dc:creator><![CDATA[Eric]]></dc:creator><pubDate>Thu, 18 Dec 2025 15:02:32 GMT</pubDate><enclosure url="https://substackcdn.com/image/fetch/$s_!7bb1!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F1466b600-4dc2-4a2c-9639-843f5a1c700a_1196x450.png" length="0" type="image/jpeg"/><content:encoded><![CDATA[<p>I&#8217;m excited to release <a href="https://jax-js.com/">jax-js</a>, a machine learning library for the web.</p><div class="captioned-image-container"><figure><a class="image-link image2" target="_blank" href="https://substackcdn.com/image/fetch/$s_!7bb1!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F1466b600-4dc2-4a2c-9639-843f5a1c700a_1196x450.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!7bb1!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F1466b600-4dc2-4a2c-9639-843f5a1c700a_1196x450.png 424w, https://substackcdn.com/image/fetch/$s_!7bb1!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F1466b600-4dc2-4a2c-9639-843f5a1c700a_1196x450.png 848w, https://substackcdn.com/image/fetch/$s_!7bb1!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F1466b600-4dc2-4a2c-9639-843f5a1c700a_1196x450.png 1272w, https://substackcdn.com/image/fetch/$s_!7bb1!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F1466b600-4dc2-4a2c-9639-843f5a1c700a_1196x450.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!7bb1!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F1466b600-4dc2-4a2c-9639-843f5a1c700a_1196x450.png" width="446" height="167.80936454849498" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/1466b600-4dc2-4a2c-9639-843f5a1c700a_1196x450.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:450,&quot;width&quot;:1196,&quot;resizeWidth&quot;:446,&quot;bytes&quot;:26596,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:&quot;image/png&quot;,&quot;href&quot;:null,&quot;belowTheFold&quot;:false,&quot;topImage&quot;:true,&quot;internalRedirect&quot;:&quot;https://ekzhang.substack.com/i/179060245?img=https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F1466b600-4dc2-4a2c-9639-843f5a1c700a_1196x450.png&quot;,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!7bb1!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F1466b600-4dc2-4a2c-9639-843f5a1c700a_1196x450.png 424w, https://substackcdn.com/image/fetch/$s_!7bb1!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F1466b600-4dc2-4a2c-9639-843f5a1c700a_1196x450.png 848w, https://substackcdn.com/image/fetch/$s_!7bb1!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F1466b600-4dc2-4a2c-9639-843f5a1c700a_1196x450.png 1272w, https://substackcdn.com/image/fetch/$s_!7bb1!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F1466b600-4dc2-4a2c-9639-843f5a1c700a_1196x450.png 1456w" sizes="100vw" fetchpriority="high"></picture><div></div></div></a></figure></div><p>You can think of it as a reimplementation of Google DeepMind&#8217;s <a href="https://docs.jax.dev/en/latest/">JAX</a> framework (similar to PyTorch) in pure JavaScript.</p><p>jax-js runs completely in the browser by generating fast <a href="https://developer.mozilla.org/en-US/docs/Web/API/WebGPU_API">WebGPU</a> and <a href="https://webassembly.org/">Wasm</a> kernels.</p><h2>Numerical computing on the web</h2><p>Starting in February this year, I spent nights and weekends working on a new ML library for the browser. I wanted a cross-platform way to run numerical programs on the frontend web, so you can do machine learning.</p><p>Python and JavaScript are the <a href="https://survey.stackoverflow.co/2025/technology#most-popular-technologies-language">most popular languages</a> in the world:</p><ol><li><p><strong>JavaScript</strong> is the language of the web.</p></li><li><p><strong>Python</strong> is simple, expressive and now ubiquitous in ML thanks to frameworks like PyTorch and JAX.</p></li></ol><p>But most developers would balk at running any number crunching in JavaScript. While the JavaScript JIT is really good, it&#8217;s not optimized for tight numerical loops. JavaScript doesn&#8217;t even have a fast, native integer data type! So how can you run fast numerical code on the web?</p><p>The answer is to rely on new browser technologies &#8212; WebAssembly and WebGPU, which allow you to run programs at near-native speeds. WebAssembly is a low-level portable bytecode, and WebGPU is GPU shaders on the web.</p><p>If we can use these native runtimes, then this lends itself to a programming model similar to JAX, where you <em>trace</em> programs and <em>JIT</em> <em>compile</em> them to GPU kernels. Here, instead of Nvidia CUDA, we write pure JavaScript to generate WebAssembly and WebGPU kernels. Then we can run them and execute instructions at near-native speed, skipping the JavaScript interpreter bottleneck.</p><p>That is what I ended up doing in jax-js, and now it &#8220;just works&#8221;.</p><h2>Getting started</h2><p>You can install jax-js as a library. It has 0 dependencies and is pure JS.</p><pre><code>npm install @jax-js/jax</code></pre><p>Then you can use it with an API almost identical to JAX.</p><pre><code>import { numpy as np } from "@jax-js/jax";

const ar = np.array([1, 5, 6, 7]);
console.log(ar.mul(10).js());  // -&gt; [10, 50, 60, 70]</code></pre><p>Under the hood, this generates a WebAssembly kernel and dispatches it.</p><blockquote><p><strong>Note:</strong> There are some surface-level syntax differences here, versus JAX:</p><ul><li><p>JavaScript doesn&#8217;t have operator overloading like Python. Instead of <code>ar * 10</code> in Python, you have to call <code>ar.mul(10)</code>.</p></li><li><p>The <code>.js()</code> method converts a jax.Array object back into a plain JS array.</p></li><li><p>JS has no reference-counted destructor method to free memory, so array values in jax-js have <a href="https://doc.rust-lang.org/rust-by-example/scope/move.html">move semantics</a> like Rust, with <code>.ref</code> incrementing their reference counts.</p></li></ul></blockquote><p>If you&#8217;d like to use WebGPU, just start your program with:</p><pre><code>import { init, setDevice } from "@jax-js/jax";

await init("webgpu");
setDevice("webgpu");</code></pre><p>You can leverage grad, vmap, and other features of JAX. Here&#8217;s automatic differentiation with <code>grad()</code>:</p><pre><code>import { grad, numpy as np } from &#8220;@jax-js/jax&#8221;;

const f = (x: np.Array) =&gt; np.sqrt(x.ref.mul(x).sum());
const df = grad(f);

const x = np.array([1, 2, 3, 4]);
console.log(df(x).js());</code></pre><p>And here&#8217;s an example the compiler fusing operations with <code>jit()</code>. The following function gets translated into a compiled GPU compute kernel:</p><pre><code>import { jit, numpy as np } from "@jax-js/jax";

const f = jit((x: np.Array) =&gt; {
  return np.sqrt(x.add(2).mul(Math.PI)).sum();
});</code></pre><h2>Machine learning</h2><p>With these simple building blocks, you can implement most machine learning algorithms and backpropagate through them.</p><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!8Tz2!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F7a346a77-cf0a-407d-aecc-ec882cd2ebe0_3442x1310.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!8Tz2!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F7a346a77-cf0a-407d-aecc-ec882cd2ebe0_3442x1310.png 424w, https://substackcdn.com/image/fetch/$s_!8Tz2!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F7a346a77-cf0a-407d-aecc-ec882cd2ebe0_3442x1310.png 848w, https://substackcdn.com/image/fetch/$s_!8Tz2!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F7a346a77-cf0a-407d-aecc-ec882cd2ebe0_3442x1310.png 1272w, https://substackcdn.com/image/fetch/$s_!8Tz2!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F7a346a77-cf0a-407d-aecc-ec882cd2ebe0_3442x1310.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!8Tz2!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F7a346a77-cf0a-407d-aecc-ec882cd2ebe0_3442x1310.png" width="1456" height="554" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/7a346a77-cf0a-407d-aecc-ec882cd2ebe0_3442x1310.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:554,&quot;width&quot;:1456,&quot;resizeWidth&quot;:null,&quot;bytes&quot;:366314,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:&quot;image/png&quot;,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:&quot;https://ekzhang.substack.com/i/179060245?img=https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F7a346a77-cf0a-407d-aecc-ec882cd2ebe0_3442x1310.png&quot;,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!8Tz2!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F7a346a77-cf0a-407d-aecc-ec882cd2ebe0_3442x1310.png 424w, https://substackcdn.com/image/fetch/$s_!8Tz2!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F7a346a77-cf0a-407d-aecc-ec882cd2ebe0_3442x1310.png 848w, https://substackcdn.com/image/fetch/$s_!8Tz2!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F7a346a77-cf0a-407d-aecc-ec882cd2ebe0_3442x1310.png 1272w, https://substackcdn.com/image/fetch/$s_!8Tz2!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F7a346a77-cf0a-407d-aecc-ec882cd2ebe0_3442x1310.png 1456w" sizes="100vw" loading="lazy"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a></figure></div><p>Here is <a href="https://jax-js.com/mnist">a runnable example</a> of training a neural network from scratch on MNIST dataset in your browser. It learns to &gt;99% accuracy in seconds, and everything from dataset loading to matmul kernels is <em>pure frontend JavaScript code</em>.</p><p>It&#8217;s remarkable to write ML programs with hot module reloading. You can edit code in real time <em>while</em> the model is training!</p><p>&#8212;</p><p>You can also build applications. <a href="https://jax-js.com/mobileclip">Here&#8217;s a demo I built yesterday</a>: download the whole text of <em>Great Expectations</em> (180,000 words), run it through a CLIP-based embedding model, and semantic search it in real time&#8212;all from your browser.</p><div class="native-video-embed" data-component-name="VideoPlaceholder" data-attrs="{&quot;mediaUploadId&quot;:&quot;e2749108-67b6-417e-a78d-9a55826b1f1b&quot;,&quot;duration&quot;:null}"></div><p><em>(The text embedding actually runs at a respectable ~500 GFLOP/s on my M1 Pro with just jax.jit(), despite me not having tried to optimize it at all yet. Not bad, crunching 500,000,000,000 calculations/second in browser on a 4-year-old laptop!)</em></p><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!bRAB!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F0bad2a9e-a8ab-4041-a13c-79f5f6abf67d_2520x1020.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!bRAB!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F0bad2a9e-a8ab-4041-a13c-79f5f6abf67d_2520x1020.png 424w, https://substackcdn.com/image/fetch/$s_!bRAB!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F0bad2a9e-a8ab-4041-a13c-79f5f6abf67d_2520x1020.png 848w, https://substackcdn.com/image/fetch/$s_!bRAB!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F0bad2a9e-a8ab-4041-a13c-79f5f6abf67d_2520x1020.png 1272w, https://substackcdn.com/image/fetch/$s_!bRAB!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F0bad2a9e-a8ab-4041-a13c-79f5f6abf67d_2520x1020.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!bRAB!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F0bad2a9e-a8ab-4041-a13c-79f5f6abf67d_2520x1020.png" width="610" height="246.7651098901099" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/0bad2a9e-a8ab-4041-a13c-79f5f6abf67d_2520x1020.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:589,&quot;width&quot;:1456,&quot;resizeWidth&quot;:610,&quot;bytes&quot;:263760,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:&quot;image/png&quot;,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:&quot;https://ekzhang.substack.com/i/179060245?img=https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F0bad2a9e-a8ab-4041-a13c-79f5f6abf67d_2520x1020.png&quot;,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!bRAB!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F0bad2a9e-a8ab-4041-a13c-79f5f6abf67d_2520x1020.png 424w, https://substackcdn.com/image/fetch/$s_!bRAB!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F0bad2a9e-a8ab-4041-a13c-79f5f6abf67d_2520x1020.png 848w, https://substackcdn.com/image/fetch/$s_!bRAB!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F0bad2a9e-a8ab-4041-a13c-79f5f6abf67d_2520x1020.png 1272w, https://substackcdn.com/image/fetch/$s_!bRAB!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F0bad2a9e-a8ab-4041-a13c-79f5f6abf67d_2520x1020.png 1456w" sizes="100vw" loading="lazy"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a><figcaption class="image-caption">Running with batch size 16 (x77 token context), each CLIP transformer inference takes 200 ms, for an estimated 485 GFLOP/s end-to-end.</figcaption></figure></div><p>For a lot of inference use cases, you might find a &#8220;model runtime&#8221; like <a href="https://onnxruntime.ai/docs/tutorials/web/">ONNX</a> to add prebuilt ML models to your browser, where the ML developers hand off pre-packaged weights to be used in product. With jax-js, it&#8217;s a bit different, and I&#8217;m imagining how a full ML framework, usually relegated to the backend, can run in a browser.</p><p>As for performance, it hasn&#8217;t been my primary focus so far, as just &#8220;getting the ML framework working&#8221; comes first. I have checked that jax-js&#8217;s generated kernels for <a href="https://jax-js.com/bench/matmul">matmuls are fast</a> (&gt;3 TFLOP on Macbook M4 Pro). But there&#8217;s a lot of room to improve (e.g., <a href="https://jax-js.com/bench/conv2d">conv2d</a> is slow), and I haven&#8217;t done much optimization work on transformer inference in particular yet. There&#8217;s plenty of low-hanging fruit.</p><h2>Project release</h2><p>I am open-sourcing jax-js today at <a href="https://github.com/ekzhang/jax-js">ekzhang/jax-js</a>.</p><p>There are rough edges in this initial release, but it&#8217;s ready to try out now.</p><p>Links:</p><ul><li><p><a href="https://jax-js.com">Website</a></p><ul><li><p><a href="https://jax-js.com/repl">Try it out! (REPL)</a></p></li><li><p><a href="https://jax-js.com/docs/">API reference</a></p></li></ul></li><li><p><a href="https://github.com/ekzhang/jax-js">GitHub repository</a></p></li></ul><p>I look forward to seeing what you create. &#129392;</p><p>&#8205;</p><p>&#8205;</p><p>&#8205;</p><p>&#8205;</p><p>&#8205;</p><div><hr></div><h2>Appendix</h2><p>This is a personal project and not related to Thinking Machines Lab. I started working on jax-js before starting my current job, and in a way, it&#8217;s partly how I ended up working in ML. Turns out this stuff is kind of fun!</p><p>If you&#8217;re still reading, hello&#8212;I have a bunch more details to share.</p><h3>Acknowledgements</h3><p>Thanks to:</p><ul><li><p>The authors of <a href="https://docs.jax.dev/en/latest/">JAX</a> for making an important ML library that&#8217;s a joy to use.</p><ul><li><p>Thanks to Matthew Johnson, Dougal Maclaurin, and others for <a href="https://docs.jax.dev/en/latest/autodidax.html">Autodidax</a>, an instructive implementation of the JAX core from scratch.</p></li><li><p>And thanks for all of the JAX ecosystem libraries as well.</p></li></ul></li><li><p><a href="https://github.com/tinygrad/tinygrad">Tinygrad</a> for a very excellent autograd library &#8212; you showed that code-generating kernels from scratch can&#8217;t really be that <em>intrinsically</em> complex!</p><ul><li><p>Many parts of jax-js in the backend internals follow Tinygrad&#8217;s design closely. The biggest example of this is <a href="https://github.com/ekzhang/jax-js/blob/jax/v0.0.5/src/shape.ts">ShapeTracker</a>, which was directly ported.</p></li></ul></li><li><p>Chrome, Safari, and Firefox for WebGPU, now <a href="https://chromestatus.com/metrics/feature/timeline/popularity/3888">used in 2% of all websites</a>.</p></li><li><p>The open-source community, for inspiration and for showing that ML on the web is actually possible!</p><ul><li><p><a href="https://www.tensorflow.org/js">TensorFlow.js</a></p></li><li><p><a href="https://www.npmjs.com/package/onnxruntime-web">onnxruntime-web</a></p></li><li><p><a href="https://github.com/praeclarum/webgpu-torch">webgpu-torch</a>, <a href="https://github.com/zanussbaum/surfgrad">surfgrad</a>, and <a href="https://jott.live/markdown/mm_wasm">wasmblr</a></p></li><li><p><a href="https://github.com/mrdoob/three.js/wiki/Three.js-Shading-Language">Three.js Shading Language</a> (<a href="https://github.com/holtsetio/flow/blob/master/src/mls-mpm/mlsMpmSimulator.js">example</a>)</p></li></ul></li><li><p><a href="https://pytorch.org/">PyTorch</a>, <a href="https://github.com/ml-explore/mlx">MLX</a>, and <a href="https://numpy.org/">NumPy</a></p></li></ul><h3>How it works: An overview of internals</h3><p>In general, I think there are roughly two parts to an ML library:</p><ol><li><p><strong>&#8220;Frontend&#8221; (think JAX):</strong> The interface for creating and manipulating arrays, the autograd engine, JIT, typing and transformations. Also where you interact with a sync/async boundary and how you track memory allocations.</p></li><li><p><strong>&#8220;Backend&#8221; (think XLA):</strong> Actual kernels for executing operations. The frontend has some kind of representation of a kernel, it dispatches it to the backend, which then optimizes it, compiles it down to native code (CPU or GPU) and runs it very efficiently.</p></li></ol><p>This dichotomy obviously isn&#8217;t perfect (e.g., where do <a href="https://github.com/triton-lang/triton">Triton</a>/<a href="https://docs.jax.dev/en/latest/pallas/">Pallas</a> fit in? how about warp-specialized <a href="https://docs.nvidia.com/cuda/cutile-python/">cuTile</a>?), and there are certainly concerns that span both parts. But it&#8217;s how jax-js works.</p><p><strong>Let&#8217;s start with the backend and build our way up.</strong> In jax-js, the backend code is actually quite self-contained; they implement the Backend interface (abridged):</p><pre><code>/** A device backend. */
export interface Backend {
  /** Allocate a new slot with reference count 1. */
  malloc(size: number, initialData?: Uint8Array): Slot;

  /** Increment the reference count of the slot. */
  incRef(slot: Slot): void;

  /**
   * Decrement the reference count of the slot. If the reference count reaches
   * zero, it is freed. This should throw if the slot was already freed.
   */
  decRef(slot: Slot): void;

  /** Read a range of bytes from a buffer. */
  read(
    slot: Slot,
    start?: number,
    count?: number,
  ): Promise&lt;Uint8Array&lt;ArrayBuffer&gt;&gt;;

  /** Prepare an expression to be executed later. */
  prepare(kernel: Kernel): Promise&lt;Executable&gt;;

  /**
   * Run a backend operation that was previously prepared.
   *
   * The operation may not run immediately, but operations are guaranteed to run
   * in the dispatch order. Also, `read()` will wait for all pending operations
   * on that slot to finish.
   */
  dispatch(exe: Executable, inputs: Slot[], outputs: Slot[]): void;
}</code></pre><p>In other words, backends need to be able to malloc/free chunks of memory for tensors, and to execute <code>Kernel</code> objects. Inside a <code>Kernel</code> there is:</p><ul><li><p>A pointwise operation on one or more tensors, with</p></li><li><p>Lazy shape-tracking information for how to index the tensors, and</p></li><li><p>A reduction to be performed (optional).<br><em>Reductions can be any associative operation (add/multiply/max/min), and they can optionally have a fused <a href="https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api_3x.html#cutlass-gemm-model">epilogue</a> as well.</em></p></li></ul><p>The pointwise operation is constructed from a pure expression tree, an <code>AluExp</code>, where each node is a symbolic <code>AluOp</code>. There are 28 AluOps &#8212; you don&#8217;t need so many distinct operations when you can depend on kernel fusion!</p><p>Note that no automatic differentiation happens here; these are pure low-level operations, so we can introduce arbitrary building blocks this way.</p><pre><code>/** Symbolic form for each mathematical operation. */
export enum AluOp {
  Add = &#8220;Add&#8221;,
  Sub = &#8220;Sub&#8221;,
  Mul = &#8220;Mul&#8221;,
  Idiv = &#8220;Idiv&#8221;,
  Mod = &#8220;Mod&#8221;,
  Min = &#8220;Min&#8221;,
  Max = &#8220;Max&#8221;,

  Sin = &#8220;Sin&#8221;,
  Cos = &#8220;Cos&#8221;,
  Asin = &#8220;Asin&#8221;,
  Atan = &#8220;Atan&#8221;,
  Exp = &#8220;Exp&#8221;,
  Log = &#8220;Log&#8221;,
  Erf = &#8220;Erf&#8221;,
  Erfc = &#8220;Erfc&#8221;,
  Sqrt = &#8220;Sqrt&#8221;,
  Reciprocal = &#8220;Reciprocal&#8221;,
  Cast = &#8220;Cast&#8221;,
  Bitcast = &#8220;Bitcast&#8221;,

  Cmplt = &#8220;Cmplt&#8221;,
  Cmpne = &#8220;Cmpne&#8221;,
  Where = &#8220;Where&#8221;, // Ternary operator: `cond ? a : b`

  Threefry2x32 = &#8220;Threefry2x32&#8221;, // PRNG operation, arg = &#8216;xor&#8217; | 0 | 1

  // Const is a literal constant, while GlobalIndex takes data from an array
  // buffer. Special and Variable are distinguished since the former is for
  // indices like the global invocation, while the latter is a value.
  Const = &#8220;Const&#8221;, // arg = value
  Special = &#8220;Special&#8221;, // arg = [variable, n]
  Variable = &#8220;Variable&#8221;, // arg = variable
  GlobalIndex = &#8220;GlobalIndex&#8221;, // arg = [gid, len]; src = [bufidx]
  GlobalView = &#8220;GlobalView&#8221;, // arg = [gid, ShapeTracker], src = [indices...]
}</code></pre><p>When auto-generating GPU kernels, they&#8217;re pretty simple for pointwise ops. The tricky part is if there&#8217;s a reduction (aka. <a href="https://en.wikipedia.org/wiki/Tensor_contraction">tensor contraction</a>), most commonly in matmuls and convolutions. These can be optimized pretty well on the web by unrolling judiciously and tiling the loads/stores.</p><p>An example WebGPU matmul kernel for <code>float32[4096,4096]</code> matrices generated by jax-js is shown below.</p><pre><code>@group(0) @binding(0) var&lt;storage, read&gt; in0 : array&lt;f32&gt;;
@group(0) @binding(1) var&lt;storage, read&gt; in1 : array&lt;f32&gt;;
@group(0) @binding(2) var&lt;storage, read_write&gt; result : array&lt;f32&gt;;

@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) id : vec3&lt;u32&gt;) {
  if (id.x &gt;= 1048576) { return; }
  let gidx: i32 = i32(id.x);
  var acc0: f32 = f32(0);
  var acc1: f32 = f32(0);
  var acc2: f32 = f32(0);
  var acc3: f32 = f32(0);
  var acc4: f32 = f32(0);
  var acc5: f32 = f32(0);
  var acc6: f32 = f32(0);
  var acc7: f32 = f32(0);
  var acc8: f32 = f32(0);
  var acc9: f32 = f32(0);
  var acc10: f32 = f32(0);
  var acc11: f32 = f32(0);
  var acc12: f32 = f32(0);
  var acc13: f32 = f32(0);
  var acc14: f32 = f32(0);
  var acc15: f32 = f32(0);
  for (var ridx: i32 = 0; ridx &lt; 1024; ridx++) {
    let x0: i32 = ((gidx / 8192) * 131072) + ((((gidx / 8) % 8) * 16384) + (ridx * 4));
    let x1: f32 = in0[x0];
    let x2: i32 = (((gidx / 64) % 128) * 32) + (((gidx % 8) * 4) + (ridx * 16384));
    let x3: f32 = in1[x2];
    let x4: f32 = in0[x0 + 1];
    let x6: f32 = in0[x0 + 2];
    let x8: f32 = in0[x0 + 3];
    let x10: f32 = in0[x0 + 4096];
    let x11: f32 = in0[x0 + 4097];
    let x12: f32 = in0[x0 + 4098];
    let x13: f32 = in0[x0 + 4099];
    let x14: f32 = in0[x0 + 8192];
    let x15: f32 = in0[x0 + 8193];
    let x16: f32 = in0[x0 + 8194];
    let x17: f32 = in0[x0 + 8195];
    let x18: f32 = in0[x0 + 12288];
    let x19: f32 = in0[x0 + 12289];
    let x20: f32 = in0[x0 + 12290];
    let x21: f32 = in0[x0 + 12291];
    let x22: f32 = in1[x2 + 1];
    let x26: f32 = in1[x2 + 2];
    let x30: f32 = in1[x2 + 3];
    let x5: f32 = in1[x2 + 4096];
    let x23: f32 = in1[x2 + 4097];
    let x27: f32 = in1[x2 + 4098];
    let x31: f32 = in1[x2 + 4099];
    let x7: f32 = in1[x2 + 8192];
    let x24: f32 = in1[x2 + 8193];
    let x28: f32 = in1[x2 + 8194];
    let x32: f32 = in1[x2 + 8195];
    let x9: f32 = in1[x2 + 12288];
    let x25: f32 = in1[x2 + 12289];
    let x29: f32 = in1[x2 + 12290];
    let x33: f32 = in1[x2 + 12291];
    acc0 += x1 * x3 + x4 * x5 + x6 * x7 + x8 * x9;
    acc1 += x10 * x3 + x11 * x5 + x12 * x7 + x13 * x9;
    acc2 += x14 * x3 + x15 * x5 + x16 * x7 + x17 * x9;
    acc3 += x18 * x3 + x19 * x5 + x20 * x7 + x21 * x9;
    acc4 += x1 * x22 + x4 * x23 + x6 * x24 + x8 * x25;
    acc5 += x10 * x22 + x11 * x23 + x12 * x24 + x13 * x25;
    acc6 += x14 * x22 + x15 * x23 + x16 * x24 + x17 * x25;
    acc7 += x18 * x22 + x19 * x23 + x20 * x24 + x21 * x25;
    acc8 += x1 * x26 + x4 * x27 + x6 * x28 + x8 * x29;
    acc9 += x10 * x26 + x11 * x27 + x12 * x28 + x13 * x29;
    acc10 += x14 * x26 + x15 * x27 + x16 * x28 + x17 * x29;
    acc11 += x18 * x26 + x19 * x27 + x20 * x28 + x21 * x29;
    acc12 += x1 * x30 + x4 * x31 + x6 * x32 + x8 * x33;
    acc13 += x10 * x30 + x11 * x31 + x12 * x32 + x13 * x33;
    acc14 += x14 * x30 + x15 * x31 + x16 * x32 + x17 * x33;
    acc15 += x18 * x30 + x19 * x31 + x20 * x32 + x21 * x33;
  }
  let x34: i32 = ((gidx / 8192) * 131072) + ((((gidx / 64) % 128) * 32) + ((((gidx / 8) % 8) * 16384) + ((gidx % 8) * 4)));
  result[x34] = acc0;
  result[x34 + 4096] = acc1;
  result[x34 + 8192] = acc2;
  result[x34 + 12288] = acc3;
  result[x34 + 1] = acc4;
  result[x34 + 4097] = acc5;
  result[x34 + 8193] = acc6;
  result[x34 + 12289] = acc7;
  result[x34 + 2] = acc8;
  result[x34 + 4098] = acc9;
  result[x34 + 8194] = acc10;
  result[x34 + 12290] = acc11;
  result[x34 + 3] = acc12;
  result[x34 + 4099] = acc13;
  result[x34 + 8195] = acc14;
  result[x34 + 12291] = acc15;
}</code></pre><p>If you&#8217;re writing a native library, this isn&#8217;t good enough. For example, you have to at least use tensor cores <code>mma.sync.aligned.*</code> on Nvidia GPUs! But on the web, it gets to pretty comparable performance with the best open-source libraries, and it seems that <a href="https://github.com/google/dawn">Dawn</a> is alright at bridging any gaps with optimization.</p><p><strong>Onto the frontend.</strong> This is the core of the library, and where the actual autograd and tracing happens. We follow the JAX design quite closely, where there is a set of primitives along with an ambient <em>interpreter stack</em>. This is&#8230; quite difficult, magical, and took me a while to figure out. To learn more see:</p><ul><li><p><a href="https://docs.jax.dev/en/latest/autodidax.html">Autodidax: JAX core from scratch</a> (2021)</p></li><li><p><a href="https://arxiv.org/abs/1804.00746">The simple essence of automatic differentiation</a> (Elliott 2018)</p></li></ul><p><em>(One particularly cool moment about this way of building an ML library is that you get reverse-mode AD &#8220;for free&#8221; by inverting/transposing the forward-mode rules. I found this really beautiful after I wrapped my head around it; it&#8217;s quite mathematically pleasing. Another cool moment is when you first get arbitrary 2nd, 3rd, &#8230; n-th order derivatives after just implementing the first-order derivative rules &#8212; GradientTape could never!)</em></p><p>Honestly this is probably the most lost I&#8217;ve ever felt in writing code. It&#8217;s like, nested mutually recursive interpreters to model functors in the &#8220;category of tensors.&#8221;</p><p>Anyway, once I reviewed my differential geometry notes from college and dusted off my understanding of tangents, pulling back cotangents, functors and so on, I think I eventually figured it out. Though I still had tiny bugs for the next 6 months. &#128514;</p><p>The list of high-level <code>Primitive</code> in jax-js is below:</p><pre><code>/**
 * Frontend primitive operations, which are lowered into Kernel objects before
 * being dispatched to the backend.
 *
 * Any operation between arrays can be described in these parts. This is also
 * the set of primitives that can occur in Jaxpr programs, and the level at
 * which transformations like vmap, grad, and jvp occur. They are loosely based
 * on [XLA](https://openxla.org/xla/operation_semantics).
 *
 * All n-ary operations support broadcasting, with NumPy semantics.
 */
export enum Primitive {
  Add = &#8220;add&#8221;,
  Mul = &#8220;mul&#8221;,
  Idiv = &#8220;idiv&#8221;,
  Neg = &#8220;neg&#8221;,
  Reciprocal = &#8220;reciprocal&#8221;,
  StopGradient = &#8220;stop_gradient&#8221;,
  Cast = &#8220;cast&#8221;,
  Bitcast = &#8220;bitcast&#8221;,
  RandomBits = &#8220;random_bits&#8221;,
  Sin = &#8220;sin&#8221;,
  Cos = &#8220;cos&#8221;,
  Asin = &#8220;asin&#8221;,
  Atan = &#8220;atan&#8221;,
  Exp = &#8220;exp&#8221;,
  Log = &#8220;log&#8221;,
  Erf = &#8220;erf&#8221;,
  Erfc = &#8220;erfc&#8221;,
  Sqrt = &#8220;sqrt&#8221;,
  Min = &#8220;min&#8221;,
  Max = &#8220;max&#8221;,
  Reduce = &#8220;reduce&#8221;,
  Dot = &#8220;dot&#8221;, // sum(x*y, axis=-1)
  Conv = &#8220;conv&#8221;, // see lax.conv_general_dilated
  Pool = &#8220;pool&#8221;,
  PoolTranspose = &#8220;pool_transpose&#8221;,
  Compare = &#8220;compare&#8221;,
  Where = &#8220;where&#8221;,
  Transpose = &#8220;transpose&#8221;,
  Broadcast = &#8220;broadcast&#8221;,
  Reshape = &#8220;reshape&#8221;,
  Flip = &#8220;flip&#8221;,
  Shrink = &#8220;shrink&#8221;,
  Pad = &#8220;pad&#8221;,
  Gather = &#8220;gather&#8221;,
  JitCall = &#8220;jit_call&#8221;,
}</code></pre><p>Notice that many of these are similar to the backend operations above, but some are different. In particular, there are convolutions and matrix multiplications here. These are useful to see in the frontend IR (and for autograd) but can be lowered to a simpler form before the kernels are generated on the backend.</p><p>By default, an operation is just lowered directly to a backend kernel after passing through any necessary transformations (<code>vmap</code>, <code>jvp</code>, <code>grad</code>). But if you&#8217;re using the <code>jit</code>, jax-js will trace your program to produce a &#8220;Jaxpr&#8221; (list of operations) followed by <a href="https://substack.com/home/post/p-163548742">automatic kernel fusion</a> to generate kernels, specialized to each input shape.</p><h3>Bugs</h3><p>It&#8217;s very hard to build an ML framework and a long task! So far, jax-js has implemented a lot of core functionality in JAX, but there&#8217;s still much more. If there&#8217;s an API or operation you want to see, please consider adding it or filing an issue (examples: np.split, FFT, AdamW).</p><p>I have a pretty varied, portable test suite that runs fast:</p><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!E_s1!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fee5101ba-af11-4ba0-b041-0b5fd0624d32_1324x926.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!E_s1!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fee5101ba-af11-4ba0-b041-0b5fd0624d32_1324x926.png 424w, https://substackcdn.com/image/fetch/$s_!E_s1!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fee5101ba-af11-4ba0-b041-0b5fd0624d32_1324x926.png 848w, https://substackcdn.com/image/fetch/$s_!E_s1!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fee5101ba-af11-4ba0-b041-0b5fd0624d32_1324x926.png 1272w, https://substackcdn.com/image/fetch/$s_!E_s1!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fee5101ba-af11-4ba0-b041-0b5fd0624d32_1324x926.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!E_s1!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fee5101ba-af11-4ba0-b041-0b5fd0624d32_1324x926.png" width="566" height="395.8580060422961" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/ee5101ba-af11-4ba0-b041-0b5fd0624d32_1324x926.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:926,&quot;width&quot;:1324,&quot;resizeWidth&quot;:566,&quot;bytes&quot;:252710,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:&quot;image/png&quot;,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:&quot;https://ekzhang.substack.com/i/179060245?img=https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fee5101ba-af11-4ba0-b041-0b5fd0624d32_1324x926.png&quot;,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!E_s1!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fee5101ba-af11-4ba0-b041-0b5fd0624d32_1324x926.png 424w, https://substackcdn.com/image/fetch/$s_!E_s1!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fee5101ba-af11-4ba0-b041-0b5fd0624d32_1324x926.png 848w, https://substackcdn.com/image/fetch/$s_!E_s1!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fee5101ba-af11-4ba0-b041-0b5fd0624d32_1324x926.png 1272w, https://substackcdn.com/image/fetch/$s_!E_s1!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fee5101ba-af11-4ba0-b041-0b5fd0624d32_1324x926.png 1456w" sizes="100vw" loading="lazy"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a></figure></div><p>So we are in a good position to find bugs and fix them. But making an ML library is quite difficult, and WebGPU is a nascent technology (e.g., I somehow gave my MacBook <a href="https://x.com/ekzhang1/status/1957107092868727225">kernel panics</a>)&#8212;there will be bugs! Please report.</p><h3>Technical: Performance</h3><p>We haven&#8217;t spent a ton of time optimizing yet, but performance is generally pretty good. <code>jit</code> is very helpful for fusing operations together, and it&#8217;s a feature only available on the web in jax-js. The default kernel-tuning heuristics get about 3000 GFLOP/s for matrix multiplication on an M4 Pro chip (<a href="https://jax-js.com/bench/matmul">try it</a>).</p><p>On that specific benchmark, it&#8217;s actually more GFLOP/s than both <a href="https://github.com/tensorflow/tfjs">TensorFlow.js</a> and <a href="https://www.npmjs.com/package/onnxruntime-web">ONNX</a>, which both use handwritten libraries of custom kernels (versus jax-js, which generates kernels with an ML compiler).</p><p>Some particularly useful / low-hanging fruit to look at:</p><ul><li><p>The WebAssembly backend currently is quite simple, I didn&#8217;t spend a ton of time optimizing it, but measurably it could be &gt;150x faster on my MacBook Pro. This difference comes from a few things multiplying:</p><ul><li><p>Don&#8217;t recompute loop indices each time, we could improve FLOPs by ~1-3x.</p></li><li><p>Do loop unrolling/tiling, will improve FLOPs by ~2-3x.</p></li><li><p>Use SIMD instructions. This would improve FLOPs by 4x.</p></li><li><p>Add multi-threading (10x on my laptop), to use all available cores. Requires SharedArrayBuffer (<a href="https://developer.mozilla.org/en-US/docs/Web/API/Window/crossOriginIsolated">crossOriginIsolated</a>) / there are some caveats here to sync/async handling, needs to be done carefully.</p></li></ul></li><li><p>Running the forward pass of the MobileCLIP2 transformer model is only about 1/3 the FLOPs compared to pure 4096x4096 matmul. Maybe we can improve this, especially in the causal self-attention layer.</p></li><li><p>Although WebGPU is rapidly gaining in popularity <a href="https://caniuse.com/webgpu">and support</a>, it&#8217;s probably worth having a WebGL backend as well, as a fallback that&#8217;s guaranteed to work in pretty much all browsers and is still pretty fast. This isn&#8217;t a huge amount of work; the WebGPU backend is &lt;700 lines of code for example.</p></li></ul><h3>Technical: Feature parity</h3><p>jax-js strives for <em>approximate</em> API compatibility with the JAX python library (and through that, NumPy). But some features vary for a few reasons:</p><ol><li><p><strong>Data model:</strong> jax-js has <em>ownership</em> of arrays using the <code>.ref</code> system, which obviates the need for APIs like <code>jit()</code>&#8216;s <code>donate_argnums</code> and <code>numpy.asarray()</code>.</p></li><li><p><strong>Language primitives:</strong> JavaScript has no named arguments, so method call signatures may take objects instead of Python&#8217;s keyword arguments. Also, PyTrees are translated in spirit to &#8220;JsTree&#8221; in jax-js, but their specification is different.</p></li><li><p><strong>Maturity:</strong> JAX has various types like <code>complex64</code>, advanced functions like <code>hessenberg()</code>, and advanced higher-order features like <code>lax.while_loop()</code> that we haven&#8217;t implemented. Some of these are not easy to implement on GPU.</p></li></ol><p>Other features just aren&#8217;t implemented yet. But those can probably be added easily!</p><p>I&#8217;ve made a table of every JAX library feature and its implementation status in jax-js, <a href="https://github.com/ekzhang/jax-js/blob/main/FEATURES.md">see here</a>. There are a couple big ones that stand out.</p><p>You&#8217;re welcome to contribute, though I&#8217;d also love if you could try using jax-js. :D</p>]]></content:encoded></item><item><title><![CDATA[ssh-hypervisor: "SimCity for VMs"]]></title><description><![CDATA[Tackling a larger systems programming project with AI tools.]]></description><link>https://ss.ekzhang.com/p/ssh-hypervisor-simcity-for-vms</link><guid isPermaLink="false">https://ss.ekzhang.com/p/ssh-hypervisor-simcity-for-vms</guid><dc:creator><![CDATA[Eric]]></dc:creator><pubDate>Tue, 23 Sep 2025 05:11:28 GMT</pubDate><enclosure url="https://substackcdn.com/image/fetch/$s_!f9Sm!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F18ec47a4-0854-4ab0-bee6-c9562bdfb29e_1988x768.png" length="0" type="image/jpeg"/><content:encoded><![CDATA[<p>This weekend I tried to make <a href="https://github.com/ekzhang/ssh-hypervisor">a hypervisor hooked up to SSH</a>. It&#8217;s like:</p><pre><code>ssh &lt;YOUR_NAME&gt;@vmcity.ekzhang.com</code></pre><p>But every time someone logs in with a different name, instead of being a user on the host machine, it greets you and then spins up a <em>virtual machine</em> with <a href="https://firecracker-microvm.github.io/">Firecracker</a>.</p><pre><code>$ ssh eric@vmcity.ekzhang.com
<strong>
Hello, eric! &#127800;</strong>

Today is Sunday. It's your first time here.

Recent logins:
&#9484;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9516;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9488;
&#9474;  User   &#9474;  Last login  &#9474;
&#9500;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9532;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9508;
&#9474; matthew &#9474; 2 hours ago  &#9474;
&#9474; kathy   &#9474; 4 hours ago  &#9474;
&#9474; linus   &#9474; 16 hours ago &#9474;
&#9474; sen     &#9474; 4 days ago   &#9474;
&#9492;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9524;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9496;

<strong>Booting up your fresh VM:</strong>
&#128161; &#9646;&#9646;&#9646;&#9647;&#9647;&#9647;&#9647;&#9647;&#9647;&#9647;&#9647;&#9647; 25%</code></pre><p>If you haven&#8217;t logged in for a while, we store your VM in a <a href="https://github.com/firecracker-microvm/firecracker/blob/main/docs/snapshotting/snapshot-support.md">snapshot</a>.</p><p><em>This isn&#8217;t an original idea, by the way!</em> I had seen this somewhere online, with a person showing off their tiny OS with Firecracker microVMs over public SSH. Unfortunately I don&#8217;t remember where I saw this, but I wanted to take this idea and make it a bit whimsical, while adding a couple toy features.</p><blockquote><p>Update: A commenter shared the project <a href="https://github.com/nuta/kerla">https://github.com/nuta/kerla</a></p></blockquote><p>Back in high school and college, I used to make a lot of smaller, fun projects <a href="https://www.ekzhang.com/projects">over the weekend</a> and share them with people. I don&#8217;t do this as much now with a job. These tiny projects became less interesting as I grew familiar with systems; more implementation-heavy rather than new ideas.</p><p>I think that&#8217;s sad though. This project would maybe have taken me 1-2 weeks in the past, so I was hoping that with AI tools, I could do it in just a weekend (<a href="https://matklad.github.io/2025/08/31/vibe-coding-terminal-editor.html">inspiration</a>). Then I can spend time on more frivolous projects. I still get ideas all the time. This is one of them, let&#8217;s just build it, see where it goes and let my creative side take control!</p><h2>What is a hypervisor?</h2><p>I saw <a href="https://seiya.me/blog/hypervisor-in-1000-lines">this quote</a> recently that sums it up well:</p><blockquote><p><strong>Hypervisor is essentially a hardware-assisted catch block<br><br></strong>This is all what I want you to learn from this book. Hardware-assisted hypervisors are event handlers. They are not like a CPU emulator.<br><br>In JavaScript, the life of a hypervisor looks like this:</p><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!XW68!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb700df60-5515-4025-90fe-21b5859c8c4e_2020x808.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!XW68!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb700df60-5515-4025-90fe-21b5859c8c4e_2020x808.png 424w, https://substackcdn.com/image/fetch/$s_!XW68!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb700df60-5515-4025-90fe-21b5859c8c4e_2020x808.png 848w, https://substackcdn.com/image/fetch/$s_!XW68!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb700df60-5515-4025-90fe-21b5859c8c4e_2020x808.png 1272w, https://substackcdn.com/image/fetch/$s_!XW68!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb700df60-5515-4025-90fe-21b5859c8c4e_2020x808.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!XW68!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb700df60-5515-4025-90fe-21b5859c8c4e_2020x808.png" width="606" height="242.2335164835165" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/b700df60-5515-4025-90fe-21b5859c8c4e_2020x808.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:582,&quot;width&quot;:1456,&quot;resizeWidth&quot;:606,&quot;bytes&quot;:156659,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:&quot;image/png&quot;,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:&quot;https://ekzhang.substack.com/i/174157700?img=https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb700df60-5515-4025-90fe-21b5859c8c4e_2020x808.png&quot;,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!XW68!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb700df60-5515-4025-90fe-21b5859c8c4e_2020x808.png 424w, https://substackcdn.com/image/fetch/$s_!XW68!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb700df60-5515-4025-90fe-21b5859c8c4e_2020x808.png 848w, https://substackcdn.com/image/fetch/$s_!XW68!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb700df60-5515-4025-90fe-21b5859c8c4e_2020x808.png 1272w, https://substackcdn.com/image/fetch/$s_!XW68!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb700df60-5515-4025-90fe-21b5859c8c4e_2020x808.png 1456w" sizes="100vw" loading="lazy"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a></figure></div><p>A hypervisor runs the guest OS in a <code>try</code> block, <code>catch</code>es events (VM exits), and goes back to the guest mode again.</p></blockquote><p>I want to keep this in mind while working through the project. Firecracker is a very lightweight hypervisor, and they spin up &#8220;microVMs&#8221; &#8212; since hypervisors are catch-blocks, that essentially means the catch-block is small. Firecracker <a href="https://unixism.net/2019/10/how-aws-firecracker-works-a-deep-dive/#:~:text=Firecracker%20is%20a%20VM%20environment,()%20%2C%20cgroups%20and%20seccomp%20rules.">only emulates a few devices</a> and relies on host features for as much as possible. This makes it really fast to boot compared to QEMU.</p><p>However, this <em>doesn&#8217;t mean</em> that Firecracker is any simpler to set up than other hypervisors. You still need to hook up all the parts of a virtual computer in the right places to get things working! For instance:</p><ul><li><p>Bring your own init system like OpenRC / Systemd.</p></li><li><p>Attach a kernel ramfs, disk at startup.</p></li><li><p>Want network? Set up a MAC address, TAP device, bridge, IP routing rules, firewall filters, packet forwarding, and so on.</p><ul><li><p>Want multiple VMs? Create a network bridge, set the controller of the TAP to that bridge, allocate private IPs from a pool, dynamically configure iptables.</p></li></ul></li><li><p>Want serial logs? Edit your kernel boot arguments to send them at a baud rate over the /dev/console TTY.</p></li></ul><p>It&#8217;s a good reminder that VMs are tiny little computers that live in your own. When you start up VMs, you&#8217;re building up your own computer from scratch!</p><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!0mH3!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F111fb6f5-abdb-4cac-b67c-50a892287e6e_6296x2226.jpeg" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!0mH3!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F111fb6f5-abdb-4cac-b67c-50a892287e6e_6296x2226.jpeg 424w, https://substackcdn.com/image/fetch/$s_!0mH3!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F111fb6f5-abdb-4cac-b67c-50a892287e6e_6296x2226.jpeg 848w, https://substackcdn.com/image/fetch/$s_!0mH3!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F111fb6f5-abdb-4cac-b67c-50a892287e6e_6296x2226.jpeg 1272w, https://substackcdn.com/image/fetch/$s_!0mH3!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F111fb6f5-abdb-4cac-b67c-50a892287e6e_6296x2226.jpeg 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!0mH3!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F111fb6f5-abdb-4cac-b67c-50a892287e6e_6296x2226.jpeg" width="6296" height="2226" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/111fb6f5-abdb-4cac-b67c-50a892287e6e_6296x2226.jpeg&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:2226,&quot;width&quot;:6296,&quot;resizeWidth&quot;:null,&quot;bytes&quot;:2102976,&quot;alt&quot;:&quot;M4 MacBook Air Teardown: Apple, When Will MacBooks Finally Get Repair  Upgrades? - iFixit&quot;,&quot;title&quot;:null,&quot;type&quot;:&quot;image/jpeg&quot;,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:null,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="M4 MacBook Air Teardown: Apple, When Will MacBooks Finally Get Repair  Upgrades? - iFixit" title="M4 MacBook Air Teardown: Apple, When Will MacBooks Finally Get Repair  Upgrades? - iFixit" srcset="https://substackcdn.com/image/fetch/$s_!0mH3!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F111fb6f5-abdb-4cac-b67c-50a892287e6e_6296x2226.jpeg 424w, https://substackcdn.com/image/fetch/$s_!0mH3!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F111fb6f5-abdb-4cac-b67c-50a892287e6e_6296x2226.jpeg 848w, https://substackcdn.com/image/fetch/$s_!0mH3!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F111fb6f5-abdb-4cac-b67c-50a892287e6e_6296x2226.jpeg 1272w, https://substackcdn.com/image/fetch/$s_!0mH3!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F111fb6f5-abdb-4cac-b67c-50a892287e6e_6296x2226.jpeg 1456w" sizes="100vw" loading="lazy"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a><figcaption class="image-caption">Apple M4. Look at that wild chip. Virtualization is taming the beast. (<a href="https://www.ifixit.com/News/108697/m4-macbook-air-teardown-apple-when-will-macbooks-finally-get-repair-upgrades">source</a>)</figcaption></figure></div><p>What I found in this project is that, while AI tools made coding a lot faster (thousands of lines in minutes), they didn&#8217;t speed up the debugging process all that much. This, understandably, ended up being a big part of the work. &#128517;</p><h2>Setting up the project</h2><p>First step is to decide what we&#8217;re building. I&#8217;m going to use Go because:</p><ul><li><p>It is a simple, relatively easy systems programming language that compiles fast.</p><ul><li><p>This makes AI is very good at writing Go code.</p></li></ul></li><li><p>It has good support for Firecracker via the <a href="https://github.com/firecracker-microvm/firecracker-go-sdk">official Go SDK</a>.</p></li><li><p>It compiles to native, statically-linked executables (without <a href="https://go.dev/wiki/cgo">cgo</a>). I <em>could</em> deal with glibc / musl issues, but that&#8217;s just one more dimension added to an already tricky setup with kernels and virtualization.</p></li></ul><p>So we initialize the project, and I give the AI some basic direction. To start up the VM, we&#8217;ll embed a static firecracker binary + vmlinux inside our own binary during the build process, which can be spooled to a tempfile and <a href="https://en.wikipedia.org/wiki/Exec_(system_call)">execv</a>&#8217;d.</p><p>This will live in a <code>Manager</code> on the Go side that creates VMs and sets up networking. Then, we build a custom SSH server (not relying on OpenSSH) to forward connections to a VM. Here&#8217;s what the AI spit back out at me:</p><pre><code>Architecture Overview

ssh-hypervisor is a Go-based SSH server that dynamically provisions
Firecracker microVMs. When users SSH in, they get a fresh or restored VM
 instance.

Core Components

1. Network Architecture
- Host SSH server (using Wish library) on configurable port (default
2222)
- TAP devices provide VM network connectivity (configurable CIDR,
default 192.168.100.0/24)
- SSH channel proxying: Host SSH channels &#8594; TAP network &#8594; Guest SSH
server
- Health check: Poll guest SSH port (22) to detect VM boot completion

2. VM Lifecycle
- Provisioning: Extract embedded Firecracker binary, configure VM
(memory/CPU), boot with TAP networking
- Snapshots: Idle VMs suspended to disk, restored on user reconnection
- Guest OS: Standard Linux with sshd, handles all SSH protocol
complexity (SFTP, port forwarding, etc.)

3. Code Structure
cmd/main.go              # CLI entry point, Linux+amd64/arm64 only
internal/
  config.go              # Configuration &amp; validation
  server/server.go       # SSH server implementation (TODO)
scripts/download-firecracker.sh  # Downloads &amp; embeds Firecracker binary
binaries/firecracker     # Embedded static binary

4. Key Features
- Static binary: Firecracker embedded via go:embed, no external
dependencies
- Build constraints: Prevents compilation on unsupported platforms
- Configuration: CLI flags for port, VM specs, network CIDR, data
directory
- Channel transparency: Host becomes invisible proxy between client SSH
and guest SSH

Next Steps

1. Integrate Wish SSH server
2. Implement VM management &amp; Firecracker integration
3. Build TAP networking &amp; health checks
4. Add snapshot/restore logic</code></pre><p>Some things about this summary, I think are a bit overcomplicated. But overall the direction seems fine for now.</p><p>It sets up the scaffold, I rearrange some files manually to get it working and to fix the //go:embed paths. Alright, this seems like it will work.</p><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!Ls7c!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F5d476370-c94a-4dc8-998c-301ef431037d_1180x888.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!Ls7c!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F5d476370-c94a-4dc8-998c-301ef431037d_1180x888.png 424w, https://substackcdn.com/image/fetch/$s_!Ls7c!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F5d476370-c94a-4dc8-998c-301ef431037d_1180x888.png 848w, https://substackcdn.com/image/fetch/$s_!Ls7c!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F5d476370-c94a-4dc8-998c-301ef431037d_1180x888.png 1272w, https://substackcdn.com/image/fetch/$s_!Ls7c!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F5d476370-c94a-4dc8-998c-301ef431037d_1180x888.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!Ls7c!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F5d476370-c94a-4dc8-998c-301ef431037d_1180x888.png" width="514" height="386.80677966101695" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/5d476370-c94a-4dc8-998c-301ef431037d_1180x888.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:888,&quot;width&quot;:1180,&quot;resizeWidth&quot;:514,&quot;bytes&quot;:164061,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:&quot;image/png&quot;,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:&quot;https://ekzhang.substack.com/i/174157700?img=https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F5d476370-c94a-4dc8-998c-301ef431037d_1180x888.png&quot;,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!Ls7c!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F5d476370-c94a-4dc8-998c-301ef431037d_1180x888.png 424w, https://substackcdn.com/image/fetch/$s_!Ls7c!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F5d476370-c94a-4dc8-998c-301ef431037d_1180x888.png 848w, https://substackcdn.com/image/fetch/$s_!Ls7c!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F5d476370-c94a-4dc8-998c-301ef431037d_1180x888.png 1272w, https://substackcdn.com/image/fetch/$s_!Ls7c!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F5d476370-c94a-4dc8-998c-301ef431037d_1180x888.png 1456w" sizes="100vw" loading="lazy"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a></figure></div><h2>Getting Firecracker to run &amp; SSH</h2><p>The hardest part of this project will be getting the VM to run and be accessible by SSH. So before I add more complexity, we should figure this out first. I ask the AI to comment-out the server startup code in the CLI entrypoint, and temporarily just have the entrypoint start up a machine &#8212; this works.</p><p>I need to make some manual changes to get logging with logrus working, as well as other things that the AI can&#8217;t figure out:</p><ul><li><p>Adding <code>syscall.SysProcAttr{ Setpgid: true }</code> to the process and providing the <code>firecracker.WithProcessRunner()</code> option. The latter specifies the spooled Firecracker binary path, and the former is needed so that Ctrl+C on <a href="https://www.man7.org/linux/man-pages/man2/setpgid.2.html">controlling terminal</a> (server) doesn&#8217;t also interrupt the Firecracker subprocess prematurely before Go code can gracefully shutdown.</p></li><li><p>Setting up the network of the machine using auto-generated MAC addresses, passing an ip=&#8230; option to the kernel boot to configure its eth0 network interface with the proper gateway and netmask. This was based on <a href="https://gist.github.com/jvns/9b274f24cfa1db7abecd0d32483666a3">Julia Evans&#8217;s gist</a>.</p></li><li><p>Creating a network bridge on startup and assigning the VM&#8217;s TAP device to that bridge. Also, setting up CI in GitHub Actions (with kvm support) and figuring out CAP_NET_ADMIN permissions to run the aforementioned setup.</p></li></ul><p>So now it&#8217;s working. I can run the binary with a provided rootfs (built from an Alpine image with customization) and it starts a full VM! I can see the serial logs too.</p><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!f9Sm!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F18ec47a4-0854-4ab0-bee6-c9562bdfb29e_1988x768.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!f9Sm!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F18ec47a4-0854-4ab0-bee6-c9562bdfb29e_1988x768.png 424w, https://substackcdn.com/image/fetch/$s_!f9Sm!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F18ec47a4-0854-4ab0-bee6-c9562bdfb29e_1988x768.png 848w, https://substackcdn.com/image/fetch/$s_!f9Sm!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F18ec47a4-0854-4ab0-bee6-c9562bdfb29e_1988x768.png 1272w, https://substackcdn.com/image/fetch/$s_!f9Sm!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F18ec47a4-0854-4ab0-bee6-c9562bdfb29e_1988x768.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!f9Sm!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F18ec47a4-0854-4ab0-bee6-c9562bdfb29e_1988x768.png" width="1456" height="562" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/18ec47a4-0854-4ab0-bee6-c9562bdfb29e_1988x768.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:562,&quot;width&quot;:1456,&quot;resizeWidth&quot;:null,&quot;bytes&quot;:256554,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:&quot;image/png&quot;,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:&quot;https://ekzhang.substack.com/i/174157700?img=https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F18ec47a4-0854-4ab0-bee6-c9562bdfb29e_1988x768.png&quot;,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!f9Sm!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F18ec47a4-0854-4ab0-bee6-c9562bdfb29e_1988x768.png 424w, https://substackcdn.com/image/fetch/$s_!f9Sm!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F18ec47a4-0854-4ab0-bee6-c9562bdfb29e_1988x768.png 848w, https://substackcdn.com/image/fetch/$s_!f9Sm!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F18ec47a4-0854-4ab0-bee6-c9562bdfb29e_1988x768.png 1272w, https://substackcdn.com/image/fetch/$s_!f9Sm!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F18ec47a4-0854-4ab0-bee6-c9562bdfb29e_1988x768.png 1456w" sizes="100vw" loading="lazy"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a></figure></div><p><strong>Unfortunately, I can&#8217;t SSH into the machine.</strong> Something is wrong. And there are no serial logs for the sshd daemon with OpenRC. <strong>&#128557;</strong></p><pre><code>$ ssh root@192.168.100.2
ssh: connect to host 192.168.100.2 port 22: Connection refused</code></pre><p>So we begin the network debugging, yet again. Let&#8217;s see if packets are reaching the destination at least.</p><pre><code>$ ping 192.168.100.2
PING 192.168.100.2 (192.168.100.2) 56(84) bytes of data.
64 bytes from 192.168.100.2: icmp_seq=1 ttl=64 time=0.401 ms
64 bytes from 192.168.100.2: icmp_seq=2 ttl=64 time=0.359 ms

$ tcptraceroute 192.168.100.2
Selected device sshvm-br0, address 192.168.100.1, port 52619 for outgoing packets
Tracing the path to 192.168.100.2 on TCP port 80 (http), 30 hops max
 1  192.168.100.2 [closed]  0.430 ms  0.243 ms  0.287 ms</code></pre><p>Ok, &#8230; so that works. Looks like ICMP and TCP packets are both reaching the VM&#8217;s destination IP address on the bridge, but port 22 is still not accepting connections.</p><p>At this point, the issue can be broken up into two possibilities:</p><ol><li><p><strong>sshd is not starting in the guest VM.</strong> We have no logs, so it could just be not listening at all, and then the host isn&#8217;t reaching port 22 of course.</p></li><li><p><strong>There is some kind of networking problem.</strong> I find this less likely because &#8220;Connection refused&#8221; (ECONNREFUSED) usually means that a server <em>actively rejected</em> a connection attempt due to a port not being open; network issues usually show up as timeout / no reachable route.</p></li></ol><p>Still, both are possible, so I think we should try and figure out first which category it&#8217;s in. So I will try and see if sshd is indeed starting up.</p><p>At this point, I copy-pasted everything back into Claude Code and had it try to figure things out. It flailed around for a while, luckily it&#8217;s able to undo its own work.</p><p>Okay, let&#8217;s go back to the basics. How do we get sshd to run on init? I think the immediate issue for observability is that I can&#8217;t SSH into the Firecracker machine and figure out why it&#8217;s not running. This turns out to be a whole rabbit hole:</p><ul><li><p>The next step is to debug why &#8220;<a href="https://www.man7.org/linux/man-pages/man8/agetty.8.html">agetty</a>&#8221; is not running on startup, but I can&#8217;t figure this out either.</p></li><li><p>So I add agetty to my VM&#8217;s inittab instead of as an OpenRC service.</p><pre><code>cat &gt; /etc/inittab &lt;&lt;'EOF'
ttyS0::respawn:/sbin/agetty -L 115200 ttyS0 linux
EOF</code></pre></li><li><p>But even after that, I can&#8217;t login with the serial console as root! It&#8217;s not supporting my password that I set with chpasswd. I search this up, maybe the issue is using busybox login, so I update to util-linux-login, no dice though.</p></li><li><p>I&#8217;m not really a sysadmin. The only thing that worked so far is opening up a &#8220;rescue shell&#8221; by setting init=/bin/sh on login, so I&#8217;ll just try that next.</p></li><li><p>This does work! I can&#8217;t run sshd still, but running <code>nc -l -p 42</code> in the guest shell and <code>nc 192.168.100.2 42</code> outside establishes network communication between the host and guest, so that&#8217;s great. <strong>The problem is no longer in the network</strong>, but in sshd itself. &#129395;</p></li><li><p>I&#8217;m kind of tired of working in a VM. So I see if I can get sshd working in a simple Alpine container instead, which this VM is based on.</p><pre><code>$ docker run -it --rm -p 1022:22 alpine sh

/ # apk add --no-cache util-linux util-linux-login openssh
...
OK: 21 MiB in 63 packages

/ # apk add --no-cache openrc
...
OK: 23 MiB in 71 packages

/ # echo "root:root" | chpasswd
chpasswd: password for 'root' changed

/ # sed -i 's/^#PermitRootLogin.*/PermitRootLogin yes/' /etc/ssh/sshd_config

/ # ssh-keygen -A
ssh-keygen: generating new host keys: RSA ECDSA ED25519 

/ # /usr/sbin/sshd -D -e
Server listening on 0.0.0.0 port 22.
Server listening on :: port 22.</code></pre><p>And well, yes it works. I can ssh into the container from port 1022. Well okay, so something is very different about the VM, and <strong>it is causing sshd to hang</strong>.</p></li><li><p>I&#8217;m very confused, but the very next thing I try fixes the issue. <strong>The issue is entropy.</strong> If I run <code>cat /proc/sys/kernel/random/entropy_avail</code>, it only has a few bits of random entropy available! So the operating system blocks on reading random initial state (needed for cryptography in sshd). This is because Firecracker does not provide a virtio-rng device by default.</p><ul><li><p>Originally validated this by installing rngd and running it in the background manually, which causes sshd to work.</p></li><li><p>But it&#8217;s probably better to produce actual entropy, so I&#8217;ll add virtio-rng now.</p></li><li><p>Nevermind, the Firecracker Go SDK <a href="https://github.com/firecracker-microvm/firecracker-go-sdk/issues/505">doesn&#8217;t support adding an entropy device</a> during machine creation. Maybe there&#8217;s a raw way to do it?</p></li><li><p>Eh&#8230; this is not worth spending more time on, will just use rngd.</p></li></ul></li></ul><p>It works now! Yes! Turns out that running VMs isn&#8217;t just like Docker, it&#8217;s being your own sysadmin but even more difficult than usual. :&#8217;)</p><p><em>During this whole debugging session (5+ hours), I asked ChatGPT a lot of stuff. Gave up on Claude Code since it kept making changes. The AI very confidently guided me toward directions that didn&#8217;t work, and it gave me a lot of false hope. But it did eventually find the issue, which was the lack of random entropy causing silent blocking, which I wouldn&#8217;t have found otherwise without Google search or strace. I think it probably saved time overall?</em></p><p>Then, I spent another hour trying to get this working with OpenRC. It does not work. I&#8217;m just going to call it quits and use <a href="https://unix.stackexchange.com/questions/34462/why-does-linux-allow-init-bin-bash">bash as my init process</a>, oh well.</p><p>And then! It&#8217;s working now, but SSH still takes 6 seconds to start up.</p><h2>virtio-rng and building my own vmlinux</h2><p>Remember the entropy device from earlier? I still have this rngd hack in my init script that initializes fake entropy:</p><pre><code>rngd -f -r /dev/urandom &amp;</code></pre><p>Lately it&#8217;s become clear that this is a bad idea, and it adds exactly 5 seconds to VM startup for starting a &#8220;jitter&#8221; generator, which makes the time between boot and getting a shell ~4x slower. I work out how to add an entropy device by manually hitting the Firecracker HTTP endpoint, but it&#8217;s still not appearing as /dev/hwrng on the guest.</p><p>I think this is because the guest kernel that I&#8217;m using is from the <a href="https://s3.amazonaws.com/spec.ccfc.min/img/quickstart_guide/x86_64/kernels/vmlinux.bin">quickstart_guide</a> public bucket in S3, and it&#8217;s a very old Linux 4.14 image without many devices. Or maybe not. In any case, if I have a newer Linux version then <code>random.trust_cpu</code> (introduced in Linux 4.19) will be respected, and I shouldn&#8217;t have a problem either way since it can rely on hardware RNG instructions.</p><p>So I try to build an image <a href="https://github.com/firecracker-microvm/firecracker/tree/main/resources/guest_configs">based on Linux 6.1</a>, and I run into&#8212;problems!</p><pre><code>[   13.203150] clk: Disabling unused clocks
[   13.207664] /dev/root: Can't open blockdev
[   13.209652] VFS: Cannot open root device "vda" or unknown-block(0,0): error -6
[   13.212952] Please append a correct "root=" boot option; here are the available partitions:
[   13.216454] Kernel panic - not syncing: VFS: Unable to mount root fs on unknown-block(0,0)
[   13.219839] CPU: 0 PID: 1 Comm: swapper/0 Not tainted 6.1.153 #2
[   13.220411] Call Trace:
[   13.220411]  &lt;TASK&gt;
[   13.220411]  show_stack+0x3a/0x40
[   13.220411]  dump_stack_lvl+0x3d/0x51
[   13.220411]  dump_stack+0x10/0x16
[   13.220411]  panic+0x100/0x297
[   13.220411]  mount_block_root+0x13e/0x1d9
[   13.220411]  mount_root+0x117/0x138
[   13.220411]  prepare_namespace+0x135/0x16a
[   13.220411]  kernel_init_freeable+0x166/0x188
[   13.220411]  ? rest_init+0xc0/0xc0
[   13.220411]  kernel_init+0x15/0x120
[   13.220411]  ret_from_fork+0x1f/0x30
[   13.220411]  &lt;/TASK&gt;
[   13.220411] Kernel Offset: disabled
[   13.220411] Rebooting in 1 seconds..
2025-09-23T00:36:38.075124657 [anonymous-instance:main] Vmm is stopping.
2025-09-23T00:36:38.075549933 [anonymous-instance:main] Vmm is stopping.
2025-09-23T00:36:38.090112389 [anonymous-instance:main] Firecracker exiting successfully. exit_code=0</code></pre><p>This is a giant pain. The &#8220;Cannot open root device&#8221; is a completely useless error message that could mean any number of things, whether APIC issues or uninitialized modules, or even Firecracker bugs. The AI is equally confused.</p><p>I spend about an hour stuck on this for a while, trying different Linux and Firecracker versions and <a href="https://www.kernelconfig.io/index.html">flipping kernel configs</a> on/off.</p><p>At this point, it&#8217;s Monday. So I go to work. And while I&#8217;m on the train there, I do some Googling and find this bit from <a href="https://github.com/firecracker-microvm/firecracker/blob/main/docs/kernel-policy.md">kernel-policy.md</a>:</p><blockquote><p>We use these configurations to build microVM-specific kernels vended by <strong>Amazon Linux</strong> &#8230; As a result, kernel configurations found in this repo <strong>should be used to build exclusively the aforementioned Amazon Linux kernels</strong>. We do not guarantee that using these configurations to build upstream kernels, will work or produce usable kernel images.</p></blockquote><p>&#128557;</p><p>Okay, so that&#8217;s it. I need to use Amazon Linux, and then it will work, right? Of course the people at Amazon would use their own Linux fork. So it&#8217;s back to the AI.</p><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!6NRe!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fc7c66803-7b12-46f8-a054-3d49d0afe41f_1136x948.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!6NRe!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fc7c66803-7b12-46f8-a054-3d49d0afe41f_1136x948.png 424w, https://substackcdn.com/image/fetch/$s_!6NRe!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fc7c66803-7b12-46f8-a054-3d49d0afe41f_1136x948.png 848w, https://substackcdn.com/image/fetch/$s_!6NRe!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fc7c66803-7b12-46f8-a054-3d49d0afe41f_1136x948.png 1272w, https://substackcdn.com/image/fetch/$s_!6NRe!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fc7c66803-7b12-46f8-a054-3d49d0afe41f_1136x948.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!6NRe!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fc7c66803-7b12-46f8-a054-3d49d0afe41f_1136x948.png" width="508" height="423.92957746478874" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/c7c66803-7b12-46f8-a054-3d49d0afe41f_1136x948.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:948,&quot;width&quot;:1136,&quot;resizeWidth&quot;:508,&quot;bytes&quot;:163059,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:&quot;image/png&quot;,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:&quot;https://ekzhang.substack.com/i/174157700?img=https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fc7c66803-7b12-46f8-a054-3d49d0afe41f_1136x948.png&quot;,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!6NRe!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fc7c66803-7b12-46f8-a054-3d49d0afe41f_1136x948.png 424w, https://substackcdn.com/image/fetch/$s_!6NRe!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fc7c66803-7b12-46f8-a054-3d49d0afe41f_1136x948.png 848w, https://substackcdn.com/image/fetch/$s_!6NRe!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fc7c66803-7b12-46f8-a054-3d49d0afe41f_1136x948.png 1272w, https://substackcdn.com/image/fetch/$s_!6NRe!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fc7c66803-7b12-46f8-a054-3d49d0afe41f_1136x948.png 1456w" sizes="100vw" loading="lazy"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a></figure></div><p>Let&#8217;s run this build again, using <a href="https://docs.orbstack.dev/machines/">Orbstack</a> for their seamless VMs on macOS. Now that we&#8217;re building from Amazon Linux, it should work with Firecracker, right?</p><p>And it fails again &#8212; but I then <a href="https://github.com/firecracker-microvm/firecracker/issues/4881">removed the pci=off acpi=off options</a>, and this combined with Amazon Linux allows it to finally boot. Hooray.</p><p>Even better, I&#8217;m no longer on an ancient Linux version. Timidly, I decide to try returning to OpenRC despite my issues from earlier. And yes: OpenRC works, sshd is running, and even agetty is finally no longer stalling. Everything is blissful. Yay! It all makes sense again, definitely worth debugging.</p><p>Now that things boot, everything just got a lot easier. I also build the vmlinux kernel for ARM64, just for fun, again inside an Orbstack VM. :)</p><h2>Hooking Firecracker up to an SSH server</h2><p>We have VMs working! It&#8217;s time to hook it up to SSH and build our app.</p><p>I&#8217;m relying heavily on the AI to figure out the implementation on this, and it&#8217;s going swimmingly. It worked out the SSH protocol with no issues at all, and it&#8217;s especially good at making cute interactive terminal output, like animated progress bars.</p><p>With things like session management and architecture, it&#8217;s also good to work in broad strokes as we make changes.</p><p>(At some point my VM ran out of RAM and started thrashing.)</p><p>But for the most part, this was pretty easy to code since nothing was too tricky to debug on the application side. Just kept iterating, trying it out and fixing things that didn&#8217;t quite look right.</p><p>On the systems side, I worked out some <code>iptables</code> rules and added a couple entries optionally when <code>-allow-internet</code> is passed in, so the VMs get Internet access.</p><pre><code>iptables -A FORWARD -i sshvm-br0 ! -o sshvm-br0 -j ACCEPT -m comment --comment "ssh-hypervisor"
iptables -A FORWARD ! -i sshvm-br0 -o sshvm-br0 -j ACCEPT -m comment --comment "ssh-hypervisor"
iptables -t nat -A POSTROUTING -s &lt;VM_CIDR&gt; ! -o sshvm-br0 -j MASQUERADE -m comment --comment "ssh-hypervisor"</code></pre><h2>The end result</h2><p>It works! And it is very cute :)</p><div class="native-video-embed" data-component-name="VideoPlaceholder" data-attrs="{&quot;mediaUploadId&quot;:&quot;a244cba9-58e2-492b-8ad8-a2650b1b5940&quot;,&quot;duration&quot;:null}"></div><p>I am still hosting this at vmcity.ekzhang.com for now, but I will stop at some point, earlier if I notice any crypto miners or other unscrupulous folk.</p><pre><code>ssh &lt;YOUR_NAME&gt;@vmcity.ekzhang.com  # try it now!</code></pre><p>This was lots of fun, and we ended up with a static binary that runs VMs hooked up to on-demand SSH.</p><p>You can get the code here: <a href="https://github.com/ekzhang/ssh-hypervisor">https://github.com/ekzhang/ssh-hypervisor</a></p>]]></content:encoded></item><item><title><![CDATA[Abridged notes on the LLM scaling book]]></title><description><![CDATA[Essays on ML systems, looking at LLMs on TPUs]]></description><link>https://ss.ekzhang.com/p/abridged-notes-on-the-llm-scaling</link><guid isPermaLink="false">https://ss.ekzhang.com/p/abridged-notes-on-the-llm-scaling</guid><dc:creator><![CDATA[Eric]]></dc:creator><pubDate>Mon, 25 Aug 2025 17:14:26 GMT</pubDate><enclosure url="https://substackcdn.com/image/fetch/$s_!8ws9!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F9cfc2249-41fd-4755-b58f-786132725af9_1802x1280.png" length="0" type="image/jpeg"/><content:encoded><![CDATA[<p>In February, folks at Google DeepMind published a book on <a href="https://jax-ml.github.io/scaling-book/">LLM scaling</a>.</p><p>The book focuses on how you can model LLM scaling with math. I&#8217;m a big fan of stuff like this. Intuitive reasoning about systems is really important; it lets you visualize their shape and behaviors at a glance.</p><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!8ws9!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F9cfc2249-41fd-4755-b58f-786132725af9_1802x1280.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!8ws9!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F9cfc2249-41fd-4755-b58f-786132725af9_1802x1280.png 424w, https://substackcdn.com/image/fetch/$s_!8ws9!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F9cfc2249-41fd-4755-b58f-786132725af9_1802x1280.png 848w, https://substackcdn.com/image/fetch/$s_!8ws9!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F9cfc2249-41fd-4755-b58f-786132725af9_1802x1280.png 1272w, https://substackcdn.com/image/fetch/$s_!8ws9!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F9cfc2249-41fd-4755-b58f-786132725af9_1802x1280.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!8ws9!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F9cfc2249-41fd-4755-b58f-786132725af9_1802x1280.png" width="580" height="411.8956043956044" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/9cfc2249-41fd-4755-b58f-786132725af9_1802x1280.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:1034,&quot;width&quot;:1456,&quot;resizeWidth&quot;:580,&quot;bytes&quot;:579964,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:&quot;image/png&quot;,&quot;href&quot;:null,&quot;belowTheFold&quot;:false,&quot;topImage&quot;:true,&quot;internalRedirect&quot;:&quot;https://ekzhang.substack.com/i/171607012?img=https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F9cfc2249-41fd-4755-b58f-786132725af9_1802x1280.png&quot;,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!8ws9!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F9cfc2249-41fd-4755-b58f-786132725af9_1802x1280.png 424w, https://substackcdn.com/image/fetch/$s_!8ws9!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F9cfc2249-41fd-4755-b58f-786132725af9_1802x1280.png 848w, https://substackcdn.com/image/fetch/$s_!8ws9!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F9cfc2249-41fd-4755-b58f-786132725af9_1802x1280.png 1272w, https://substackcdn.com/image/fetch/$s_!8ws9!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F9cfc2249-41fd-4755-b58f-786132725af9_1802x1280.png 1456w" sizes="100vw" fetchpriority="high"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a><figcaption class="image-caption"><em>How to Scale Your Model</em> (2025)</figcaption></figure></div><p>I thought it would be a good time to read, learn and add some personal commentary (as someone working in the industry at <a href="https://modal.com/">Modal</a>). Expect these notes to be super abridged and not a book replacement &#8212; I will skip things, and any errors are my fault.</p><p><em>(I&#8217;m still working on jax-js, by the way. It&#8217;s going well! Since the <a href="https://substack.com/home/post/p-163548742">last update in May</a>, 3 months ago, we&#8217;ve went from matrix multiplication to full neural networks, complex operations like softmax + convolution, optimizers, and kernel tuning / optimization. The project is now over 10,000 lines of code.)</em></p><p>Without further ado, let&#8217;s begin!</p><p></p><h2><a href="https://jax-ml.github.io/scaling-book/">Part 0: Intro</a></h2><p>This book is about <em>scaling LLMs on TPUs</em>. In the past, ML researchers didn&#8217;t think so much about performance. But today, research takes a lot of compute.</p><p>ML systems are complex enough that you can&#8217;t just fiddle with parameters until they becomes fast. You need a deep understanding of <em>how long it takes to run</em> LLMs, based on compute, memory, and network factors. This informs the fundamental research you do, as well as systems design and efficiency.</p><p>We&#8217;ll then discuss:</p><ul><li><p>Transformer architecture, FLOP math for forward and backward passes.</p></li><li><p>Parallelism strategies (<strong>data, tensor, pipeline, expert</strong>) and other tricks (<strong>FSDP, host offload, gradient accumulation</strong>) for scaling LLM training and inference with increased numbers of GPUs and nodes, hopefully linear in performance.</p></li><li><p>Practical examples in JAX and with the LLaMA-3 model.</p></li><li><p>The final chapter is about Nvidia GPUs.</p></li></ul><h2><a href="https://jax-ml.github.io/scaling-book/roofline/">Part 1: Rooflines</a></h2><p>The roofline model considers <strong>communication time</strong> and <strong>computation time</strong>:</p><p>(Note: Communication could either be a single chip, loading from <a href="https://modal.com/gpu-glossary/device-software/global-memory">global memory</a> in a GPU, or multi-chip / multi-node links like PCIe, NVLink, Infiniband, RoCEv2, &#8230;)</p><div class="latex-rendered" data-attrs="{&quot;persistentExpression&quot;:&quot;\\max(T_\\text{math}, T_\\text{comms}) \\leq T \\leq T_\\text{math} + T_\\text{comms}.&quot;,&quot;id&quot;:&quot;HTHLKRXSTO&quot;}" data-component-name="LatexBlockToDOM"></div><p>Typically, we use the maximum of communication and computation, since you can overlap them in most cases. But even if you can&#8217;t overlap them, it&#8217;s a good approximation, since it&#8217;s off by at most a factor of 2.</p><p>Since there&#8217;s a max() here, we have two regimes:</p><ul><li><p><strong>Compute-bound:</strong> T_math &gt; T_comms. You are getting full utilization from your hardware, and the link is not saturated.</p></li><li><p><strong>Comms-bound:</strong> T_comms &gt; T_math. You&#8217;re wasting at least some of the FLOPs/s from your hardware, waiting on the saturated link.</p></li></ul><p>You want to be <em>compute-bound</em>, since that&#8217;s what you&#8217;re paying for &#8212; <a href="https://modal.com/blog/gpu-utilization-guide#what-is-model-flops-utilization-mfu">FLOPs</a>.</p><p>Assuming a well-written kernel, you can estimate whether it will be compute-bound based on the arithmetic intensity AI = W/Q, or work over memory traffic. On <a href="https://cloud.google.com/tpu/docs/v5e">TPU v5e MXU</a>, you want &#8805;<strong>240 FLOPs/byte for bfloat16</strong> (= compute / mem bandwidth).</p><p>For matmul in neural networks, this translates to a <strong>batch size of ~240 </strong>(0.5*AI for bfloat16 = 2 bytes, but 2*AI because of 2 FLOPs).</p><p>You can do the same roofline analysis for tensor parallelism: splitting along the reduction axis in a neural network. This would give us a critical threshold in the rough thousands for reduction axis length when this is viable (basically: each device needs to do at least X work per byte transferred).</p><p>Roofline analysis is the main way to evaluate parallelism.</p><div class="subscription-widget-wrap-editor" data-attrs="{&quot;url&quot;:&quot;https://ss.ekzhang.com/subscribe?&quot;,&quot;text&quot;:&quot;Subscribe&quot;,&quot;language&quot;:&quot;en&quot;}" data-component-name="SubscribeWidgetToDOM"><div class="subscription-widget show-subscribe"><div class="preamble"><p class="cta-caption">Thanks for reading eric makes software! Subscribe for free to receive new posts and support my work.</p></div><form class="subscription-widget-subscribe"><input type="email" class="email-input" name="email" placeholder="Type your email&#8230;" tabindex="-1"><input type="submit" class="button primary" value="Subscribe"><div class="fake-input-wrapper"><div class="fake-input"></div><div class="fake-button"></div></div></form></div></div><h2><a href="https://jax-ml.github.io/scaling-book/tpus/">Part 2: How to Think About TPUs</a></h2><p>TPUs are <a href="https://modal.com/gpu-glossary/device-hardware/tensor-core">tensor cores</a> on <a href="https://en.wikipedia.org/wiki/High_Bandwidth_Memory">high-bandwidth memory (HBM)</a>. They can do matrix multiplications fast with systolic arrays. Lots of FLOPs for matmuls.</p><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!QaWO!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F510aad71-5777-43fc-8b9b-bb31390430f2_1600x802.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!QaWO!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F510aad71-5777-43fc-8b9b-bb31390430f2_1600x802.png 424w, https://substackcdn.com/image/fetch/$s_!QaWO!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F510aad71-5777-43fc-8b9b-bb31390430f2_1600x802.png 848w, https://substackcdn.com/image/fetch/$s_!QaWO!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F510aad71-5777-43fc-8b9b-bb31390430f2_1600x802.png 1272w, https://substackcdn.com/image/fetch/$s_!QaWO!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F510aad71-5777-43fc-8b9b-bb31390430f2_1600x802.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!QaWO!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F510aad71-5777-43fc-8b9b-bb31390430f2_1600x802.png" width="1456" height="730" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/510aad71-5777-43fc-8b9b-bb31390430f2_1600x802.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:730,&quot;width&quot;:1456,&quot;resizeWidth&quot;:null,&quot;bytes&quot;:null,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:null,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:null,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!QaWO!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F510aad71-5777-43fc-8b9b-bb31390430f2_1600x802.png 424w, https://substackcdn.com/image/fetch/$s_!QaWO!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F510aad71-5777-43fc-8b9b-bb31390430f2_1600x802.png 848w, https://substackcdn.com/image/fetch/$s_!QaWO!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F510aad71-5777-43fc-8b9b-bb31390430f2_1600x802.png 1272w, https://substackcdn.com/image/fetch/$s_!QaWO!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F510aad71-5777-43fc-8b9b-bb31390430f2_1600x802.png 1456w" sizes="100vw" loading="lazy"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a></figure></div><p>How does it work? There are <a href="https://fleetwood.dev/posts/domain-specific-architectures#google-tpu">some animations</a> about the pipelining and systolic array architecture on the hardware level. Basically, it does a 8x128 x 128x128 &#8594; 8x128 matmul every 8 cycles, and it&#8217;s very fast but needs a bit of ramp-up.</p><p>There are two kinds of memory on a TPU chip (1 chip = 2 cores, shared HBM):</p><ul><li><p>HBM is the main memory, similar to GPUs. This is ~16-95 GB, ~1 Tbps.</p></li><li><p>VMEM is smaller working memory / cache. It&#8217;s about 0.1 GB, much smaller but tops out at around an intensity of about 10-20 FLOPs/byte. Good for fast inference on small-batch workloads if you can fit weights.</p></li></ul><p>Chips are &#8220;logical megacores&#8221; each consisting of two cores. Four chips are exposed on a single TPU-VM host with PCIe (~200 Gbps NIC).</p><div class="captioned-image-container"><figure><a class="image-link image2" target="_blank" href="https://substackcdn.com/image/fetch/$s_!8-4B!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb16140cc-8d44-4287-a2aa-7cc58ad5ef7e_1484x582.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!8-4B!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb16140cc-8d44-4287-a2aa-7cc58ad5ef7e_1484x582.png 424w, https://substackcdn.com/image/fetch/$s_!8-4B!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb16140cc-8d44-4287-a2aa-7cc58ad5ef7e_1484x582.png 848w, https://substackcdn.com/image/fetch/$s_!8-4B!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb16140cc-8d44-4287-a2aa-7cc58ad5ef7e_1484x582.png 1272w, https://substackcdn.com/image/fetch/$s_!8-4B!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb16140cc-8d44-4287-a2aa-7cc58ad5ef7e_1484x582.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!8-4B!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb16140cc-8d44-4287-a2aa-7cc58ad5ef7e_1484x582.png" width="559" height="219.22321428571428" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/b16140cc-8d44-4287-a2aa-7cc58ad5ef7e_1484x582.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:571,&quot;width&quot;:1456,&quot;resizeWidth&quot;:559,&quot;bytes&quot;:null,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:null,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:null,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!8-4B!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb16140cc-8d44-4287-a2aa-7cc58ad5ef7e_1484x582.png 424w, https://substackcdn.com/image/fetch/$s_!8-4B!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb16140cc-8d44-4287-a2aa-7cc58ad5ef7e_1484x582.png 848w, https://substackcdn.com/image/fetch/$s_!8-4B!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb16140cc-8d44-4287-a2aa-7cc58ad5ef7e_1484x582.png 1272w, https://substackcdn.com/image/fetch/$s_!8-4B!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb16140cc-8d44-4287-a2aa-7cc58ad5ef7e_1484x582.png 1456w" sizes="100vw" loading="lazy"></picture><div></div></div></a></figure></div><p>TPUs are connected to each other in 2D/3D torus configurations with ICI, <em>inter-chip interconnects</em>. These are ~1.6-4.8 Tbps, and there 4 or 6 of them. Compare to Nvidia&#8217;s <a href="https://docs.nvidia.com/dgx/dgxb200-user-guide/introduction-to-dgxb200.html#dgx-b200-component-descriptions">3.2 Tbps Infiniband</a> cluster networking.</p><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!gYC7!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb5c8db66-d7c7-45f3-bc91-a829371b7645_1516x754.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!gYC7!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb5c8db66-d7c7-45f3-bc91-a829371b7645_1516x754.png 424w, https://substackcdn.com/image/fetch/$s_!gYC7!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb5c8db66-d7c7-45f3-bc91-a829371b7645_1516x754.png 848w, https://substackcdn.com/image/fetch/$s_!gYC7!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb5c8db66-d7c7-45f3-bc91-a829371b7645_1516x754.png 1272w, https://substackcdn.com/image/fetch/$s_!gYC7!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb5c8db66-d7c7-45f3-bc91-a829371b7645_1516x754.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!gYC7!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb5c8db66-d7c7-45f3-bc91-a829371b7645_1516x754.png" width="594" height="295.36813186813185" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/b5c8db66-d7c7-45f3-bc91-a829371b7645_1516x754.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:724,&quot;width&quot;:1456,&quot;resizeWidth&quot;:594,&quot;bytes&quot;:null,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:null,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:null,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!gYC7!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb5c8db66-d7c7-45f3-bc91-a829371b7645_1516x754.png 424w, https://substackcdn.com/image/fetch/$s_!gYC7!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb5c8db66-d7c7-45f3-bc91-a829371b7645_1516x754.png 848w, https://substackcdn.com/image/fetch/$s_!gYC7!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb5c8db66-d7c7-45f3-bc91-a829371b7645_1516x754.png 1272w, https://substackcdn.com/image/fetch/$s_!gYC7!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb5c8db66-d7c7-45f3-bc91-a829371b7645_1516x754.png 1456w" sizes="100vw" loading="lazy"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a></figure></div><p>It&#8217;s cheaper and more scalable than <a href="https://en.wikipedia.org/wiki/Fat_tree">fat tree</a>-style networks that Nvidia uses for Infiniband, but it probably makes collective communication tricky. When you purchase TPUs from Google, you buy a <em>slice</em> of the topology.</p><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!KULK!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F1e150416-6956-4ce7-815d-af1c2a99dfc9_2378x1202.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!KULK!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F1e150416-6956-4ce7-815d-af1c2a99dfc9_2378x1202.png 424w, https://substackcdn.com/image/fetch/$s_!KULK!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F1e150416-6956-4ce7-815d-af1c2a99dfc9_2378x1202.png 848w, https://substackcdn.com/image/fetch/$s_!KULK!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F1e150416-6956-4ce7-815d-af1c2a99dfc9_2378x1202.png 1272w, https://substackcdn.com/image/fetch/$s_!KULK!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F1e150416-6956-4ce7-815d-af1c2a99dfc9_2378x1202.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!KULK!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F1e150416-6956-4ce7-815d-af1c2a99dfc9_2378x1202.png" width="592" height="299.25274725274727" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/1e150416-6956-4ce7-815d-af1c2a99dfc9_2378x1202.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:736,&quot;width&quot;:1456,&quot;resizeWidth&quot;:592,&quot;bytes&quot;:200896,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:&quot;image/png&quot;,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:&quot;https://ekzhang.substack.com/i/171607012?img=https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F1e150416-6956-4ce7-815d-af1c2a99dfc9_2378x1202.png&quot;,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!KULK!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F1e150416-6956-4ce7-815d-af1c2a99dfc9_2378x1202.png 424w, https://substackcdn.com/image/fetch/$s_!KULK!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F1e150416-6956-4ce7-815d-af1c2a99dfc9_2378x1202.png 848w, https://substackcdn.com/image/fetch/$s_!KULK!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F1e150416-6956-4ce7-815d-af1c2a99dfc9_2378x1202.png 1272w, https://substackcdn.com/image/fetch/$s_!KULK!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F1e150416-6956-4ce7-815d-af1c2a99dfc9_2378x1202.png 1456w" sizes="100vw" loading="lazy"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a></figure></div><p>The smallest slice is a single 2x2x1 host. This checks out with each host being connected to 4 TPU chips.</p><h2><a href="https://jax-ml.github.io/scaling-book/sharding/">Part 3: Sharded Matrices and How to Multiply Them</a></h2><p>We introduce tensor notation for device sharding. When you have multiple devices (relevant for TPUs especially due to topology), they live on a <em>mesh</em> with axis names. For example, a 2x2 mesh of 4 TPUs, with axes (X, Y):</p><pre><code>Mesh(devices=((0, 1), (2, 3)), axis_names=(&#8216;X', &#8216;Y'))</code></pre><p>Notation for sharding arrays is A[Ix, Jy]. The indices can be subscripted by mesh axes, which tells us how each mesh axis corresponds to a tensor axis.</p><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!nxjN!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F8caeb38a-ba03-4fe1-a2b2-aaf018433c20_4148x4309.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!nxjN!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F8caeb38a-ba03-4fe1-a2b2-aaf018433c20_4148x4309.png 424w, https://substackcdn.com/image/fetch/$s_!nxjN!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F8caeb38a-ba03-4fe1-a2b2-aaf018433c20_4148x4309.png 848w, https://substackcdn.com/image/fetch/$s_!nxjN!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F8caeb38a-ba03-4fe1-a2b2-aaf018433c20_4148x4309.png 1272w, https://substackcdn.com/image/fetch/$s_!nxjN!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F8caeb38a-ba03-4fe1-a2b2-aaf018433c20_4148x4309.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!nxjN!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F8caeb38a-ba03-4fe1-a2b2-aaf018433c20_4148x4309.png" width="457" height="474.8907967032967" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/8caeb38a-ba03-4fe1-a2b2-aaf018433c20_4148x4309.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:1513,&quot;width&quot;:1456,&quot;resizeWidth&quot;:457,&quot;bytes&quot;:null,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:null,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:null,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!nxjN!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F8caeb38a-ba03-4fe1-a2b2-aaf018433c20_4148x4309.png 424w, https://substackcdn.com/image/fetch/$s_!nxjN!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F8caeb38a-ba03-4fe1-a2b2-aaf018433c20_4148x4309.png 848w, https://substackcdn.com/image/fetch/$s_!nxjN!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F8caeb38a-ba03-4fe1-a2b2-aaf018433c20_4148x4309.png 1272w, https://substackcdn.com/image/fetch/$s_!nxjN!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F8caeb38a-ba03-4fe1-a2b2-aaf018433c20_4148x4309.png 1456w" sizes="100vw" loading="lazy"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a></figure></div><p>There are a couple rules to this notation system. Internalizing the rules will help you reason about device sharding:</p><ul><li><p><strong>Not all mesh axes need to be mentioned in a sharding.</strong> For example, A[I, J] would fully replicate the array across all axes. A[Ix, J] would shard the first tensor axis against X and replicate the rest along Y.</p></li><li><p><strong>Each mesh axis can be mentioned at most once.</strong> A[Ix, Jx] is invalid, since that doesn&#8217;t actually include all the data.</p></li><li><p><strong>The order of axes matters.</strong> A[Ixy, J] shards the first tensor axis on both dimensions of the mesh. But it shards against the X mesh axis first, and then the Y mesh axis second. Contrast this with A[Iyx, J], which reverses the order.</p></li></ul><p>This notation lets us talk about tensor sharding over TPU devices in a torus. Each mesh axis and can do full <a href="https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html">collective operations</a> like AllReduce, AllGather, and ReduceScatter. Next, we ask the big question.</p><p><strong>Question: How long does it take to do matmul on sharded arrays?</strong></p><p>Matrix multiplication is a <em><a href="https://en.wikipedia.org/wiki/Tensor_contraction">tensor contraction</a></em> (&#8220;like numpy.<a href="https://openxla.org/xla/operation_semantics#dotgeneral">dot</a>&#8221;) op. When you do a dot product of A[I, J] * B[J, K] &#8594; C[I, K], you&#8217;re contracting along the J axis.</p><p>Generally, if your tensor is sharded along the contracting dimension, you may need to use one of the collective operations:</p><blockquote><ol><li><p><strong><a href="https://jax-ml.github.io/scaling-book/sharding/#case-1-neither-multiplicand-has-a-sharded-contracting-dimension">Case 1</a>:</strong> neither input is sharded along the contracting dimension. <em>We can multiply local shards without any communication.</em></p></li><li><p><strong><a href="https://jax-ml.github.io/scaling-book/sharding/#case-2-one-multiplicand-has-a-sharded-contracting-dimension">Case 2</a>:</strong> one input has a sharded contracting dimension. <em>We typically &#8220;AllGather&#8221; the sharded input along the contracting dimension.</em></p></li><li><p><strong><a href="https://jax-ml.github.io/scaling-book/sharding/#case-3-both-multiplicands-have-sharded-contracting-dimensions">Case 3</a>:</strong> both inputs are sharded along the contracting dimension. <em>We can multiply the local shards, then &#8220;AllReduce&#8221; the result.</em></p></li><li><p><strong><a href="https://jax-ml.github.io/scaling-book/sharding/#case-4-both-multiplicands-have-a-non-contracting-dimension-sharded-along-the-same-axis">Case 4</a>:</strong> both inputs have a non-contracting dimension sharded along the same axis. We cannot proceed without AllGathering one of the two inputs first.</p></li></ol></blockquote><p>I think Case 3 is probably the most illustrative one, since it&#8217;s the AllReduce that you typically see when you shard computations along a contracting dimension and need to aggregate the results.</p><p>They have derivations to work through and very nice animations. Here&#8217;s a summary of the communication primitives and their effect:</p><ul><li><p><strong>AllGather:</strong> [Ix, J] &#8594; [I, J]. Costs |I|*|J| comms/device.</p></li><li><p><strong>ReduceScatter:</strong> [I, J]{Ux} &#8594; [I, Jx]. Costs |I|*|J| comms/device.</p></li><li><p><strong>AllToAll:</strong> [I, Jx] &#8594; [Ix, J]. Costs 0.5*|I|*|J| comms/device (assuming 2D torus).</p></li><li><p><strong>AllReduce:</strong> [Ix, J]{Uy} &#8594; [Ix, J]. Costs 2*|I|*|J| comms/device.</p><ul><li><p>This is the same as ReduceScatter + AllGather.</p></li></ul></li></ul><p>Notably, the AllToAll primitive is suited for toroidal TPU topologies. The other collective operations use a standard ring algorithm.</p><h2><a href="https://jax-ml.github.io/scaling-book/transformers/">Part 4: All the Transformer Math You Need to Know</a></h2><p>We start with some tensor math. When you have a dot product, some axes are <em>contracting</em>, some are <em>batching</em>, and others are just broadcast. Cheatsheet:</p><ul><li><p><strong>Contracting:</strong> I * I &#8594; &#8709;. This is a reduction axis.</p></li><li><p><strong>Batching:</strong> I * I &#8594; I. This axis is mapped / vectorized over in both tensors.</p></li><li><p><strong>Others:</strong> I * J &#8594; IJ. The axes are broadcast like an outer product.</p></li></ul><p>The total number of FLOPs equals 2x the product of all axes, taking care not to double-count them if they appear in both operands of the product.</p><p>The <em>reverse pass</em> (backprop) takes twice the number of FLOPs as the forward pass. This isn&#8217;t exactly the case for all operations (i.e., scalar ones), but since most FLOPs in a transformer are in matmuls, a good rule-of-thumb is to <strong>multiply the total FLOPs by 3</strong> (= 1 + 2) when thinking about training.</p><p>Anyway, if you go ahead and use this trick, you get all the transformer FLOPs.</p><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!mYQ3!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb96a0ee2-0631-4afb-b1a8-3e44c4215599_2322x994.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!mYQ3!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb96a0ee2-0631-4afb-b1a8-3e44c4215599_2322x994.png 424w, https://substackcdn.com/image/fetch/$s_!mYQ3!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb96a0ee2-0631-4afb-b1a8-3e44c4215599_2322x994.png 848w, https://substackcdn.com/image/fetch/$s_!mYQ3!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb96a0ee2-0631-4afb-b1a8-3e44c4215599_2322x994.png 1272w, https://substackcdn.com/image/fetch/$s_!mYQ3!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb96a0ee2-0631-4afb-b1a8-3e44c4215599_2322x994.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!mYQ3!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb96a0ee2-0631-4afb-b1a8-3e44c4215599_2322x994.png" width="1456" height="623" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/b96a0ee2-0631-4afb-b1a8-3e44c4215599_2322x994.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:623,&quot;width&quot;:1456,&quot;resizeWidth&quot;:null,&quot;bytes&quot;:167495,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:&quot;image/png&quot;,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:&quot;https://ekzhang.substack.com/i/171607012?img=https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb96a0ee2-0631-4afb-b1a8-3e44c4215599_2322x994.png&quot;,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!mYQ3!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb96a0ee2-0631-4afb-b1a8-3e44c4215599_2322x994.png 424w, https://substackcdn.com/image/fetch/$s_!mYQ3!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb96a0ee2-0631-4afb-b1a8-3e44c4215599_2322x994.png 848w, https://substackcdn.com/image/fetch/$s_!mYQ3!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb96a0ee2-0631-4afb-b1a8-3e44c4215599_2322x994.png 1272w, https://substackcdn.com/image/fetch/$s_!mYQ3!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb96a0ee2-0631-4afb-b1a8-3e44c4215599_2322x994.png 1456w" sizes="100vw" loading="lazy"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a></figure></div><p>Or, in a nutshell: multiply params by 6BT (BT = #tokens), and some of the multi-head attention layers scale by 3BT&#178;/D instead.</p><p>Great! Will be useful for thinking about KV cache later, too.</p><h2><a href="https://jax-ml.github.io/scaling-book/training/">Part 5: How to Parallelize a Transformer for Training</a></h2><p>This chapter is about <em>train-time scaling</em>. Assume big but fixed batch size (too big slows down convergence), so you&#8217;re compute-bound on the chip itself for HBM access. You want to use more chips to speed up each iteration.</p><p>My tl;dr is that you can scale things along different axes: data parallelism (or FSDP), tensor parallelism, and pipeline parallelism. These are along different axes: batch, model, and layers. <strong>Each of them stresses a different communication overhead,</strong> so if you combine all of them together, you can coordinate very low batch sizes per device without being comms-bound! (i.e., run lots of devices, train fast)</p><p>That means there&#8217;s no &#8220;best&#8221; parallelism strategy. You apply all of them as needed, since they multiply together. Start with data parallelism though (it&#8217;s easy).</p><p>I&#8217;m going to ignore the TPU numbers here and speak more generally, since I only use GPUs in my job anyway. The TPU numbers are kind of messed up versus GPUs because of the bandwidth (here, in-and-out):</p><ul><li><p>TPUs have more ICI bandwidth (v5p = 4 * 0.8 Tbps / v5e = 6 * 2.4 Tbps) in bigger &#8220;pods&#8221; of up to 8960 chips, but wide-diameter torus connectivity, and</p></li><li><p>GPUs have NVLink (7.2 Tbps, fully-connected) within nodes of 8 GPUs each and Infiniband (0.125 * 3.2 Tbps, switched tree) connections between nodes. Also, each GPU has more FLOPs than each TPU.</p></li></ul><p>This means TPU clusters can skip pipeline parallelism, which is complicated, but GPU clusters need to use it for inter-node to reduce comms overhead.</p><h3>Data parallelism</h3><p>This is the simplest method.</p><ol><li><p>Split the batch across X devices and do the forward and backward passes independently.</p></li><li><p><em>(Interleaved)</em> When gradients are ready for a layer, do an all-reduce, then update optimizer state with the accumulated gradients across all devices.</p></li></ol><p>You become bottlenecked on comms when B/X &gt; C/W_ici. In other words, the number of FLOPs divided by the bandwidth. (The constants cancel out for fp16.)</p><p>Within an <a href="https://www.nvidia.com/en-us/data-center/h100/">8x H100 node</a>, data parallelism takes <strong>(1979 TFLOPs) / (900 GB/s) ~ 2200</strong> required batch size per GPU to max out compute with sparsity, or <strong>~1100 without</strong>.</p><p>But between nodes, your bandwidth per GPU is 18x lower. So you&#8217;d probably want to run ReduceScatter first between GPUs within each node, followed by AllReduce inter-node (8x less comms) and another AllGather within each node.</p><h3>FSDP / ZeRO-3</h3><p>Ah yes, the famous fully-sharded data parallelism. It&#8217;s like DDP, but model weights &amp; optimizer states are sharded. Each device stores 1/X of the params. See the &#8220;experiences&#8221; <a href="https://www.vldb.org/pvldb/vol16/p3848-huang.pdf">FSDP paper</a> for details on this, including how to interleave compute and comms within the framework.</p><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!jaa4!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F6a6cc667-cb0d-4281-ae5e-efa8e79f62c0_4372x1975.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!jaa4!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F6a6cc667-cb0d-4281-ae5e-efa8e79f62c0_4372x1975.png 424w, https://substackcdn.com/image/fetch/$s_!jaa4!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F6a6cc667-cb0d-4281-ae5e-efa8e79f62c0_4372x1975.png 848w, https://substackcdn.com/image/fetch/$s_!jaa4!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F6a6cc667-cb0d-4281-ae5e-efa8e79f62c0_4372x1975.png 1272w, https://substackcdn.com/image/fetch/$s_!jaa4!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F6a6cc667-cb0d-4281-ae5e-efa8e79f62c0_4372x1975.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!jaa4!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F6a6cc667-cb0d-4281-ae5e-efa8e79f62c0_4372x1975.png" width="656" height="296.46153846153845" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/6a6cc667-cb0d-4281-ae5e-efa8e79f62c0_4372x1975.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:658,&quot;width&quot;:1456,&quot;resizeWidth&quot;:656,&quot;bytes&quot;:null,&quot;alt&quot;:&quot;Getting Started with Fully Sharded Data Parallel (FSDP2) &#8212; PyTorch  Tutorials 2.8.0+cu128 documentation&quot;,&quot;title&quot;:null,&quot;type&quot;:null,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:null,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="Getting Started with Fully Sharded Data Parallel (FSDP2) &#8212; PyTorch  Tutorials 2.8.0+cu128 documentation" title="Getting Started with Fully Sharded Data Parallel (FSDP2) &#8212; PyTorch  Tutorials 2.8.0+cu128 documentation" srcset="https://substackcdn.com/image/fetch/$s_!jaa4!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F6a6cc667-cb0d-4281-ae5e-efa8e79f62c0_4372x1975.png 424w, https://substackcdn.com/image/fetch/$s_!jaa4!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F6a6cc667-cb0d-4281-ae5e-efa8e79f62c0_4372x1975.png 848w, https://substackcdn.com/image/fetch/$s_!jaa4!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F6a6cc667-cb0d-4281-ae5e-efa8e79f62c0_4372x1975.png 1272w, https://substackcdn.com/image/fetch/$s_!jaa4!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F6a6cc667-cb0d-4281-ae5e-efa8e79f62c0_4372x1975.png 1456w" sizes="100vw" loading="lazy"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a></figure></div><p>Compared to DDP, you incur 1.5x comms cost in mixed-precision training (&#8220;full sharding&#8221; at least, they also have &#8220;hybrid sharding&#8221; which is partial DDP), since you have to AllGather weights in addition to the AllReduce of gradients.</p><p>It&#8217;s said that the FSDP backward pass is &#8220;free&#8221; though, in the sense that you shard optimizer state and weights with AllGather+ReduceScatter, while reducing FLOPs. This is also the difference between ZeRO-1, ZeRO-2, and ZeRO-3 = FSDP.</p><p>FSDP lets you scale up model sizes that don&#8217;t fit in a single GPU&#8217;s memory.</p><h3>Tensor parallelism (Megatron)</h3><p>Let&#8217;s switch our mesh axis from X to Y. Tensor parallelism shards both the weights &amp; activations across devices. It makes each layer run faster because we don&#8217;t have to do as much work on each device, but we do need to insert AllGather / ReduceScatter ops in the forward and reverse passes.</p><p>Generally, this becomes worth it when the dimension of the MLP hidden layer exceeds C/W_ici * Y.</p><p>In other words, you can start splitting apart the model with tensor parallelism on H100 GPUs if the model dimension is over ~250 or so (since there&#8217;s 4x expansion ratio). This becomes very useful for large models, and it combines well with FSDP. Generally the tensor parallel factor is 8 or 16.</p><h3>Pipeline parallelism</h3><p>The book doesn&#8217;t really talk about pipeline parallelism since it&#8217;s mostly used on GPUs. It saves comms overhead by only requiring you to transfer activations between layers, while also partitioning the model and speeding up training. This way, you don&#8217;t actually have to send O(weights) data, only O(activations).</p><div class="captioned-image-container"><figure><a class="image-link image2" target="_blank" href="https://substackcdn.com/image/fetch/$s_!SQIc!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fff015207-babc-4d1f-ac40-dc18e5e72eaf_1954x336.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!SQIc!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fff015207-babc-4d1f-ac40-dc18e5e72eaf_1954x336.png 424w, https://substackcdn.com/image/fetch/$s_!SQIc!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fff015207-babc-4d1f-ac40-dc18e5e72eaf_1954x336.png 848w, https://substackcdn.com/image/fetch/$s_!SQIc!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fff015207-babc-4d1f-ac40-dc18e5e72eaf_1954x336.png 1272w, https://substackcdn.com/image/fetch/$s_!SQIc!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fff015207-babc-4d1f-ac40-dc18e5e72eaf_1954x336.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!SQIc!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fff015207-babc-4d1f-ac40-dc18e5e72eaf_1954x336.png" width="1456" height="250" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/ff015207-babc-4d1f-ac40-dc18e5e72eaf_1954x336.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:250,&quot;width&quot;:1456,&quot;resizeWidth&quot;:null,&quot;bytes&quot;:null,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:null,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:null,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!SQIc!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fff015207-babc-4d1f-ac40-dc18e5e72eaf_1954x336.png 424w, https://substackcdn.com/image/fetch/$s_!SQIc!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fff015207-babc-4d1f-ac40-dc18e5e72eaf_1954x336.png 848w, https://substackcdn.com/image/fetch/$s_!SQIc!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fff015207-babc-4d1f-ac40-dc18e5e72eaf_1954x336.png 1272w, https://substackcdn.com/image/fetch/$s_!SQIc!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fff015207-babc-4d1f-ac40-dc18e5e72eaf_1954x336.png 1456w" sizes="100vw" loading="lazy"></picture><div></div></div></a></figure></div><p>The hard thing here is avoiding bubbles. Honestly pipeline schedules make my head hurt. But people have figured it out, likely at great engineering cost!</p><h3>Takeaways</h3><p>With enough parallelism strategies, we can achieve near-linear scaling over many nodes while keeping batch size per device low, nearing ~100. Each choice requires a lot of engineering effort though, especially for interleaving operations and reliability in the face of failures.</p><p>You can combine the strategies together (DP+TP+PP) for multiplicative bonus. Requires you to do some math though.</p><h2><a href="https://jax-ml.github.io/scaling-book/applied-training/">Part 6: Training LLaMA 3 on TPUs</a></h2><p>This was just an applied exercise of the previous section.</p><p>One interesting thing was that they introduced &#8220;sequence parallelism&#8221; here, which is similar to data parallelism but over the sequence axis. This happened when they ran out of &#8220;batch&#8221; to FSDP over. I guess this introduces a bit more comms overhead, but not too much since you&#8217;re just syncing activations ahead of attention.</p><p>I was also curious about the cost for this hypothetical training of LLaMA 3 70B with 40% MFU on TPUs. Assuming 3-year commitment prices, you get:</p><p><strong>8960 chips * $1.89/chip/hr * 1056 hours = $18 million</strong></p><p>That&#8217;s just a 70B model. Makes sense why the big labs are raising billions of dollars for their frontier models with trillions of parameters.</p><h2><a href="https://jax-ml.github.io/scaling-book/inference/">Part 7: All About Transformer Inference</a></h2><p>Inference is very different from training. You have a <strong>latency-throughput tradeoff</strong> curve, since big batches take longer but vastly improve throughput due to higher arithmetic intensity, being less memory-bound.</p><p>(This section is also relevant in post-training, since you do rollouts for RL.)</p><h3>Basics of transformer inference</h3><p>Alright, as we know: to sample from a transformer, you run prior tokens through the model layers. This generates logits, you draw from the posterior according to temperature, and then you repeat the process again for each subsequent token.</p><p>You also need a paged <strong>KV cache</strong> though, so you don&#8217;t need to recompute intermediate activations each time. Instead, you reuse the previous ones. KV cache size is proportional to the sequence length, number of layers (L), number of KV heads, and model dimension (D), Example: <a href="https://github.com/Dao-AILab/flash-attention/blob/v2.8.3/flash_attn/flash_attn_interface.py#L1464-L1484">flash_attn_with_kvcache()</a>.</p><p>Given this KV cache, there are two phases to inference:</p><ol><li><p><strong>Prefill.</strong> Generate all the KV cache for a long prompt, and generate first set of logits. Initializes the cache.</p></li><li><p><strong>Generation (also &#8220;decode&#8221;).</strong> From a previous KV cache for all previous tokens in the sequence, incrementally sample one token and generate logits. Appends +1 token to cache.</p></li></ol><p>Although, engines like vLLM can <a href="https://docs.vllm.ai/en/latest/configuration/optimization.html#chunked-prefill_1">run both simultaneously (&#8220;chunked prefill&#8221;)</a>, and perhaps other inference systems may also <a href="https://arxiv.org/abs/2401.09670">split across separate machines (&#8220;disaggregated prefill&#8221;)</a>.</p><p>Anyway, here&#8217;s the tl;dr about the two parts of inference from a performance lens:</p><ul><li><p><strong>MLP:</strong> Arithmetic intensity. Token batch size &#8805; FLOPs / HBM bandwidth.</p><ul><li><p>For TPU v5e, ~240. For H100, this is ~600 (with sparsity) or ~300 (without).</p></li><li><p>Critical batch size decreases with param quantization (less loads), but increases if FLOPs are in lower precision since they become faster.</p></li><li><p>Trivial to get this batch size in prefill with sequence length, harder during inference to batch up many concurrent requests.</p></li></ul></li><li><p><strong>Attention:</strong> With S past tokens and T inference, AI ~ ST / (S+T).</p><ul><li><p>During prefill, with cross-attention you get a good arithmetic intensity, linear with batch size, easy to saturate and not the bottleneck.</p></li><li><p>During decode, AI ~ 1 because you load all of the weights from KV cache. You&#8217;re bottlenecked on memory bandwidth, loading from KV cache, since each of those is <em>only used once</em> in attention.</p></li><li><p>So yeah, that&#8217;s sad. Once you increase your decode batch size enough, you&#8217;ll get diminishing returns &#8212; each forward pass gets slower because loading memory from <strong>KV cache dominates model weights size</strong>.</p></li></ul></li></ul><p>This observation about the memory-bound nature of attention is fundamentally why we have a latency-bandwidth tradeoff. You can&#8217;t actually run transformer inference (decode) at the MLP critical batch size of ~300, without sacrificing lots of time on loading KV cache, <em>slowing down inference</em> <em>(inter-token latency)</em>.</p><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!K3sy!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd4929ccf-64f9-46c3-acb6-6aaf9402a02c_1202x940.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!K3sy!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd4929ccf-64f9-46c3-acb6-6aaf9402a02c_1202x940.png 424w, https://substackcdn.com/image/fetch/$s_!K3sy!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd4929ccf-64f9-46c3-acb6-6aaf9402a02c_1202x940.png 848w, https://substackcdn.com/image/fetch/$s_!K3sy!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd4929ccf-64f9-46c3-acb6-6aaf9402a02c_1202x940.png 1272w, https://substackcdn.com/image/fetch/$s_!K3sy!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd4929ccf-64f9-46c3-acb6-6aaf9402a02c_1202x940.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!K3sy!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd4929ccf-64f9-46c3-acb6-6aaf9402a02c_1202x940.png" width="624" height="487.98668885191347" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/d4929ccf-64f9-46c3-acb6-6aaf9402a02c_1202x940.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:940,&quot;width&quot;:1202,&quot;resizeWidth&quot;:624,&quot;bytes&quot;:null,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:null,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:null,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!K3sy!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd4929ccf-64f9-46c3-acb6-6aaf9402a02c_1202x940.png 424w, https://substackcdn.com/image/fetch/$s_!K3sy!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd4929ccf-64f9-46c3-acb6-6aaf9402a02c_1202x940.png 848w, https://substackcdn.com/image/fetch/$s_!K3sy!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd4929ccf-64f9-46c3-acb6-6aaf9402a02c_1202x940.png 1272w, https://substackcdn.com/image/fetch/$s_!K3sy!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd4929ccf-64f9-46c3-acb6-6aaf9402a02c_1202x940.png 1456w" sizes="100vw" loading="lazy"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a></figure></div><p>If you squint a bit, this is also the graph you get <a href="https://modal.com/llm-almanac/advisor?metric=itl&amp;filters=model%3DLlama+3.1+70B%2Ctokens%3D128%3B1024%2Cttft_p95%3C10&amp;aggregate=p50">when benchmarking vLLM / SGLang</a>.</p><h3>Tricks to improve latency / throughput</h3><p>At scale, we do actually want to run big batch sizes and not waste all our time on memory bandwidth to load the KV cache! It would also be nice to make the KV cache smaller on a per-token basis, since that saves memory. So we have two, mutually beneficial reasons to reduce KV cache.</p><p>Here are common tricks people do in service of this goal:</p><ul><li><p><em>Grouped Query Attention (GQA)</em> reduces the number of KV heads and shares each with multiple Q heads. You can pick along a sliding scale between 1 KV head per Q head (standard multi-head attention), up to just 1 KV head in total. Seems like not all of the KV heads are needed for performance!</p><div class="captioned-image-container"><figure><a class="image-link image2" target="_blank" href="https://substackcdn.com/image/fetch/$s_!nfeL!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F88723164-c71d-493b-8ab7-59a189cb75b3_1258x498.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!nfeL!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F88723164-c71d-493b-8ab7-59a189cb75b3_1258x498.png 424w, https://substackcdn.com/image/fetch/$s_!nfeL!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F88723164-c71d-493b-8ab7-59a189cb75b3_1258x498.png 848w, https://substackcdn.com/image/fetch/$s_!nfeL!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F88723164-c71d-493b-8ab7-59a189cb75b3_1258x498.png 1272w, https://substackcdn.com/image/fetch/$s_!nfeL!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F88723164-c71d-493b-8ab7-59a189cb75b3_1258x498.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!nfeL!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F88723164-c71d-493b-8ab7-59a189cb75b3_1258x498.png" width="551" height="218.12241653418124" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/88723164-c71d-493b-8ab7-59a189cb75b3_1258x498.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:498,&quot;width&quot;:1258,&quot;resizeWidth&quot;:551,&quot;bytes&quot;:null,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:null,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:null,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!nfeL!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F88723164-c71d-493b-8ab7-59a189cb75b3_1258x498.png 424w, https://substackcdn.com/image/fetch/$s_!nfeL!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F88723164-c71d-493b-8ab7-59a189cb75b3_1258x498.png 848w, https://substackcdn.com/image/fetch/$s_!nfeL!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F88723164-c71d-493b-8ab7-59a189cb75b3_1258x498.png 1272w, https://substackcdn.com/image/fetch/$s_!nfeL!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F88723164-c71d-493b-8ab7-59a189cb75b3_1258x498.png 1456w" sizes="100vw" loading="lazy"></picture><div></div></div></a></figure></div></li><li><p><em>Mixing local attention layers</em> is done by some models. For example, <a href="https://arxiv.org/pdf/2503.19786">Gemma 3</a> uses 5 local layers between each global attention layer. <a href="https://magazine.sebastianraschka.com/p/from-gpt-2-to-gpt-oss-analyzing-the">GPT-OSS</a> alternates between local and global attention in a 1:1 ratio. The idea is that you can focus on local details most of the time, and this reduces KV cache.</p></li><li><p><em>Sharing KVs across layers:</em> You can go even further than GQA and share the KV&#8217;s across layers, not just queries. It reduces KV cache size, but it doesn&#8217;t reduce memory bandwidth since they need to be read in each layer.</p></li><li><p><em>Quantization</em> saves on memory bandwidth and size for both params and KV cache. If you keep activations at the same precision, it helps reach the roofline.</p></li><li><p><em>Paged attention</em> uses ragged reads into sections (&#8220;pages&#8221;) of HBM that are allocated as needed, based on the sequence length. It adds a bunch of complexity around memory allocation, preemption and interleaving, but it&#8217;s used by almost every inference engine to save memory. See <a href="https://github.com/GeeeekExplorer/nano-vllm/blob/main/nanovllm/engine/scheduler.py">nano-vllm&#8217;s scheduler</a>.</p></li></ul><h3>Distributing inference</h3><p>If you&#8217;re scaling to multiple accelerators, now you have the opportunity to explore various parallelism strategies. The default is to just replicate the model, with all of its weights in multiple instances, which is simple and doesn&#8217;t need any comms / syncing.</p><p>But you might want to speed up the model or fit large models that are too big for a single chip&#8217;s HBM. And then you have some choices.</p><ul><li><p><strong>Prefill:</strong> This is almost identical to training because of the sequence length dimension. You can shard prefill with model parallelism (~4-8 shards, as determined by ICI bandwidth) and then use sequence parallelism.</p><ul><li><p>Here, sequence parallelism doesn&#8217;t incur much overhead because you just AllGather activations. Note that this is different from <em><a href="https://docs.vllm.ai/en/v0.4.2/models/performance.html">chunked prefill</a></em>, which batches prefill+decode on a single device.</p></li></ul></li><li><p><strong>Generation:</strong> FSDP / model sharding is bad because you&#8217;re bottlenecked on memory bandwidth. So your option is model parallelism (or maybe pipeline / expert parallelism at scale?). You can also shard the KV cache while doing this, which reduces memory &#8220;bandwidth cost&#8221; if done along the right axis.</p></li></ul><p>Anyway, at this point the book goes back to discussing basic principles of inference engine scheduling. Most of these ideas apply to single-node serving as well. You can see them better in engines like <a href="https://docs.vllm.ai/en/v0.7.3/index.html">vLLM</a> and <a href="https://docs.sglang.ai/">SGLang</a>, so I&#8217;ll just summarize.</p><ul><li><p>Typically you <em>interleave prefill and generation</em>, so you run prefill requests with priority and at a smaller-than-max batch size to reduce time-to-first-token (TTFT). With a smaller batch size, you also avoid blocking generations from taking too long as well.</p></li><li><p>The natural next step at sufficient scale is to <em>disaggregate prefill and generation</em>, since disaggregated prefill allows machines to specialize on that particular workload and not block the generation step for other queries. This requires transmitting KV cache over the network.</p></li><li><p><em>Continuous batching</em> is an obvious optimization for generation steps, where you run each token and concurrently listen for incoming requests to add to the batch until it is complete. Don&#8217;t wait for a full batch to be ready before starting. This also means inter-token latency naturally degrades as your load (~continuous batch size) increases, which is a nice global measure of system load.</p></li><li><p>Obviously, you might serve inference requests with the same prefix later on especially in chat applications, so <em>prefix caching</em> and <em>sticky routing</em> are essential, probably using some kind of <a href="https://en.wikipedia.org/wiki/Consistent_hashing">consistent hashing</a> + <a href="https://en.wikipedia.org/wiki/Cache_replacement_policies#Least_Recently_Used_(LRU)">LRU</a> scheme.</p></li></ul><p>The book links their <a href="https://github.com/AI-Hypercomputer/JetStream">JetStream library</a> as an implementation example for inference at scale on TPUs. Some exercises analyze &#8220;expert sharding&#8221; in MoE models.</p><p><em>Some industry commentary:</em> This section focuses on TPUs, but the common industry standard by far is Nvidia GPUs. These are 8x per node, and 8x B200 GPUs are enough to serve all but the very largest open models like <a href="https://kimik2.com/">Kimi K2</a> (and this can use expert parallelism). Mostly you can get away with model parallelism on a single node, running 8x Nvidia H100/A100 (perhaps A10 if small) and NVLink. This reduces the <em>engineering complexity</em> a lot (big forcing function for whether something will actually get built) and indeed most companies&#8212;including specialized inference ones like Baseten/Fireworks&#8212;only have nascent multi-node offerings if at all. Edit: Feedback I&#8217;ve gotten is frontier labs all use expert parallelism and multi-node inference though, makes sense due to their much larger models.</p><p>The appendices talk about other methods and considerations, specifically for low-latency inference (inter-token latency).</p><ul><li><p>As device count increases, you may implement <a href="https://jax-ml.github.io/scaling-book/inference/#appendix-b-2d-weight-stationary-sharding">2D weight sharding</a> for MLP weights, along both the hidden and input axes. This becomes useful when sharding along hidden axis makes the per-device dimension smaller than the input dimension, so you balance them to reduce comms cost.</p></li><li><p>It&#8217;s mentioned that during inference, you can actually become <a href="https://jax-ml.github.io/scaling-book/inference/#appendix-c-latency-bound-communications">latency-bound</a> in AllGather due to the small amount of data, such that the cost of sending stuff around the ring is just from hops, not bandwidth. This seems like a TPU-specific problem from toroidal topologies.</p></li><li><p>The book briefly discusses <a href="https://jax-ml.github.io/scaling-book/inference/#appendix-d-speculative-sampling">speculative decoding</a>, which uses a cheaper draft model to &#8220;guess&#8221; the next several tokens and verifies post-hoc with rejection sampling or MCMC. It trades off some throughput for more tokens/sec.</p></li></ul><h2><a href="https://jax-ml.github.io/scaling-book/applied-inference/">Part 8: Serving LLaMA 3-70B on TPUs</a></h2><p>The first thing I notice is this comparison of devices / cost per hour on GCP.</p><div class="captioned-image-container"><figure><a class="image-link image2" target="_blank" href="https://substackcdn.com/image/fetch/$s_!E19y!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F15d5c12d-ce70-4caf-ad33-df43bd6df126_2614x506.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!E19y!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F15d5c12d-ce70-4caf-ad33-df43bd6df126_2614x506.png 424w, https://substackcdn.com/image/fetch/$s_!E19y!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F15d5c12d-ce70-4caf-ad33-df43bd6df126_2614x506.png 848w, https://substackcdn.com/image/fetch/$s_!E19y!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F15d5c12d-ce70-4caf-ad33-df43bd6df126_2614x506.png 1272w, https://substackcdn.com/image/fetch/$s_!E19y!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F15d5c12d-ce70-4caf-ad33-df43bd6df126_2614x506.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!E19y!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F15d5c12d-ce70-4caf-ad33-df43bd6df126_2614x506.png" width="1456" height="282" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/15d5c12d-ce70-4caf-ad33-df43bd6df126_2614x506.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:282,&quot;width&quot;:1456,&quot;resizeWidth&quot;:null,&quot;bytes&quot;:96953,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:&quot;image/png&quot;,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:&quot;https://ekzhang.substack.com/i/171607012?img=https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F15d5c12d-ce70-4caf-ad33-df43bd6df126_2614x506.png&quot;,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!E19y!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F15d5c12d-ce70-4caf-ad33-df43bd6df126_2614x506.png 424w, https://substackcdn.com/image/fetch/$s_!E19y!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F15d5c12d-ce70-4caf-ad33-df43bd6df126_2614x506.png 848w, https://substackcdn.com/image/fetch/$s_!E19y!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F15d5c12d-ce70-4caf-ad33-df43bd6df126_2614x506.png 1272w, https://substackcdn.com/image/fetch/$s_!E19y!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F15d5c12d-ce70-4caf-ad33-df43bd6df126_2614x506.png 1456w" sizes="100vw" loading="lazy"></picture><div></div></div></a></figure></div><p>As a principle, FLOPs / $ is right. But the price for H100 GPUs is off by quite a lot. For instance, Modal offers <em>serverless</em> H100s (boot in &lt;2s, premium offering) for $3.95/hr. A quick search shows you can get H100s for much cheaper than even that if you&#8217;re willing to run your own servers. <em>Anyway, just take these prices with a grain of salt; the authors and GCP have a business incentive to make TPUs look good.</em></p><p>Some quick takeaways from this chapter:</p><ul><li><p><strong>KV cache takes up a lot of space!</strong> Each token is about 1/440k the total size of the model in memory, assuming int8 quantization for both. If you have 32k context windows, this limits your batch size a lot. Quick-and-dirty explanation of the &#8220;440k&#8221; number is that it&#8217;s how much bigger the MLP params are.</p></li><li><p><strong>Consider doing int8 quantization but bf16 FLOPs.</strong> You don&#8217;t pay for the extra precision in FLOPs because batch size isn&#8217;t high enough to get to the point where matmul is compute-bound. Low-precision arithmetic may affect performance.</p></li><li><p><strong>You pay a lot for lower latency, below a point.</strong> <a href="https://jax-ml.github.io/scaling-book/applied-inference/#visualizing-the-latency-throughput-tradeoff">These graphs</a> are quite dramatic and show that if you have a very small batch size, you get to run super fast, but throughput sucks.</p></li></ul><h2><a href="https://jax-ml.github.io/scaling-book/profiling/">Part 9: How to Profile TPU Programs</a></h2><p>We&#8217;ve finally started writing JAX. I really like JAX and so I&#8217;m familiar with the framework + its functional (even <em>functorial</em>?) style. But this is something new for me, using the JAX profiler, which is a tool for understanding TPU traces.</p><p>Review of the compiler pipeline: Jaxpr &#8594; StableHLO &#8594; HLO &#8594; TPU LLO. Or you can write custom kernels in <a href="https://docs.jax.dev/en/latest/pallas/index.html">Pallas</a>.</p><p>The key thing to remember is that you can wrap code in <code>jax.profiler.trace()</code> contexts (as well as named scopes / calls) to generate linear traces, profiles, and XLA graph views that open in TensorBoard.</p><pre><code>import jax

with jax.profiler.trace("/tmp/tensorboard"):
    key = jax.random.key(0)
    x = jax.random.normal(key, (1024, 1024))
    y = x @ x
    y.block_until_ready()</code></pre><p>The visualization suite reminds me of <a href="https://github.com/google/pprof">pprof</a>, which I guess makes sense as another Google product. Graphs, looking at the lowering of various parts of code, and a focus on the trace timeline first and foremost.</p><p>Profiling is never super fun, but general rules apply here. It becomes easier as you get fluent in the domain (e.g., author just happens to know &#8220;AllReduce + dynamic slice = ReduceScatter&#8221;), and it&#8217;s also definitely much more approachable if you come in with an idea of what you <em>think</em> the profile will look like and validate.</p><h2><a href="https://jax-ml.github.io/scaling-book/jax-stuff/">Part 10: Programming TPUs in JAX</a></h2><p>There are three modes: fully automatic, explicit sharding (via type system), and manual sharding with shard_map().</p><p>It&#8217;s pretty cool to see these automatic sharding modes in JAX. All of the methods are pretty cutting-edge work. I first read about auto-sharding in 2022 via <a href="https://www.usenix.org/conference/osdi22/presentation/zheng-lianmin">Alpa (OSDI &#8216;22)</a>, which also included inter-operator (pipeline) parallelism in its scope. Probably too much, but hey it was a research prototype.</p><p><a href="https://jax-ml.github.io/scaling-book/jax-stuff/#auto-sharding-mode">Auto sharding mode:</a></p><ol><li><p>You create a device mesh with <code>jax.make_mesh()</code>, with axis shapes and names. Each array gets a <code>jax.NamedSharding</code> set as its device on construction, which lets you specify how to shard the array across devices.</p></li><li><p>After that, <code>jax.jit()</code> allows you to specify in and out-shardings, and all intermediates are then automatically inferred (via heuristic) via <a href="https://openxla.org/shardy">Shardy</a> (XLA).</p></li><li><p>You can then profile it. Maybe you see an issue, and give the compiler a hint with <code>jax.lax.with_sharding_constraint()</code> to change the behavior.</p></li></ol><p><a href="https://jax-ml.github.io/scaling-book/jax-stuff/#explicit-sharding-mode">Explicit sharding mode:</a></p><ol><li><p>You create a mesh as before but pass in the &#8220;Explicit&#8221; axis type.</p></li><li><p>Now, every array you create has sharding in its metadata / type. When you run an operation, it determines whether the sharding can be inferred cleanly or not. If it&#8217;s not possible to infer the best possible sharding, you have to resolve the ambiguity by providing an <code>out_sharding</code> kwarg.</p></li></ol><p><a href="https://jax-ml.github.io/scaling-book/jax-stuff/#manual-sharding-mode-via-shard-map">Manual sharding mode with shard_map:</a> (see also <a href="https://docs.jax.dev/en/latest/notebooks/shard_map.html">tutorial</a>)</p><ol><li><p>You write a program that runs on one device, in <a href="https://en.wikipedia.org/wiki/Single_program%2C_multiple_data">SPMD</a> style like torch.</p></li><li><p>Decorate it with <code>jax.shard_map()</code>, and it will run on all devices in parallel with each device receiving a particular sharding.</p></li><li><p>Insert <a href="https://docs.jax.dev/en/latest/jax.lax.html#parallel-operators">collective operations</a> as needed like <code>jax.lax.ppermute()</code>, <code>jax.lax.pmean()</code>, <code>jax.lax.all_gather()</code>, etc.</p></li></ol><h2><a href="https://jax-ml.github.io/scaling-book/gpus/">Part 12: How to Think About GPUs</a></h2><p>(Slightly out-of-order, leaving Part 11 until the end since it&#8217;s the conclusion.)</p><p>We have one last &#8220;addon&#8221; chapter, this one compares TPUs and GPUs. Their description of GPUs is short and witty.</p><blockquote><p>A modern ML GPU (e.g. H100, B200) is basically a bunch of compute cores that specialize in matrix multiplication (called <strong>Streaming Multiprocessors</strong> or <strong>SMs</strong>) connected to a stick of fast memory (called <strong>HBM</strong>).</p></blockquote><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!aPSy!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd2ba96a9-16c6-49b3-970a-cb49d376d2f1_1910x888.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!aPSy!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd2ba96a9-16c6-49b3-970a-cb49d376d2f1_1910x888.png 424w, https://substackcdn.com/image/fetch/$s_!aPSy!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd2ba96a9-16c6-49b3-970a-cb49d376d2f1_1910x888.png 848w, https://substackcdn.com/image/fetch/$s_!aPSy!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd2ba96a9-16c6-49b3-970a-cb49d376d2f1_1910x888.png 1272w, https://substackcdn.com/image/fetch/$s_!aPSy!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd2ba96a9-16c6-49b3-970a-cb49d376d2f1_1910x888.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!aPSy!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd2ba96a9-16c6-49b3-970a-cb49d376d2f1_1910x888.png" width="664" height="308.74175824175825" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/d2ba96a9-16c6-49b3-970a-cb49d376d2f1_1910x888.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:677,&quot;width&quot;:1456,&quot;resizeWidth&quot;:664,&quot;bytes&quot;:null,&quot;alt&quot;:&quot;&quot;,&quot;title&quot;:null,&quot;type&quot;:null,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:null,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" title="" srcset="https://substackcdn.com/image/fetch/$s_!aPSy!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd2ba96a9-16c6-49b3-970a-cb49d376d2f1_1910x888.png 424w, https://substackcdn.com/image/fetch/$s_!aPSy!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd2ba96a9-16c6-49b3-970a-cb49d376d2f1_1910x888.png 848w, https://substackcdn.com/image/fetch/$s_!aPSy!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd2ba96a9-16c6-49b3-970a-cb49d376d2f1_1910x888.png 1272w, https://substackcdn.com/image/fetch/$s_!aPSy!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd2ba96a9-16c6-49b3-970a-cb49d376d2f1_1910x888.png 1456w" sizes="100vw" loading="lazy"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a></figure></div><p>I&#8217;m pretty familiar with GPUs, having done some CUDA programming, so there&#8217;s a bunch of concepts here like warp = 32 threads, warp scheduler, divergence, shared memory, warpgroups (SM) and so on. GPUs are more general-purpose than TPUs, but they also have a huge chunk cut out for tensor cores that do matmul.</p><p>As I mentioned before, networking is very different in GPUs versus TPUs. Among nodes in a <em>Scalable Unit (SU)</em>, GPUs get full bisection bandwidth, and every node is accessible to every other node by Infiniband switches in <a href="https://en.wikipedia.org/wiki/Fat_tree">fat tree</a> topology. You use RDMA to communicate.</p><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!5ySO!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F05831512-176d-4ddf-8e8a-01e36606bac0_1936x862.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!5ySO!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F05831512-176d-4ddf-8e8a-01e36606bac0_1936x862.png 424w, https://substackcdn.com/image/fetch/$s_!5ySO!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F05831512-176d-4ddf-8e8a-01e36606bac0_1936x862.png 848w, https://substackcdn.com/image/fetch/$s_!5ySO!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F05831512-176d-4ddf-8e8a-01e36606bac0_1936x862.png 1272w, https://substackcdn.com/image/fetch/$s_!5ySO!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F05831512-176d-4ddf-8e8a-01e36606bac0_1936x862.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!5ySO!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F05831512-176d-4ddf-8e8a-01e36606bac0_1936x862.png" width="640" height="284.83516483516485" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/05831512-176d-4ddf-8e8a-01e36606bac0_1936x862.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:648,&quot;width&quot;:1456,&quot;resizeWidth&quot;:640,&quot;bytes&quot;:null,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:null,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:null,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!5ySO!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F05831512-176d-4ddf-8e8a-01e36606bac0_1936x862.png 424w, https://substackcdn.com/image/fetch/$s_!5ySO!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F05831512-176d-4ddf-8e8a-01e36606bac0_1936x862.png 848w, https://substackcdn.com/image/fetch/$s_!5ySO!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F05831512-176d-4ddf-8e8a-01e36606bac0_1936x862.png 1272w, https://substackcdn.com/image/fetch/$s_!5ySO!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F05831512-176d-4ddf-8e8a-01e36606bac0_1936x862.png 1456w" sizes="100vw" loading="lazy"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a></figure></div><p>Each node itself has a few NVLink switches that offer very fast, 1-hop networking between its 8 GPUs, in an all-to-all fashion. This is a pretty nice diagram of how the connectivity changed over time!</p><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!_4wB!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fab2bcf20-d4dd-45ea-8b05-dd908f412916_1999x1090.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!_4wB!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fab2bcf20-d4dd-45ea-8b05-dd908f412916_1999x1090.png 424w, https://substackcdn.com/image/fetch/$s_!_4wB!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fab2bcf20-d4dd-45ea-8b05-dd908f412916_1999x1090.png 848w, https://substackcdn.com/image/fetch/$s_!_4wB!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fab2bcf20-d4dd-45ea-8b05-dd908f412916_1999x1090.png 1272w, https://substackcdn.com/image/fetch/$s_!_4wB!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fab2bcf20-d4dd-45ea-8b05-dd908f412916_1999x1090.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!_4wB!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fab2bcf20-d4dd-45ea-8b05-dd908f412916_1999x1090.png" width="598" height="326.10714285714283" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/ab2bcf20-d4dd-45ea-8b05-dd908f412916_1999x1090.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:794,&quot;width&quot;:1456,&quot;resizeWidth&quot;:598,&quot;bytes&quot;:null,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:null,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:null,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!_4wB!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fab2bcf20-d4dd-45ea-8b05-dd908f412916_1999x1090.png 424w, https://substackcdn.com/image/fetch/$s_!_4wB!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fab2bcf20-d4dd-45ea-8b05-dd908f412916_1999x1090.png 848w, https://substackcdn.com/image/fetch/$s_!_4wB!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fab2bcf20-d4dd-45ea-8b05-dd908f412916_1999x1090.png 1272w, https://substackcdn.com/image/fetch/$s_!_4wB!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fab2bcf20-d4dd-45ea-8b05-dd908f412916_1999x1090.png 1456w" sizes="100vw" loading="lazy"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a></figure></div><p>Beyond the level of a scalable unit, customers and cloud providers do various things to connect GPUs. It sounds like the most popular option is just Infiniband though. You need more high-bandwidth switches to support the fat tree as the number of GPUs and nodes grows. This may add to switching latency, but it also means that you get full connectivity, unlike TPUs.</p><p>Of course, everything changes with the &#8220;GB200 NVL72 SuperPod&#8221; system (what a name&#8230;). Instead of 8 devices on NVLink, you have 72. Great. Not going to think about that one for a while, haha.</p><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!cyAi!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F57a39373-2a9f-46bc-97b3-6aec191f95ee_1186x876.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!cyAi!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F57a39373-2a9f-46bc-97b3-6aec191f95ee_1186x876.png 424w, https://substackcdn.com/image/fetch/$s_!cyAi!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F57a39373-2a9f-46bc-97b3-6aec191f95ee_1186x876.png 848w, https://substackcdn.com/image/fetch/$s_!cyAi!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F57a39373-2a9f-46bc-97b3-6aec191f95ee_1186x876.png 1272w, https://substackcdn.com/image/fetch/$s_!cyAi!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F57a39373-2a9f-46bc-97b3-6aec191f95ee_1186x876.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!cyAi!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F57a39373-2a9f-46bc-97b3-6aec191f95ee_1186x876.png" width="618" height="456.4654300168634" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/57a39373-2a9f-46bc-97b3-6aec191f95ee_1186x876.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:876,&quot;width&quot;:1186,&quot;resizeWidth&quot;:618,&quot;bytes&quot;:null,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:null,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:null,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!cyAi!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F57a39373-2a9f-46bc-97b3-6aec191f95ee_1186x876.png 424w, https://substackcdn.com/image/fetch/$s_!cyAi!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F57a39373-2a9f-46bc-97b3-6aec191f95ee_1186x876.png 848w, https://substackcdn.com/image/fetch/$s_!cyAi!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F57a39373-2a9f-46bc-97b3-6aec191f95ee_1186x876.png 1272w, https://substackcdn.com/image/fetch/$s_!cyAi!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F57a39373-2a9f-46bc-97b3-6aec191f95ee_1186x876.png 1456w" sizes="100vw" loading="lazy"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a></figure></div><p>The full connectivity changes up some collectives. AllReduce, ReduceScatter, and AllGather mostly stay the same, though latency can be reduced with an optimized implementation. AllToAll gets significantly faster because you can send the sharded data directly from one device to others.</p><p>One fun gadget is <a href="https://developer.nvidia.com/blog/advancing-performance-with-nvidia-sharp-in-network-computing/">SHARP</a>, which lets you do reductions within the network switch itself. This theoretically reduces time by half, but the authors here only saw 30% improvement in practice. May slightly change up calculations!</p><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!-yoD!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fe79e406b-8a18-4f23-b2be-0da4cae7a6b7_1650x890.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!-yoD!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fe79e406b-8a18-4f23-b2be-0da4cae7a6b7_1650x890.png 424w, https://substackcdn.com/image/fetch/$s_!-yoD!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fe79e406b-8a18-4f23-b2be-0da4cae7a6b7_1650x890.png 848w, https://substackcdn.com/image/fetch/$s_!-yoD!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fe79e406b-8a18-4f23-b2be-0da4cae7a6b7_1650x890.png 1272w, https://substackcdn.com/image/fetch/$s_!-yoD!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fe79e406b-8a18-4f23-b2be-0da4cae7a6b7_1650x890.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!-yoD!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fe79e406b-8a18-4f23-b2be-0da4cae7a6b7_1650x890.png" width="601" height="324.02815934065933" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/e79e406b-8a18-4f23-b2be-0da4cae7a6b7_1650x890.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:785,&quot;width&quot;:1456,&quot;resizeWidth&quot;:601,&quot;bytes&quot;:null,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:null,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:null,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!-yoD!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fe79e406b-8a18-4f23-b2be-0da4cae7a6b7_1650x890.png 424w, https://substackcdn.com/image/fetch/$s_!-yoD!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fe79e406b-8a18-4f23-b2be-0da4cae7a6b7_1650x890.png 848w, https://substackcdn.com/image/fetch/$s_!-yoD!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fe79e406b-8a18-4f23-b2be-0da4cae7a6b7_1650x890.png 1272w, https://substackcdn.com/image/fetch/$s_!-yoD!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fe79e406b-8a18-4f23-b2be-0da4cae7a6b7_1650x890.png 1456w" sizes="100vw" loading="lazy"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a></figure></div><p>The second half of the chapter follows with a bunch of lore on GPU training and inference, basically going through the same roofline calculations as before. It&#8217;s slightly different because you need to consider the inter-node bandwidth as well.</p><p>Remember that <em>pipeline parallelism does not play well with FSDP</em> due to the weight sharding getting screwed up by pipelines.</p><p>During training, you need a local batch size of about 2500 tokens per GPU. Besides that, you can combine model parallelism / expert parallelism (8-64 GPUs), then pipeline parallelism, and finally ZeRO-1 data parallelism.</p><h2><a href="https://jax-ml.github.io/scaling-book/conclusion/">Part 11: Conclusions and Further Reading</a></h2><p>This was a great read. I&#8217;m happy that so many people came together to write this; the mental model and visual acuity are on point. Also a nice way to advertise TPUs (haha), at least in spreading awareness of their programming model.</p><p>It&#8217;s a &#8220;textbook&#8221; but definitely one of the more well-written textbooks I&#8217;ve seen, and it presents a nice calculus of sorts for reasoning about these systems. Kind of like a math book, where you introduce new complexity, help the reader get used to it, simplify and then move on to more difficult challenges.</p><p>I went through my notes again with a friend and realized that they&#8217;re quite sparse if you aren&#8217;t already familiar with parallelism. Sorry! Consider this a band-pass filter over the book&#8217;s content, made for Eric. :)</p>]]></content:encoded></item><item><title><![CDATA[How the jax.jit() JIT compiler works in jax-js]]></title><description><![CDATA[A lightweight compiler for a new numerical computing / ML library.]]></description><link>https://ss.ekzhang.com/p/how-the-jaxjit-jit-compiler-works</link><guid isPermaLink="false">https://ss.ekzhang.com/p/how-the-jaxjit-jit-compiler-works</guid><dc:creator><![CDATA[Eric]]></dc:creator><pubDate>Wed, 14 May 2025 14:02:13 GMT</pubDate><enclosure url="https://substackcdn.com/image/fetch/$s_!nQCt!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fbcbb046e-d61b-4f17-9040-282a6230ec28_1040x1488.png" length="0" type="image/jpeg"/><content:encoded><![CDATA[<p>Since the start of this year, I&#8217;ve been working on a version of <strong>JAX in pure JavaScript</strong>.</p><p>For this, I need to make a deep learning compiler from scratch, and I want to keep it lightweight (e.g., JAX uses XLA as its compiler, which is 200 KLoC &#8212; too much bundle size for the web!). This is a note about the trickiest fundamental problem I&#8217;ve run into, and how I&#8217;m going about solving it.</p><h2>What is jax-js?</h2><p><a href="https://docs.jax.dev/en/latest/">JAX</a> is a great library. It takes the numerical computing properties of NumPy, shoves in GPU + Autograd, then packages it all up in a convenient API.</p><pre><code>import jax.numpy as jnp

a = jnp.array([1, 2, 3])
a * 10  # [10, 20, 30]
grad(lambda x: (x * x).sum())(a)  # [2, 4, 6]</code></pre><p>By writing JAX in pure JS, using web APIs, we solve two problems:</p><ol><li><p><strong>How to do <a href="https://github.com/numpy/numpy">numerical compute</a> in the browser?</strong> Like taking the mean of some numbers, or applying an image filter. Lots of applications, (statistics, data science, classical ML, CV, etc.), but right now it&#8217;s pretty hard to do well.</p></li><li><p><strong>How do you run <a href="https://github.com/amandaghassaei/gpu-io">GPU compute</a> in the browser?</strong> There are technologies like WebGPU if you want to write your own shaders, which is great if you&#8217;re making a video game. But this is tricky if you just want to do something simple. After all, a lot more people use PyTorch/JAX than write CUDA kernels!</p></li></ol><p>No other library, ported to JS directly, would solve both problems at the same time. JAX hits the sweet spot since it&#8217;s useful for ML, and it also matches <a href="https://data-apis.org/array-api/latest/">NumPy&#8217;s API</a>.</p><pre><code>import { grad, numpy as np } from "@jax-js/jax";

const a = np.array([1, 2, 3]);  // note: type is np.Array
a.mul(10);  // [10, 20, 30]
grad((x) =&gt; x.mul(x).sum())(a);  // [2, 4, 6]</code></pre><p>If you just want numerical computing features, import <code>numpy as np</code>. If you need everything else, you can pull it in as needed.</p><h2>Optimistic dispatch</h2><p>So how do you implement this? If your operations are individual CPU calls and you&#8217;re following NumPy, you would dispatch them one-by-one to a kernel. Maybe that&#8217;s a Wasm kernel for instance, and you could implement core operations like:</p><pre><code>function neg(a: Array) {  // a =&gt; -a
  const output = arrayLike(a);
  wasmBackend.dispatch("NEG:1", [a.buffer], [output.buffer]);
  return output;
}

function mul(a: Array, b: Array) {  // a, b =&gt; a * b
  [a, b] = broadcast(a, b);
  const output = arrayLike(a);
  wasmBackend.dispatch("MUL:2", [a.buffer, b.buffer], [output.buffer]);
  return output;
}</code></pre><p>And then you&#8217;d have optimized Wasm kernels for each of these core operations. This is what <a href="https://www.npmjs.com/package/@tensorflow/tfjs-backend-wasm">tfjs-backend-wasm</a> does, for instance.</p><p>But for deep learning workloads, you often want to fuse operations together. For example, let&#8217;s say you want to compute <strong>norm(x * 3 + 2)</strong> for a vector <strong>x</strong>. Doing this naively might take <strong>4 data round-trips</strong> to the GPU or other device:</p><ol><li><p>Compute <strong>x * 3</strong>, store the result in <strong>a</strong>.</p></li><li><p>Compute <strong>a + 2</strong>, store the result in <strong>b</strong>.</p></li><li><p>Compute <strong>b * b</strong>, store the result in <strong>c</strong>.</p></li><li><p>Compute <strong>sum(c)</strong>, store the result in <strong>d</strong>.</p></li><li><p>Return <strong>sqrt(d)</strong>.</p></li></ol><p>For experimenting on small data, a few round trips won&#8217;t hurt anyone. But this can get painfully slow for more complex math, especially when you add in <a href="https://docs.jax.dev/en/latest/autodidax.html">JAX-style autograd</a> via transformations, which can increase the number of generated operations a lot.</p><p>So we&#8217;d like a way to make the operations more efficient, especially for repeated operations. This way, your browser simulation doesn&#8217;t skip frames, and your LLM produces more output tokens.</p><h2>Understanding the machine learning JIT</h2><p>The inspiration for this JIT, or <a href="https://en.wikipedia.org/wiki/Just-in-time_compilation">just-in-time compiler</a>, comes from XLA, which is JAX&#8217;s backend, originating from the <a href="https://www.tensorflow.org/">TensorFlow</a> project. XLA represents computations as directed acyclic graphs (DAGs) of core primitives. Some examples:</p><ul><li><p>Exponential computes e<sup>x</sup>.</p></li><li><p>Multiply multiplies two numbers.</p></li><li><p>Broadcast expands the axes of its input by repeating it.</p></li><li><p>Reduce(Subcomputation:add) takes the sum of a tensor along some axes.</p></li></ul><p>Then XLA transforms the graph on the left into the graph on the right through a series of optimization passes. In this case, several operations are turned into <em>fused expressions</em>, which reduces the number of round-trips and makes the overall computation <strong>~50x faster</strong> on a T4 GPU.</p><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!nQCt!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fbcbb046e-d61b-4f17-9040-282a6230ec28_1040x1488.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!nQCt!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fbcbb046e-d61b-4f17-9040-282a6230ec28_1040x1488.png 424w, https://substackcdn.com/image/fetch/$s_!nQCt!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fbcbb046e-d61b-4f17-9040-282a6230ec28_1040x1488.png 848w, https://substackcdn.com/image/fetch/$s_!nQCt!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fbcbb046e-d61b-4f17-9040-282a6230ec28_1040x1488.png 1272w, https://substackcdn.com/image/fetch/$s_!nQCt!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fbcbb046e-d61b-4f17-9040-282a6230ec28_1040x1488.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!nQCt!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fbcbb046e-d61b-4f17-9040-282a6230ec28_1040x1488.png" width="1040" height="1488" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/bcbb046e-d61b-4f17-9040-282a6230ec28_1040x1488.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:1488,&quot;width&quot;:1040,&quot;resizeWidth&quot;:null,&quot;bytes&quot;:null,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:null,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:null,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!nQCt!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fbcbb046e-d61b-4f17-9040-282a6230ec28_1040x1488.png 424w, https://substackcdn.com/image/fetch/$s_!nQCt!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fbcbb046e-d61b-4f17-9040-282a6230ec28_1040x1488.png 848w, https://substackcdn.com/image/fetch/$s_!nQCt!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fbcbb046e-d61b-4f17-9040-282a6230ec28_1040x1488.png 1272w, https://substackcdn.com/image/fetch/$s_!nQCt!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fbcbb046e-d61b-4f17-9040-282a6230ec28_1040x1488.png 1456w" sizes="100vw" loading="lazy"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a></figure></div><p>Well you might say, XLA is a really high-caliber compiler, at the state of the art for ML compilation. jax-js doesn&#8217;t <em>need to</em> achieve top-level performance like this, since it&#8217;s running in the browser. People&#8217;s hardware / platforms are different, and it can tolerate some slack. But not <em><strong>50x</strong></em>(!!); I think getting within 3-5x of optimal would be reasonable&#8212;so we need the JIT compiler.</p><p><em>(Aside: A lot of the performance difference in this case is because jax.jit() saves the graph and avoids dynamic tracing on each run, which is also relevant for us. Ignoring the dynamic tracing, I would guess the compiler alone accounts for only ~10x, maybe.)</em></p><h2>How do you build an ML JIT?</h2><p>So you need a compiler, and with compilers, you need an <em>intermediate representation (IR)</em> that lets you represent the computation internally. The compiler plan is to take an input, pass it through the frontend and create an IR, then optimize that IR and produce an output.</p><p>To make this work, I&#8217;m basing my IR on <a href="https://github.com/tinygrad/tinygrad">tinygrad</a>, which is a very small deep learning library. The key difference between tinygrad and XLA is that tinygrad have a lot fewer primitive operations. For example, to represent a 2048x2048 matmul, the HLO would be:</p><pre><code>HloModule jit_matmul, entry_computation_layout={(f32[2048,2048]{1,0}, f32[2048,2048]{1,0})-&gt;f32[2048,2048]{1,0}}

ENTRY main.4 {
  Arg_0.1 = f32[2048,2048]{1,0} parameter(0)
  Arg_1.2 = f32[2048,2048]{1,0} parameter(1)
  <strong>ROOT dot.3 = f32[2048,2048]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}</strong>
}</code></pre><p>The last line uses the primitive <strong>dot</strong> operation, which is literally just a matmul.</p><p>In contrast, tinygrad produces something more like:</p><pre><code>a1 = a.reshape([2048, 1, 2048])
b1 = b.transpose().reshape([1, 2048, 2048])
return (a * b).sum(axis=2)</code></pre><p>The first two lines are &#8220;movement operations&#8221; that just produce views of the data, and crucially, tracking the view can all be done within a single kernel without actually making copies. They call this <em>laziness</em> &#8212; but honestly I think the core thing that makes it work is not the laziness, but rather their algebra of tracking views.</p><p>So I&#8217;m taking this view-tracking system for jax-js, and it&#8217;s been working great. jax-js has an IR defined by the <strong>AluExp</strong> class, which is a (very) simplified version of tinygrad&#8217;s <strong>UOp</strong> and looks like:</p><pre><code>/** Mathematical expression on scalar values. */
export class AluExp {
  constructor(
    readonly op: AluOp,
    readonly dtype: DType,
    readonly src: AluExp[],
    readonly arg: any = undefined,
  ) {}
  // ...
}</code></pre><p>An expression is fused and then is placed into a <em>kernel</em>, where each kernel contains at most one reduction.</p><pre><code>/**
 * Description of a kernel to be compiled.
 *
 * Each of these can be processed by a backend into some lower-level
 * representation. It consists of one or more fused operations, optionally
 * indexing into a buffer.
 */
export class Kernel {
  constructor(
    /** Number of global arguments / arrays. */
    readonly nargs: number,
    /** Size of the result array in element count. */
    readonly size: number,
    /** Expression to be evaluated. */
    readonly exp: AluExp,
    /** Optional reduction to be performed. */
    readonly reduction?: Reduction,
  ) {
    this.exp = exp.simplify();
  }
  // ...
}</code></pre><p>This gives us everything we need to implement compiler optimizations and lower IR expressions into optimized WebGPU or WebAssembly code.</p><h2>jax.jit() &#8211; joining the frontend with the IR</h2><p>Now that we have the IR done, let&#8217;s return to the actual library frontend. Recall we&#8217;ve been generating graphs of operations through JAX, which can have combinators like <strong>grad()</strong> and <strong>jvp()</strong> for <a href="https://docs.jax.dev/en/latest/automatic-differentiation.html">automatic differentiation</a>. So you could write an operation like <strong>log(2*x)</strong>, and it would produce the computation graph for <strong>2/(2*x)</strong> after applying the chain rule.</p><p>These graphs are almost what we need &#8212; but we need to decide when to dispatch them to the backend via <strong>Kernel</strong> objects, knowing that:</p><ol><li><p>Each kernel fuses a common subexpression and then runs it on the GPU.</p></li><li><p>A kernel can have at most 1 reduction (for technical reasons; reductions are the starting point for optimizations).</p></li></ol><p>A motivating example is the matmul operation, which we can try porting over:</p><pre><code>function matmul(a: Array, b: Array) {
  // for clarity, assume a, b are of shape (n, n)
  const c = a.reshape([n, 1, n]) * b.transpose().reshape([1, n, n]);
  return c.sum({ axis: 2 });
}</code></pre><p>There&#8217;s a tradeoff with this approach. tinygrad doesn&#8217;t actually <em>do anything</em> until you call the <strong>realize()</strong> function, which kicks off work. So it&#8217;s fine that you&#8217;re multiplying these matrices and producing c, which is of size n<sup>3</sup>, since c never actually gets realized.</p><p>jax-js tries to be a general-purpose library, so this behavior might be a bit confusing to people used to NumPy.</p><p>Luckily, we can borrow another primitive from JAX, which is the jit() function. This traces an expression, produces a &#8220;Jaxpr&#8221; or DAG of operations, and then passes it down to the ML compiler.</p><pre><code><strong>const matmul = jit(</strong>function matmul(a: Array, b: Array) {
  // for clarity, assume a, b are of shape (n, n)
  const c = a.reshape([n, 1, n]) * b.transpose().reshape([1, n, n]);
  return c.sum({ axis: 2 });
}<strong>);</strong></code></pre><p>This <em>opts into</em> kernel fusion and optimization. Now, whenever the function is called with inputs of a certain shape, we get the full DAG and can run a graph algorithm to break it down into common subexpressions, each lowered into a <strong>Kernel</strong> object containing a fused <strong>AluExp</strong>.</p><p>With this, I think I&#8217;m able to offer a really fast, optimized matrix multiplication, while doing minimal work on the compiler side and keeping in line with the &#8220;spirit&#8221; of JAX: composable function transformations. There&#8217;s no need for me to write new primitives for every ML operation: like pad, fused batch normalization, and so on.</p><h2>Conclusion</h2><p>I started this project at the beginning of the year, so it&#8217;s been about 3-4 months now. At the beginning, I never thought that I would be actually implementing an ML compiler, but here we are. What makes it more manageable was a combination of:</p><ol><li><p><strong>Relying on JAX&#8217;s in-built JIT tracing.</strong> So composite operations like matmul(), but also anything from norm() to einsum(), can be implemented in terms of smaller parts. It gives us a clean DAG, after autograd and any combinators, to hand off to the compiler backend.</p></li><li><p><strong>Borrowing tinygrad&#8217;s &#8220;view&#8221; system.</strong> This drastically simplifies the IR (see <a href="https://openxla.org/xla/operation_semantics">XLA&#8217;s IR</a> for instance) and the amount of work needed to build a working library.</p></li></ol><p>So that&#8217;s how jax-js is going. We&#8217;ll soon have jax.jit() support, and then some demos.</p><h3>What comes next?</h3><p>On the performance front, jax-js is already looking pretty good. It produces better matmul benchmarks than TensorFlow.js. I think landing jit() will be okay for now.</p><p>There are some unresolved questions related to memory:</p><ul><li><p>How do you do free memory? JS doesn&#8217;t have a destructor like Python does with its reference-counted <code>__del__()</code> method. Maybe use linear types.</p></li><li><p>For the WebAssembly backend, how do you allocate buffers in Wasm linear memory? Generally you want to avoid fragmentation, so maybe there&#8217;s a simple way to do memory allocation here, like relying on a <a href="https://en.wikipedia.org/wiki/Buddy_memory_allocation">buddy allocator</a> for tracking free chunks of pages.</p></li></ul><p>But I&#8217;m excited about what&#8217;s coming up, since it&#8217;s almost fully usable as a numerical computing library. Some stuff I want to put in the browser soon: <a href="https://www.webgpuaudio.com/">audio visualizers</a>, <a href="https://www.ekzhang.com/webgl-julia-viewer/">fractals</a>, <a href="https://apps.amandaghassaei.com/gpu-io/examples/fluid/">fluid simulation</a>, and <a href="https://github.com/facebookresearch/encodec">neural audio coding</a>. After that, I&#8217;ll open-source the library for others to try out.</p><p>If you want to keep up to date, feel free to follow me at <a href="https://twitter.com/ekzhang1">@ekzhang1</a>.</p><p>Hope you learned something about compilers. &#128062;</p><p></p><div><hr></div><p><strong>EDIT (May 24):</strong> Since writing this post, I&#8217;ve implemented jax.jit() in the library and successfully now have an auto-fusing, composable and GPU kernel-tuned ML compiler in JavaScript. I think that&#8217;s the first of its kind! As a simple example, here&#8217;s an implementation of matmul() in terms of jit(). :D</p><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!kP1q!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fe33cf466-4e18-4d11-9602-97ac786200fb_1128x836.png" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!kP1q!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fe33cf466-4e18-4d11-9602-97ac786200fb_1128x836.png 424w, https://substackcdn.com/image/fetch/$s_!kP1q!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fe33cf466-4e18-4d11-9602-97ac786200fb_1128x836.png 848w, https://substackcdn.com/image/fetch/$s_!kP1q!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fe33cf466-4e18-4d11-9602-97ac786200fb_1128x836.png 1272w, https://substackcdn.com/image/fetch/$s_!kP1q!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fe33cf466-4e18-4d11-9602-97ac786200fb_1128x836.png 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!kP1q!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fe33cf466-4e18-4d11-9602-97ac786200fb_1128x836.png" width="1128" height="836" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/e33cf466-4e18-4d11-9602-97ac786200fb_1128x836.png&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:836,&quot;width&quot;:1128,&quot;resizeWidth&quot;:null,&quot;bytes&quot;:188830,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:&quot;image/png&quot;,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:&quot;https://ekzhang.substack.com/i/163548742?img=https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fe33cf466-4e18-4d11-9602-97ac786200fb_1128x836.png&quot;,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!kP1q!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fe33cf466-4e18-4d11-9602-97ac786200fb_1128x836.png 424w, https://substackcdn.com/image/fetch/$s_!kP1q!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fe33cf466-4e18-4d11-9602-97ac786200fb_1128x836.png 848w, https://substackcdn.com/image/fetch/$s_!kP1q!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fe33cf466-4e18-4d11-9602-97ac786200fb_1128x836.png 1272w, https://substackcdn.com/image/fetch/$s_!kP1q!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fe33cf466-4e18-4d11-9602-97ac786200fb_1128x836.png 1456w" sizes="100vw" loading="lazy"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a></figure></div><p></p>]]></content:encoded></item><item><title><![CDATA[jax-js devlog feb 17]]></title><description><![CDATA[staging interpreters and a swirl of info]]></description><link>https://ss.ekzhang.com/p/jax-js-devlog-feb-17</link><guid isPermaLink="false">https://ss.ekzhang.com/p/jax-js-devlog-feb-17</guid><dc:creator><![CDATA[Eric]]></dc:creator><pubDate>Tue, 18 Feb 2025 01:20:12 GMT</pubDate><content:encoded><![CDATA[<p>So a lot has happened since last week. Today is a holiday, and besides the nice lunch meetup with an old acquaintance, my focus today is just to make progress on this side project.</p><h3>personal reflections</h3><p>it&#8217;s cool that so far this year, I&#8217;ve been making the most (non-work) GitHub activity of any year in the past, at least in terms of rate. this is pretty good! I wonder if I&#8217;ll be able to keep that up.</p><p>but that&#8217;s a good reminder that, no, I&#8217;m not getting lazier or less inspired or something with software work, even though it seems that way. I&#8217;m just pushing myself to do bigger things!</p><blockquote><p><em>it matters</em> -A</p></blockquote><p>when you&#8217;re small, your elementary school looks so big, and then you come back as an adult and marvel at how small everything was. this is kind of like that</p><h3>relevant to jax-js plans</h3><p>um so we had Chenyu from tinygrad come over to NYSRG and I think after reading the codebase and also <a href="https://github.com/ekzhang/nysrg-notes/blob/main/2025/02-gpu-kernel-programming.md">taking notes on related cuda things</a></p><ul><li><p>i understand the picture of compiling operations into kernels a lot better now</p><ul><li><p>the rewrite rules / lazy pattern matchers is just a less PL jargon-infused way (or should I saw, less PL-aware) of talking about staging, like what JAX does with HLO/XLA</p></li><li><p>you can get a <em>long</em> way with pretty simple kernels and just a couple hand-rolled heuristics is my takeaway from the tinygrad paper</p></li><li><p>in retrospect this should be pretty obvious. like, automatic heuristics should certainly at least be better than a static library of a couple compiled kernels. it&#8217;s smaller and more flexible with low development resources</p></li><li><p>and gpus can&#8217;t be <em>that</em> complicated. there are memory hierarchies, but even complex problems tend to have fairly parsimonious solutions</p></li></ul></li><li><p>this means that I am pretty confident (overconfident??) in being able to get rid of the dependency on tfjs-core at some point in the future</p></li><li><p>which is huge, since then I&#8217;m not limited to a couple dtypes and can also optimize any operations of my choice, and extend the project arbitrarily to support even more operations or algorithms to achieve numpy API-compatibility</p><ul><li><p>you want a QR decomposition from numpy.linalg.qr()? sure, have it</p></li></ul></li></ul><pre><code>$ TC=0 DEBUG=4 python3 test.py

# ... stuff

 0: (64, 32, 8, 16, 1, 4, 4, 1) float.ptr(4194304)   (65536, 64, 8192, 4, 0, 1, 2048, 0)
 1: (64, 32, 8, 16, 512, 4, 4, 4) float.ptr(4194304)   (0, 64, 0, 4, 8192, 1, 0, 2048)
 2: (64, 32, 8, 16, 512, 4, 4, 4) float.ptr(4194304)   (65536, 0, 8192, 0, 4, 0, 2048, 1)
[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.LOCAL, axis=1, arg=16)]
#include &lt;metal_stdlib&gt;
using namespace metal;
kernel void r_64_32_8_16_512_4_4_4(device float* data0, device float* data1, device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
  int gidx0 = gid.x; /* 32 */
  int gidx1 = gid.y; /* 64 */
  int lidx0 = lid.x; /* 8 */
  int lidx1 = lid.y; /* 16 */
  int alu0 = (gidx0&lt;&lt;6);
  int alu1 = (gidx1&lt;&lt;16);
  int alu2 = (lidx0&lt;&lt;13);
  int alu3 = (lidx1&lt;&lt;2);
  float acc0 = 0.0f;
  float acc1 = 0.0f;
  float acc2 = 0.0f;
  float acc3 = 0.0f;
  float acc4 = 0.0f;
  float acc5 = 0.0f;
  float acc6 = 0.0f;
  float acc7 = 0.0f;
  float acc8 = 0.0f;
  float acc9 = 0.0f;
  float acc10 = 0.0f;
  float acc11 = 0.0f;
  float acc12 = 0.0f;
  float acc13 = 0.0f;
  float acc14 = 0.0f;
  float acc15 = 0.0f;
  for (int ridx0 = 0; ridx0 &lt; 512; ridx0++) {
    int alu4 = (alu1+alu2+(ridx0&lt;&lt;2));
    float4 val0 = *((device float4*)((data1+alu4)));
    float4 val1 = *((device float4*)((data1+(alu4+2048))));
    float4 val2 = *((device float4*)((data1+(alu4+4096))));
    float4 val3 = *((device float4*)((data1+(alu4+6144))));
    int alu5 = (alu0+alu3+(ridx0&lt;&lt;13));
    float4 val4 = *((device float4*)((data2+alu5)));
    float4 val5 = *((device float4*)((data2+(alu5+2048))));
    float4 val6 = *((device float4*)((data2+(alu5+4096))));
    float4 val7 = *((device float4*)((data2+(alu5+6144))));
    acc0 = (acc0+(val0.x*val4.x)+(val0.y*val5.x)+(val0.z*val6.x)+(val0.w*val7.x));
    acc1 = (acc1+(val1.x*val4.x)+(val1.y*val5.x)+(val1.z*val6.x)+(val1.w*val7.x));
    acc2 = (acc2+(val2.x*val4.x)+(val2.y*val5.x)+(val2.z*val6.x)+(val2.w*val7.x));
    acc3 = (acc3+(val3.x*val4.x)+(val3.y*val5.x)+(val3.z*val6.x)+(val3.w*val7.x));
    acc4 = (acc4+(val0.x*val4.y)+(val0.y*val5.y)+(val0.z*val6.y)+(val0.w*val7.y));
    acc5 = (acc5+(val1.x*val4.y)+(val1.y*val5.y)+(val1.z*val6.y)+(val1.w*val7.y));
    acc6 = (acc6+(val2.x*val4.y)+(val2.y*val5.y)+(val2.z*val6.y)+(val2.w*val7.y));
    acc7 = (acc7+(val3.x*val4.y)+(val3.y*val5.y)+(val3.z*val6.y)+(val3.w*val7.y));
    acc8 = (acc8+(val0.x*val4.z)+(val0.y*val5.z)+(val0.z*val6.z)+(val0.w*val7.z));
    acc9 = (acc9+(val1.x*val4.z)+(val1.y*val5.z)+(val1.z*val6.z)+(val1.w*val7.z));
    acc10 = (acc10+(val2.x*val4.z)+(val2.y*val5.z)+(val2.z*val6.z)+(val2.w*val7.z));
    acc11 = (acc11+(val3.x*val4.z)+(val3.y*val5.z)+(val3.z*val6.z)+(val3.w*val7.z));
    acc12 = (acc12+(val0.x*val4.w)+(val0.y*val5.w)+(val0.z*val6.w)+(val0.w*val7.w));
    acc13 = (acc13+(val1.x*val4.w)+(val1.y*val5.w)+(val1.z*val6.w)+(val1.w*val7.w));
    acc14 = (acc14+(val2.x*val4.w)+(val2.y*val5.w)+(val2.z*val6.w)+(val2.w*val7.w));
    acc15 = (acc15+(val3.x*val4.w)+(val3.y*val5.w)+(val3.z*val6.w)+(val3.w*val7.w));
  }
  int alu23 = (alu0+alu1+alu2+alu3);
  *((device float4*)((data0+alu23))) = float4(acc0,acc4,acc8,acc12);
  *((device float4*)((data0+(alu23+2048)))) = float4(acc1,acc5,acc9,acc13);
  *((device float4*)((data0+(alu23+4096)))) = float4(acc2,acc6,acc10,acc14);
  *((device float4*)((data0+(alu23+6144)))) = float4(acc3,acc7,acc11,acc15);
}
*** METAL      9 r_64_32_8_16_512_4_4_4                    arg  3 mem  0.05 GB tm     19.86ms/    22.58ms (   864.86 GFLOPS    2.5|865.7   GB/s) ['__matmul__']</code></pre><p>but right now the milestones still look like:</p><p>- [ ] It works!<br>- [ ] Demos: Navier-Stokes, neural networks, statistics<br>- [ ] We figure out the `dispose()` / linear types stuff<br>- [ ] Device switching with `.to()` between webgl/webgpu/cpu/wasm<br>- [ ] First custom kernel<br>- [ ] numpy/jax API compatibility table<br>- [ ] Convert Jaxprs into a tree data structure<br><strong>- [ ] Pattern matchers for kernel fusion<br>- [ ] Kernel codegen, or synthesis</strong></p><p>in particular I think the pattern matchers, scheduling, and codegen components (equivalent of ExecItem in tinygrad) will probably end up fitting into the equivalent of the `xla_call` operation in JAX. so we&#8217;ll have two separate parts of the codebase, one for compilation and one for non-jitted code.</p><p>this sounds kind of weird at first, but I think it&#8217;s the right choice given the design tradeoffs we&#8217;re making. we want it to be fast, but we don&#8217;t need to squeeze out every drop of performance &#8212; after all, we don&#8217;t even know what hardware we&#8217;re running on since it&#8217;s a javascript in-browser library.</p><p>the other advantage of jitting this is that we can auto-manage memory (er, we have to predict static memory patterns anyway, so we get this for free) and that&#8217;s important given that javascript has no reliable GC dispose hook (destructor)</p><p>anyway this seems pretty solid</p><h3>development</h3><p>tests continue to pass and reveal their utility over time. also vitest&#8217;s <a href="https://vitest.dev/guide/snapshot.html">inline snapshot testing</a> is quite fast &amp; awesome.</p><p>anyway, it&#8217;s 8 PM right now, here&#8217;s what we got from today</p><pre><code>git --no-pager diff --stat "@{1 day ago}"

 README.md            |   3 +
 src/core.ts          | 501 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-
 src/index.ts         |  13 ++-
 src/pprint.test.ts   |  68 +++++++++++++
 src/pprint.ts        |  57 +++++++++++
 src/utils.ts         |   2 +-
 test/tracing.test.ts |  48 +++++++++
 7 files changed, 688 insertions(+), 4 deletions(-)</code></pre><p>basically just finished implementing jaxpr logic and tracing. I understand how xla_call works as well now, the jaxpr is placed into the parameters and it composes in some interesting ways. mental exercises:</p><ul><li><p>what happens when you jit() a jit()</p></li><li><p>what happens when you jvp() a jit()</p></li><li><p>what happens when you jit() a jvp()</p></li><li><p>what happens when you makeJaxpr() a jit()</p></li></ul><p>I also understand (I think?) this gem of a quote laden with PL terminology, lmao</p><blockquote><p>There are two options for how to handle higher-order primitives. Each requires a different approach to tracing and engenders different tradeoffs:</p><ol><li><p><strong>On-the-fly processing, where </strong><code>bind</code><strong> takes a Python callable as an argument.</strong> We defer forming a jaxpr until as late as possible, namely until we&#8217;re running the final interpreter at the bottom of the interpreter stack. That way we can swap a <code>JaxprTrace</code> in at the bottom of the interpreter stack and thus stage out rather than execute all primitive operations. With this approach, transformations in the stack get applied as we execute the Python callable as usual. This approach can be very tricky to implement, but it&#8217;s as general as possible because it allows higher-order primitives not to raise the abstraction level of their arguments and thus allows data-dependent Python control flow. We refer to this approach as using a &#8220;final-style higher-order primitive&#8221; employing the discharge-at-tracing-time &#8220;final-style transformations&#8221; we&#8217;ve used so far.</p></li><li><p><strong>Staged processing, where </strong><code>bind</code><strong> takes a jaxpr as an argument.</strong> Before we call <code>bind</code>, in the primitive wrapper we can just use <code>make_jaxpr</code> to form a jaxpr up-front and be done with the Python callable entirely. In this case, <code>make_jaxpr</code> puts its <code>JaxprTrace</code> at the top of the interpreter stack, and no transformations lower in the stack, which might enter via closed-over Tracers, are applied to the Python callable as we trace it. (Transformations applied within the Python callable are applied as usual, being added to the stack above the JaxprTrace.) Instead, the transformations lower in the stack are later applied to the call primitive, and the call primitive&#8217;s rules must then transform the jaxpr itself. Because we trace to a jaxpr up-front, this approach can&#8217;t support data-dependent Python control flow, but it is more straightforward to implement. We refer to this kind of higher-order primitive as an &#8220;initial-style higher-order primitive&#8221;, and say that its jaxpr-processing transformation rules are &#8220;initial-style transformation rules.&#8221;</p></li></ol><p>The latter approach fits for <code>jit</code> because we don&#8217;t need to support data-dependent Python control flow in the user-provided Python callable, as the whole purpose of <code>jit</code> is to stage computation out of Python to be executed by XLA. (In contrast, <code>custom_jvp</code> is a higher-order primitive in which we want to support data-dependent Python control flow.)</p><p>Historically, we started using the &#8220;initial-style&#8221; and &#8220;final-style&#8221; terminology after reading the <a href="http://okmij.org/ftp/tagless-final/index.html">typed tagless final interpreters</a> paper, and jokingly referring to JAX as an implementation of &#8220;untyped tagful final interpreters.&#8221; We don&#8217;t claim to carry over (or understand) any deep meaning behind these terms; we loosely use &#8220;initial style&#8221; to mean &#8220;build an AST and then transform it&#8221;, and we use &#8220;final style&#8221; to mean &#8220;transform as we trace.&#8221; But it&#8217;s just imprecise yet sticky jargon.</p></blockquote><p>next up is linearize / vjp, which I&#8217;m excited about. finally getting a glimpse into <a href="http://conal.net/papers/essence-of-ad/">conal elliott&#8217;s mind</a></p><p>anyway, we&#8217;re getting there a bit at a time!</p><h3>concluding</h3><p>i think side projects are hard, but I&#8217;m reminded that like a lot of things in life, you just need to make a routine. discipline is hard, but routines are easy</p><p>if you write 200 lines of code each day for a month, you&#8217;ll have written 6000 lines of code in that month</p><p>that&#8217;s pretty substantial. like <a href="https://sshx.io/">sshx.io</a>-sized! the difference is that sshx.io took nearly 2 years, lol &#8212; but to be fair, you&#8217;re oftentimes debugging or removing code too. in any case, routines make everything easier, whether it&#8217;s organizing meetups like nysrg or running or learning to play an instrument, and let&#8217;s try to find one.</p><p>it can be pretty hard to work on something so difficult by yourself, but at the same time, it&#8217;s true that I find it really cool, and i enjoy this kind of creative work :)</p><div class="native-video-embed" data-component-name="VideoPlaceholder" data-attrs="{&quot;mediaUploadId&quot;:&quot;a9df7d9a-70ff-4155-b0d0-ae7ac7d58dad&quot;,&quot;duration&quot;:null}"></div><p></p>]]></content:encoded></item><item><title><![CDATA[jax-js progress]]></title><description><![CDATA[feb 8 &#8212; at val]]></description><link>https://ss.ekzhang.com/p/jax-js-progress</link><guid isPermaLink="false">https://ss.ekzhang.com/p/jax-js-progress</guid><dc:creator><![CDATA[Eric]]></dc:creator><pubDate>Sat, 08 Feb 2025 16:49:24 GMT</pubDate><enclosure url="https://substackcdn.com/image/fetch/$s_!nDbS!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd0fa168b-428a-40e9-a831-d87bad4d7a6e_4032x3024.heic" length="0" type="image/jpeg"/><content:encoded><![CDATA[<div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!nDbS!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd0fa168b-428a-40e9-a831-d87bad4d7a6e_4032x3024.heic" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!nDbS!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd0fa168b-428a-40e9-a831-d87bad4d7a6e_4032x3024.heic 424w, https://substackcdn.com/image/fetch/$s_!nDbS!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd0fa168b-428a-40e9-a831-d87bad4d7a6e_4032x3024.heic 848w, https://substackcdn.com/image/fetch/$s_!nDbS!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd0fa168b-428a-40e9-a831-d87bad4d7a6e_4032x3024.heic 1272w, https://substackcdn.com/image/fetch/$s_!nDbS!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd0fa168b-428a-40e9-a831-d87bad4d7a6e_4032x3024.heic 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!nDbS!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd0fa168b-428a-40e9-a831-d87bad4d7a6e_4032x3024.heic" width="1456" height="1092" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/d0fa168b-428a-40e9-a831-d87bad4d7a6e_4032x3024.heic&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:1092,&quot;width&quot;:1456,&quot;resizeWidth&quot;:null,&quot;bytes&quot;:1903103,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:&quot;image/heic&quot;,&quot;href&quot;:null,&quot;belowTheFold&quot;:false,&quot;topImage&quot;:true,&quot;internalRedirect&quot;:null,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!nDbS!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd0fa168b-428a-40e9-a831-d87bad4d7a6e_4032x3024.heic 424w, https://substackcdn.com/image/fetch/$s_!nDbS!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd0fa168b-428a-40e9-a831-d87bad4d7a6e_4032x3024.heic 848w, https://substackcdn.com/image/fetch/$s_!nDbS!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd0fa168b-428a-40e9-a831-d87bad4d7a6e_4032x3024.heic 1272w, https://substackcdn.com/image/fetch/$s_!nDbS!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd0fa168b-428a-40e9-a831-d87bad4d7a6e_4032x3024.heic 1456w" sizes="100vw" fetchpriority="high"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a></figure></div><p></p><p>currently doing this at the val town office, hosted by justin-b. it&#8217;s a chilly day</p><p>on the backlog: tabletop.js, and Jute. lots of stuff, but let me make progress on this for the morning! sometimes feels like my personal work is a stack: recency bias toward what I work on. but that&#8217;s fine, exploration is hard work. let&#8217;s have fun</p><p></p><h2>progress</h2><p>I just implemented jacfwd(). got distracted by type signatures a bit, but I think things are going roughly smoothly. some folks next to me are collaborating on an API.</p><p>one important thing to remember is that there&#8217;s some type signatures that aren&#8217;t quite exactly right, but what matters is that they <em>work</em></p><pre><code>// Convert a subtype of JsTree&lt;A&gt; into a JsTree&lt;B&gt;, with the same structure.
type MapJsTree&lt;T, A, B&gt; = T extends A
  ? B
  : T extends globalThis.Array&lt;infer U&gt;
    ? MapJsTree&lt;U, A, B&gt;[]
    : { [K in keyof T]: MapJsTree&lt;T[K], A, B&gt; };

// Assert that a function's arguments are a subtype of the given type.
type WithArgsSubtype&lt;F extends (args: any[]) =&gt; any, T&gt; =
  Parameters&lt;F&gt; extends T ? F : never;

/** Compute the forward-mode Jacobian-vector product for a function. */
export const jvp = core.jvp as &lt;F extends (...args: any[]) =&gt; JsTree&lt;Array&gt;&gt;(
  f: WithArgsSubtype&lt;F, JsTree&lt;ArrayLike&gt;&gt;,
  primals: MapJsTree&lt;Parameters&lt;F&gt;, Array, ArrayLike&gt;,
  tangents: MapJsTree&lt;Parameters&lt;F&gt;, Array, ArrayLike&gt;
) =&gt; [ReturnType&lt;F&gt;, ReturnType&lt;F&gt;];

/** Vectorize an operation on a batched axis for one or more inputs. */
export const vmap = core.vmap as &lt;F extends (...args: any[]) =&gt; JsTree&lt;Array&gt;&gt;(
  f: WithArgsSubtype&lt;F, JsTree&lt;ArrayLike&gt;&gt;,
  inAxes: MapJsTree&lt;Parameters&lt;F&gt;, Array, number&gt;
) =&gt; F;

/** Compute the Jacobian evaluated column-by-column by forward-mode AD. */
export const jacfwd = core.jacfwd as &lt;F extends (x: Array) =&gt; Array&gt;(
  f: F,
  x: Array
) =&gt; F;</code></pre><p>i&#8217;m feeling a bit annoyed by vs code freezing every so often. as well as my lack of tests. let me restart my computer for perf, and then I can add some utilities to assert arrays are close to each other, which will be needed for tests.</p><p>I&#8217;ll skip on the utility for printing arrays for now &#8212; we&#8217;ll get there later!</p><p></p><p>&#8212;my computer has restarted.</p><h2>today&#8217;s idea</h2><p>you cannot tell other people how to work. you can only show them, make it easy to change, give suggestions. but it&#8217;s ultimately their choice, and learning how to work together with others takes time. relationships are damaged when we aren&#8217;t patient and generous with each other.</p><p>sometimes, you just need to sit down with the other party, and hear them out. there should always be a place for honest listening</p><p></p><h2>&#8230;and back to tests</h2><p>i got tests to work, and then my tests promptly discovered several bugs (lol) and then in the process of debugging futily, I learned the core place where you need to intercept stuff at bind(), and then the necessity of having a toString() LOL.</p><p>but yeah it turns out a one-line typo in fullLower() caused all this mess. now you know &#8212; experiences!</p><p></p><h2>and now it&#8217;s noon</h2><p>so they&#8217;re doing demos. a nice group of people</p><div class="captioned-image-container"><figure><a class="image-link image2 is-viewable-img" target="_blank" href="https://substackcdn.com/image/fetch/$s_!P99e!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F13da6893-2291-4ad3-b7d4-1c7b9ba64639_4032x3024.heic" data-component-name="Image2ToDOM"><div class="image2-inset"><picture><source type="image/webp" srcset="https://substackcdn.com/image/fetch/$s_!P99e!,w_424,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F13da6893-2291-4ad3-b7d4-1c7b9ba64639_4032x3024.heic 424w, https://substackcdn.com/image/fetch/$s_!P99e!,w_848,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F13da6893-2291-4ad3-b7d4-1c7b9ba64639_4032x3024.heic 848w, https://substackcdn.com/image/fetch/$s_!P99e!,w_1272,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F13da6893-2291-4ad3-b7d4-1c7b9ba64639_4032x3024.heic 1272w, https://substackcdn.com/image/fetch/$s_!P99e!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F13da6893-2291-4ad3-b7d4-1c7b9ba64639_4032x3024.heic 1456w" sizes="100vw"><img src="https://substackcdn.com/image/fetch/$s_!P99e!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F13da6893-2291-4ad3-b7d4-1c7b9ba64639_4032x3024.heic" width="1456" height="1092" data-attrs="{&quot;src&quot;:&quot;https://substack-post-media.s3.amazonaws.com/public/images/13da6893-2291-4ad3-b7d4-1c7b9ba64639_4032x3024.heic&quot;,&quot;srcNoWatermark&quot;:null,&quot;fullscreen&quot;:null,&quot;imageSize&quot;:null,&quot;height&quot;:1092,&quot;width&quot;:1456,&quot;resizeWidth&quot;:null,&quot;bytes&quot;:2051084,&quot;alt&quot;:null,&quot;title&quot;:null,&quot;type&quot;:&quot;image/heic&quot;,&quot;href&quot;:null,&quot;belowTheFold&quot;:true,&quot;topImage&quot;:false,&quot;internalRedirect&quot;:null,&quot;isProcessing&quot;:false,&quot;align&quot;:null,&quot;offset&quot;:false}" class="sizing-normal" alt="" srcset="https://substackcdn.com/image/fetch/$s_!P99e!,w_424,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F13da6893-2291-4ad3-b7d4-1c7b9ba64639_4032x3024.heic 424w, https://substackcdn.com/image/fetch/$s_!P99e!,w_848,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F13da6893-2291-4ad3-b7d4-1c7b9ba64639_4032x3024.heic 848w, https://substackcdn.com/image/fetch/$s_!P99e!,w_1272,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F13da6893-2291-4ad3-b7d4-1c7b9ba64639_4032x3024.heic 1272w, https://substackcdn.com/image/fetch/$s_!P99e!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F13da6893-2291-4ad3-b7d4-1c7b9ba64639_4032x3024.heic 1456w" sizes="100vw" loading="lazy"></picture><div class="image-link-expand"><div class="pencraft pc-display-flex pc-gap-8 pc-reset"><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container restack-image"><svg role="img" width="20" height="20" viewBox="0 0 20 20" fill="none" stroke-width="1.5" stroke="var(--color-fg-primary)" stroke-linecap="round" stroke-linejoin="round" xmlns="http://www.w3.org/2000/svg"><g><title></title><path d="M2.53001 7.81595C3.49179 4.73911 6.43281 2.5 9.91173 2.5C13.1684 2.5 15.9537 4.46214 17.0852 7.23684L17.6179 8.67647M17.6179 8.67647L18.5002 4.26471M17.6179 8.67647L13.6473 6.91176M17.4995 12.1841C16.5378 15.2609 13.5967 17.5 10.1178 17.5C6.86118 17.5 4.07589 15.5379 2.94432 12.7632L2.41165 11.3235M2.41165 11.3235L1.5293 15.7353M2.41165 11.3235L6.38224 13.0882"></path></g></svg></button><button tabindex="0" type="button" class="pencraft pc-reset pencraft icon-container view-image"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-maximize2 lucide-maximize-2"><polyline points="15 3 21 3 21 9"></polyline><polyline points="9 21 3 21 3 15"></polyline><line x1="21" x2="14" y1="3" y2="10"></line><line x1="3" x2="10" y1="21" y2="14"></line></svg></button></div></div></div></a></figure></div><p>when you&#8217;re working on a team that works fast, you often find yourself in a situation where you don&#8217;t understand everything. a primitive capability. context, understanding, some representation. design docs take energy to write</p><p>express your ideas is a lot of the work, perhaps it&#8217;s the productive part</p><p><a href="https://protomaps.com/">protomaps</a>! a nice file format for serving full-world maps via byte ranges. seems like a cool system. <a href="https://maplibre.org/">maplibre</a> is the client format. replace mapbox gl :)</p><p></p>]]></content:encoded></item><item><title><![CDATA[jax-js note]]></title><description><![CDATA[feb 7]]></description><link>https://ss.ekzhang.com/p/jax-js-note</link><guid isPermaLink="false">https://ss.ekzhang.com/p/jax-js-note</guid><dc:creator><![CDATA[Eric]]></dc:creator><pubDate>Fri, 07 Feb 2025 16:43:38 GMT</pubDate><content:encoded><![CDATA[<ul><li><p>Current goals</p><ul><li><p>"Tutorial mindset" &#8212; imagine making a workshop where I'm teaching people about array programming or functional programming in practice.</p></li><li><p>Try to just implement float32, grad() and arithmetic.</p></li><li><p><a href="https://jax.readthedocs.io/en/latest/autodidax.html">https://jax.readthedocs.io/en/latest/autodidax.html</a></p></li><li><p>Build out a quick code editor / REPL in the browser and import the library using Vite. Then run code to experiment a bit.</p><ul><li><p>After that, play it by ear. PyTrees -&gt; JSON, &#8230;</p></li></ul></li></ul></li><li><p>Description (future)</p><ul><li><p>NumPy and JAX for the browser, running on CPU or GPU.</p></li><li><p>Machine learning and numerical computing in JavaScript with the JAX/NumPy API. Define arrays, then run arbitrary differentiable code on CPU, WASM, WebGL, or WebGPU backends.</p></li><li><p>Examples: fluid simulation, neural networks, computer vision, robotics, statistics.</p></li></ul></li></ul><pre><code>import { grad, numpy as np } from "jax-js";

const y = grad(x =&gt; x.mul(2))(np.array([1, 2, 3]))
console.log(y.js())</code></pre><ul><li><p>Memory management</p><ul><li><p>Refcount / ownership contract: all arguments to functions must be used or disposed.</p></li><li><p>ref() and dispose()</p></li></ul></li></ul><p></p><p><strong>who is your target?</strong> scientists, artists, anyone who uses numerical computing. maybe eventually porting ML models (but that&#8217;s better suited for ONNX probably)</p><p><strong>why?</strong></p><ul><li><p>story: you can&#8217;t figure out how to use pip. oh look, here&#8217;s a webgpu version that&#8217;s fast and &#8220;just works&#8221; &#8212; shaders can do a lot</p></li><li><p>story: XLA is hard and very complicated, what if you set up a minimal compiler toolchain to do the 80%+ of fusing operations directly in the browser?</p></li><li><p>story: you are an artist and want to write some numerical simulations, but you don&#8217;t want to invent a whole library like https://github.com/amandaghassaei/gpu-io yourself</p></li></ul>]]></content:encoded></item></channel></rss>