# Correct checkpointing & idempotent resume — full state, atomic write, sharded checkpoints, framework APIs Make a training job resume **exactly where it stopped** after any kill — not "reload the weights and silently restart the epoch." This layer owns the *mechanics*: what FULL state to save, how to write it without corruption, how to load it unconditionally, and the framework-specific knobs (FSDP / DeepSpeed / HF Trainer / Accelerate / Lightning) plus the resume **bugs** that make a job look resumed while it quietly lost progress. **verifying-dl-experiments** (**REQUIRED**) owns *is the resumed number correct* — e.g. proving step/epoch/loss actually continued instead of resetting is its reproducibility check applied here. The spot/preemption *cadence* (when + how often, Young/Daly) lives in `references/spot-resilience.md` (**REQUIRED** for any interruptible/spot tier) — this file is the *content and correctness* of each checkpoint; that file is the *timing*. To jump: `grep -in '' references/training/checkpoint-resume.md` (e.g. `atomic`, `rename`, `scaler`, `ema`, `sampler`, `fsdp`, `sharded`, `zero_to_fp32`, `dcp`, `resume_from_checkpoint`, `save_state`, `ckpt_path`, `save_total_limit`, `reshuffle`). ## Table of contents - **The contract** — C1 full-state-list · C2 atomic-write · C3 load-latest-unconditionally · C4 durable-location - **Sharded checkpoints (multi-GPU)** — C5 FSDP-FULL_STATE_DICT-rank0-OOM · C6 FSDP-SHARDED_STATE_DICT · C7 DCP-(dcp.save/load) · C8 DeepSpeed-ZeRO-dir+zero_to_fp32 - **Framework APIs** — C9 HF-Trainer-resume_from_checkpoint+save_total_limit · C10 Accelerate-save_state/load_state · C11 Lightning-ModelCheckpoint+ckpt_path - **The resume BUGS** — C12 epoch-restarts · C13 data-reshuffles/order · C14 LR-schedule-resets · C15 scaler-not-restored · C16 EMA-not-saved · C17 save_total_limit-deletes-best · C18 strict-load-key-mismatch - **Pointers** — disk-full on save → gotchas_universal.md U6 · silent sync → U33 · keepable-policy/save_top_k → verifying-dl-experiments (skill) · cadence/Young-Daly → spot-resilience.md --- ## The contract ### C1 — A checkpoint that restores only weights is NOT a resume — save the FULL training state **Symptom**: resume "works" (no crash) but the loss jumps up, accuracy regresses, or training takes more total epochs than an uninterrupted run — because the resume silently restarted the epoch, reset the optimizer momentum, and reshuffled the data. **Root cause**: `torch.save(model.state_dict())` captures *weights only*. Optimizer momentum/variance, the LR-scheduler position, the epoch/step counter, RNG state, the AMP scaler, and the dataloader position are all lost, so the restarted run is a *different* trajectory, not a continuation. **Fix**: every checkpoint must carry the full state (PyTorch tutorial [saving multiple / general checkpoint](https://docs.pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html); the spot-resilience §3 list): | Must save | Why losing it breaks resume | |---|---| | model `state_dict` | the weights (obvious) | | optimizer `state_dict` | Adam m/v momentum — losing it = a cold optimizer restart (C12) | | LR-scheduler `state_dict` | step-based LR position — losing it resets the schedule (C14) | | `epoch` **and** global `step`/iteration | resume the exact position, not the epoch start (C12) | | RNG state: Python `random`, NumPy, `torch`, **CUDA** (`torch.cuda.get_rng_state_all()`) | reproducible augmentation/dropout stream after restart | | dataloader / sampler position | so the next batch is the *next* unseen one, not a reshuffle (C13) | | AMP `GradScaler` `state_dict` | the loss-scale + growth tracker — losing it triggers an inf-scale stall (C15) | | EMA / SWA shadow weights (if used) | the EMA copy is often what's evaluated — losing it = eval on the wrong weights (C16) | | best-metric-so-far + `best.pth` selection state | so "best" survives a restart instead of resetting | The runnable atomic skeleton that assembles this dict is in `references/spot-resilience.md` §5 — do not duplicate it; this table is the *checklist*, that is the *code*. ### C2 — Write atomically: tmp → fsync → os.replace (a kill mid-write corrupts a naive save) **Symptom**: after a preemption/OOM, `latest.pth` is truncated/zero-byte or `torch.load` raises `RuntimeError: PytorchStreamReader failed reading zip archive`; a `latest.pth.tmp` is left behind. **Root cause**: overwriting `latest.pth` in place is **not** atomic — a kill partway through leaves a corrupt file and (if it was the only checkpoint) zero good ones. `torch.save` itself does *not* fsync. **Fix**: write to a temp file, force bytes to disk, then atomically rename (POSIX `rename`/`os.replace` is atomic on the **same filesystem**): ```python tmp = ckpt_path + ".tmp" with open(tmp, "wb") as f: torch.save(state, f); f.flush(); os.fsync(f.fileno()) # bytes hit disk BEFORE the swap os.replace(tmp, ckpt_path) # all-or-nothing; keep prev until this returns ``` Keep the previous `latest.pth` valid until the rename returns (a kill at any instant leaves one intact file). `os.replace` (not `os.rename`) also works on Windows for the local-test path. Full recipe + rationale: `references/spot-resilience.md` §3. Disk-full *during* the save is a separate failure with the same `.tmp` left behind → `references/gotchas_universal.md` U6 (pre-budget + prune `latest`, keep `best`). ### C3 — Load-latest UNCONDITIONALLY on startup → idempotent resume **Symptom**: a relaunch starts from scratch because the resume is gated behind a `--resume` flag the launch wrapper forgot to pass; or two code paths (fresh vs resume) diverge. **Root cause**: making resume *opt-in* means a generic relaunch (spot recovery, SSH-drop restart, queue retry) re-trains from zero. A divergent "first launch" code path also drifts from the resume path. **Fix**: one code path that loads the latest checkpoint if it exists, else starts fresh — so the **identical launch command** converges to the same end state no matter how many times it runs. This is what makes principle #7's "retry the identical config" actually *resume* instead of restart, and it is the universal spine (principle #8) under SSH-drop / Slurm-walltime / K8s-reschedule / spot-preemption. Skeleton: `references/spot-resilience.md` §3 (`load_latest_if_any`). ### C4 — Checkpoint to the platform's DURABLE location, not local scratch **Symptom**: resume after a managed-spot replacement (or a `terminate`/`destroy`) finds no checkpoint — the box came up *fresh* and the only copy was on the dead instance's local disk. **Root cause**: a replacement node is clean; anything not on a cloud bucket / network volume / shared FS is gone (principle #4 — know what survives stop vs destroy). **Fix**: write checkpoints to the profile's durable mount (`DURABLE_DIR` in `profiles/.md` §8), or mirror local→durable on the checkpoint timer. The single biggest portability trap is assuming local disk survives — see each profile's STORAGE survival-matrix and the SKILL Quick-reference table. Gate the sync on the actual copy result, never an unconditional `echo synced` → `references/gotchas_universal.md` U33. --- ## Sharded checkpoints (multi-GPU) ### C5 — FSDP `FULL_STATE_DICT` OOMs on rank 0 when gathering a large model **Symptom**: an FSDP job trains fine but **crashes at the first checkpoint** with CUDA OOM on rank 0; the model is larger than one GPU. **Root cause**: `StateDictType.FULL_STATE_DICT` all-gathers every shard onto **one rank** to assemble the unsharded dict. For a model that only fits *because* it's sharded, materializing the whole thing on rank 0 exceeds that GPU's VRAM. **Fix**: when taking a full (consolidated) dict, offload it to CPU and build it on rank 0 only — `FullStateDictConfig(offload_to_cpu=True, rank0_only=True)`. This all-gathers parameters one-by-one, offloading each to CPU on rank 0, so peak GPU memory stays bounded and non-rank-0 workers skip the GPU→CPU copy entirely ([HF Accelerate FSDP guide](https://huggingface.co/docs/accelerate/en/usage_guides/fsdp), [Lightning issue #11207](https://github.com/Lightning-AI/pytorch-lightning/issues/11207)). The full dict is only viable when it fits in CPU RAM; past that, use sharded (C6). Save a full dict only at the **end** for a portable single-file artifact; checkpoint *during* training as sharded. ### C6 — `SHARDED_STATE_DICT`: each rank saves its own shard (no gather, no rank-0 OOM) **Symptom**: need to checkpoint a model too big to consolidate even on CPU, or want a fast resume that re-shards onto a *different* world size. **Root cause**: `FULL_STATE_DICT` is fundamentally a single-rank materialization; it does not scale and cannot reshard. **Fix**: use `StateDictType.SHARDED_STATE_DICT` — every rank writes only its own shard, so there is no all-gather and no OOM, and the per-rank files load back in parallel. Pair it with Distributed Checkpoint (C7), which is the production path for sharded save/load and supports **resharding** (resume on a different GPU count). Tradeoff: a sharded checkpoint is a *directory of N files*, not a single `.pth` — convert to a full dict for export/inference (C7's `get_model_state_dict`, or the DeepSpeed analogue C8). ### C7 — Distributed Checkpoint (DCP): `dcp.save` / `dcp.load` for FSDP/sharded models **Symptom**: hand-rolling FSDP state-dict context managers is brittle, slow, and breaks when the world size changes between save and resume. **Root cause**: `torch.save` produces a single file and has no notion of sharding or FQN remapping; manually toggling `FSDP.state_dict_type` is error-prone. **Fix**: use `torch.distributed.checkpoint` (DCP), the current PyTorch-2.x sharded-checkpoint API ([DCP recipe](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html), [2.12 reference](https://docs.pytorch.org/docs/2.12/distributed.checkpoint.html)). **Save**: get canonical dicts with `get_state_dict(model, optimizer)` from `torch.distributed.checkpoint.state_dict`, then `dcp.save(state_dict, checkpoint_id=DIR)` — it writes **≥1 file per rank in parallel** and auto-manages FQN mappings. **Load**: allocate the model first, then `dcp.load(state_dict, checkpoint_id=DIR)` (loads **in place** and **auto-reshards** to the current world size), then `set_state_dict(...)`. DCP beats `torch.save` for any distributed model because it shards the write across ranks (no rank-0 gather, C5) and reshards on load. For a single portable inference file, convert offline with `torch.distributed.checkpoint.format_utils.dcp_to_torch_save(DIR, "out.pt")` (or the CLI `python -m torch.distributed.checkpoint.format_utils dcp_to_torch DIR out.pt`). ### C8 — DeepSpeed ZeRO: a checkpoint *directory* per save + `zero_to_fp32.py` to consolidate **Symptom**: `model_engine.save_checkpoint(dir)` writes a *folder* of `mp_rank_*` / `zero_pp_rank_*` files, not a `.pth`; loading the weights into a plain (non-DeepSpeed) model for inference fails. **Root cause**: ZeRO **partitions** optimizer state (stage 1), gradients (2), and parameters (3) across ranks; the on-disk checkpoint is inherently sharded across per-rank files — it is not a single fp32 model. **Fix** ([DeepSpeed model-checkpointing](https://deepspeed.readthedocs.io/en/stable/model-checkpointing.html), [ZeRO tutorial](https://www.deepspeed.ai/tutorials/zero/)): - **Save/resume training** — `model_engine.save_checkpoint(save_dir, tag)` / `model_engine.load_checkpoint(save_dir, tag)`. **All ranks must call both** (they're collective; rank-0 only deadlocks/corrupts). Round-trips full sharded optimizer+param state. - **Export a single fp32 model** — DeepSpeed auto-drops a `zero_to_fp32.py` into the checkpoint dir; run `python zero_to_fp32.py pytorch_model.bin`, or in-process `from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint(dir)` / `convert_zero_checkpoint_to_fp32_state_dict(...)` / `load_state_dict_from_zero_checkpoint(model, dir)` (the last returns a model that **can't continue training** without re-init). The consolidated file no longer needs DeepSpeed. For ZeRO-3, set `"zero_optimization": {"stage3_gather_16bit_weights_on_model_save": true}` + `engine.save_16bit_model(dir)`. --- ## Framework APIs ### C9 — HF Trainer: `resume_from_checkpoint` + `save_total_limit` (and what it actually saves) **Symptom**: assuming `Trainer.save_model()` is a resume point (it saves *weights only*); or a relaunch re-trains from step 0 because `resume_from_checkpoint` wasn't passed; or the disk fills with `checkpoint-*` dirs. **Root cause**: `save_model` ≠ a training checkpoint. A real Trainer checkpoint dir (`checkpoint-`) contains the model **plus** `optimizer.pt`, `scheduler.pt`, `rng_state.pth`, `trainer_state.json`, and the AMP `scaler.pt` — the full state. Without `resume_from_checkpoint` the run starts cold. **Fix** ([Trainer docs](https://huggingface.co/docs/transformers/main/en/main_classes/trainer)): `trainer.train(resume_from_checkpoint="path/to/checkpoint-1500")` resumes that exact dir; `resume_from_checkpoint=True` auto-finds the **last** checkpoint in `args.output_dir` (idempotent spelling, C3; `trainer_utils.get_last_checkpoint(output_dir)` finds it in code). `save_strategy="steps"` + `save_steps=N` (or `"epoch"`) sets cadence; **`save_total_limit=k`** keeps only the `k` most-recent `checkpoint-*` and **deletes older ones in `output_dir`** — the built-in disk-budget knob (pairs with `references/gotchas_universal.md` U6). `load_best_model_at_end=True` + `metric_for_best_model` + `greater_is_better` reloads the best checkpoint at the end **and** protects it from `save_total_limit` deletion (C17). ### C10 — Accelerate: `accelerator.save_state(dir)` / `load_state(dir)` + dataloader skip **Symptom**: a custom (non-Trainer) Accelerate loop resumes with a cold optimizer/scaler, or the LR scheduler resets, or it replays already-seen batches. **Root cause**: saving only `accelerator.get_state_dict(model)` drops optimizer/scaler/RNG; and a mid-epoch resume re-iterates the dataloader from batch 0. **Fix** ([Accelerate checkpoint guide](https://huggingface.co/docs/accelerate/en/usage_guides/checkpoint)): `accelerator.save_state(output_dir)` saves model, optimizer, **GradScaler**, and RNG generators in one call; `accelerator.load_state(output_dir)` restores all of it (objects must come from the *same* script). The LR scheduler (and any object with `state_dict`/`load_state_dict`) **must** be registered first — `accelerator.register_for_checkpointing(my_scheduler)` — or it is not saved and resets (C14). For mid-epoch resume, skip consumed batches with `accelerator.skip_first_batches(train_dataloader, N)` on the first resumed epoch, then fall back to the full dataloader (C13). `ProjectConfiguration(automatic_checkpoint_naming=True, total_limit=k)` gives rolling `checkpoints/checkpoint_` dirs with a built-in limit. ### C11 — Lightning: `ModelCheckpoint` + `trainer.fit(ckpt_path=...)` (don't use `resume_from_checkpoint`) **Symptom**: an old tutorial's `Trainer(resume_from_checkpoint=...)` is ignored/deprecated; or `save_top_k` quietly deletes the checkpoint needed to resume. **Root cause**: `resume_from_checkpoint` moved to `fit(ckpt_path=...)` (deprecated since 1.x). A Lightning `.ckpt` is a full dump — epoch, global step, LightningModule `state_dict`, **all** optimizer + LR-scheduler states, callback states, loop state, and the 16-bit scaling factor (AMP) ([Lightning checkpointing basics](https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html)). **Fix**: - Configure `ModelCheckpoint(dirpath=..., monitor="val_loss", mode="min", save_top_k=k, save_last=True)`; resume with `trainer.fit(model, datamodule, ckpt_path="path/to/last.ckpt")`, or `ckpt_path="last"` to auto-pick the `save_last=True` file (the idempotent spelling, C3). Best/last paths read back from `cb.best_model_path` / `cb.last_model_path`. - `save_top_k` keeps only the k best by `monitor`; **always set `save_last=True`** so a resume target exists even when the latest step isn't a top-k metric (otherwise resume may have no recent checkpoint). Add custom state (EMA, C16) via `on_save_checkpoint` / `on_load_checkpoint` on the module or a stateful callback. Lightning's DeepSpeed strategy writes a ZeRO dir — convert with `lightning.pytorch.utilities.deepspeed.convert_zero_checkpoint_to_fp32_state_dict` (C8 analogue). --- ## The resume BUGS (looks resumed, silently lost progress) These are the "it ran without error but the result is wrong" traps — confirm the fix with the `verifying-dl-experiments` reproducibility check (**REQUIRED**): kill mid-run, relaunch the *identical* command, and verify step/epoch/loss **continue** rather than reset. ### C12 — Epoch/step restarts from 0 despite "resuming" **Symptom**: tracker shows a second run starting at epoch 1; total trained epochs exceed the schedule; LR warm-up replays. (The remote-ops version of this — a tmux script re-executed mid-run — is `references/gotchas_universal.md` U2.) **Root cause**: the loop is `for epoch in range(total_epochs)` with a hardcoded `0` start; the saved `epoch`/`step` was never read back, or was saved but not used to seed the range. **Fix**: `start_epoch, start_step = load_latest_if_any(...)` then `for epoch in range(start_epoch, total_epochs)` and seed the step counter from `start_step`. The counter **must** be in the checkpoint (C1) *and* consumed on load. ### C13 — Data reshuffles / repeats the same order after resume **Symptom**: resume re-shows already-seen samples (worse, the *same* batch every epoch even without resume), hurting convergence or leaking. **Root cause**: two distinct bugs. (a) Resume restarts the epoch from batch 0 without skipping consumed batches. (b) `DistributedSampler` seeds its shuffle from an internal epoch that defaults to 0 forever unless `sampler.set_epoch(epoch)` is called each epoch — so every epoch (and every resume) produces the **identical** order ([PyTorch #31771](https://github.com/pytorch/pytorch/issues/31771), [DistributedSampler docs](https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler)). **Fix**: call `train_sampler.set_epoch(epoch)` at the top of every epoch (restore the epoch counter on resume so the shuffle stream continues). For mid-epoch resume, fast-forward consumed batches (`accelerator.skip_first_batches`, C10) or use a resumable/stateful sampler (`torchdata` `StatefulDataLoader`) whose offset is in the checkpoint (C1). ### C14 — LR schedule resets (cosine restarts, warm-up replays) **Symptom**: the LR curve restarts from the initial/warm-up value on resume; final LR is wrong; cosine decay never reaches its floor. **Root cause**: the LR scheduler's `state_dict` (its `last_epoch`/step counter) was not saved or not restored. With Accelerate, the scheduler wasn't `register_for_checkpointing`-ed (C10). **Fix**: save `scheduler.state_dict()` and call `scheduler.load_state_dict(...)` on resume (C1). Note a step-based scheduler advanced *per optimizer step* must restore the **step**, not the epoch — restoring only `epoch` under-/over-shoots the schedule. ### C15 — AMP `GradScaler` not restored → "No inf checks were recorded" / scale stall **Symptom**: resuming a mixed-precision run raises `AssertionError: No inf checks were recorded for this optimizer`, or training stalls/NaNs because the loss-scale snapped back to the default and re-enters the scale-search. **Root cause**: the `GradScaler` holds dynamic state — `scale`, `growth_factor`, `backoff_factor`, `growth_interval`, `_growth_tracker` — that evolves during training; dropping it resets the scaler ([PyTorch AMP recipe](https://docs.pytorch.org/tutorials/recipes/recipes/amp_recipe.html), [forum: No inf checks were recorded](https://discuss.pytorch.org/t/resume-training-with-mixed-precision-lead-to-no-inf-checks-were-recorded-for-this-optimizer/115828)). **Fix**: save `scaler.state_dict()` (call it **after** `scaler.update()` in the iteration) and `scaler.load_state_dict(checkpoint["scaler"])` on resume. HF Trainer (`scaler.pt`), Accelerate (`save_state`), and Lightning (16-bit factor) all do this automatically — the bug bites hand-written loops. Resuming a *non-AMP* checkpoint into an AMP run has no saved scaler → start a **fresh** `GradScaler`. ### C16 — EMA / SWA shadow weights not saved → eval on the wrong weights after resume **Symptom**: pre-resume eval (using EMA weights) is good; post-resume eval drops sharply, then recovers over many steps — because the EMA copy restarted from the raw weights. **Root cause**: EMA/SWA maintain a *separate* shadow parameter set that is what gets evaluated/exported; saving only the live model `state_dict` loses it, so EMA reinitializes from the (noisier) live weights. **Fix**: include `ema.state_dict()` (and SWA `AveragedModel` / `swa_scheduler` state) in the checkpoint dict (C1) and restore it. In Lightning, persist it via `on_save_checkpoint`/`on_load_checkpoint` (C11). This is a *which-weights-are-correct* concern at the boundary — cross-link **verifying-dl-experiments** (**REQUIRED**) for confirming the evaluated weights are the intended ones. ### C17 — `save_total_limit` / `save_top_k` deletes the very checkpoint resume needs **Symptom**: resume fails because the target checkpoint was auto-pruned; or `load_best_model_at_end` errors because the best checkpoint was rotated out. **Root cause**: a rolling limit prunes by *recency* (`save_total_limit`) or by *metric* (`save_top_k`), and neither guarantees the most-recent-step checkpoint is the one kept — so the resume anchor can be the one deleted. **Fix**: keep an explicit `last`/`latest` alongside the top-k (`save_last=True` in Lightning, C11; in HF, `load_best_model_at_end=True` makes Trainer preserve the best checkpoint past `save_total_limit`). General keepable-checkpoint *policy* (how many, which selection criterion, `save_top_k ≤ 3`, prune `latest`) is owned by **verifying-dl-experiments** (**REQUIRED**); the disk-budget consequence is `references/gotchas_universal.md` U6. ### C18 — `load_state_dict` key mismatch on resume (`module.` prefix, compiled-model prefix) **Symptom**: resume raises `Missing key(s)` / `Unexpected key(s) ... module.` or `_orig_mod.`, or strict load fails after switching DDP/`torch.compile` on or off. **Root cause**: `DataParallel`/DDP wrap adds a `module.` prefix and `torch.compile` adds `_orig_mod.` to every key; a checkpoint saved wrapped and loaded unwrapped (or vice-versa) won't key-match under `strict=True`. **Fix**: save the **unwrapped** module — `model.module.state_dict()` (DDP) / `accelerator.unwrap_model(model).state_dict()` / `model._orig_mod.state_dict()` (compiled) — so the checkpoint is wrapper-agnostic. On load, strip the prefix if present (`{k.replace("module.", "").replace("_orig_mod.", ""): v for k, v in sd.items()}`). Keep `strict=True` while debugging a resume so a silent partial load can't masquerade as success; only relax it deliberately. --- ## Pointers — owned elsewhere, do NOT restate here - **Cadence — when/how often** (Young/Daly `W = sqrt(2·mu·C)`, grace windows, opportunistic SIGTERM last-flush, the runnable atomic skeleton) → `references/spot-resilience.md` (**REQUIRED**, spot tier). - **Disk-full on save** (pre-budget, prune `latest`, keep `best`, `.tmp` recovery) → `references/gotchas_universal.md` U6; **silent "synced" line** → U33; **inode exhaustion** → U7. - **Sharding a model that won't fit** (FSDP wrap policy, ZeRO stages, offload) is the *fitting* concern → `references/training/oom-memory.md` M9/M10; this file owns *checkpointing* the sharded state. - **Multi-rank save/load collectives + elastic restart** (torchrun `--max-restarts` restores from the checkpoint) → `references/training/distributed-launch.md`, `references/multinode.md`. - **Keepable-checkpoint policy + "is the resumed/best number real"** (selection criterion, `save_top_k`, proving step/epoch/loss continued) → **verifying-dl-experiments** (**REQUIRED**).