playbook/antigravity-awesome-skills/skills/remote-gpu-trainer/references/training/oom-memory.md

22 KiB
Raw Permalink Blame History

OOM & fitting a model that doesn't — VRAM + host-RAM out-of-memory during training

How to read a CUDA OOM trace, understand what fills the card (params vs optimizer vs gradients vs activations vs fragmentation), and apply the fixes in cost order — from a free batch-size cut to ZeRO-3/QLoRA sharding. This layer owns making training RUN and fit; verifying-dl-experiments owns is the resulting number correct. Cross-link it (REQUIRED) wherever a "fix" risks changing the science (shrinking the one variable under test, swapping precision, changing seq-len).

To jump: grep -in '<keyword>' references/training/oom-memory.md (e.g. expandable, checkpoint, zero, validation, snapshot, lora, empty_cache, longest, fragment).

Table of contents

  • Read it first — M1 anatomy (where VRAM goes) · M2 reading the OOM trace · M3 VRAM-OOM vs host-RAM-OOM
  • Fixes in order — M4 the ladder (do these top-down) · M5 batch/grad-accum · M6 bf16 mixed precision · M7 activation/gradient checkpointing · M8 expandable_segments · M9 FSDP / ZeRO sharding · M10 CPU/NVMe offload · M11 seq-len/resolution · M12 8-bit & paged optimizers · M13 LoRA/QLoRA
  • OOM at a specific step — M14 first backward · M15 validation/eval · M16 the longest batch · M17 step-2 (optimizer alloc)
  • Debugging — M18 memory_summary · M19 the snapshot + visualizer · M20 empty_cache & "leak" myths
  • Pointers — host-RAM cgroup-OOM → gotchas_universal.md U9 · VRAM-vs-cgroup → U10 · zombie-VRAM → U11

Read it first

M1 — Anatomy: where the VRAM actually goes (a model that runs inference OOMs in training)

Training memory is not just weights. For Adam mixed-precision, per parameter:

Bucket Bytes/param Note
Weights (fp16/bf16 + fp32 master) 6 2 B working copy + 4 B fp32 master for stable updates
Optimizer states (Adam m, v, fp32) 8 momentum 4 B + variance 4 B
Gradients (fp32) 4 one per param, backward pass
Subtotal (fixed, batch-independent) ~18 B/param a 4 B-param model ≈ 72 GB before any activation
Forward activations (cached for backward) varies scales with batch × seq_len × depth × hidden; the part that explodes
Temporary buffers (softmax, matmul scratch) spikes a single peak op can OOM even when steady-state fits

Source: HF model-memory-anatomy (https://huggingface.co/docs/transformers/en/model_memory_anatomy) gives the 6+8+4 split and the "4B params, batch 16 ≈ 85 GB" worked figure. Why a model that infers in 16 GB OOMs in training: inference is just the 2 B/param working copy + small activations; training adds the +12 B/param of grads+optimizer and keeps the entire forward activation graph alive for backward. The fixed 18 B/param is attacked by M9/M12/M13; the activation term by M5/M6/M7/M11.

M2 — Reading the CUDA OOM trace (the numbers tell you which fix)

Symptom: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate X MiB (GPU 0; Y GiB total capacity; Z GiB already allocated; A GiB free; B GiB reserved in total by PyTorch ...).

Root cause — decode the four numbers:

  • Tried to allocate X — the size of the single failing request. Large X = a big tensor (long-seq attention score matrix, the longest batch M16); tiny X failing with GBs "free" = fragmentation.
  • reserved B vs allocated Zreserved = total the caching allocator grabbed from the driver; allocated = live tensors. reservedallocated with a small failing X ⇒ fragmentation (free blocks exist but none is contiguous enough). This is the explicit PyTorch diagnostic: "if reserved but unallocated is large, set expandable_segments:True" (M8).
  • free A — driver-visible free on the card; if A is large but the alloc still fails, suspect another process (M3) or a zombie holding VRAM (gotchas_universal.md U11).

Sources: PyTorch forums thread on the trace fields (https://discuss.pytorch.org/t/torch-outofmemoryerror-cuda-out-of-memory/217669); the reserved-vs-allocated → fragmentation rule is from the allocator docs (M8 URL).

M3 — VRAM OOM is not host-RAM OOM (distinct failure, distinct fix)

A torch.OutOfMemoryError: CUDA out of memory (a Python traceback) is VRAM exhaustion. A bare Killed / exit 137 with no traceback is the Linux kernel killing the process for host-RAM (cgroup memory.max) exhaustion — almost always num_workers × a big in-RAM object. These have opposite fixes and live in the universal catalog:

  • host-RAM cgroup-OOM (Killed, exit 137, dataloader workers) → gotchas_universal.md U9.
  • VRAM-OOM distinct from cgroup, fragmentation, concurrent-job sizing → gotchas_universal.md U10.
  • "empty GPU" still OOMs (a zombie holds VRAM nvidia-smi can't attribute) → gotchas_universal.md U11.

Confirm which one before "fixing": dmesg | grep -iE 'killed process|out of memory' non-empty ⇒ host-RAM kernel kill (U9), not a CUDA OOM. Do not shrink the model to "fix" a host-RAM kill.


Fixes, in order (cheapest / least-science-disturbing first)

M4 — The ladder: apply top-down, stop when it fits

Each rung costs more (speed, complexity, or risk to the result). Climb only as far as needed:

  1. Reduce micro-batch + grad-accumulation (M5) — free, exact same effective batch, zero accuracy change.
  2. bf16 mixed precision (M6) — halves activations, usually a speedup; bf16 needs no loss scaling.
  3. Activation / gradient checkpointing (M7) — trades ~2030% compute for a large activation cut.
  4. PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True (M8) — free; fixes fragmentation OOM only.
  5. SDPA / FlashAttention — stop materializing the full O(seq²) attention matrix (M11).
  6. Shorter seq-len / lower resolution (M11) — cheap but changes the science → verify (REQUIRED).
  7. 8-bit / paged optimizer (M12) — cuts the 8 B/param optimizer state to ~2 B with ~no accuracy loss.
  8. FSDP / DeepSpeed ZeRO-1→2→3 (M9) — shard optimizer→grads→params across GPUs (needs ≥2 GPUs).
  9. CPU / NVMe offload (M10) — last resort on a single GPU; trades a large speed hit for fit.
  10. LoRA / QLoRA (M13) — for finetuning: freeze base, train adapters; QLoRA quantizes base to 4-bit.

Rungs 14 and 7 do not alter the model/optimization math; rung 6 does (declare it, re-verify per verifying-dl-experiments). Rungs 810 change where state lives, not the math (LoRA changes capacity).

M5 — Reduce micro-batch + gradient accumulation (the free first move)

Symptom: OOM scales with batch size; effective batch must stay fixed for the result to hold.

Fix: drop per_device_train_batch_size to what fits, raise gradient_accumulation_steps to keep the effective batch identical (effective = micro_batch × accum × world_size). Gradients accumulate over sub-batches before one optimizer step — same math, lower peak activation memory. Keep micro-batch as large as fits (batch 4 × accum 16 beats batch 1 × accum 64 — better GPU utilization). Source: https://huggingface.co/docs/transformers/main/en/perf_train_gpu_one (gradient accumulation). Caveat: with token-level loss + a custom loop, naive accumulation can mis-weight the loss across uneven sub-batch token counts — a correctness issue owned by verifying-dl-experiments (REQUIRED).

M6 — bf16 mixed precision (prefer bf16 over fp16 on Ampere+)

Symptom: fp32 training; activations dominate; the GPU is Ampere (A100/30xx) or newer.

Fix: bf16=True (HF TrainingArguments) or torch.autocast("cuda", dtype=torch.bfloat16). The main win is activations stored in 16-bit. bf16 over fp16: bf16 has fp32's exponent range, so it needs no loss-scaling and won't overflow/underflow — fewer NaN failures. Note fp16 can increase memory at small batch (it keeps both fp16 and fp32 weight copies); bf16 is the safer default where supported. Source: https://huggingface.co/docs/transformers/main/en/perf_train_gpu_one (mixed precision; bf16 needs Ampere+). NaN/divergence after switching precision = a numerics question → verifying-dl-experiments.

M7 — Activation / gradient checkpointing (trade compute for activation memory)

Symptom: the 18 B/param fixed cost fits but activations OOM (deep model, long seq, big batch).

Fix: gradient_checkpointing=True (HF), model.gradient_checkpointing_enable(), or torch.utils.checkpoint.checkpoint(...) manually. Only a subset of activations is stored; the rest are recomputed during backward — cuts activation memory substantially at ~2030% slower training. Source: https://huggingface.co/docs/transformers/main/en/perf_train_gpu_one ("~20% slower"). Gotcha: with HF generate/caching set model.config.use_cache=False when checkpointing, or it warns and ignores; with DDP, reentrant checkpointing can break — use use_reentrant=False.

M8 — expandable_segments:True (the free fragmentation fix)

Symptom: OOM where the failing alloc is small yet reservedallocated and GBs look "free" (M2); common with variable shapes (changing batch/seq-len, dynamic padding).

Fix: launch with PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True (env var, set before the process starts; the modern alias is PYTORCH_ALLOC_CONF). It backs segments with CUDA VMM so they grow/shrink instead of each cudaMalloc being an unmergeable block — which is the root of fragmentation. Source: PyTorch CUDA notes (https://docs.pytorch.org/docs/stable/notes/cuda.html) and the allocator devlog (https://docs.pytorch.org/devlogs/eager/2026-06-01-cuda-caching-allocator/). Alternative knob if fragmenting on large blocks: max_split_size_mb:<N> (stops the allocator splitting blocks above N MiB). This is the same knob referenced in gotchas_universal.md U10 — set it as a default on the box, it is nearly free. Version note: expandable_segments is still flagged experimental; it has known interop edges with some VMM allocators (e.g. NCCL ncclMemAlloc, pytorch/pytorch#165419) — if a custom-allocator stack misbehaves, drop it.

M9 — FSDP / DeepSpeed ZeRO sharding (≥2 GPUs: shard the 18 B/param)

Symptom: the fixed 18 B/param state alone won't fit one card; multiple GPUs available.

Fix — shard training state across the data-parallel group, escalating by stage:

  • ZeRO-1 / optim_state_dict shard — partition optimizer states (the 8 B/param). Smallest comms change; start here if DP OOMs on optimizer state.
  • ZeRO-2 / FSDP SHARD_GRAD_OP — also partition gradients (the 4 B). Good memory/comms balance.
  • ZeRO-3 / FSDP FULL_SHARD — also partition parameters (the 6 B). Maximum savings; needs high-bandwidth interconnect (NVLink/NVSwitch) because params are all-gathered per layer. Source: DeepSpeed ZeRO tutorial (https://www.deepspeed.ai/tutorials/zero/) and HF DeepSpeed integration (https://huggingface.co/docs/transformers/en/deepspeed); FSDP's ShardingStrategy maps 1:1 to these stages. Multi-GPU launch + NCCL fabric gotchas (wrong NIC, timeout, MTU) → references/multinode.md.

M10 — CPU / NVMe offload (single-GPU last resort)

Symptom: even ZeRO-3 (or single GPU) can't hold params/optimizer; willing to trade large speed for fit.

Fix: offload optimizer states (and with ZeRO-3, parameters) to CPU RAM or NVMe. DeepSpeed offload_optimizer: {device: cpu|nvme} (valid ZeRO-1/2/3), offload_param: {device: nvme} (ZeRO-3 only); ZeRO-Infinity offloads to both for huge models. QLoRA's paged optimizers (M12) are a lighter form — optimizer state pages to CPU only on memory spikes. Source: DeepSpeed ZeRO docs (https://deepspeed.readthedocs.io/en/stable/zero3.html); ZeRO-Infinity (https://www.deepspeed.ai/2021/03/07/zero3-offload.html). Cost: PCIe/NVMe bandwidth becomes the bottleneck — expect a multi-× slowdown; on a metered box, weigh it against renting a bigger card (principle #1, #9).

M11 — Reduce sequence length / resolution / attention footprint

Symptom: activations (esp. the O(seq²) attention score matrix) dominate; OOM grows super-linearly with seq-len/resolution.

Fix (cheapest variant first):

  • Use SDPA / FlashAttention to avoid materializing the full seq² attention matrix — attn_implementation="sdpa" (default in PyTorch 2.1.1+) or "flash_attention_2". No accuracy change.
  • Only then shorten seq-len / lower image resolution / patchify — this changes the task/science; declare it and re-verify (the resolution-change-broke-training failure mode is owned by verifying-dl-experiments, REQUIRED). Source: https://huggingface.co/docs/transformers/main/en/perf_train_gpu_one (SDPA, attention backends) and model-memory-anatomy (attention score matrix grows with seq²).

M12 — 8-bit & paged optimizers (cut the 8 B/param optimizer state to ~2 B)

Symptom: optimizer states are the largest single bucket (M1); want the cut with ~no accuracy loss.

Fix: swap AdamW for a quantized optimizer — HF optim="adamw_bnb_8bit" / "paged_adamw_8bit" (bitsandbytes 8-bit Adam, states held in 8-bit, dequantized per step → ~2 B/param vs 8 B), or optim="adafactor" (stores row/column moments instead of per-element → much less memory, slower convergence). Paged variants additionally page optimizer state to CPU on spikes to survive transient peaks. Source: https://huggingface.co/docs/transformers/main/en/perf_train_gpu_one (optimizers) and model-memory-anatomy ("quantized Adam → 2 bytes/param"). Adafactor's convergence change is a science question → verifying-dl-experiments (REQUIRED) before trusting its ablation deltas.

M13 — LoRA / QLoRA (finetuning only: don't train the full model)

Symptom: finetuning a large pretrained model; full-finetune state won't fit.

Fix: LoRA freezes the base weights and trains small low-rank adapters → grads+optimizer exist only for the adapter (a tiny fraction of params), so the 18 B/param cost nearly vanishes for the base. QLoRA goes further: quantize the frozen base to 4-bit NF4 (+ double quantization + paged optimizers), train fp16/bf16 adapters on top — reported to finetune a 65B model on a single 48 GB GPU with no accuracy degradation vs 16-bit. Source: QLoRA paper (https://arxiv.org/abs/2305.14314) and repo (https://github.com/artidoro/qlora). Note: LoRA changes model capacity — it is a different optimization target, not a free OOM trick. Whether the LoRA result matches full-finetune is a science claim → verifying-dl-experiments (REQUIRED).


OOM at a SPECIFIC step (the step number is the diagnosis)

M14 — OOM on the first backward (not the forward)

Symptom: forward pass completes, OOM hits at .backward().

Root cause: forward only allocates activations; backward additionally allocates the full gradient buffers (4 B/param) and needs every cached activation live simultaneously — peak memory is at backward, not forward. A model that forwards fine still OOMs here.

Fix: M7 (checkpointing — recompute instead of store activations) is the targeted fix; then M5/M6. If the peak is a single huge layer, gradient-checkpoint that block specifically.

M15 — OOM only during validation / eval, training was fine

Symptom: training epochs run; the first eval pass OOMs — even with torch.no_grad() / model.eval(), sometimes even at eval batch size 1.

Root cause — two distinct ones:

  1. Eval batch > train batch, or no-grad eval lets a larger batch be attempted that exceeds the train peak. The activation graph isn't kept, but a big single forward + its temporary buffers can still OOM.
  2. HF Trainer accumulates predictions on the GPU: by default eval logits/labels are concatenated on the GPU for the whole eval set before moving to CPU — a large eval set OOMs regardless of batch size (huggingface/transformers#7232).

Fix: set per_device_eval_batch_size explicitly (don't inherit a too-large value); set eval_accumulation_steps=N so predictions move to CPU every N steps instead of piling on the GPU (https://huggingface.co/docs/transformers/main_classes/trainer). In a custom loop: wrap eval in torch.no_grad() / torch.inference_mode(), and .cpu() / .detach() outputs before appending to any list. Eval-artifact sizing (per-sample dumps blowing up) is owned by verifying-dl-experiments.

M16 — OOM mid-epoch on the longest batch (variable-length / bucketed data)

Symptom: thousands of steps succeed, then a random step OOMs; restarting from there OOMs again at the same data; fixed batches never OOM.

Root cause: with variable-length inputs (NLP token batches, point clouds, variable-resolution images), peak activation memory is set by the longest sequence in the batch, not the average. Memory is sized for the worst case, which only appears on certain batches.

Fix: size everything for the max length, not the mean: cap max_length / use length bucketing or sorted batching so long samples share small batches; set group_by_length=True (HF) and a hard max_length. A expandable_segments:True (M8) also helps because the variable shapes otherwise fragment. Don't conclude "data corruption" from a step-N OOM — it's the longest batch.

M17 — OOM on step 2 (after the first optimizer step), step 1 fine

Symptom: step 1 trains; OOM on step 2 or at the first optimizer.step().

Root cause: Adam lazily allocates its m/v state (the 8 B/param) on the first optimizer.step(), not at construction. Peak therefore jumps after step 1. Reserved memory also climbs as the allocator caches backward buffers. Source: the memory-snapshot timeline shows optimizer state appearing after iter 1 (https://pytorch.org/blog/understanding-gpu-memory-1/).

Fix: budget for the post-step peak, not step-1 — measure peak with max_memory_allocated() after two full steps, not one. Then apply M12 (8-bit optimizer halves this jump) or M5.


Debugging tools (measure, don't guess)

M18 — torch.cuda.memory_summary() + the stat functions (first look)

Symptom: need to know what is resident before choosing a fix.

Fix: print torch.cuda.memory_summary() at the OOM point (or in an except torch.cuda.OutOfMemoryError) for a table of allocated/reserved/active by size class. Programmatic: torch.cuda.memory_allocated() (live tensors) vs torch.cuda.memory_reserved() (allocator total) — a big gap = fragmentation/caching (→ M8); torch.cuda.max_memory_allocated() for the true peak (reset with reset_peak_memory_stats() between phases to isolate forward vs backward vs optimizer). Source: https://docs.pytorch.org/docs/stable/notes/cuda.html (memory management functions).

M19 — The CUDA memory snapshot + visualizer (find the exact culprit allocation)

Symptom: summary stats aren't enough — need which line of code allocated the memory that OOMs.

Fix — record a snapshot around the OOM and view the timeline:

torch.cuda.memory._record_memory_history(max_entries=100000)   # start before the step(s)
try:
    train_a_few_steps()
finally:
    torch.cuda.memory._dump_snapshot("oom_snapshot.pickle")     # write history
    torch.cuda.memory._record_memory_history(enabled=None)      # stop

Drag oom_snapshot.pickle onto https://pytorch.org/memory_viz (snapshots are not uploaded server-side), or python torch/cuda/_memory_viz.py trace_plot oom_snapshot.pickle -o snapshot.html. The timeline colors parameters / gradients / optimizer state / activations / temporaries separately, so the tallest band at the OOM moment names the bucket to attack (→ map back to M5M13). Available PyTorch 2.1+. Source: https://pytorch.org/blog/understanding-gpu-memory-1/. On a remote box: dump the pickle, scp it down (references/ssh_transport.md), view locally — don't try to run the visualizer over ssh.

M20 — empty_cache() and the "memory leak" myths

Symptom: belief that torch.cuda.empty_cache() "frees memory and fixes OOM," or that steadily-rising reserved memory is a leak.

Root cause / myth-busting:

  • torch.cuda.empty_cache() returns cached-but-unused blocks to the driver; it does not free live tensors and does not make more room for your own process (the allocator would reuse that cache anyway). It only helps a second process on the same GPU, or it reduces fragmentation at a ~10% speed cost (HF torch_empty_cache_steps=N runs it every N steps — https://huggingface.co/docs/transformers/main/en/perf_train_gpu_one). It is not a fix for a model that genuinely doesn't fit.
  • Rising reserved ≠ leak. The caching allocator holds freed blocks for reuse; reserved climbing then plateauing is normal. A true leak is rising memory_allocated() across steps with constant batch — usually accumulating tensors that still require grad (appending loss instead of loss.item(), keeping references in a Python list). Fix the reference, not with empty_cache().
  • Calling empty_cache() every step to "stay safe" just slows training and can increase fragmentation.

Real OOM-mechanics leaks (accumulate-loss-tensor, no detach) belong here; whether a metric drift is a real effect vs a bug belongs to verifying-dl-experiments (REQUIRED).


Pointers — memory gotchas catalogued elsewhere (do NOT restate)

  • Host-RAM cgroup-OOM (bare Killed / exit 137, num_workers × big tensor) → gotchas_universal.md U9.
  • VRAM-OOM vs cgroup-OOM, concurrent-job sizing, the expandable_segments one-liner → gotchas_universal.md U10.
  • Zombie holds VRAM nvidia-smi can't see (OOM on an "empty" GPU) → gotchas_universal.md U11.
  • Disk-full crashes torch.save (not memory, but the other "out of space") → gotchas_universal.md U6.
  • Multi-GPU NCCL / fabric for FSDP/ZeRO launches → references/multinode.md.
  • Is the post-fit number correct (precision swap, seq-len change, LoRA-vs-full, accumulation loss weighting, determinism) → verifying-dl-experiments (REQUIRED — this layer makes it fit and run; that one decides if the result is true).