#!/usr/bin/env python """Verify integrity of downloaded ckpt directories. For each / in the target dir, check: - best.pth exists - best.pth loads cleanly via torch.load - best.pth contains a weights key ('model_state_dict' / 'model' / 'state_dict') - best_metrics.json exists and is valid JSON - reports best epoch + main metric per ablation Usage: python verify_local.py [--expect N] [--list-metrics] Exit code: 0 = all OK 1 = at least one error, an empty input dir, or a dir count != --expect """ from __future__ import annotations import argparse import json import sys from pathlib import Path def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("ckpt_dir", help="Directory containing ablation subdirs (each with best.pth + best_metrics.json)") ap.add_argument("--list-metrics", action="store_true", help="Print per-ablation epoch + main metric") ap.add_argument("--expect", type=int, default=None, help="Assert exactly N ablation subdirs are present -- guards a teardown gate against a partial/empty pull") ap.add_argument("--allow-pickle", action="store_true", help="Permit the weights_only=False fallback (executes pickle) for checkpoints you trust -- " "needed only when a checkpoint pickles non-tensor objects (e.g. an args Namespace); OFF by default") args = ap.parse_args() root = Path(args.ckpt_dir) if not root.exists(): print(f"ERROR: {root} does not exist") return 1 if not root.is_dir(): print(f"ERROR: {root} is not a directory") return 1 # Structural checks BEFORE importing torch: an empty (or short) input must fail # LOUDLY here -- never silently print "OK: 0/0" and return success, which would let # a Phase-5 teardown gate destroy the rented disk having verified nothing # (principle #3: trust the artifact, not a success line; the teardown Iron Law). dirs = sorted([d for d in root.iterdir() if d.is_dir()]) if not dirs: print(f"ERROR: no ablation subdirectories found in {root} -- refusing to report success on an empty input") return 1 if args.expect is not None and len(dirs) != args.expect: print(f"ERROR: expected {args.expect} ablation dirs but found {len(dirs)} in {root} -- partial/incomplete pull") return 1 try: import torch except ImportError: print("ERROR: torch not installed in this environment") return 1 print(f"Found {len(dirs)} ablation dirs in {root}") print() ok = 0 errors: list[tuple[str, str]] = [] metrics_rows: list[tuple[str, int, str]] = [] total_size_bytes = 0 for d in dirs: name = d.name pth = d / "best.pth" metrics_path = d / "best_metrics.json" if not pth.exists(): errors.append((name, "missing best.pth")) continue if not metrics_path.exists(): errors.append((name, "missing best_metrics.json")) continue # Load safe-by-default: weights_only=True refuses to execute pickle, so a poisoned or # compromised remote checkpoint cannot run code on the operator's machine. The unsafe # weights_only=False path (which DOES execute pickle) is OPT-IN via --allow-pickle: an attacker # who controls the remote file could otherwise craft one that fails the safe load to FORCE the # fallback, so auto-falling-back would defeat the gate. Pass --allow-pickle ONLY for your own ckpts. try: ckpt = torch.load(pth, map_location="cpu", weights_only=True) except Exception as e_safe: if not args.allow_pickle: errors.append((name, f"safe load (weights_only=True) failed: {str(e_safe)[:70]} " "-- re-run with --allow-pickle if this is your own checkpoint")) continue try: print( f" [warn] {name}: weights_only=True failed; --allow-pickle set, retrying " "weights_only=False (executes pickle -- trust this file)" ) ckpt = torch.load(pth, map_location="cpu", weights_only=False) except Exception as e: errors.append((name, f"torch.load failed: {str(e)[:100]}")) continue if not isinstance(ckpt, dict) or not any(k in ckpt for k in ("model_state_dict", "model", "state_dict")): errors.append((name, "no model/model_state_dict/state_dict key in checkpoint")) continue try: with open(metrics_path) as f: m = json.load(f) except Exception as e: errors.append((name, f"best_metrics.json invalid: {str(e)[:80]}")) continue epoch = m.get("epoch", "?") if epoch is None: # {"epoch": null} → .get returns None (not the default); guard the :3 format. `or` would wrongly eat epoch 0. epoch = "?" # Pick main metric (PSNR for recon, mAP50 for det, dice for seg, fall back to loss) main_metric_key = next( (k for k in ["psnr", "mAP50", "dice"] if k in m), "loss", ) main_metric_val = m.get(main_metric_key, "?") metrics_rows.append((name, epoch, f"{main_metric_key}={main_metric_val}")) total_size_bytes += pth.stat().st_size ok += 1 print(f"OK: {ok}/{len(dirs)}") print(f"Errors: {len(errors)}") for name, err in errors[:20]: print(f" - {name}: {err}") print(f"Total best.pth size: {total_size_bytes / 1e9:.1f} GB") if args.list_metrics: print() print("=== Per-ablation metrics ===") for name, epoch, metric in metrics_rows: print(f" {name:40s} epoch={epoch:3} {metric}") return 0 if not errors else 1 if __name__ == "__main__": sys.exit(main())