playbook/antigravity-awesome-skills/skills/remote-gpu-trainer/references/training/precision-stability.md

22 KiB
Raw Permalink Blame History

Numerical precision & training stability — make it RUN, then stop it diverging

The mechanics of getting a DL run to compute finite numbers fast on a rented card, and of debugging it when the loss goes NaN or spikes. This layer owns make-it-run + the mechanics of divergence; it does NOT own is the converged number real / cuDNN-nondeterminism-as-a-metric-error — that is verifying-dl-experiments (cross-link REQUIRED at every "is this a bug or a real effect" fork).

To jump: grep -in '<keyword>' references/training/precision-stability.md (e.g. tf32, bf16, scaler, nan, anomaly, z-loss, clip, warmup, qk, deterministic).

Table of contents

  • Precision choice — P1 fp32/tf32/fp16/bf16 decision · P2 TF32 default-off footgun · P3 H100/A100/V100 capability
  • AMP mechanics — P4 autocast scope · P5 GradScaler (fp16 only) · P6 bf16 needs no scaler · P7 grad-clip under scaler
  • NaN / Inf — P8 where NaNs come from · P9 anomaly detection · P10 fp16 overflow vs underflow · P11 bad-data NaN
  • Loss spikes / divergence — P12 LR + warmup · P13 grad clipping · P14 skip-the-batch · P15 z-loss · P16 qk-norm · P17 init
  • Gradients — P18 explosion/vanishing diagnosis
  • Repro — P19 determinism knobs (cross-link)
  • Pointers — gotchas_universal.md, multinode.md, spot-resilience.md

Precision choice

P1 — Which precision: fp32 / TF32 / fp16 / bf16

Symptom: unsure which dtype to train in; run is either slow (fp32) or NaN-prone (fp16).

Root cause: the four modes trade dynamic range against mantissa precision against tensor-core speed. fp16 has a 5-bit exponent (max ~65504) so it overflows and underflows easily; bf16 keeps fp32's 8-bit exponent (same range as fp32) but only 7 mantissa bits, so it never needs loss-scaling but is coarser per value. TF32 is an fp32-storage mode that runs matmuls at 10 mantissa bits on tensor cores.

Fix — default ladder (PyTorch 2.x):

  1. bf16 autocast on Ampere+ (A100/H100/4090/...) — the modern default; same range as fp32, no GradScaler, robust. torch.autocast("cuda", dtype=torch.bfloat16).
  2. TF32 for the fp32 matmuls that remain (the non-autocast path) — torch.set_float32_matmul_precision("high"). Free ~speedup, negligible convergence impact for most nets (P2).
  3. fp16 autocast + GradScaler ONLY if stuck on a card with no bf16 tensor cores (V100/T4/2080Ti) — needs the scaler (P5) and is overflow-prone.
  4. Pure fp32 as the diagnostic fallback: if a run NaNs, first prove it's finite in fp32 before blaming the model. fp32 isolates "is this a numerics bug or a model bug."

bf16 handles large dot-products / attention logits better than fp16, which saturates and triggers scaler-step-skipping. URLs: https://docs.pytorch.org/docs/2.12/amp.html · https://www.runpod.io/articles/guides/fp16-bf16-fp8-mixed-precision-speed-up-my-model-training

P2 — TF32 is OFF by default for matmul since PyTorch 1.12 — the "why is my A100 slow" footgun

Symptom: an fp32 (or autocast-but-fp32-matmul-heavy) run on an A100/H100 is ~24× slower than expected; nothing is wrong with the code.

Root cause: torch.backends.cuda.matmul.allow_tf32 defaulted True in 1.71.11, then flipped to False in 1.12+ (precision-loss complaints from non-DL users). So a fresh PyTorch 2.x box runs fp32 matmuls at full fp32 on the tensor cores' slow path unless TF32 is re-enabled. Convolutions' TF32 (cudnn.allow_tf32) is a separate knob, enabled by default.

Fix: opt back in once at startup —

torch.set_float32_matmul_precision("high")   # preferred: enables TF32 (or bf16x3) for fp32 matmul
# legacy-equivalent, still works:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

"high" = TF32; "highest" = true fp32 (default); "medium" = even coarser. HF Trainer exposes --tf32 1. Most nets converge identically with TF32 as with fp32. URLs: https://github.com/pytorch/pytorch/pull/76509 · https://docs.pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html · https://docs.pytorch.org/docs/2.12/notes/numerical_accuracy.html

P3 — Card capability gates the choice: bf16 needs Ampere+; V100/T4 are fp16-only

Symptom: bf16 training is unexpectedly slow (no error), or a config picks bf16 on an old card and falls to a slow path.

Root cause: fast bf16 tensor cores arrived with Ampere (A100, RTX 30xx); Hopper (H100/H200) adds native FP8. V100/T4/RTX 20xx have fp16 tensor cores but no fast bf16 (runs emulated/slow). A rental hands whatever card is free, so the right precision is a per-rental fact, not a constant.

Fix: branch on capability at runtime, never hardcode —

use_bf16 = torch.cuda.is_bf16_supported()    # True on Ampere+
amp_dtype = torch.bfloat16 if use_bf16 else torch.float16

On V100/T4 use fp16+GradScaler (P5). FP8 (H100) is opt-in via Transformer Engine / torchao, not plain autocast (out of scope). Record the card next to nvidia-smi in Phase 0. URL: https://www.e2enetworks.com/blog/nvidia-a100-vs-h100-vs-h200-gpu-comparison


AMP mechanics

P4 — autocast: wrap ONLY forward + loss, never backward, never .half() the model

Symptom: dtype-mismatch errors, or AMP gives no speedup, or grads look wrong.

Root cause: autocast is a context that casts eligible ops per-op inside the region; manually .half()-ing the model or wrapping the backward pass fights it.

Fix:

for x, y in loader:
    optimizer.zero_grad(set_to_none=True)
    with torch.autocast("cuda", dtype=amp_dtype):   # forward + loss ONLY
        out = model(x); loss = loss_fn(out, y)
    # backward is OUTSIDE autocast:
    loss.backward()                                 # (+ scaler for fp16, P5)
    optimizer.step()

Keep the model and optimizer in fp32; do NOT call model.half(). Use the new torch.amp.autocast("cuda", ...) / torch.amp.GradScaler("cuda") API — torch.cuda.amp.* is deprecated in PyTorch 2.x. autocast state is thread-local (re-enter it inside each DDP/DataParallel worker thread). URL: https://docs.pytorch.org/docs/2.12/amp.html

P5 — GradScaler: required for fp16 to stop gradient underflow

Symptom (no scaler, fp16): loss looks fine but the model doesn't learn — small gradients flush to 0 in fp16's tiny subnormal range.

Root cause: fp16's narrow range underflows small gradients to zero. GradScaler multiplies the loss by a large factor before backward (pushing grads into representable range), then unscales before the step and adapts the factor: on any inf/NaN grad it skips the optimizer step and halves the scale (backoff 0.5); after growth_interval (default 2000) clean steps it doubles it (growth 2.0).

Fix — canonical fp16 loop:

scaler = torch.amp.GradScaler("cuda")
for x, y in loader:
    optimizer.zero_grad(set_to_none=True)
    with torch.autocast("cuda", dtype=torch.float16):
        loss = loss_fn(model(x), y)
    scaler.scale(loss).backward()
    scaler.step(optimizer)     # internally unscales; SKIPS step if inf/NaN found
    scaler.update()            # adapts the scale factor

Early-training "skipped step" warnings as the scaler calibrates are normal; persistent skips every step = a real overflow (go to P10). URLs: https://github.com/pytorch/pytorch/blob/main/docs/source/notes/amp_examples.rst · https://docs.pytorch.org/docs/2.12/amp.html

P6 — bf16 needs NO GradScaler (adding one is pointless, not harmful)

Symptom: a copied fp16 recipe carries a GradScaler into a bf16 run — wasted overhead, not a crash or a wrong result.

Root cause: bf16 has fp32's exponent range, so gradients don't underflow → loss-scaling is unnecessary and the scaler's skip/backoff machinery is dead weight (scale-then-unscale cancels, and it never finds an overflow to skip).

Fix: for bf16, drop the scaler entirely — plain loss.backward(); optimizer.step(). Only fp16 (and the V100/T4 path) uses GradScaler. URL: https://docs.pytorch.org/docs/2.12/amp.html

P7 — Gradient clipping under GradScaler: unscale_ FIRST or you clip scaled grads

Symptom: clip_grad_norm_ under fp16 AMP has no effect, or clips at the wrong magnitude.

Root cause: inside the scaler the grads are still multiplied by the (large) scale factor, so clipping to max_norm=1.0 is really clipping to 1.0 × scale — effectively never.

Fix: scaler.unscale_(optimizer) once, THEN clip, THEN scaler.step:

scaler.scale(loss).backward()
scaler.unscale_(optimizer)                                  # grads now in true scale
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer); scaler.update()

unscale_ is idempotent-per-step (call it once). For bf16, just clip_grad_norm_ directly — no unscale. URL: https://github.com/pytorch/pytorch/blob/main/docs/source/notes/amp_examples.rst


NaN / Inf

P8 — Where NaNs come from: the four arithmetic origins

Symptom: loss prints nan (or inf) after N steps; everything was fine before.

Root cause — NaN/Inf is produced by a finite set of ops on bad inputs:

  • log(x) / log_softmax with x ≤ 0 (e.g. log of a sigmoid output that hit 0).
  • sqrt(x) / x ** 0.5 with x < 0, or its grad at x = 0 (d/dx sqrt = 1/(2√x) → inf).
  • division a / b with b → 0 (un-epsilon'd normalization, variance ≈ 0 in BatchNorm/LayerNorm).
  • exp(x) overflow → inf, then inf inf / inf / inf → nan.
  • fp16 overflow (P10): a value exceeds 65504 → inf → grads → NaN.

Fix — make the op stable, don't paper over it:

  • Never hand-roll log(softmax(x)) — use F.log_softmax / F.cross_entropy (fused, log-sum-exp-stable).
  • Add epsilon inside the unstable op: torch.log(x + 1e-8), torch.sqrt(x + 1e-12), a / (b + 1e-8).
  • Clamp before the danger op: x.clamp(min=1e-7) before log; clamp logits before a manual softmax.
  • Use eps in the optimizer/norm (AdamW eps=1e-8; raise modestly if v is tiny and steps explode).

URLs: https://docs.pytorch.org/docs/stable/generated/torch.log.html · https://medium.com/better-ml/loss-spikes-in-training-causes-detection-and-mitigations-ed66e591b1a1

P9 — Find the exact op: anomaly detection + a cheap forward hook

Symptom: loss is NaN but the stack trace points at loss.backward(), not the op that caused it.

Root cause: by default the NaN surfaces wherever it's consumed, not where it was born.

Fix — two tools, cheap → precise:

  • Forward NaN hook (cheap, leave on) — register on every module to catch the first layer to emit NaN:
    for name, m in model.named_modules():
        m.register_forward_hook(lambda mod, i, o, n=name:
            print(f"NaN in {n}") if torch.is_tensor(o) and not torch.isfinite(o).all() else None)
    
  • torch.autograd.set_detect_anomaly(True) (expensive, debug-only) — records the forward traceback of each backward op and raises at the first backward NaN, pointing at the forward line that created it.
    with torch.autograd.detect_anomaly():   # or set_detect_anomaly(True, check_nan=True)
        loss.backward()
    
    The docs warn it "will slow down your program" (roughly an order of magnitude) — enable to locate, then turn OFF for the real run, never ship it on. URL: https://docs.pytorch.org/docs/2.12/autograd.html

P10 — fp16 overflow vs underflow: read the GradScaler signal

Symptom (fp16): loss → inf/NaN; or the scaler skips every step and the scale factor collapses toward 0.

Root cause: a forward activation exceeds fp16's 65504 max → inf → NaN grads → the scaler can't find a scale that avoids overflow, so it backs off forever. Common in attention logits and large residual sums. (Distinct from underflow, which the scaler fixes by P5.)

Fix: switch fp16 → bf16 (P1) — its fp32 range absorbs the large values; this is the single most effective fix. If bf16 is unavailable (V100/T4): keep the overflow-prone block (final logits, attention scores, the loss) in fp32 via a nested torch.autocast("cuda", enabled=False) region, and apply z-loss (P15) / qk-norm (P16) to stop the logits growing. URL: https://medium.com/better-ml/loss-spikes-in-training-causes-detection-and-mitigations-ed66e591b1a1

P11 — NaN from the data, not the math

Symptom: NaN appears at a specific, reproducible step (always step 4137), not gradually.

Root cause: a corrupt sample — NaN/Inf pixel, all-zero target, label outside [0, C), empty sequence, divide-by-zero in a custom transform. The math is fine; the input is poison.

Fix: guard at the data boundary — assert torch.isfinite(x).all(), f"non-finite input @ step {step}" (fail loud, with the index). A reproducible-step NaN ⇒ inspect that batch (seed the loader, dump the index); a step-varying NaN ⇒ a numerics/LR problem (P12), not data. Smoke the data first — smoke content is owned by verifying-dl-experiments (cross-link REQUIRED). URL: https://arxiv.org/pdf/2311.03938


Loss spikes / divergence

P12 — Loss spike / divergence: LR too high or warmup too short

Symptom: training is stable, then the loss jumps orders of magnitude (spike), sometimes recovering, sometimes diverging to NaN — most often early, or after a fast LR ramp.

Root cause: if the LR ramps too fast or starts too high, early updates land before activation norms and the optimizer's second moment (v) have stabilized, overshooting into sharp loss regions → gradient-norm blowup → spike. A sustained grad-norm rise typically precedes the loss spike by several steps.

Fix — in order of cheapness:

  1. Lengthen warmup (linear ramp 0 → peak over e.g. 110% of steps); warmup is the single biggest lever on LR-sensitivity of final loss.
  2. Lower peak LR ~310× and re-check.
  3. Log grad-norm every step as the early-warning signal — spikes are predictable from activation/grad-norm scaling before they hit.
  4. Resume from the last good checkpoint before the spike (don't train through a diverged region).

URLs: https://arxiv.org/pdf/2309.14322 · https://apxml.com/courses/how-to-build-a-large-language-model/chapter-24-identifying-mitigating-training-instabilities/stabilization-techniques-revisited

P13 — Gradient clipping: the standard guardrail (and what constant clipping means)

Symptom: occasional grad-norm spikes; or NaN right after a single bad batch.

Root cause: one pathological batch (rare embedding IDs, an outlier sample) produces an outsized global grad norm that overshoots.

Fix: clip global grad norm every step — torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) with max_norm ∈ [0.5, 1.0] typical for transformers (under the scaler: P7). Diagnostic: if clipping is active every step or needs an absurdly low threshold to stay stable, that's a symptom of a deeper problem (LR too high P12, bad init P17, architecture), not a fix — chase the cause. Global-norm clipping scales all grads down, so one embedding-heavy batch can throttle everything else that step — consider per-module clipping if embeddings dominate. URL: https://medium.com/better-ml/loss-spikes-in-training-causes-detection-and-mitigations-ed66e591b1a1

P14 — Skip-the-batch: drop the update when this step is non-finite

Symptom: a single bad batch every few thousand steps NaNs the whole run; restarting wastes hours.

Root cause: the optimizer applies a non-finite grad and permanently corrupts the weights.

Fix: gate the optimizer step on finiteness (fp16's GradScaler already does this internally, P5; bf16 needs it explicit):

loss.backward()
gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
if torch.isfinite(gnorm):
    optimizer.step()
else:
    optimizer.zero_grad(set_to_none=True)   # skip this batch, keep weights intact
    skipped += 1

Log a skipped counter — a rising skip rate means a systematic problem (P12/P10), not stray bad data. Adaptive spike-clipping (ZClip) and momentum-reset on spike (SPAM) automate this for large runs. URLs: https://arxiv.org/pdf/2504.02507 · https://arxiv.org/pdf/2501.06842

P15 — z-loss: stop softmax logits from drifting unbounded

Symptom: training is slowly destabilizing; the softmax normalizer / output logits grow over time and eventually overflow (acute in fp16/bf16); the "output logits diverge from log-probs" failure mode.

Root cause: nothing pins the absolute scale of pre-softmax logits, so they drift up; large logits cause numerical instability and (in low precision) overflow → collapse.

Fix: add an auxiliary z-loss = 1e-4 · (log Z)² where Z is the softmax denominator (log Z = logsumexp(logits)), pulling log Z → 0:

logits = model(x)
z = torch.logsumexp(logits, dim=-1)
loss = F.cross_entropy(logits, y) + 1e-4 * (z ** 2).mean()

Coefficient 1e-4 is the PaLM/ST-MoE value; too large lets z-loss dominate. Standard in LLM pretraining; also the recommended fix for MoE router instability. URLs: https://medium.com/dair-ai/papers-explained-50-palm-480e72fa3fd5 · https://arxiv.org/pdf/2202.08906 · https://arxiv.org/pdf/2309.14322

P16 — qk-norm: kill attention-logit growth at high LR

Symptom: a transformer diverges only at higher LR; the instability traces to attention scores (Q·Kᵀ) growing large before the softmax.

Root cause: "growth of logits in attention layers" — one of the two dominant transformer instability modes (the other is output-logit divergence, P15). Unbounded attention logits saturate the softmax.

Fix: apply QK-LayerNorm — LayerNorm query and key per-head before the dot-product. Combined with z-loss (P15) + warmup (P12), it lets small models train to similar loss across orders of magnitude of LR, i.e. removes most LR-sensitivity. URL: https://arxiv.org/pdf/2309.14322

P17 — Initialization & normalization placement

Symptom: divergence in the first few hundred steps regardless of LR; or vanishing signal (P18) in deep stacks.

Root cause: residual streams accumulate variance with depth; default init can make early activations/grads too large (spike) or too small (vanish). Norm/embedding init scale matters.

Fix: scale residual-branch init by 1/√(2·n_layers) (GPT-2-style); prefer pre-LN over post-LN for deep transformers; init embeddings at small std (~0.02). When unsure, copy a known-good config's init+norm scheme rather than tuning blind. URL: https://arxiv.org/pdf/2309.14322


Gradients

P18 — Gradient explosion vs vanishing: diagnose by logging the norm

Symptom: loss NaN/diverges (explosion) OR loss plateaus and the model never learns (vanishing).

Root cause: per-layer grad norms blow up (explosion: deep nets, high LR, no clip) or decay to ~0 (vanishing: saturating activations, bad init P17, too-deep unnormalized stacks).

Fix — measure first:

total = sum(p.grad.detach().norm()**2 for p in model.parameters() if p.grad is not None) ** 0.5
# log `total` every step; also log per-layer norms when hunting the culprit layer
  • Explosion (norm ↑↑): grad clipping (P13), lower LR (P12), longer warmup, bf16 over fp16 (P10).
  • Vanishing (norm → 0): residual connections, normalization layers, better init (P17), non-saturating activations (GELU/SiLU over deep sigmoid/tanh stacks), check the LR isn't too low.

A grad-norm trace is the cheapest, highest-signal stability instrument — log it from step 1. URL: https://apxml.com/courses/how-to-build-a-large-language-model/chapter-24-identifying-mitigating-training-instabilities/stabilization-techniques-revisited


Reproducibility

P19 — Deterministic / repro knobs — set them, but the interpretation is delegated

Symptom: same config + seed gives slightly different loss/metrics run-to-run.

Root cause: nondeterministic CUDA kernels + cudnn.benchmark autotuning pick different algorithms per run; TF32/AMP add low-order noise on top.

Fix — the mechanical knobs (set these here):

torch.manual_seed(s); np.random.seed(s); random.seed(s)
torch.use_deterministic_algorithms(True)        # may need CUBLAS_WORKSPACE_CONFIG=:4096:8
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False          # benchmark=True trades determinism for speed

Whether a run-to-run delta is "a real effect vs cuDNN nondeterminism," and the full determinism methodology, is owned by verifying-dl-experiments (cross-link REQUIRED) — catalogued as U36 in references/gotchas_universal.md. This layer only ensures the knobs are set and logged. Determinism costs speed — enable for the datapoint that must be clean, not every throwaway run. URL: https://docs.pytorch.org/docs/stable/notes/randomness.html


Pointers — adjacent layers, do NOT restate here

  • references/gotchas_universal.md — the infra failure modes that masquerade as numerics: U6 disk-full crashes torch.save, U9 cgroup-OOM (bare Killed, not a NaN), U28 CUDA/driver/ torch-build mismatch (no kernel image ≠ a precision bug), U10/U11 VRAM OOM. Rule out infra before chasing a "numerics" ghost.
  • verifying-dl-experiments (REQUIRED cross-link) — owns is-the-number-real: smoke content, cuDNN-nondeterminism-as-metric-error (U36), collapse/constant-output diagnosis, "bug vs real effect." This file makes training run and stay finite; that skill judges whether the converged result is true.
  • references/spot-resilience.md — checkpoint cadence so a divergence-and-resume (P12) loses minimal work.
  • references/multinode.md — NCCL/precision interactions in DDP (all-reduce dtype, loss-scale sync) for multi-node runs; single-box users skip.