playbook/antigravity-awesome-skills/skills/remote-gpu-trainer/references/training/checkpoint-resume.md

24 KiB

Correct checkpointing & idempotent resume — full state, atomic write, sharded checkpoints, framework APIs

Make a training job resume exactly where it stopped after any kill — not "reload the weights and silently restart the epoch." This layer owns the mechanics: what FULL state to save, how to write it without corruption, how to load it unconditionally, and the framework-specific knobs (FSDP / DeepSpeed / HF Trainer / Accelerate / Lightning) plus the resume bugs that make a job look resumed while it quietly lost progress. verifying-dl-experiments (REQUIRED) owns is the resumed number correct — e.g. proving step/epoch/loss actually continued instead of resetting is its reproducibility check applied here. The spot/preemption cadence (when + how often, Young/Daly) lives in references/spot-resilience.md (REQUIRED for any interruptible/spot tier) — this file is the content and correctness of each checkpoint; that file is the timing.

To jump: grep -in '<keyword>' references/training/checkpoint-resume.md (e.g. atomic, rename, scaler, ema, sampler, fsdp, sharded, zero_to_fp32, dcp, resume_from_checkpoint, save_state, ckpt_path, save_total_limit, reshuffle).

Table of contents

  • The contract — C1 full-state-list · C2 atomic-write · C3 load-latest-unconditionally · C4 durable-location
  • Sharded checkpoints (multi-GPU) — C5 FSDP-FULL_STATE_DICT-rank0-OOM · C6 FSDP-SHARDED_STATE_DICT · C7 DCP-(dcp.save/load) · C8 DeepSpeed-ZeRO-dir+zero_to_fp32
  • Framework APIs — C9 HF-Trainer-resume_from_checkpoint+save_total_limit · C10 Accelerate-save_state/load_state · C11 Lightning-ModelCheckpoint+ckpt_path
  • The resume BUGS — C12 epoch-restarts · C13 data-reshuffles/order · C14 LR-schedule-resets · C15 scaler-not-restored · C16 EMA-not-saved · C17 save_total_limit-deletes-best · C18 strict-load-key-mismatch
  • Pointers — disk-full on save → gotchas_universal.md U6 · silent sync → U33 · keepable-policy/save_top_k → verifying-dl-experiments (skill) · cadence/Young-Daly → spot-resilience.md

The contract

C1 — A checkpoint that restores only weights is NOT a resume — save the FULL training state

Symptom: resume "works" (no crash) but the loss jumps up, accuracy regresses, or training takes more total epochs than an uninterrupted run — because the resume silently restarted the epoch, reset the optimizer momentum, and reshuffled the data.

Root cause: torch.save(model.state_dict()) captures weights only. Optimizer momentum/variance, the LR-scheduler position, the epoch/step counter, RNG state, the AMP scaler, and the dataloader position are all lost, so the restarted run is a different trajectory, not a continuation.

Fix: every checkpoint must carry the full state (PyTorch tutorial saving multiple / general checkpoint; the spot-resilience §3 list):

Must save Why losing it breaks resume
model state_dict the weights (obvious)
optimizer state_dict Adam m/v momentum — losing it = a cold optimizer restart (C12)
LR-scheduler state_dict step-based LR position — losing it resets the schedule (C14)
epoch and global step/iteration resume the exact position, not the epoch start (C12)
RNG state: Python random, NumPy, torch, CUDA (torch.cuda.get_rng_state_all()) reproducible augmentation/dropout stream after restart
dataloader / sampler position so the next batch is the next unseen one, not a reshuffle (C13)
AMP GradScaler state_dict the loss-scale + growth tracker — losing it triggers an inf-scale stall (C15)
EMA / SWA shadow weights (if used) the EMA copy is often what's evaluated — losing it = eval on the wrong weights (C16)
best-metric-so-far + best.pth selection state so "best" survives a restart instead of resetting

The runnable atomic skeleton that assembles this dict is in references/spot-resilience.md §5 — do not duplicate it; this table is the checklist, that is the code.

C2 — Write atomically: tmp → fsync → os.replace (a kill mid-write corrupts a naive save)

Symptom: after a preemption/OOM, latest.pth is truncated/zero-byte or torch.load raises RuntimeError: PytorchStreamReader failed reading zip archive; a latest.pth.tmp is left behind.

Root cause: overwriting latest.pth in place is not atomic — a kill partway through leaves a corrupt file and (if it was the only checkpoint) zero good ones. torch.save itself does not fsync.

Fix: write to a temp file, force bytes to disk, then atomically rename (POSIX rename/os.replace is atomic on the same filesystem):

tmp = ckpt_path + ".tmp"
with open(tmp, "wb") as f:
    torch.save(state, f); f.flush(); os.fsync(f.fileno())   # bytes hit disk BEFORE the swap
os.replace(tmp, ckpt_path)                                   # all-or-nothing; keep prev until this returns

Keep the previous latest.pth valid until the rename returns (a kill at any instant leaves one intact file). os.replace (not os.rename) also works on Windows for the local-test path. Full recipe + rationale: references/spot-resilience.md §3. Disk-full during the save is a separate failure with the same .tmp left behind → references/gotchas_universal.md U6 (pre-budget + prune latest, keep best).

C3 — Load-latest UNCONDITIONALLY on startup → idempotent resume

Symptom: a relaunch starts from scratch because the resume is gated behind a --resume flag the launch wrapper forgot to pass; or two code paths (fresh vs resume) diverge.

Root cause: making resume opt-in means a generic relaunch (spot recovery, SSH-drop restart, queue retry) re-trains from zero. A divergent "first launch" code path also drifts from the resume path.

Fix: one code path that loads the latest checkpoint if it exists, else starts fresh — so the identical launch command converges to the same end state no matter how many times it runs. This is what makes principle #7's "retry the identical config" actually resume instead of restart, and it is the universal spine (principle #8) under SSH-drop / Slurm-walltime / K8s-reschedule / spot-preemption. Skeleton: references/spot-resilience.md §3 (load_latest_if_any).

C4 — Checkpoint to the platform's DURABLE location, not local scratch

Symptom: resume after a managed-spot replacement (or a terminate/destroy) finds no checkpoint — the box came up fresh and the only copy was on the dead instance's local disk.

Root cause: a replacement node is clean; anything not on a cloud bucket / network volume / shared FS is gone (principle #4 — know what survives stop vs destroy).

Fix: write checkpoints to the profile's durable mount (DURABLE_DIR in profiles/<platform>.md §8), or mirror local→durable on the checkpoint timer. The single biggest portability trap is assuming local disk survives — see each profile's STORAGE survival-matrix and the SKILL Quick-reference table. Gate the sync on the actual copy result, never an unconditional echo syncedreferences/gotchas_universal.md U33.


Sharded checkpoints (multi-GPU)

C5 — FSDP FULL_STATE_DICT OOMs on rank 0 when gathering a large model

Symptom: an FSDP job trains fine but crashes at the first checkpoint with CUDA OOM on rank 0; the model is larger than one GPU.

Root cause: StateDictType.FULL_STATE_DICT all-gathers every shard onto one rank to assemble the unsharded dict. For a model that only fits because it's sharded, materializing the whole thing on rank 0 exceeds that GPU's VRAM.

Fix: when taking a full (consolidated) dict, offload it to CPU and build it on rank 0 only — FullStateDictConfig(offload_to_cpu=True, rank0_only=True). This all-gathers parameters one-by-one, offloading each to CPU on rank 0, so peak GPU memory stays bounded and non-rank-0 workers skip the GPU→CPU copy entirely (HF Accelerate FSDP guide, Lightning issue #11207). The full dict is only viable when it fits in CPU RAM; past that, use sharded (C6). Save a full dict only at the end for a portable single-file artifact; checkpoint during training as sharded.

C6 — SHARDED_STATE_DICT: each rank saves its own shard (no gather, no rank-0 OOM)

Symptom: need to checkpoint a model too big to consolidate even on CPU, or want a fast resume that re-shards onto a different world size.

Root cause: FULL_STATE_DICT is fundamentally a single-rank materialization; it does not scale and cannot reshard.

Fix: use StateDictType.SHARDED_STATE_DICT — every rank writes only its own shard, so there is no all-gather and no OOM, and the per-rank files load back in parallel. Pair it with Distributed Checkpoint (C7), which is the production path for sharded save/load and supports resharding (resume on a different GPU count). Tradeoff: a sharded checkpoint is a directory of N files, not a single .pth — convert to a full dict for export/inference (C7's get_model_state_dict, or the DeepSpeed analogue C8).

C7 — Distributed Checkpoint (DCP): dcp.save / dcp.load for FSDP/sharded models

Symptom: hand-rolling FSDP state-dict context managers is brittle, slow, and breaks when the world size changes between save and resume.

Root cause: torch.save produces a single file and has no notion of sharding or FQN remapping; manually toggling FSDP.state_dict_type is error-prone.

Fix: use torch.distributed.checkpoint (DCP), the current PyTorch-2.x sharded-checkpoint API (DCP recipe, 2.12 reference). Save: get canonical dicts with get_state_dict(model, optimizer) from torch.distributed.checkpoint.state_dict, then dcp.save(state_dict, checkpoint_id=DIR) — it writes ≥1 file per rank in parallel and auto-manages FQN mappings. Load: allocate the model first, then dcp.load(state_dict, checkpoint_id=DIR) (loads in place and auto-reshards to the current world size), then set_state_dict(...). DCP beats torch.save for any distributed model because it shards the write across ranks (no rank-0 gather, C5) and reshards on load. For a single portable inference file, convert offline with torch.distributed.checkpoint.format_utils.dcp_to_torch_save(DIR, "out.pt") (or the CLI python -m torch.distributed.checkpoint.format_utils dcp_to_torch DIR out.pt).

C8 — DeepSpeed ZeRO: a checkpoint directory per save + zero_to_fp32.py to consolidate

Symptom: model_engine.save_checkpoint(dir) writes a folder of mp_rank_* / zero_pp_rank_* files, not a .pth; loading the weights into a plain (non-DeepSpeed) model for inference fails.

Root cause: ZeRO partitions optimizer state (stage 1), gradients (2), and parameters (3) across ranks; the on-disk checkpoint is inherently sharded across per-rank files — it is not a single fp32 model.

Fix (DeepSpeed model-checkpointing, ZeRO tutorial):

  • Save/resume trainingmodel_engine.save_checkpoint(save_dir, tag) / model_engine.load_checkpoint(save_dir, tag). All ranks must call both (they're collective; rank-0 only deadlocks/corrupts). Round-trips full sharded optimizer+param state.
  • Export a single fp32 model — DeepSpeed auto-drops a zero_to_fp32.py into the checkpoint dir; run python zero_to_fp32.py <checkpoint_dir> pytorch_model.bin, or in-process from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint(dir) / convert_zero_checkpoint_to_fp32_state_dict(...) / load_state_dict_from_zero_checkpoint(model, dir) (the last returns a model that can't continue training without re-init). The consolidated file no longer needs DeepSpeed. For ZeRO-3, set "zero_optimization": {"stage3_gather_16bit_weights_on_model_save": true} + engine.save_16bit_model(dir).

Framework APIs

C9 — HF Trainer: resume_from_checkpoint + save_total_limit (and what it actually saves)

Symptom: assuming Trainer.save_model() is a resume point (it saves weights only); or a relaunch re-trains from step 0 because resume_from_checkpoint wasn't passed; or the disk fills with checkpoint-* dirs.

Root cause: save_model ≠ a training checkpoint. A real Trainer checkpoint dir (checkpoint-<step>) contains the model plus optimizer.pt, scheduler.pt, rng_state.pth, trainer_state.json, and the AMP scaler.pt — the full state. Without resume_from_checkpoint the run starts cold.

Fix (Trainer docs): trainer.train(resume_from_checkpoint="path/to/checkpoint-1500") resumes that exact dir; resume_from_checkpoint=True auto-finds the last checkpoint in args.output_dir (idempotent spelling, C3; trainer_utils.get_last_checkpoint(output_dir) finds it in code). save_strategy="steps" + save_steps=N (or "epoch") sets cadence; save_total_limit=k keeps only the k most-recent checkpoint-* and deletes older ones in output_dir — the built-in disk-budget knob (pairs with references/gotchas_universal.md U6). load_best_model_at_end=True + metric_for_best_model + greater_is_better reloads the best checkpoint at the end and protects it from save_total_limit deletion (C17).

C10 — Accelerate: accelerator.save_state(dir) / load_state(dir) + dataloader skip

Symptom: a custom (non-Trainer) Accelerate loop resumes with a cold optimizer/scaler, or the LR scheduler resets, or it replays already-seen batches.

Root cause: saving only accelerator.get_state_dict(model) drops optimizer/scaler/RNG; and a mid-epoch resume re-iterates the dataloader from batch 0.

Fix (Accelerate checkpoint guide): accelerator.save_state(output_dir) saves model, optimizer, GradScaler, and RNG generators in one call; accelerator.load_state(output_dir) restores all of it (objects must come from the same script). The LR scheduler (and any object with state_dict/load_state_dict) must be registered first — accelerator.register_for_checkpointing(my_scheduler) — or it is not saved and resets (C14). For mid-epoch resume, skip consumed batches with accelerator.skip_first_batches(train_dataloader, N) on the first resumed epoch, then fall back to the full dataloader (C13). ProjectConfiguration(automatic_checkpoint_naming=True, total_limit=k) gives rolling checkpoints/checkpoint_<n> dirs with a built-in limit.

C11 — Lightning: ModelCheckpoint + trainer.fit(ckpt_path=...) (don't use resume_from_checkpoint)

Symptom: an old tutorial's Trainer(resume_from_checkpoint=...) is ignored/deprecated; or save_top_k quietly deletes the checkpoint needed to resume.

Root cause: resume_from_checkpoint moved to fit(ckpt_path=...) (deprecated since 1.x). A Lightning .ckpt is a full dump — epoch, global step, LightningModule state_dict, all optimizer + LR-scheduler states, callback states, loop state, and the 16-bit scaling factor (AMP) (Lightning checkpointing basics).

Fix:

  • Configure ModelCheckpoint(dirpath=..., monitor="val_loss", mode="min", save_top_k=k, save_last=True); resume with trainer.fit(model, datamodule, ckpt_path="path/to/last.ckpt"), or ckpt_path="last" to auto-pick the save_last=True file (the idempotent spelling, C3). Best/last paths read back from cb.best_model_path / cb.last_model_path.
  • save_top_k keeps only the k best by monitor; always set save_last=True so a resume target exists even when the latest step isn't a top-k metric (otherwise resume may have no recent checkpoint). Add custom state (EMA, C16) via on_save_checkpoint / on_load_checkpoint on the module or a stateful callback. Lightning's DeepSpeed strategy writes a ZeRO dir — convert with lightning.pytorch.utilities.deepspeed.convert_zero_checkpoint_to_fp32_state_dict (C8 analogue).

The resume BUGS (looks resumed, silently lost progress)

These are the "it ran without error but the result is wrong" traps — confirm the fix with the verifying-dl-experiments reproducibility check (REQUIRED): kill mid-run, relaunch the identical command, and verify step/epoch/loss continue rather than reset.

C12 — Epoch/step restarts from 0 despite "resuming"

Symptom: tracker shows a second run starting at epoch 1; total trained epochs exceed the schedule; LR warm-up replays. (The remote-ops version of this — a tmux script re-executed mid-run — is references/gotchas_universal.md U2.)

Root cause: the loop is for epoch in range(total_epochs) with a hardcoded 0 start; the saved epoch/step was never read back, or was saved but not used to seed the range.

Fix: start_epoch, start_step = load_latest_if_any(...) then for epoch in range(start_epoch, total_epochs) and seed the step counter from start_step. The counter must be in the checkpoint (C1) and consumed on load.

C13 — Data reshuffles / repeats the same order after resume

Symptom: resume re-shows already-seen samples (worse, the same batch every epoch even without resume), hurting convergence or leaking.

Root cause: two distinct bugs. (a) Resume restarts the epoch from batch 0 without skipping consumed batches. (b) DistributedSampler seeds its shuffle from an internal epoch that defaults to 0 forever unless sampler.set_epoch(epoch) is called each epoch — so every epoch (and every resume) produces the identical order (PyTorch #31771, DistributedSampler docs).

Fix: call train_sampler.set_epoch(epoch) at the top of every epoch (restore the epoch counter on resume so the shuffle stream continues). For mid-epoch resume, fast-forward consumed batches (accelerator.skip_first_batches, C10) or use a resumable/stateful sampler (torchdata StatefulDataLoader) whose offset is in the checkpoint (C1).

C14 — LR schedule resets (cosine restarts, warm-up replays)

Symptom: the LR curve restarts from the initial/warm-up value on resume; final LR is wrong; cosine decay never reaches its floor.

Root cause: the LR scheduler's state_dict (its last_epoch/step counter) was not saved or not restored. With Accelerate, the scheduler wasn't register_for_checkpointing-ed (C10).

Fix: save scheduler.state_dict() and call scheduler.load_state_dict(...) on resume (C1). Note a step-based scheduler advanced per optimizer step must restore the step, not the epoch — restoring only epoch under-/over-shoots the schedule.

C15 — AMP GradScaler not restored → "No inf checks were recorded" / scale stall

Symptom: resuming a mixed-precision run raises AssertionError: No inf checks were recorded for this optimizer, or training stalls/NaNs because the loss-scale snapped back to the default and re-enters the scale-search.

Root cause: the GradScaler holds dynamic state — scale, growth_factor, backoff_factor, growth_interval, _growth_tracker — that evolves during training; dropping it resets the scaler (PyTorch AMP recipe, forum: No inf checks were recorded).

Fix: save scaler.state_dict() (call it after scaler.update() in the iteration) and scaler.load_state_dict(checkpoint["scaler"]) on resume. HF Trainer (scaler.pt), Accelerate (save_state), and Lightning (16-bit factor) all do this automatically — the bug bites hand-written loops. Resuming a non-AMP checkpoint into an AMP run has no saved scaler → start a fresh GradScaler.

C16 — EMA / SWA shadow weights not saved → eval on the wrong weights after resume

Symptom: pre-resume eval (using EMA weights) is good; post-resume eval drops sharply, then recovers over many steps — because the EMA copy restarted from the raw weights.

Root cause: EMA/SWA maintain a separate shadow parameter set that is what gets evaluated/exported; saving only the live model state_dict loses it, so EMA reinitializes from the (noisier) live weights.

Fix: include ema.state_dict() (and SWA AveragedModel / swa_scheduler state) in the checkpoint dict (C1) and restore it. In Lightning, persist it via on_save_checkpoint/on_load_checkpoint (C11). This is a which-weights-are-correct concern at the boundary — cross-link verifying-dl-experiments (REQUIRED) for confirming the evaluated weights are the intended ones.

C17 — save_total_limit / save_top_k deletes the very checkpoint resume needs

Symptom: resume fails because the target checkpoint was auto-pruned; or load_best_model_at_end errors because the best checkpoint was rotated out.

Root cause: a rolling limit prunes by recency (save_total_limit) or by metric (save_top_k), and neither guarantees the most-recent-step checkpoint is the one kept — so the resume anchor can be the one deleted.

Fix: keep an explicit last/latest alongside the top-k (save_last=True in Lightning, C11; in HF, load_best_model_at_end=True makes Trainer preserve the best checkpoint past save_total_limit). General keepable-checkpoint policy (how many, which selection criterion, save_top_k ≤ 3, prune latest) is owned by verifying-dl-experiments (REQUIRED); the disk-budget consequence is references/gotchas_universal.md U6.

C18 — load_state_dict key mismatch on resume (module. prefix, compiled-model prefix)

Symptom: resume raises Missing key(s) / Unexpected key(s) ... module.<name> or _orig_mod.<name>, or strict load fails after switching DDP/torch.compile on or off.

Root cause: DataParallel/DDP wrap adds a module. prefix and torch.compile adds _orig_mod. to every key; a checkpoint saved wrapped and loaded unwrapped (or vice-versa) won't key-match under strict=True.

Fix: save the unwrapped module — model.module.state_dict() (DDP) / accelerator.unwrap_model(model).state_dict() / model._orig_mod.state_dict() (compiled) — so the checkpoint is wrapper-agnostic. On load, strip the prefix if present ({k.replace("module.", "").replace("_orig_mod.", ""): v for k, v in sd.items()}). Keep strict=True while debugging a resume so a silent partial load can't masquerade as success; only relax it deliberately.


Pointers — owned elsewhere, do NOT restate here

  • Cadence — when/how often (Young/Daly W = sqrt(2·mu·C), grace windows, opportunistic SIGTERM last-flush, the runnable atomic skeleton) → references/spot-resilience.md (REQUIRED, spot tier).
  • Disk-full on save (pre-budget, prune latest, keep best, .tmp recovery) → references/gotchas_universal.md U6; silent "synced" line → U33; inode exhaustion → U7.
  • Sharding a model that won't fit (FSDP wrap policy, ZeRO stages, offload) is the fitting concern → references/training/oom-memory.md M9/M10; this file owns checkpointing the sharded state.
  • Multi-rank save/load collectives + elastic restart (torchrun --max-restarts restores from the checkpoint) → references/training/distributed-launch.md, references/multinode.md.
  • Keepable-checkpoint policy + "is the resumed/best number real" (selection criterion, save_top_k, proving step/epoch/loss continued) → verifying-dl-experiments (REQUIRED).