TPU Book Problems
Roofline and Scaling Practice
Source: How to Scale Your Model (Austin et al., 2025)
Ch 1: Intro to Rooflines
Read this chapter on the Scaling Book →
Question 1
int8 matmul
Say we want to do the matmul in int8 precision (1 byte per parameter) instead of bfloat16 (2 bytes per parameter) since TPUs/GPUs can do matmuls faster in lower precision.
Notation: indicates the multiplication contracts over the dimension (an abuse of einsum notation).
-
How many bytes need to be loaded from memory? How many need to be written back to memory?
-
How many total OPs are performed?
-
What is the arithmetic intensity?
-
What is a roofline estimate for and ? What are reasonable upper and lower bounds for the runtime of the whole operation?
Assume our HBM bandwidth is 8.2e11 bytes/s and our int8 peak OPs/s is 3.94e14 (about 2x bfloat16).
Question 2
int8 + bf16 matmul
In practice we often do different weight vs. activation quantization, so we might store our weights in very low precision but keep activations (and compute) in a higher precision. Say we want to quantize our weights in int8 but keep activations (and compute) in bfloat16. At what batch size do we become compute bound? Assume 1.97e14 bfloat16 FLOPs/s.
Hint: this means specifically bf16[B, D] * int8[D, F] -> bf16[B, F] where is the “batch size”.
Question 3
Taking the setup from Question 2, make a roofline plot of peak FLOPs/s vs. for and . Use the exact number of bytes loaded, not an approximation.
Question 4
What if we wanted to perform where we imagine having a different matrix for each batch element. What is the arithmetic intensity of this operation?
Question 5
Memory Rooflines for GPUs
Using the spec sheet provided by NVIDIA for the H100 SXM, calculate the batch size at which a bfloat16 matrix multiplication will become compute-bound. Note that the Tensor Core FLOPs numbers are twice the true value since they’re only achievable with structured sparsity.
Ch 2: All About TPUs
Read this chapter on the Scaling Book →
Question 1
bounding LLM latency
Say you want to sample from a 200B parameter model in bf16 that’s split across 32 TPU v4p. How long would it take to load all the parameters from HBM into the systolic array? Hint: use the numbers above.
Question 2
TPU details
Consider a full TPU v5e pod. How many total CPU hosts are there? How many TPU TensorCores? What is the total FLOPs/s for the whole pod? What is the total HBM? Do the same exercise for TPU v5p pod.
Question 3
PCIe operational intensity
Imagine we’re forced to store a big weight matrix of type , and a batch of activations of type in host DRAM and want to do a matrix multiplication on them. This is running on a single host, and we’re using a single TPU v6e chip attached to it. You can assume , and (we’ll see in future chapters why these are reasonable assumptions). What is the smallest batch size we need to remain FLOPs bound over PCIe? Assume PCIe bandwidth of 1.6e10 bytes / second.
Question 4
general matmul latency
Let’s say we want to multiply a weight matrix int8[16384, 4096] by an activation matrix of size int8[B, 4096] where B is some unknown batch size. Let’s say we’re on 1 TPU v5e to start.
-
How long will this multiplication take as a function of B? Hint: it may help to calculate how long it will take to load the arrays from HBM and how long the multiplication will actually take. Which is bottlenecking you?
-
What if we wanted to run this operation out of VMEM? How long would it take as a function of B?
Question 5
ICI bandwidth
Let’s say we have a TPU v5e 4x4 slice. Let’s say we want to send an array of type bf16[8, 128, 8192] from TPU{0,0} to TPU{3, 3}. Let’s say the per-hop latency for TPU v5e is .
-
How soon will the first byte arrive at its destination?
-
How long will the total transfer take?
Question 6
pulling it all together, hard
Imagine you have a big matrix A: int8[128 * 1024, 128 * 1024] sharded evenly across a TPU v5e 4x4 slice but offloaded to host DRAM on each chip. Let’s say you want to copy the entire array to TPU{0, 0} and multiply it by a vector bf16[8, 128 * 1024]. How long will this take? Hint: use the numbers above.
Pop Quiz
Calculating VPU throughput
Using the above information, calculate how many vector FLOPs/s a TPU v5p can perform. A TPU v5p has a clock speed of about 1.75GHz.
Ch 3: Sharded Matmuls
Read this chapter on the Scaling Book →
Question 1
replicated sharding
An array is sharded (i.e., only sharded across ), with a mesh Mesh({'X': 4, 'Y': 8, 'Z': 2}). What is the ratio of the total number of bytes taken up by across all chips to the size of one copy of the array?
Question 2
AllGather latency
How long should take on a TPU v4p 4x4x4 slice with mesh Mesh({'X': 4, 'Y': 4, 'Z': 4}) if and in bfloat16? How about (\text{AllGather}_{XY}([B_X, D_Y]))? How about (\text{AllReduce}_Z([B_X, D_Y] {U_Z }))?
Question 3
latency-bound AllGather
Let’s say we’re performing an but is very small (say 128). How long should this take on a TPU v4p 4x4x4 slice with mesh Mesh({'X': 4, 'Y': 4, 'Z': 4}) in bfloat16? Hint: you’re probably latency bound.
Question 4
matmul strategies
To perform , in this section we tell you to perform and multiply the fully replicated matrices (Case 2, Strategy 1). Instead, you could multiply the local shards like (Case 3, Strategy 2), and then . How many FLOPs and comms does each of these perform? Which is better and why?
Question 5
minimum latency
Let’s say I want to do a matmul on a TPU v4p 4x4x4 with the lowest possible latency. Assume the inputs can be sharded arbitrarily but the result should be fully replicated. How should my inputs be sharded? What is the total FLOPs and comms time?
Question 6
Let’s say we want to perform on TPU v5e 4x4. What communication do we perform? How much time is spent on communication vs. computation?
-
What about ? This is the most standard setting for training where we combine data, tensor, and ZeRO sharding.
-
What about ? This is standard for inference, where we do pure tensor parallelism (+data).
Question 7
A typical Transformer block has two matrices and where . Say we have a batch size B. Then the full block is . Let’s pick , , and and assume everything is in bfloat16. Assume we’re running on a TPU v5e 2x2 slice but let’s pretend each TPU only has 300MB of free memory. How should In, , , and Out be sharded to stay below the memory limit while minimizing overall time? How much time is spent on comms and FLOPs? Hint: the final output doesn’t need to be fully replicated, but it should be sharded the same as the input so the “layer” can be repeated.
Question 8
challenge
Using the short code snippet above as a template, allocate a sharded array and benchmark each of the 4 main communication primitives (AllGather, AllReduce, ReduceScatter, and AllToAll) using pmap or shard_map. You will want to use jax.lax.all_gather, jax.lax.psum, jax.lax.psum_scatter, and jax.lax.all_to_all. Do you understand the semantics of these functions? How long do they take?
Question 9
another strategy for sharded matmuls?
Above we claimed that when only one input to a matmul is sharded along its contracting dimension, we should AllGather the sharded matrix and perform the resulting contraction locally. Another strategy you might think of is to perform the sharded matmul and then AllReduce the result (as if both inputs were sharded along the contracting dimension), i.e. by way of
Answer the following:
-
Explicitly write out this algorithm for matrices and , using indices to show exactly what computation is done on what device. Assume is sharded as across ND devices, and you want your output to be replicated across all devices.
-
Now suppose you are ok with the final result not being replicated on each device, but instead sharded (across either the N or K dimension). How would the algorithm above change?
-
Looking purely at the communication cost of the strategy above (in part 2, not 1), how does this communication cost compare to the communication cost of the algorithm in which we first AllGather A and then do the matmul?
Question 10
Fun with AllToAll
Reference (collective costs, throughput-bound regime):
| Operation | Description | Syntax | Runtime |
|---|---|---|---|
| AllGather | Gathers all shards of a sharded array along an axis, removing a subscript. | bytes / (bidirectional ICI bandwidth num_axes) | |
| ReduceScatter | Sums a partially-summed array along an axis and shards it along another axis (adding a subscript). | Same as AllGather | |
| AllReduce | Sums a partially-summed array along an axis (removes a ). Combines an AllGather and a ReduceScatter. | AllGather | |
| AllToAll | Gathers (replicates) an axis and shards a different dimension along the same axis. | AllGather / 4 (bidirectional ring) |
In the table above, it was noted that the time to perform an AllToAll is a factor of 4 lower than the time to perform an AllGather or ReduceScatter (in the regime where we are throughput-bound). In this problem we will see where that factor of 4 comes from, and also see how this factor would change if we only had single-direction ICI links, rather than bidirectional ICI links.
-
Let’s start with the single-direction case first. Imagine we have D devices in a ring topology and want to do either an AllGather or a ReduceScatter on an N x N matrix (say divides for simplicity). Describe the comms involved in these two collectives, and calculate the total number of scalars (floats or ints) which are transferred across a single ICI link during the entirety of this algorithm.
-
Now let’s think about an AllToAll, still in the single-directional ICI case. How is the algorithm different in this case than the all-gather case? Calculate the number of scalars that are transferred across a single ICI link in this algorithm.
-
You should have found that the ratio between your answers to part (a) and part (b) is a nice number. Explain where this factor comes from in simple terms.
-
Now let’s add bidirectional communication. How does this affect the total time needed in the all-gather case?
-
How does adding bidirectional communication affect the total time needed in the AllToAll case?
-
Now simply explain the ratio between AllGather time and AllToAll time in a bidirectional ring.
Ch 4: Transformers
Read this chapter on the Scaling Book →
Question 1
How many parameters does a model with , , , and have? What fraction of these are attention parameters? How large are our KV caches per token? You can assume and multi-head attention with int8 KVs.
Question 2
How many total FLOPs are required to perform A[BX, DY] *D W[DY, F] on {'X': 4, 'Y': 8, 'Z': 4}. How many FLOPs are performed by each TPU?
Question 3
How many FLOPs are involved in performing ?
Question 4
What is the arithmetic intensity of self-attention (ignoring the Q/K/V/O projections)? Give the answer as a function of the Q and KV lengths T and S. At what context length is attention FLOPs-bound? Given the HBM bandwidth of our TPUs, plot the effective relative cost of attention to the FFW block as the context length grows.
Question 5
At what sequence length are self-attention FLOPs equal to the QKVO projection FLOPs?
Question 6
Say we only save the output of each of the 7 main matmuls in a Transformer layer during our forward pass (Q, K, V, O + the three FFW matrices). How many extra FLOPs do we need to “rematerialize” during the backwards pass?
Question 7
DeepSeek v3 says it was trained for 2.79M H800 hours on 14.8T tokens (source). Given that it has 37B activated parameters, roughly what hardware utilization did they achieve? Hint: note that they used FP8 FLOPs without structured sparsity.
Question 8
Mixture of Experts (MoE) models have copies of a standard dense MLP block, and each token activates of these experts. What batch size in tokens is required to be compute-bound for an MoE with weights in int8 on TPU v5e? For DeepSeek, which has 256 (routed) experts and , what is this number?
Ch 5: Training
Read this chapter on the Scaling Book →
Question 1
How many parameters does LLaMA-2 13B have (I know that’s silly but do the math)? Note that, as in Transformer Math, LLaMA-3 has 3 big FFW matrices, two up-projection and one down-projection. We ignored the two “gating” einsum matrices in this section, but they behave the same as Win in this section.
Question 2
Let’s assume we’re training with BS=16M tokens and using Adam. Ignoring parallelism for a moment, how much total memory is used by the model’s parameters, optimizer state, and activations? Assume we store the parameters in bf16 and the optimizer state in fp32 and checkpoint activations three times per layer (after the three big matmuls).
Question 3
Assume we want to train with 32k sequence length and a total batch size of 3M tokens on a TPUv5p 16x16x16 slice. Assume we want to use bfloat16 weights and a float32 optimizer, as above.
-
Can we use pure data parallelism? Why or why not?
-
Can we use pure FSDP? Why or why not? With pure FSDP, how much memory will be used per device (assume we do gradient checkpointing only after the 3 big FFW matrices).
-
Can we use mixed FSDP + tensor parallelism? Why or why not? If so, what should and be? How much memory will be stored per device? Using only roofline FLOPs estimates and ignoring attention, how long will each training step take at 40% MFU?
Ch 6: Training LLaMA
Read this chapter on the Scaling Book →
Question 1
From this table, can we calculate the LLaMA 3-70B parameter count? 🤫 Apply the content of Section 4 and see if you can get to 70B.
LLaMA 3-70B configuration:
| Hyperparameter | Symbol | Value |
|---|---|---|
| Layers | 80 | |
| Model dimension | (d_model) |
8,192 |
| FFW hidden dimension | (d_ff) |
28,672 |
| Attention heads | (n_heads) |
64 |
| KV heads | (n_kv_heads) |
8 |
| Head dimension | (d_qkv) |
128 |
| Vocabulary size | (n_embeddings) |
128,256 |
Question 2
How many FLOPs does LLaMA-3 perform per token per training step? This helps us determine how expensive the whole training process will be.
Question 3
LLaMA 3 was trained for about 15 trillion tokens. How many FLOPs is that total?
Question 4
Let’s say we wanted to train on a full TPU v5p pod with 16x20x28 = 8960 chips. How long would this take to train at 40% MFU in bfloat16, assuming we are compute-bound?
Question 5
LLaMA 3-70B was pretrained with a batch size of about 4M tokens. How many TPUs do we need at minimum to train with this batch size? You can assume bfloat16 parameters and float32 optimizer state, and that you checkpoint gradients 4 times per layer.
Question 6
Under the same assumptions as the question above, if we use 8960 TPU v5p chips, how much memory will we use per-chip?
Question 7
Under the assumptions above, can we train our model with FSDP alone? To start, let’s say we can’t do any sequence/context parallelism. This should be the first idea you have, since it’s simple and will introduce no extra communication if it works.
Question 8
Let’s relax the requirement of not doing any sequence sharding. If we allow ourselves to do FSDP over both the batch and sequence axes, can we train LLaMA 3-70B with only FSDP on 8960 chips?
Question 9
Now let’s look at mixed tensor parallelism and FSDP. Does there exist some combination that lets us remain compute-bound? What amount of FSDP and tensor parallelism should we do if so?
Question 1
Scaling LLaMA 70B to more chips
say we want to train LLaMA 3-70B on 4 pods with the same batch size. What parallelism scheme would we use? Would we be compute or communication bound? Roughly how long would it take to train? Make sure to use the correct roofline bound.
Question 2
LLaMA 405B
(a) Using the LLaMA 3-405B config, write a table with all the key hyperparameters as above. How many total parameters does this model have? How many FLOPs per training step? How many FLOPs do we perform if we train for 15T tokens?
(b) Assume we want to train on 8 TPU v5p pods. What parallelism scheme would we use? How long would training take? Would we be compute or comms bound?
Ch 7: Inference
Read this chapter on the Scaling Book →
Question 1
How many parameters does the above model have? How large are its KV caches per token in int8? You can assume we share the input and output projection matrices.
Question 2
Say we want to serve this model on a TPUv5e 4x4 slice and can fully shard our KV cache over this topology. What’s the largest batch size we can fit, assuming we use int8 for everything and want to support 128k sequences? What if we dropped the number of KV heads to 1?
Question 3
How long does it take to load all the parameters into the MXU from HBM assuming they’re fully sharded on a TPU v5e 4x4 slice? Assume int8 parameters. This is a good lower bound on the per-step latency.
Question 4
Let’s say we want to serve this model on a TPUv5e 4x4 slice using int8 FLOPs and parameters/activations. How would we shard it for both prefill and decode? Hint: maybe answer these questions first:
-
What does ICI look like on a 4x4?
-
What’s the roofline bound on tensor parallelism?
-
How can we shard the KV caches?
For this sharding, what is the rough per-step latency for generation?
Question 5
Let’s pretend the above model is actually an MoE. An MoE model is effectively a dense model with E copies of the FFW block. Each token passes through k of the FFW blocks and these k are averaged to produce the output. Let’s use E=16 and k=2 with the above settings.
-
How many total and activated parameters does it have? Activated means used by any given token.
-
What batch size is needed to become FLOPs bound on TPU v5e?
-
How large are its KV caches per token?
-
How many FLOPs are involved in a forward pass with T tokens?
Question 6
With MoEs, we can do “expert sharding”, where we split our experts across one axis of our mesh. In our standard notation, our first FFW weight has shape [E, D, F] and we shard it as [EZ, DX, FY] where X is only used during training as our FSDP dimension. Let’s say we want to do inference on a TPU v5e:
-
What’s the HBM weight loading time for the above model on a TPU v5e 8x16 slice with Y=8, Z=16? How much free HBM is available per TPU?
-
What is the smallest slice we could fit our model on?
Question 7
2D model sharding
Here we’ll work through the math of what the ESTI paper calls 2D weight-stationary sharding. We describe this briefly in Appendix B, but try doing this problem first to see if you can work out the math. The basic idea of 2D weight stationary sharding is to shard our weights along both the and axes so that each chunk is roughly square. This reduces the comms load and allows us to scale slightly farther.
Here’s the algorithm for 2D weight stationary:
-
In[B, DX] = AllGatherYZ(In[B, DXYZ])
-
Tmp[B, FYZ] {UX} = In[B, DX] *D Win[DX, FYZ]
-
Tmp[B, FYZ] = AllReduceX(Tmp[B, FYZ] {UX})
-
Out[B, DX] {UYZ} = Tmp[B, FYZ] *F Wout[FYZ, DX]
-
Out[B, DXYZ] = ReduceScatterYZ(Out[B, DX] {UYZ})
Your goal is to work out and for this algorithm and find when it will outperform traditional 3D model sharding?
Ch 8: Serving LLaMA
Read this chapter on the Scaling Book →
Question 1
How large are LLaMA 3-70B’s KV caches per token? You can assume we store them in int8. This determines how large our batch size can be on a given topology.
Question 2
Let’s say we want to serve L3 70B at batch size 32 and 8192 sequence length with everything (params and KVs) in int8. How much total memory will this use? What’s the smallest slice we could serve this on?
Question 3
At this batch size and quantization on a TPU v5e 4x2, roughly what latency would we expect per decode step? What throughput (tokens / sec / chip). What about a 4x4? Assume we perform our FLOPs in bfloat16 and everything is fully sharded.
Question 4
On TPU v5e, using bfloat16 weights and activations, how large do our batch sizes need to be for us to be compute-bound in our matmuls? What if we do int8 weights but perform our FLOPs in bfloat16? What about int8 weights with int8 FLOPs?
Question 5
What is the smallest TPU v5e topology we could serve LLaMA 3-70B on using bfloat16, int8, and int4 (both KVs and parameters) with 8k context? You can think of KV caches as negligibly small for this one.
Question 6
Assume we use the largest batch size that fits on these topologies, what latency could we expect for each generate step?
Question 7
For each of these, what throughput per chip does this give us (in terms of queries / chip)? You can assume our median decode length is 512 tokens.
Question 8
How would our peak throughput change if we doubled our topology for each of the above examples?
Question 9
Now let’s dig into the question of sharding. Let’s say we wanted to serve in bfloat16 on a TPU v5e 4x8. What sharding would we use for our model on a TPU v5e 4x8 during generation? Can we avoid being communication bound?
Question 10
Assume we achieve a 40% FLOPs utilization during prefill. How long will a prefill of length 8192 take on 16 TPU v5e chips?
Question 11
Assume we have a median prefill length of 8192 tokens and a median decode length of 4096 tokens. Say we have a generate batch size of 32. On average how many sequences finish decoding per step? On average how many tokens are evicted from our KV cache each step?
Question 12
Assume we do disaggregated serving with a median prefill length of 8192 and a median decode length of 512. Assume the prefill and generate latencies calculated above in bfloat16. What ratio of prefill:generate servers will you need to keep both fully saturated.
Question 1
How many FLOPs does each forward pass for LLaMA 3-405B use per-token? Assuming we’re FLOPs bound, what is a lower bound on a single forward pass on N chips on TPU v5e? What if we’re comms bound? Ignore the fact that the model does not fit on a single chip.
Question 2
Assume we want to serve LLaMA 3-8B with BS240 using int8 weights and int8 KV caches. How many bytes are used by (a) model parameters (b) KV caches and (c) peak working activations (roughly)? What’s the smallest topology we can run this on?
Question 3
How would you serve LLaMA 3-405B on TPU v5e? Assume int8 weights and bfloat16 FLOPs. Let’s say we have a firm limit of 15ms / token, what’s the highest throughput configuration we could achieve? What is the theoretical minimum step time?
Ch 9: Profiling
Read this chapter on the Scaling Book →
Question 1
take a look at this Colab/profile and figure out what looks suspicious and what’s going on here. Can you tell me exactly what computations are happening and what each operation is doing? What are the true shapes of each matrix involved and how are they sharded? Try looking at the profile first without reading the code.
Question 2
The Transformer Colab from earlier implements a simple mock Transformer. Follow the instructions in the Colab and get a benchmark of the naive Transformer with GSPMD partitioning. How long does each part take? How long should it take? What sharding is being used? Try fixing the sharding! Hint: use jax.lax.with_sharding_constraint to constrain the behavior. With this fix, what’s the best MXU you can get?
For reference, the initial version gets roughly 184ms / layer and the optimized profile gets 67ms / layer. Once you’ve done this, try staring at the profile and see if you can answer these questions purely from the profile:
-
What sharding strategy is this?
-
What is the batch size, (d_\text{model}), (d_\text{ff})?
-
What fraction of time is spent on attention vs. the MLP block?
-
What fraction of time should be spent on each op at the roofline?
Note: since this problem was written, the XLA compiler has gotten better. The initial version is now at roughly 90ms / layer and the optimized profile is only about 10ms / layer better (80ms / layer). Still, it’s worth playing with and seeing if you can do better.
Ch 10: All About JAX
Read this chapter on the Scaling Book →
Question 1
Let A be an array of activations of shape float32[SX, DY] with X * Y = N. Do the following:
Write a function in JAX that computes the average within each (X, Y) shard, i.e. it returns an array of size [X, Y] where arr[i, j] is the average over shard (i, j). Do this with both jax.jit and shard_map. Profile each and see how long they took. Was there any communication added? Hint: there shouldn’t be, but sometimes XLA adds it anyway.
Write a function in JAX that returns roll(x, shift, axis=0) - x for some shift within each shard along X. I’m not enough of a masochist to make you do this in jax.jit, so just do this with shard_map.
Question 2
Here we’ll make a basic “mixture of experts” model together. Let W: float32[EX, D, F] be a set of E “expert” matrices. Let A: float32[SX, D] (our activations) and let B: int32[SX] be a set of “routing assignments” where B[i] is an integer in the range [0, E) telling us which matrix we want to process that activation. We want to write a function in JAX that returns Out[i] = A[i] @ W[B[i]].
Let’s start by ignoring sharding altogether. Make all of these tensors small enough so they fit in one device. Write a local implementation of this function. Make sure you don’t materialize an array of shape [S, D, F]! Hint: try sorting the tokens into a new buffer of shape [E, S, D] with some attention to masking (why do we need the second dimension to have size S?).
If you just jax.jit the above method, something will happen. Profile this and see what communication it decided to do. How long does it take?
One problem you’ll notice with the above is that it likely gathers the full set of activations A locally, i.e. AllGatherX([SX, D]). Not only is this expensive communication-wise, it’s also incredibly expensive memory-wise if we can’t fit the full set of activations locally. Implement the above using shard_map and explicit communication.
For a first pass, it might be easiest to use a jax.lax.all_gather and reorder as in step 1.
For a second pass, try to avoid materializing any array of size [E, S, D], i.e. try to perform the computation in a ragged fashion using a jax.lax.all_to_all inside a jax.lax.while_loop. This way, you can avoid materializing the full activations and wasting compute on padding. How much faster is this than your original implementation?
Most MoEs route to multiple (k) experts and then average the result. Refactor the above to implement this. Let B: int32[SX, k] in this case for the k experts to route to.
Question 3
The collective matmul example above is actually super relevant for real LLMs. Let’s tweak the example to do the full Transformer stack.
As an exercise, let’s start by implementing an AllReduce collective matmul, i.e. A[BX, DY] *D W[DY, F] -> Out[BX, F]. Note that the output isn’t replicated. The naive algorithm is discussed above, basically just a local matmul followed by an AllReduce. Try to make a comms overlapped “collective” version of this operation. Hint: tile over the output dimension and feel free to use jax.lax.psum (aka AllReduce). Note: due to the way XLA handles this, it may not actually be faster than the baseline.
The complement to the AllReduce collective matmul above is a ReduceScatter collective matmul, as in Tmp[BX, FY] *F W2[FY, D] -> Out[BX, DY]. This occurs in the down-projection matrix in a Transformer. Implement a collective, overlapped version of this in JAX. Be careful about passing only the minimal amount of data you need. Hint: try permuting the result as you accumulate it.
Put these two together into an end-to-end Transformer block that performs In[BX, DY] *D Win[D, FY] *F Wout[FY, D] -> Out[BX, DY] with overlapped communication.As before, we can't do first because of a non-linearity we've omitted here. How much faster is this than a jax.jit implementation?
Question 4
All of the collective matmuls implemented above are unidirectional: they only permute in one direction. Rewrite the collective AllReduce matmul and the collective ReduceScatter matmuls to use bidirectional communication. How much faster are these?
Ch 12: GPUs
Read this chapter on the Scaling Book →
Question 1
CUDA cores
How many fp32 CUDA cores (ALUs) does an H100 have? B200? How does this compare to the number of independent ALUs in a TPU v5p?
Question 2
Vector FLOPs calculation
A single H100 has 132 SMs and runs at a clock speed of 1.59GHz (up to 1.98GHz boost). Assume it can do one vector op per cycle per ALU. How many vector fp32 FLOPs can be done per second? With boost? How does this compare to matmul FLOPs?
Question 3
GPU matmul intensity
What is the peak fp16 matmul intensity on an H100? A B200? What about fp8? By intensity we mean the ratio of matmul FLOPs/s to memory bandwidth.
Question 4
Matmul runtime
Using the answer to Question 3, how long would you expect an fp16[64, 4096] * fp16[4096, 8192] matmul to take on a single B200? How about fp16[512, 4096] * fp16[4096, 8192]?
Question 5
L1 cache capacity
What is the total L1/SMEM capacity for an H100? What about register memory? How does this compare to TPU VMEM capacity?
Question 6
Calculating B200 clock frequency
NVIDIA reports here that a B200 can perform 80TFLOPs/s of vector fp32 compute. Given that each CUDA core can perform 2 FLOPs/cycle in a FMA (fused multiply add) op, estimate the peak clock cycle.
Question 7
Estimating H100 add runtime
Using the figures above, calculate how long it ought to take to add two fp32[N] vectors together on a single H100. Calculate both and . What is the arithmetic intensity of this operation? If you can get access, try running this operation in PyTorch or JAX as well for N = 1024 and N=1024 * 1024 * 1024. How does this compare?
Question 1
Total bandwidth for H100 node
How much total bandwidth do we have per node in an 8xH100 node with 4 switches? Hint: consider both the NVLink and NVSwitch bandwidth.
Question 2
Bisection bandwidth
Bisection bandwidth is defined as the smallest bandwidth available between any even partition of a network. In other words, if we split a network into two equal halves, how much bandwidth crosses between the two halves? Can you calculate the bisection bandwidth of an 8x H100 node? Hint: bisection bandwidth typically includes flow in both directions.
Question 3
AllGather cost
Given an array of B bytes, how long would a (throughput-bound) AllGather take on an 8xH100 node? Do the math for bf16[DX, F] where D=4096, F=65,536. It’s worth reading the TPU collectives section before answering this. Think this through here but we’ll talk much more about collectives next.
Question 1
Fat tree topology
Using the DGX H100 diagram (see the GPUs chapter), calculate the bisection bandwidth of the entire 1024 GPU pod at the node level. Show that the bandwidth of each link is chosen to ensure full bisection bandwidth. Hint: make sure to calculate both the link bandwidth and switch bandwidth.
Question 2
Scaling to a larger DGX pod
Say we wanted to train on 2048 GPUs instead of 1024. What would be the simplest/best way to modify the above DGX topology to handle this? What about 4096? Hint: there’s no single correct answer, but try to keep costs down. Keep link capacity in mind. This documentation may be helpful.
Pop Quiz 3
Sharding along 2 axes
Say we want to perform where is the inner axis over a single SU (256 chips). How long will this take as a function of , , and ?
Question 1
SU AllGather
Consider only a single SU with M nodes and N GPUs per node. Precisely how many bytes are ingressed and egressed by the node level switch during an AllGather? What about the top-level switch?
Question 2
Single-node SHARP AR
Consider a single node with N GPUs per node. Precisely how many bytes are ingressed and egressed by the switch during an AllReduce using SHARP (in-network reductions)?
Question 3
Cross-node SHARP AR
Consider an array bf16[DX, FY] sharded over a single node of N GPUs. How long does AllReduce(bf16[D, FY] { UX }) take? You can assume we do in-network reductions. Explain how this differs if we have more than a single node?
Question 4
Spine level AR cost
Consider the same setting as above, but with (so the AR happens at the spine level). How long does the AllReduce take? Again, feel free to assume in-network reductions.
Question 5
2-way AllGather cost
Calculate the precise cost of an AllGather of bytes over exactly 2 nodes. Make sure to calculate the precise cost and not the approximation, and consider both the intra-node and cross-node cost.
Question 1
B200 rooflines
A B200 DGX SuperPod (not GB200 NVL72) has 2x the bandwidth within a node (900GB/s egress) but the same amount of bandwidth in the scale-out network (400GB/s) (source). The total FLOPs are reported above. How does this change the model and data parallel rooflines?
Question 2
How to shard LLaMA-3 70B
Consider LLaMA-3 70B, training in bfloat16 with fp32 optimizer state with Adam.
-
At a minimum, how many H100s would we need simply to store the weights and optimizer?
-
Say we want to train on 4096 H100 GPUs for 15T tokens. Say we achieved 45% MFU (Model FLOPs Utilization). How long would it take to train?
-
LLaMA-3 70B has
F = 28,672and was trained with a batch size of about 4M tokens. What is the most model parallelism we could do without being comms-bound? With this plus pure DP, could we train LLaMA-3 while staying compute-bound on 4k chips? What about ZeRO-3? What about with 8-way pipelining? Note: consider both the communication cost and GPU memory usage.
Question 3
Megatron-LM hyperparams
Consider the MFU figure from the Megatron-LM repository (reproduced in the GPUs chapter).
Note that their sequence length is 4096 everywhere. For the 16B, 70B, and 314B models, what is the per-GPU token batch size? Assuming data parallelism is the outermost axis and assuming bfloat16 reductions, determine whether each of these is theoretically compute-bound or communication-bound, and whether there is a more optimal configuration available?
Reference: TPU spec sheet
The numbers most problems above ("use the numbers above") are based on. Source: All About TPUs, values per chip.
Compute and memory
| Model | Pod size | Host size | HBM/chip | HBM BW/chip (B/s) | bf16 FLOPs/s | int8 OPs/s |
|---|---|---|---|---|---|---|
| TPU v3 | 32x32 | 4x2 | 32 GB | 9.0e11 | 1.4e14 | 1.4e14 |
| TPU v4p | 16x16x16 | 2x2x1 | 32 GB | 1.2e12 | 2.75e14 | 2.75e14 |
| TPU v5e | 16x16 | 4x2 | 16 GB | 8.2e11 | 1.97e14 | 3.94e14 |
| TPU v5p | 16x20x28 | 2x2x1 | 96 GB | 2.8e12 | 4.59e14 | 9.18e14 |
| TPU v6e (Trillium) | 16x16 | 4x2 | 32 GB | 1.6e12 | 9.20e14 | 1.84e15 |
| TPU7x | 4x4x576 | 2x2x1 | 192 GB | 7.4e12 | 2.30e15 | 4.61e15 |
Interconnect (ICI) bandwidth per link
| Model | One-way (B/s) | Bidirectional (B/s) |
|---|---|---|
| TPU v3 | 1.0e11 | 2.0e11 |
| TPU v4p | 4.5e10 | 9.0e10 |
| TPU v5e | 4.5e10 | 9.0e10 |
| TPU v5p | 9.0e10 | 1.8e11 |
| TPU v6e | 9.0e10 | 1.8e11 |
| TPU7x | 9.0e10 | 1.8e11 |
Other useful constants
- VMEM: 128 MiB on-chip scratchpad (v5e), with ~22x higher bandwidth to the MXU than HBM.
- MXU (systolic array): 128x128 on v3-v5 (one
bf16[8,128] @ bf16[128,128] -> f32[8,128]per 8 cycles); 256x256 on v6e (~4x throughput). - Clock speed: ~1.75 GHz (TPU v5p).
- PCIe / DCN: ~1.6e10 B/s per chip over PCIe (3.2e10 for v6e); DCN egress ~6.25e9 B/s per chip (3.125e9 for v5e, 12.5e9 for v6e/TPU7x).
- Per-hop ICI latency: ~1 microsecond (TPU v5e).
Reference: GPU spec sheet
The numbers the GPU problems rely on. Source: How to Think About GPUs, values per chip unless noted.
Chip specifications
| GPU | Generation | Clock | SMs/chip | SMEM/SM | L2/chip | HBM/chip |
|---|---|---|---|---|---|---|
| V100 | Volta | 1.25 / 1.38 GHz | 80 | 96 kB | 6 MB | 32 GB |
| A100 | Ampere | 1.10 / 1.41 GHz | 108 | 192 kB | 40 MB | 80 GB |
| H100 | Hopper | 1.59 / 1.98 GHz | 132 | 256 kB | 50 MB | 80 GB |
| H200 | Hopper | 1.59 / 1.98 GHz | 132 | 256 kB | 50 MB | 141 GB |
| B200 | Blackwell | — | 148 | 256 kB | 126 MB | 192 GB |
FLOPs and bandwidth
| GPU | HBM BW/chip (B/s) | bf16/fp16 FLOPs/s | fp8/int8 OPs/s | fp4 OPs/s |
|---|---|---|---|---|
| V100 | 9.0e11 | — | — | — |
| A100 | 2.0e12 | 3.1e14 | 6.2e14 | — |
| H100 | 3.4e12 | 9.9e14 | 2.0e15 | — |
| H200 | 4.8e12 | 9.9e14 | 2.0e15 | — |
| B200 | 8.0e12 | 2.3e15 | 4.5e15 | 9.0e15 |
Tensor Core FLOPs are quoted with structured sparsity; halve them for the dense peak.
GPU vs. TPU terminology
| GPU | TPU | What it is |
|---|---|---|
| Streaming Multiprocessor (SM) | Tensor Core | Core "cell" containing other units |
| Warp Scheduler | VPU | SIMD vector arithmetic unit |
| CUDA Core | VPU ALU | SIMD ALU |
| SMEM (L1 cache) | VMEM | Fast on-chip cache memory |
| Tensor Core | MXU | Matrix-multiplication unit |
| HBM (GMEM) | HBM | High-bandwidth, high-capacity memory |
H100 vs. TPU v5p (per chip)
| Component | H100 | TPU v5p |
|---|---|---|
| SMs / Tensor Cores | 132 | 2 |
| Warp Schedulers / VPUs | 528 | 8 |
| SMEM / VMEM (L1) | 32 MB | 128 MB |
| Registers | 32 MB | 256 kB |
| Tensor Cores / MXUs | 528 | 8 |
NVLink / NVSwitch
| NVLink Gen | NVSwitch Gen | GPU Gen | NVLink BW (GB/s, full-duplex) | Ports/GPU | Node GPU-GPU BW (GB/s) | Node size | NVSwitches/node |
|---|---|---|---|---|---|---|---|
| 3.0 | 2.0 | Ampere | 25 | 12 | 300 | 8 | 6 |
| 4.0 | 3.0 | Hopper | 25 | 18 | 450 | 8 | 4 |
| 5.0 | 4.0 | Blackwell | 50 | 18 | 900 | 8 / 72 | 2 / 18 |
Network hierarchy (1024-GPU H100 SuperPod)
| Level | GPUs | Switches/unit | Switch type | BW/unit (TB/s) | GPU-GPU BW (GB/s) | Fat-tree BW (GB/s) |
|---|---|---|---|---|---|---|
| Node | 8 | 4 | NVL | 3.6 | 450 | 450 |
| Leaf | 256 | 8 | IB | 12.8 | 50 | 400 |
| Spine | 1024 | 16 | IB | 51.2 | 50 | 400 |
Node egress bandwidth
| Node type | GPUs/node | GPU egress (B/s) | Node egress (B/s) |
|---|---|---|---|
| H100 | 8 | 450e9 | 400e9 |
| B200 | 8 | 900e9 | 400e9 |
| GB200 NVL72 | 72 | 900e9 | 3600e9 |