playbook/antigravity-awesome-skills/skills/remote-gpu-trainer/scripts/verify_local.py

146 lines
5.8 KiB
Python

#!/usr/bin/env python
"""Verify integrity of downloaded ckpt directories.
For each <name>/ in the target dir, check:
- best.pth exists
- best.pth loads cleanly via torch.load
- best.pth contains a weights key ('model_state_dict' / 'model' / 'state_dict')
- best_metrics.json exists and is valid JSON
- reports best epoch + main metric per ablation
Usage:
python verify_local.py <path_to_final_ckpts_dir> [--expect N] [--list-metrics]
Exit code:
0 = all OK
1 = at least one error, an empty input dir, or a dir count != --expect
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("ckpt_dir", help="Directory containing ablation subdirs (each with best.pth + best_metrics.json)")
ap.add_argument("--list-metrics", action="store_true", help="Print per-ablation epoch + main metric")
ap.add_argument("--expect", type=int, default=None,
help="Assert exactly N ablation subdirs are present -- guards a teardown gate against a partial/empty pull")
ap.add_argument("--allow-pickle", action="store_true",
help="Permit the weights_only=False fallback (executes pickle) for checkpoints you trust -- "
"needed only when a checkpoint pickles non-tensor objects (e.g. an args Namespace); OFF by default")
args = ap.parse_args()
root = Path(args.ckpt_dir)
if not root.exists():
print(f"ERROR: {root} does not exist")
return 1
if not root.is_dir():
print(f"ERROR: {root} is not a directory")
return 1
# Structural checks BEFORE importing torch: an empty (or short) input must fail
# LOUDLY here -- never silently print "OK: 0/0" and return success, which would let
# a Phase-5 teardown gate destroy the rented disk having verified nothing
# (principle #3: trust the artifact, not a success line; the teardown Iron Law).
dirs = sorted([d for d in root.iterdir() if d.is_dir()])
if not dirs:
print(f"ERROR: no ablation subdirectories found in {root} -- refusing to report success on an empty input")
return 1
if args.expect is not None and len(dirs) != args.expect:
print(f"ERROR: expected {args.expect} ablation dirs but found {len(dirs)} in {root} -- partial/incomplete pull")
return 1
try:
import torch
except ImportError:
print("ERROR: torch not installed in this environment")
return 1
print(f"Found {len(dirs)} ablation dirs in {root}")
print()
ok = 0
errors: list[tuple[str, str]] = []
metrics_rows: list[tuple[str, int, str]] = []
total_size_bytes = 0
for d in dirs:
name = d.name
pth = d / "best.pth"
metrics_path = d / "best_metrics.json"
if not pth.exists():
errors.append((name, "missing best.pth"))
continue
if not metrics_path.exists():
errors.append((name, "missing best_metrics.json"))
continue
# Load safe-by-default: weights_only=True refuses to execute pickle, so a poisoned or
# compromised remote checkpoint cannot run code on the operator's machine. The unsafe
# weights_only=False path (which DOES execute pickle) is OPT-IN via --allow-pickle: an attacker
# who controls the remote file could otherwise craft one that fails the safe load to FORCE the
# fallback, so auto-falling-back would defeat the gate. Pass --allow-pickle ONLY for your own ckpts.
try:
ckpt = torch.load(pth, map_location="cpu", weights_only=True)
except Exception as e_safe:
if not args.allow_pickle:
errors.append((name, f"safe load (weights_only=True) failed: {str(e_safe)[:70]} "
"-- re-run with --allow-pickle if this is your own checkpoint"))
continue
try:
print(
f" [warn] {name}: weights_only=True failed; --allow-pickle set, retrying "
"weights_only=False (executes pickle -- trust this file)"
)
ckpt = torch.load(pth, map_location="cpu", weights_only=False)
except Exception as e:
errors.append((name, f"torch.load failed: {str(e)[:100]}"))
continue
if not isinstance(ckpt, dict) or not any(k in ckpt for k in ("model_state_dict", "model", "state_dict")):
errors.append((name, "no model/model_state_dict/state_dict key in checkpoint"))
continue
try:
with open(metrics_path) as f:
m = json.load(f)
except Exception as e:
errors.append((name, f"best_metrics.json invalid: {str(e)[:80]}"))
continue
epoch = m.get("epoch", "?")
if epoch is None: # {"epoch": null} → .get returns None (not the default); guard the :3 format. `or` would wrongly eat epoch 0.
epoch = "?"
# Pick main metric (PSNR for recon, mAP50 for det, dice for seg, fall back to loss)
main_metric_key = next(
(k for k in ["psnr", "mAP50", "dice"] if k in m),
"loss",
)
main_metric_val = m.get(main_metric_key, "?")
metrics_rows.append((name, epoch, f"{main_metric_key}={main_metric_val}"))
total_size_bytes += pth.stat().st_size
ok += 1
print(f"OK: {ok}/{len(dirs)}")
print(f"Errors: {len(errors)}")
for name, err in errors[:20]:
print(f" - {name}: {err}")
print(f"Total best.pth size: {total_size_bytes / 1e9:.1f} GB")
if args.list_metrics:
print()
print("=== Per-ablation metrics ===")
for name, epoch, metric in metrics_rows:
print(f" {name:40s} epoch={epoch:3} {metric}")
return 0 if not errors else 1
if __name__ == "__main__":
sys.exit(main())