# Numerical precision & training stability — make it RUN, then stop it diverging The mechanics of getting a DL run to compute *finite* numbers fast on a rented card, and of debugging it when the loss goes NaN or spikes. This layer owns **make-it-run + the mechanics of divergence**; it does NOT own *is the converged number real* / cuDNN-nondeterminism-as-a-metric-error — that is **verifying-dl-experiments** (cross-link **REQUIRED** at every "is this a bug or a real effect" fork). To jump: `grep -in '' references/training/precision-stability.md` (e.g. `tf32`, `bf16`, `scaler`, `nan`, `anomaly`, `z-loss`, `clip`, `warmup`, `qk`, `deterministic`). ## Table of contents - **Precision choice** — P1 fp32/tf32/fp16/bf16 decision · P2 TF32 default-off footgun · P3 H100/A100/V100 capability - **AMP mechanics** — P4 autocast scope · P5 GradScaler (fp16 only) · P6 bf16 needs no scaler · P7 grad-clip under scaler - **NaN / Inf** — P8 where NaNs come from · P9 anomaly detection · P10 fp16 overflow vs underflow · P11 bad-data NaN - **Loss spikes / divergence** — P12 LR + warmup · P13 grad clipping · P14 skip-the-batch · P15 z-loss · P16 qk-norm · P17 init - **Gradients** — P18 explosion/vanishing diagnosis - **Repro** — P19 determinism knobs (cross-link) - **Pointers** — gotchas_universal.md, multinode.md, spot-resilience.md --- ## Precision choice ### P1 — Which precision: fp32 / TF32 / fp16 / bf16 **Symptom**: unsure which `dtype` to train in; run is either slow (fp32) or NaN-prone (fp16). **Root cause**: the four modes trade dynamic range against mantissa precision against tensor-core speed. fp16 has a 5-bit exponent (max ~65504) so it *overflows* and *underflows* easily; bf16 keeps fp32's 8-bit exponent (same range as fp32) but only 7 mantissa bits, so it never needs loss-scaling but is coarser per value. TF32 is an fp32-storage mode that runs matmuls at 10 mantissa bits on tensor cores. **Fix — default ladder (PyTorch 2.x)**: 1. **bf16 autocast** on Ampere+ (A100/H100/4090/...) — the modern default; same range as fp32, no GradScaler, robust. `torch.autocast("cuda", dtype=torch.bfloat16)`. 2. **TF32** for the fp32 matmuls that remain (the non-autocast path) — `torch.set_float32_matmul_precision("high")`. Free ~speedup, negligible convergence impact for most nets (P2). 3. **fp16 autocast + GradScaler** ONLY if stuck on a card with no bf16 tensor cores (V100/T4/2080Ti) — needs the scaler (P5) and is overflow-prone. 4. **Pure fp32** as the diagnostic fallback: if a run NaNs, *first* prove it's finite in fp32 before blaming the model. fp32 isolates "is this a numerics bug or a model bug." bf16 handles large dot-products / attention logits better than fp16, which saturates and triggers scaler-step-skipping. URLs: https://docs.pytorch.org/docs/2.12/amp.html · https://www.runpod.io/articles/guides/fp16-bf16-fp8-mixed-precision-speed-up-my-model-training ### P2 — TF32 is OFF by default for matmul since PyTorch 1.12 — the "why is my A100 slow" footgun **Symptom**: an fp32 (or autocast-but-fp32-matmul-heavy) run on an A100/H100 is ~2–4× slower than expected; nothing is wrong with the code. **Root cause**: `torch.backends.cuda.matmul.allow_tf32` defaulted **True in 1.7–1.11**, then flipped to **False in 1.12+** (precision-loss complaints from non-DL users). So a fresh PyTorch 2.x box runs fp32 matmuls at full fp32 on the tensor cores' slow path unless TF32 is re-enabled. Convolutions' TF32 (`cudnn.allow_tf32`) is a separate knob, enabled by default. **Fix**: opt back in once at startup — ```python torch.set_float32_matmul_precision("high") # preferred: enables TF32 (or bf16x3) for fp32 matmul # legacy-equivalent, still works: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True ``` `"high"` = TF32; `"highest"` = true fp32 (default); `"medium"` = even coarser. HF Trainer exposes `--tf32 1`. Most nets converge identically with TF32 as with fp32. URLs: https://github.com/pytorch/pytorch/pull/76509 · https://docs.pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html · https://docs.pytorch.org/docs/2.12/notes/numerical_accuracy.html ### P3 — Card capability gates the choice: bf16 needs Ampere+; V100/T4 are fp16-only **Symptom**: bf16 training is unexpectedly slow (no error), or a config picks bf16 on an old card and falls to a slow path. **Root cause**: fast bf16 tensor cores arrived with **Ampere (A100, RTX 30xx)**; Hopper (H100/H200) adds native **FP8**. **V100/T4/RTX 20xx have fp16 tensor cores but no fast bf16** (runs emulated/slow). A rental hands whatever card is free, so the right precision is a *per-rental* fact, not a constant. **Fix**: branch on capability at runtime, never hardcode — ```python use_bf16 = torch.cuda.is_bf16_supported() # True on Ampere+ amp_dtype = torch.bfloat16 if use_bf16 else torch.float16 ``` On V100/T4 use fp16+GradScaler (P5). FP8 (H100) is opt-in via Transformer Engine / `torchao`, not plain autocast (out of scope). Record the card next to `nvidia-smi` in Phase 0. URL: https://www.e2enetworks.com/blog/nvidia-a100-vs-h100-vs-h200-gpu-comparison --- ## AMP mechanics ### P4 — autocast: wrap ONLY forward + loss, never backward, never `.half()` the model **Symptom**: dtype-mismatch errors, or AMP gives no speedup, or grads look wrong. **Root cause**: autocast is a context that casts *eligible ops* per-op inside the region; manually `.half()`-ing the model or wrapping the backward pass fights it. **Fix**: ```python for x, y in loader: optimizer.zero_grad(set_to_none=True) with torch.autocast("cuda", dtype=amp_dtype): # forward + loss ONLY out = model(x); loss = loss_fn(out, y) # backward is OUTSIDE autocast: loss.backward() # (+ scaler for fp16, P5) optimizer.step() ``` Keep the model and optimizer in fp32; do NOT call `model.half()`. Use the new `torch.amp.autocast("cuda", ...)` / `torch.amp.GradScaler("cuda")` API — `torch.cuda.amp.*` is **deprecated** in PyTorch 2.x. autocast state is thread-local (re-enter it inside each DDP/DataParallel worker thread). URL: https://docs.pytorch.org/docs/2.12/amp.html ### P5 — GradScaler: required for fp16 to stop gradient *underflow* **Symptom (no scaler, fp16)**: loss looks fine but the model doesn't learn — small gradients flush to 0 in fp16's tiny subnormal range. **Root cause**: fp16's narrow range underflows small gradients to zero. GradScaler multiplies the loss by a large factor before backward (pushing grads into representable range), then unscales before the step and **adapts the factor**: on any inf/NaN grad it *skips the optimizer step* and halves the scale (backoff 0.5); after `growth_interval` (default 2000) clean steps it doubles it (growth 2.0). **Fix — canonical fp16 loop**: ```python scaler = torch.amp.GradScaler("cuda") for x, y in loader: optimizer.zero_grad(set_to_none=True) with torch.autocast("cuda", dtype=torch.float16): loss = loss_fn(model(x), y) scaler.scale(loss).backward() scaler.step(optimizer) # internally unscales; SKIPS step if inf/NaN found scaler.update() # adapts the scale factor ``` Early-training "skipped step" warnings as the scaler calibrates are **normal**; *persistent* skips every step = a real overflow (go to P10). URLs: https://github.com/pytorch/pytorch/blob/main/docs/source/notes/amp_examples.rst · https://docs.pytorch.org/docs/2.12/amp.html ### P6 — bf16 needs NO GradScaler (adding one is pointless, not harmful) **Symptom**: a copied fp16 recipe carries a GradScaler into a bf16 run — wasted overhead, not a crash or a wrong result. **Root cause**: bf16 has fp32's exponent range, so gradients don't underflow → loss-scaling is unnecessary and the scaler's skip/backoff machinery is dead weight (scale-then-unscale cancels, and it never finds an overflow to skip). **Fix**: for bf16, drop the scaler entirely — plain `loss.backward(); optimizer.step()`. Only fp16 (and the V100/T4 path) uses GradScaler. URL: https://docs.pytorch.org/docs/2.12/amp.html ### P7 — Gradient clipping under GradScaler: `unscale_` FIRST or you clip scaled grads **Symptom**: `clip_grad_norm_` under fp16 AMP has no effect, or clips at the wrong magnitude. **Root cause**: inside the scaler the grads are still multiplied by the (large) scale factor, so clipping to `max_norm=1.0` is really clipping to `1.0 × scale` — effectively never. **Fix**: `scaler.unscale_(optimizer)` once, THEN clip, THEN `scaler.step`: ```python scaler.scale(loss).backward() scaler.unscale_(optimizer) # grads now in true scale torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(optimizer); scaler.update() ``` `unscale_` is idempotent-per-step (call it once). For bf16, just `clip_grad_norm_` directly — no unscale. URL: https://github.com/pytorch/pytorch/blob/main/docs/source/notes/amp_examples.rst --- ## NaN / Inf ### P8 — Where NaNs come from: the four arithmetic origins **Symptom**: loss prints `nan` (or `inf`) after N steps; everything was fine before. **Root cause** — NaN/Inf is produced by a *finite* set of ops on bad inputs: - `log(x)` / `log_softmax` with `x ≤ 0` (e.g. `log` of a `sigmoid` output that hit 0). - `sqrt(x)` / `x ** 0.5` with `x < 0`, or its grad at `x = 0` (`d/dx sqrt = 1/(2√x) → inf`). - division `a / b` with `b → 0` (un-epsilon'd normalization, variance ≈ 0 in BatchNorm/LayerNorm). - `exp(x)` overflow → `inf`, then `inf − inf` / `inf / inf → nan`. - fp16 overflow (P10): a value exceeds 65504 → `inf` → grads → NaN. **Fix — make the op stable, don't paper over it**: - Never hand-roll `log(softmax(x))` — use `F.log_softmax` / `F.cross_entropy` (fused, log-sum-exp-stable). - Add epsilon *inside* the unstable op: `torch.log(x + 1e-8)`, `torch.sqrt(x + 1e-12)`, `a / (b + 1e-8)`. - Clamp before the danger op: `x.clamp(min=1e-7)` before `log`; clamp logits before a manual softmax. - Use `eps` in the optimizer/norm (AdamW `eps=1e-8`; raise modestly if `v` is tiny and steps explode). URLs: https://docs.pytorch.org/docs/stable/generated/torch.log.html · https://medium.com/better-ml/loss-spikes-in-training-causes-detection-and-mitigations-ed66e591b1a1 ### P9 — Find the exact op: anomaly detection + a cheap forward hook **Symptom**: loss is NaN but the stack trace points at `loss.backward()`, not the op that caused it. **Root cause**: by default the NaN surfaces wherever it's *consumed*, not where it was *born*. **Fix — two tools, cheap → precise**: - **Forward NaN hook (cheap, leave on)** — register on every module to catch the *first* layer to emit NaN: ```python for name, m in model.named_modules(): m.register_forward_hook(lambda mod, i, o, n=name: print(f"NaN in {n}") if torch.is_tensor(o) and not torch.isfinite(o).all() else None) ``` - **`torch.autograd.set_detect_anomaly(True)` (expensive, debug-only)** — records the forward traceback of each backward op and raises at the first backward NaN, pointing at the *forward* line that created it. ```python with torch.autograd.detect_anomaly(): # or set_detect_anomaly(True, check_nan=True) loss.backward() ``` The docs warn it "will slow down your program" (roughly an order of magnitude) — enable to *locate*, then turn OFF for the real run, never ship it on. URL: https://docs.pytorch.org/docs/2.12/autograd.html ### P10 — fp16 overflow vs underflow: read the GradScaler signal **Symptom (fp16)**: loss → inf/NaN; or the scaler skips *every* step and the scale factor collapses toward 0. **Root cause**: a forward activation exceeds fp16's 65504 max → `inf` → NaN grads → the scaler can't find a scale that avoids overflow, so it backs off forever. Common in attention logits and large residual sums. (Distinct from underflow, which the scaler *fixes* by P5.) **Fix**: switch fp16 → **bf16** (P1) — its fp32 range absorbs the large values; this is the single most effective fix. If bf16 is unavailable (V100/T4): keep the overflow-prone block (final logits, attention scores, the loss) in **fp32** via a nested `torch.autocast("cuda", enabled=False)` region, and apply z-loss (P15) / qk-norm (P16) to stop the logits growing. URL: https://medium.com/better-ml/loss-spikes-in-training-causes-detection-and-mitigations-ed66e591b1a1 ### P11 — NaN from the *data*, not the math **Symptom**: NaN appears at a specific, reproducible step (always step 4137), not gradually. **Root cause**: a corrupt sample — NaN/Inf pixel, all-zero target, label outside `[0, C)`, empty sequence, divide-by-zero in a custom transform. The math is fine; the input is poison. **Fix**: guard at the data boundary — `assert torch.isfinite(x).all(), f"non-finite input @ step {step}"` (fail loud, with the index). A reproducible-step NaN ⇒ inspect *that batch* (seed the loader, dump the index); a *step-varying* NaN ⇒ a numerics/LR problem (P12), not data. Smoke the data first — smoke *content* is owned by **verifying-dl-experiments** (cross-link **REQUIRED**). URL: https://arxiv.org/pdf/2311.03938 --- ## Loss spikes / divergence ### P12 — Loss spike / divergence: LR too high or warmup too short **Symptom**: training is stable, then the loss jumps orders of magnitude (spike), sometimes recovering, sometimes diverging to NaN — most often early, or after a fast LR ramp. **Root cause**: if the LR ramps too fast or starts too high, early updates land before activation norms and the optimizer's second moment (`v`) have stabilized, overshooting into sharp loss regions → gradient-norm blowup → spike. A sustained **grad-norm** rise typically *precedes* the loss spike by several steps. **Fix — in order of cheapness**: 1. **Lengthen warmup** (linear ramp 0 → peak over e.g. 1–10% of steps); warmup is the single biggest lever on LR-sensitivity of final loss. 2. **Lower peak LR** ~3–10× and re-check. 3. **Log grad-norm every step** as the early-warning signal — spikes are predictable from activation/grad-norm scaling before they hit. 4. Resume from the last good checkpoint *before* the spike (don't train through a diverged region). URLs: https://arxiv.org/pdf/2309.14322 · https://apxml.com/courses/how-to-build-a-large-language-model/chapter-24-identifying-mitigating-training-instabilities/stabilization-techniques-revisited ### P13 — Gradient clipping: the standard guardrail (and what constant clipping means) **Symptom**: occasional grad-norm spikes; or NaN right after a single bad batch. **Root cause**: one pathological batch (rare embedding IDs, an outlier sample) produces an outsized global grad norm that overshoots. **Fix**: clip global grad norm every step — `torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)` with `max_norm` ∈ [0.5, 1.0] typical for transformers (under the scaler: P7). **Diagnostic**: if clipping is active *every* step or needs an absurdly low threshold to stay stable, that's a symptom of a deeper problem (LR too high P12, bad init P17, architecture), not a fix — chase the cause. Global-norm clipping scales *all* grads down, so one embedding-heavy batch can throttle everything else that step — consider per-module clipping if embeddings dominate. URL: https://medium.com/better-ml/loss-spikes-in-training-causes-detection-and-mitigations-ed66e591b1a1 ### P14 — Skip-the-batch: drop the update when this step is non-finite **Symptom**: a single bad batch every few thousand steps NaNs the whole run; restarting wastes hours. **Root cause**: the optimizer applies a non-finite grad and permanently corrupts the weights. **Fix**: gate the optimizer step on finiteness (fp16's GradScaler already does this internally, P5; bf16 needs it explicit): ```python loss.backward() gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) if torch.isfinite(gnorm): optimizer.step() else: optimizer.zero_grad(set_to_none=True) # skip this batch, keep weights intact skipped += 1 ``` Log a `skipped` counter — a *rising* skip rate means a systematic problem (P12/P10), not stray bad data. Adaptive spike-clipping (ZClip) and momentum-reset on spike (SPAM) automate this for large runs. URLs: https://arxiv.org/pdf/2504.02507 · https://arxiv.org/pdf/2501.06842 ### P15 — z-loss: stop softmax logits from drifting unbounded **Symptom**: training is slowly destabilizing; the softmax normalizer / output logits grow over time and eventually overflow (acute in fp16/bf16); the "output logits diverge from log-probs" failure mode. **Root cause**: nothing pins the absolute scale of pre-softmax logits, so they drift up; large logits cause numerical instability and (in low precision) overflow → collapse. **Fix**: add an auxiliary **z-loss** = `1e-4 · (log Z)²` where `Z` is the softmax denominator (`log Z = logsumexp(logits)`), pulling `log Z → 0`: ```python logits = model(x) z = torch.logsumexp(logits, dim=-1) loss = F.cross_entropy(logits, y) + 1e-4 * (z ** 2).mean() ``` Coefficient **1e-4** is the PaLM/ST-MoE value; too large lets z-loss dominate. Standard in LLM pretraining; also the recommended fix for MoE router instability. URLs: https://medium.com/dair-ai/papers-explained-50-palm-480e72fa3fd5 · https://arxiv.org/pdf/2202.08906 · https://arxiv.org/pdf/2309.14322 ### P16 — qk-norm: kill attention-logit growth at high LR **Symptom**: a transformer diverges only at higher LR; the instability traces to attention scores (Q·Kᵀ) growing large before the softmax. **Root cause**: "growth of logits in attention layers" — one of the two dominant transformer instability modes (the other is output-logit divergence, P15). Unbounded attention logits saturate the softmax. **Fix**: apply **QK-LayerNorm** — LayerNorm query and key per-head before the dot-product. Combined with z-loss (P15) + warmup (P12), it lets small models train to similar loss across *orders of magnitude* of LR, i.e. removes most LR-sensitivity. URL: https://arxiv.org/pdf/2309.14322 ### P17 — Initialization & normalization placement **Symptom**: divergence in the first few hundred steps regardless of LR; or vanishing signal (P18) in deep stacks. **Root cause**: residual streams accumulate variance with depth; default init can make early activations/grads too large (spike) or too small (vanish). Norm/embedding init scale matters. **Fix**: scale residual-branch init by `1/√(2·n_layers)` (GPT-2-style); prefer pre-LN over post-LN for deep transformers; init embeddings at small std (~0.02). When unsure, copy a *known-good* config's init+norm scheme rather than tuning blind. URL: https://arxiv.org/pdf/2309.14322 --- ## Gradients ### P18 — Gradient explosion vs vanishing: diagnose by logging the norm **Symptom**: loss NaN/diverges (explosion) OR loss plateaus and the model never learns (vanishing). **Root cause**: per-layer grad norms blow up (explosion: deep nets, high LR, no clip) or decay to ~0 (vanishing: saturating activations, bad init P17, too-deep unnormalized stacks). **Fix — measure first**: ```python total = sum(p.grad.detach().norm()**2 for p in model.parameters() if p.grad is not None) ** 0.5 # log `total` every step; also log per-layer norms when hunting the culprit layer ``` - **Explosion** (norm ↑↑): grad clipping (P13), lower LR (P12), longer warmup, bf16 over fp16 (P10). - **Vanishing** (norm → 0): residual connections, normalization layers, better init (P17), non-saturating activations (GELU/SiLU over deep sigmoid/tanh stacks), check the LR isn't *too low*. A grad-norm trace is the cheapest, highest-signal stability instrument — log it from step 1. URL: https://apxml.com/courses/how-to-build-a-large-language-model/chapter-24-identifying-mitigating-training-instabilities/stabilization-techniques-revisited --- ## Reproducibility ### P19 — Deterministic / repro knobs — set them, but the *interpretation* is delegated **Symptom**: same config + seed gives slightly different loss/metrics run-to-run. **Root cause**: nondeterministic CUDA kernels + `cudnn.benchmark` autotuning pick different algorithms per run; TF32/AMP add low-order noise on top. **Fix — the mechanical knobs (set these here)**: ```python torch.manual_seed(s); np.random.seed(s); random.seed(s) torch.use_deterministic_algorithms(True) # may need CUBLAS_WORKSPACE_CONFIG=:4096:8 torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # benchmark=True trades determinism for speed ``` **Whether a run-to-run delta is "a real effect vs cuDNN nondeterminism," and the full determinism methodology, is owned by verifying-dl-experiments (cross-link REQUIRED)** — catalogued as **U36** in `references/gotchas_universal.md`. This layer only ensures the knobs are *set and logged*. Determinism costs speed — enable for the datapoint that must be clean, not every throwaway run. URL: https://docs.pytorch.org/docs/stable/notes/randomness.html --- ## Pointers — adjacent layers, do NOT restate here - **`references/gotchas_universal.md`** — the *infra* failure modes that masquerade as numerics: **U6** disk-full crashes `torch.save`, **U9** cgroup-OOM (bare `Killed`, not a NaN), **U28** CUDA/driver/ torch-build mismatch (`no kernel image` ≠ a precision bug), **U10/U11** VRAM OOM. Rule out infra before chasing a "numerics" ghost. - **`verifying-dl-experiments`** (**REQUIRED** cross-link) — owns *is-the-number-real*: smoke **content**, cuDNN-nondeterminism-as-metric-error (U36), collapse/constant-output diagnosis, "bug vs real effect." This file makes training *run and stay finite*; that skill judges whether the converged result is *true*. - **`references/spot-resilience.md`** — checkpoint cadence so a divergence-and-resume (P12) loses minimal work. - **`references/multinode.md`** — NCCL/precision interactions in DDP (all-reduce dtype, loss-scale sync) for multi-node runs; single-box users skip.