playbook/antigravity-awesome-skills/skills/remote-gpu-trainer/references/training/precision-stability.md

402 lines
22 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Numerical precision & training stability — make it RUN, then stop it diverging
The mechanics of getting a DL run to compute *finite* numbers fast on a rented card, and of debugging it
when the loss goes NaN or spikes. This layer owns **make-it-run + the mechanics of divergence**; it does
NOT own *is the converged number real* / cuDNN-nondeterminism-as-a-metric-error — that is
**verifying-dl-experiments** (cross-link **REQUIRED** at every "is this a bug or a real effect" fork).
To jump: `grep -in '<keyword>' references/training/precision-stability.md` (e.g. `tf32`, `bf16`, `scaler`,
`nan`, `anomaly`, `z-loss`, `clip`, `warmup`, `qk`, `deterministic`).
## Table of contents
- **Precision choice** — P1 fp32/tf32/fp16/bf16 decision · P2 TF32 default-off footgun · P3 H100/A100/V100 capability
- **AMP mechanics** — P4 autocast scope · P5 GradScaler (fp16 only) · P6 bf16 needs no scaler · P7 grad-clip under scaler
- **NaN / Inf** — P8 where NaNs come from · P9 anomaly detection · P10 fp16 overflow vs underflow · P11 bad-data NaN
- **Loss spikes / divergence** — P12 LR + warmup · P13 grad clipping · P14 skip-the-batch · P15 z-loss · P16 qk-norm · P17 init
- **Gradients** — P18 explosion/vanishing diagnosis
- **Repro** — P19 determinism knobs (cross-link)
- **Pointers** — gotchas_universal.md, multinode.md, spot-resilience.md
---
## Precision choice
### P1 — Which precision: fp32 / TF32 / fp16 / bf16
**Symptom**: unsure which `dtype` to train in; run is either slow (fp32) or NaN-prone (fp16).
**Root cause**: the four modes trade dynamic range against mantissa precision against tensor-core speed.
fp16 has a 5-bit exponent (max ~65504) so it *overflows* and *underflows* easily; bf16 keeps fp32's 8-bit
exponent (same range as fp32) but only 7 mantissa bits, so it never needs loss-scaling but is coarser per
value. TF32 is an fp32-storage mode that runs matmuls at 10 mantissa bits on tensor cores.
**Fix — default ladder (PyTorch 2.x)**:
1. **bf16 autocast** on Ampere+ (A100/H100/4090/...) — the modern default; same range as fp32, no GradScaler, robust. `torch.autocast("cuda", dtype=torch.bfloat16)`.
2. **TF32** for the fp32 matmuls that remain (the non-autocast path) — `torch.set_float32_matmul_precision("high")`. Free ~speedup, negligible convergence impact for most nets (P2).
3. **fp16 autocast + GradScaler** ONLY if stuck on a card with no bf16 tensor cores (V100/T4/2080Ti) — needs the scaler (P5) and is overflow-prone.
4. **Pure fp32** as the diagnostic fallback: if a run NaNs, *first* prove it's finite in fp32 before blaming the model. fp32 isolates "is this a numerics bug or a model bug."
bf16 handles large dot-products / attention logits better than fp16, which saturates and triggers
scaler-step-skipping. URLs: https://docs.pytorch.org/docs/2.12/amp.html ·
https://www.runpod.io/articles/guides/fp16-bf16-fp8-mixed-precision-speed-up-my-model-training
### P2 — TF32 is OFF by default for matmul since PyTorch 1.12 — the "why is my A100 slow" footgun
**Symptom**: an fp32 (or autocast-but-fp32-matmul-heavy) run on an A100/H100 is ~24× slower than expected;
nothing is wrong with the code.
**Root cause**: `torch.backends.cuda.matmul.allow_tf32` defaulted **True in 1.71.11**, then flipped to
**False in 1.12+** (precision-loss complaints from non-DL users). So a fresh PyTorch 2.x box runs fp32
matmuls at full fp32 on the tensor cores' slow path unless TF32 is re-enabled. Convolutions' TF32
(`cudnn.allow_tf32`) is a separate knob, enabled by default.
**Fix**: opt back in once at startup —
```python
torch.set_float32_matmul_precision("high") # preferred: enables TF32 (or bf16x3) for fp32 matmul
# legacy-equivalent, still works:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
```
`"high"` = TF32; `"highest"` = true fp32 (default); `"medium"` = even coarser. HF Trainer exposes `--tf32 1`.
Most nets converge identically with TF32 as with fp32. URLs:
https://github.com/pytorch/pytorch/pull/76509 ·
https://docs.pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html ·
https://docs.pytorch.org/docs/2.12/notes/numerical_accuracy.html
### P3 — Card capability gates the choice: bf16 needs Ampere+; V100/T4 are fp16-only
**Symptom**: bf16 training is unexpectedly slow (no error), or a config picks bf16 on an old card and falls
to a slow path.
**Root cause**: fast bf16 tensor cores arrived with **Ampere (A100, RTX 30xx)**; Hopper (H100/H200) adds
native **FP8**. **V100/T4/RTX 20xx have fp16 tensor cores but no fast bf16** (runs emulated/slow). A rental
hands whatever card is free, so the right precision is a *per-rental* fact, not a constant.
**Fix**: branch on capability at runtime, never hardcode —
```python
use_bf16 = torch.cuda.is_bf16_supported() # True on Ampere+
amp_dtype = torch.bfloat16 if use_bf16 else torch.float16
```
On V100/T4 use fp16+GradScaler (P5). FP8 (H100) is opt-in via Transformer Engine / `torchao`, not plain
autocast (out of scope). Record the card next to `nvidia-smi` in Phase 0.
URL: https://www.e2enetworks.com/blog/nvidia-a100-vs-h100-vs-h200-gpu-comparison
---
## AMP mechanics
### P4 — autocast: wrap ONLY forward + loss, never backward, never `.half()` the model
**Symptom**: dtype-mismatch errors, or AMP gives no speedup, or grads look wrong.
**Root cause**: autocast is a context that casts *eligible ops* per-op inside the region; manually
`.half()`-ing the model or wrapping the backward pass fights it.
**Fix**:
```python
for x, y in loader:
optimizer.zero_grad(set_to_none=True)
with torch.autocast("cuda", dtype=amp_dtype): # forward + loss ONLY
out = model(x); loss = loss_fn(out, y)
# backward is OUTSIDE autocast:
loss.backward() # (+ scaler for fp16, P5)
optimizer.step()
```
Keep the model and optimizer in fp32; do NOT call `model.half()`. Use the new `torch.amp.autocast("cuda",
...)` / `torch.amp.GradScaler("cuda")` API — `torch.cuda.amp.*` is **deprecated** in PyTorch 2.x. autocast
state is thread-local (re-enter it inside each DDP/DataParallel worker thread).
URL: https://docs.pytorch.org/docs/2.12/amp.html
### P5 — GradScaler: required for fp16 to stop gradient *underflow*
**Symptom (no scaler, fp16)**: loss looks fine but the model doesn't learn — small gradients flush to 0 in
fp16's tiny subnormal range.
**Root cause**: fp16's narrow range underflows small gradients to zero. GradScaler multiplies the loss by a
large factor before backward (pushing grads into representable range), then unscales before the step and
**adapts the factor**: on any inf/NaN grad it *skips the optimizer step* and halves the scale (backoff 0.5);
after `growth_interval` (default 2000) clean steps it doubles it (growth 2.0).
**Fix — canonical fp16 loop**:
```python
scaler = torch.amp.GradScaler("cuda")
for x, y in loader:
optimizer.zero_grad(set_to_none=True)
with torch.autocast("cuda", dtype=torch.float16):
loss = loss_fn(model(x), y)
scaler.scale(loss).backward()
scaler.step(optimizer) # internally unscales; SKIPS step if inf/NaN found
scaler.update() # adapts the scale factor
```
Early-training "skipped step" warnings as the scaler calibrates are **normal**; *persistent* skips every
step = a real overflow (go to P10). URLs:
https://github.com/pytorch/pytorch/blob/main/docs/source/notes/amp_examples.rst ·
https://docs.pytorch.org/docs/2.12/amp.html
### P6 — bf16 needs NO GradScaler (adding one is pointless, not harmful)
**Symptom**: a copied fp16 recipe carries a GradScaler into a bf16 run — wasted overhead, not a crash or a wrong result.
**Root cause**: bf16 has fp32's exponent range, so gradients don't underflow → loss-scaling is unnecessary
and the scaler's skip/backoff machinery is dead weight (scale-then-unscale cancels, and it never finds an
overflow to skip).
**Fix**: for bf16, drop the scaler entirely — plain `loss.backward(); optimizer.step()`. Only fp16 (and the
V100/T4 path) uses GradScaler.
URL: https://docs.pytorch.org/docs/2.12/amp.html
### P7 — Gradient clipping under GradScaler: `unscale_` FIRST or you clip scaled grads
**Symptom**: `clip_grad_norm_` under fp16 AMP has no effect, or clips at the wrong magnitude.
**Root cause**: inside the scaler the grads are still multiplied by the (large) scale factor, so clipping to
`max_norm=1.0` is really clipping to `1.0 × scale` — effectively never.
**Fix**: `scaler.unscale_(optimizer)` once, THEN clip, THEN `scaler.step`:
```python
scaler.scale(loss).backward()
scaler.unscale_(optimizer) # grads now in true scale
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer); scaler.update()
```
`unscale_` is idempotent-per-step (call it once). For bf16, just `clip_grad_norm_` directly — no unscale.
URL: https://github.com/pytorch/pytorch/blob/main/docs/source/notes/amp_examples.rst
---
## NaN / Inf
### P8 — Where NaNs come from: the four arithmetic origins
**Symptom**: loss prints `nan` (or `inf`) after N steps; everything was fine before.
**Root cause** — NaN/Inf is produced by a *finite* set of ops on bad inputs:
- `log(x)` / `log_softmax` with `x ≤ 0` (e.g. `log` of a `sigmoid` output that hit 0).
- `sqrt(x)` / `x ** 0.5` with `x < 0`, or its grad at `x = 0` (`d/dx sqrt = 1/(2√x) → inf`).
- division `a / b` with `b → 0` (un-epsilon'd normalization, variance ≈ 0 in BatchNorm/LayerNorm).
- `exp(x)` overflow → `inf`, then `inf inf` / `inf / inf → nan`.
- fp16 overflow (P10): a value exceeds 65504 → `inf` → grads → NaN.
**Fix — make the op stable, don't paper over it**:
- Never hand-roll `log(softmax(x))` — use `F.log_softmax` / `F.cross_entropy` (fused, log-sum-exp-stable).
- Add epsilon *inside* the unstable op: `torch.log(x + 1e-8)`, `torch.sqrt(x + 1e-12)`, `a / (b + 1e-8)`.
- Clamp before the danger op: `x.clamp(min=1e-7)` before `log`; clamp logits before a manual softmax.
- Use `eps` in the optimizer/norm (AdamW `eps=1e-8`; raise modestly if `v` is tiny and steps explode).
URLs: https://docs.pytorch.org/docs/stable/generated/torch.log.html ·
https://medium.com/better-ml/loss-spikes-in-training-causes-detection-and-mitigations-ed66e591b1a1
### P9 — Find the exact op: anomaly detection + a cheap forward hook
**Symptom**: loss is NaN but the stack trace points at `loss.backward()`, not the op that caused it.
**Root cause**: by default the NaN surfaces wherever it's *consumed*, not where it was *born*.
**Fix — two tools, cheap → precise**:
- **Forward NaN hook (cheap, leave on)** — register on every module to catch the *first* layer to emit NaN:
```python
for name, m in model.named_modules():
m.register_forward_hook(lambda mod, i, o, n=name:
print(f"NaN in {n}") if torch.is_tensor(o) and not torch.isfinite(o).all() else None)
```
- **`torch.autograd.set_detect_anomaly(True)` (expensive, debug-only)** — records the forward traceback of
each backward op and raises at the first backward NaN, pointing at the *forward* line that created it.
```python
with torch.autograd.detect_anomaly(): # or set_detect_anomaly(True, check_nan=True)
loss.backward()
```
The docs warn it "will slow down your program" (roughly an order of magnitude) — enable to *locate*, then
turn OFF for the real run, never ship it on. URL: https://docs.pytorch.org/docs/2.12/autograd.html
### P10 — fp16 overflow vs underflow: read the GradScaler signal
**Symptom (fp16)**: loss → inf/NaN; or the scaler skips *every* step and the scale factor collapses toward 0.
**Root cause**: a forward activation exceeds fp16's 65504 max → `inf` → NaN grads → the scaler can't find a
scale that avoids overflow, so it backs off forever. Common in attention logits and large residual sums.
(Distinct from underflow, which the scaler *fixes* by P5.)
**Fix**: switch fp16 → **bf16** (P1) — its fp32 range absorbs the large values; this is the single most
effective fix. If bf16 is unavailable (V100/T4): keep the overflow-prone block (final logits, attention
scores, the loss) in **fp32** via a nested `torch.autocast("cuda", enabled=False)` region, and apply z-loss
(P15) / qk-norm (P16) to stop the logits growing.
URL: https://medium.com/better-ml/loss-spikes-in-training-causes-detection-and-mitigations-ed66e591b1a1
### P11 — NaN from the *data*, not the math
**Symptom**: NaN appears at a specific, reproducible step (always step 4137), not gradually.
**Root cause**: a corrupt sample — NaN/Inf pixel, all-zero target, label outside `[0, C)`, empty sequence,
divide-by-zero in a custom transform. The math is fine; the input is poison.
**Fix**: guard at the data boundary — `assert torch.isfinite(x).all(), f"non-finite input @ step {step}"`
(fail loud, with the index). A reproducible-step NaN ⇒ inspect *that batch* (seed the loader, dump the
index); a *step-varying* NaN ⇒ a numerics/LR problem (P12), not data. Smoke the data first — smoke
*content* is owned by **verifying-dl-experiments** (cross-link **REQUIRED**).
URL: https://arxiv.org/pdf/2311.03938
---
## Loss spikes / divergence
### P12 — Loss spike / divergence: LR too high or warmup too short
**Symptom**: training is stable, then the loss jumps orders of magnitude (spike), sometimes recovering,
sometimes diverging to NaN — most often early, or after a fast LR ramp.
**Root cause**: if the LR ramps too fast or starts too high, early updates land before activation norms and
the optimizer's second moment (`v`) have stabilized, overshooting into sharp loss regions → gradient-norm
blowup → spike. A sustained **grad-norm** rise typically *precedes* the loss spike by several steps.
**Fix — in order of cheapness**:
1. **Lengthen warmup** (linear ramp 0 → peak over e.g. 110% of steps); warmup is the single biggest lever on LR-sensitivity of final loss.
2. **Lower peak LR** ~310× and re-check.
3. **Log grad-norm every step** as the early-warning signal — spikes are predictable from activation/grad-norm scaling before they hit.
4. Resume from the last good checkpoint *before* the spike (don't train through a diverged region).
URLs: https://arxiv.org/pdf/2309.14322 ·
https://apxml.com/courses/how-to-build-a-large-language-model/chapter-24-identifying-mitigating-training-instabilities/stabilization-techniques-revisited
### P13 — Gradient clipping: the standard guardrail (and what constant clipping means)
**Symptom**: occasional grad-norm spikes; or NaN right after a single bad batch.
**Root cause**: one pathological batch (rare embedding IDs, an outlier sample) produces an outsized global
grad norm that overshoots.
**Fix**: clip global grad norm every step — `torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)`
with `max_norm` ∈ [0.5, 1.0] typical for transformers (under the scaler: P7). **Diagnostic**: if clipping is
active *every* step or needs an absurdly low threshold to stay stable, that's a symptom of a deeper problem
(LR too high P12, bad init P17, architecture), not a fix — chase the cause. Global-norm clipping scales
*all* grads down, so one embedding-heavy batch can throttle everything else that step — consider per-module
clipping if embeddings dominate.
URL: https://medium.com/better-ml/loss-spikes-in-training-causes-detection-and-mitigations-ed66e591b1a1
### P14 — Skip-the-batch: drop the update when this step is non-finite
**Symptom**: a single bad batch every few thousand steps NaNs the whole run; restarting wastes hours.
**Root cause**: the optimizer applies a non-finite grad and permanently corrupts the weights.
**Fix**: gate the optimizer step on finiteness (fp16's GradScaler already does this internally, P5; bf16
needs it explicit):
```python
loss.backward()
gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
if torch.isfinite(gnorm):
optimizer.step()
else:
optimizer.zero_grad(set_to_none=True) # skip this batch, keep weights intact
skipped += 1
```
Log a `skipped` counter — a *rising* skip rate means a systematic problem (P12/P10), not stray bad data.
Adaptive spike-clipping (ZClip) and momentum-reset on spike (SPAM) automate this for large runs. URLs:
https://arxiv.org/pdf/2504.02507 · https://arxiv.org/pdf/2501.06842
### P15 — z-loss: stop softmax logits from drifting unbounded
**Symptom**: training is slowly destabilizing; the softmax normalizer / output logits grow over time and
eventually overflow (acute in fp16/bf16); the "output logits diverge from log-probs" failure mode.
**Root cause**: nothing pins the absolute scale of pre-softmax logits, so they drift up; large logits cause
numerical instability and (in low precision) overflow → collapse.
**Fix**: add an auxiliary **z-loss** = `1e-4 · (log Z)²` where `Z` is the softmax denominator
(`log Z = logsumexp(logits)`), pulling `log Z → 0`:
```python
logits = model(x)
z = torch.logsumexp(logits, dim=-1)
loss = F.cross_entropy(logits, y) + 1e-4 * (z ** 2).mean()
```
Coefficient **1e-4** is the PaLM/ST-MoE value; too large lets z-loss dominate. Standard in LLM pretraining;
also the recommended fix for MoE router instability. URLs:
https://medium.com/dair-ai/papers-explained-50-palm-480e72fa3fd5 · https://arxiv.org/pdf/2202.08906 ·
https://arxiv.org/pdf/2309.14322
### P16 — qk-norm: kill attention-logit growth at high LR
**Symptom**: a transformer diverges only at higher LR; the instability traces to attention scores (Q·Kᵀ)
growing large before the softmax.
**Root cause**: "growth of logits in attention layers" — one of the two dominant transformer instability
modes (the other is output-logit divergence, P15). Unbounded attention logits saturate the softmax.
**Fix**: apply **QK-LayerNorm** — LayerNorm query and key per-head before the dot-product. Combined with
z-loss (P15) + warmup (P12), it lets small models train to similar loss across *orders of magnitude* of LR,
i.e. removes most LR-sensitivity. URL: https://arxiv.org/pdf/2309.14322
### P17 — Initialization & normalization placement
**Symptom**: divergence in the first few hundred steps regardless of LR; or vanishing signal (P18) in deep
stacks.
**Root cause**: residual streams accumulate variance with depth; default init can make early
activations/grads too large (spike) or too small (vanish). Norm/embedding init scale matters.
**Fix**: scale residual-branch init by `1/√(2·n_layers)` (GPT-2-style); prefer pre-LN over post-LN for deep
transformers; init embeddings at small std (~0.02). When unsure, copy a *known-good* config's init+norm
scheme rather than tuning blind. URL: https://arxiv.org/pdf/2309.14322
---
## Gradients
### P18 — Gradient explosion vs vanishing: diagnose by logging the norm
**Symptom**: loss NaN/diverges (explosion) OR loss plateaus and the model never learns (vanishing).
**Root cause**: per-layer grad norms blow up (explosion: deep nets, high LR, no clip) or decay to ~0
(vanishing: saturating activations, bad init P17, too-deep unnormalized stacks).
**Fix — measure first**:
```python
total = sum(p.grad.detach().norm()**2 for p in model.parameters() if p.grad is not None) ** 0.5
# log `total` every step; also log per-layer norms when hunting the culprit layer
```
- **Explosion** (norm ↑↑): grad clipping (P13), lower LR (P12), longer warmup, bf16 over fp16 (P10).
- **Vanishing** (norm → 0): residual connections, normalization layers, better init (P17), non-saturating
activations (GELU/SiLU over deep sigmoid/tanh stacks), check the LR isn't *too low*.
A grad-norm trace is the cheapest, highest-signal stability instrument — log it from step 1.
URL: https://apxml.com/courses/how-to-build-a-large-language-model/chapter-24-identifying-mitigating-training-instabilities/stabilization-techniques-revisited
---
## Reproducibility
### P19 — Deterministic / repro knobs — set them, but the *interpretation* is delegated
**Symptom**: same config + seed gives slightly different loss/metrics run-to-run.
**Root cause**: nondeterministic CUDA kernels + `cudnn.benchmark` autotuning pick different algorithms per
run; TF32/AMP add low-order noise on top.
**Fix — the mechanical knobs (set these here)**:
```python
torch.manual_seed(s); np.random.seed(s); random.seed(s)
torch.use_deterministic_algorithms(True) # may need CUBLAS_WORKSPACE_CONFIG=:4096:8
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False # benchmark=True trades determinism for speed
```
**Whether a run-to-run delta is "a real effect vs cuDNN nondeterminism," and the full determinism
methodology, is owned by verifying-dl-experiments (cross-link REQUIRED)** — catalogued as **U36** in
`references/gotchas_universal.md`. This layer only ensures the knobs are *set and logged*. Determinism costs
speed — enable for the datapoint that must be clean, not every throwaway run.
URL: https://docs.pytorch.org/docs/stable/notes/randomness.html
---
## Pointers — adjacent layers, do NOT restate here
- **`references/gotchas_universal.md`** — the *infra* failure modes that masquerade as numerics:
**U6** disk-full crashes `torch.save`, **U9** cgroup-OOM (bare `Killed`, not a NaN), **U28** CUDA/driver/
torch-build mismatch (`no kernel image` ≠ a precision bug), **U10/U11** VRAM OOM. Rule out infra before
chasing a "numerics" ghost.
- **`verifying-dl-experiments`** (**REQUIRED** cross-link) — owns *is-the-number-real*: smoke **content**,
cuDNN-nondeterminism-as-metric-error (U36), collapse/constant-output diagnosis, "bug vs real effect." This
file makes training *run and stay finite*; that skill judges whether the converged result is *true*.
- **`references/spot-resilience.md`** — checkpoint cadence so a divergence-and-resume (P12) loses minimal work.
- **`references/multinode.md`** — NCCL/precision interactions in DDP (all-reduce dtype, loss-scale sync) for
multi-node runs; single-box users skip.