# Model Training & Experimentation ## --- ### 13) Distributed Training * **Trigger** * Automatic: new `golden_train/val/test.manifest` from Workflow #11, or new auto-mined slice from #12. * Manual: engineer launches a training run from the UI/CLI, selecting dataset spec + model recipe. * Scheduled: nightly/weekly rebuild on rolling data window. * **Inputs** * Curated dataset manifests + `slices.yaml` (from #11), plus DVC tag/version. * Model recipe (backbone, heads, loss weights, augmentations) and training config (`train.yaml`). * Pretrained weights (optional) for warm-start or continual learning. * Code container image (ECR) and commit SHA; W\&B project/entity; environment secrets. * Hardware profile (e.g., `p4d.24xlarge` × N nodes) and storage profile (FSx for Lustre). * **Core Steps** * Packaging & Staging * Build/push Docker image with PyTorch, NCCL, CUDA, AMP, torchvision/OpenMMLab, and your repo. * Stage dataset to **FSx for Lustre** (for I/O throughput) using DVC or AWS DataSync; keep bulk in S3. * Pre-generate sharded WebDataset/TFRecord (optional) for faster streaming. * Orchestration * Launch **SageMaker distributed training** job (preferred) or **EKS** Job with `torchrun` (DDP). * Configure NCCL and networking (env: `NCCL_DEBUG=INFO`, `NCCL_SOCKET_IFNAME=eth0`, `NCCL_ASYNC_ERROR_HANDLING=1`). * Training Runtime * Initialize **DDP** with `backend=nccl`, set `torch.set_float32_matmul_precision("high")` where supported. * **Data loading**: multi-worker `DataLoader` + persistent workers, pinned memory, async prefetch. * **Mixed precision (AMP)** + gradient scaling; gradient accumulation for large effective batch sizes. * **Task balancing** for multi-head models (e.g., HydraNet-style): dynamic or curriculum weighting. * **Checkpointing**: local → FSx (fast) every N steps; periodic sync to S3 (`s3://…/checkpoints/…`). * **Fault tolerance**: resume from last global step on spot interruption; save RNG states/optimizers/scalers. * Evaluation & Gating * After each epoch: run full **offline eval** on `val` + key **slices**; compute mAP/mIoU, per-class AP, trajectory ADE/FDE, lane F1, etc. * Produce **slice dashboards** (per weather/time/road class) and **regression checks** against last prod model. * Latency/throughput microbenchmarks: TorchScript export + dummy inference to report p50/p95 latency. * Logging & Metadata * Log all metrics/artifacts to **Weights & Biases**: run config, gradients, losses, images/videos, PR/ROC curves. * Register artifacts: dataset version (DVC hash), code SHA, Docker digest, checkpoints, exported models. * Testing inside the run (hard gates) * **Sanity checks**: one forward+backward batch before training; fail fast on NaNs or exploding loss. * **Determinism smoke** (seeded) on small shard; tolerance bands on metrics. * **Data drift guard**: compare batch feature stats vs. training reference (Evidently profile) and warn/abort on severe drift. * Post-Run Actions * Generate **model card** (data provenance, metrics, caveats) and attach to W\&B run & S3. * If gates pass, push **exported model** (TorchScript/ONNX) to artifact store and **Model Registry** (e.g., SageMaker Model Registry or W\&B Artifacts promoted alias). * Emit event to orchestration (Airflow/Step Functions) to trigger **#14 HPO** (if configured) or **#15 Eval/Sim** (next workflow). * **AWS/Tooling** * **SageMaker Training** (DDP), **ECR**, **S3**, **FSx for Lustre**, **CloudWatch Logs**, **IAM**, optional **EKS**. * **PyTorch** (DDP/torchrun), **NCCL**, **torch.cuda.amp**, **OpenMMLab/MMDetection** (optional), **Albumentations**. * **W\&B** for tracking/artifacts; **DVC** for dataset pinning; **Evidently** for drift profiles. * **Outputs** * Best checkpoint (`.pt`), plus **TorchScript/ONNX** exports; quantization-ready or TensorRT plan (optional). * `eval_report.json` (overall + per-slice metrics, latency), confusion matrices, PR curves. * **W\&B run** (summary, artifacts), **model card** (`model_card.md`), **training logs**. * Stored in **S3 (Gold)** under `/models///…`, and linked in **W\&B Artifacts** and registry. --- ### 14) Hyper-Parameter Optimization / Sweeps * **Trigger** * Auto: training workflow marks model as “candidate” but below target on one or more slices. * Scheduled: weekly sweeps on prioritized tasks (e.g., night/pedestrian detector). * Manual: engineer launches targeted sweep (e.g., loss weights for lane vs. detection). * **Inputs** * Same dataset manifests as #13 (or focused **slice packs** for the weak area). * Baseline config (`train.yaml`) with **search space**: LR/WD, warmup, aug policy, loss weights, NMS/score thresholds, backbone/neck options, EMA on/off, AMP level, batch size/accum steps. * **Sweep strategy**: Bayesian, Hyperband/ASHA, Random, or **Population-Based Training** (PBT) for long runs. * Resource budget: max trials, parallelism, GPU hours, early-stop policy, cost cap. * **Core Steps** * Orchestration & Budgeting * Create **W\&B Sweep** config (YAML) with objective metric (e.g., `val/mAP_weighted` or **multi-objective** with constraints: maximize `mAP_vehicle` subject to `latency_p95 < X ms` and `regression_Δslice < Y`). * Choose executor: * **Kubernetes Jobs** on EKS with the **W\&B agent** (elastic parallelism), **or** * **SageMaker** multiple training jobs tagged to the sweep (agent inside container). * Enforce **cost guardrails**: kill/truncate low-performers via **ASHA**; set CloudWatch alarms on spend. * Trial Runtime * Each trial inherits **DDP** setup from #13; parameters injected via env/CLI override. * Log full telemetry to W\&B: metrics, hyperparams, system stats (GPU util, memory), eval artifacts. * **Early stopping** on plateau or rule-based pruning (e.g., after 3 epochs if `val/mAP` < quantile Q). * Validation & Gating (per trial) * Same eval battery as #13, including **slice metrics** and **latency microbenchmarks**. * **Fairness/coverage checks**: ensure no >K% regression on protected or safety-critical slices. * **Stability check** (optional): re-run best few trials with a different seed on a small shard; require consistent ranking. * Selection & Promotion * W\&B **sweep dashboard** computes leaderboard; export **Pareto front** for multi-objective cases. * Auto-package **Top-K** models; re-evaluate on **holdout test** and simulation smoke set. * Generate **HPO report**: importance analysis (SHAP/Iceberg plots of hyperparams), budget used, recommended defaults. * Promote winner to **Model Registry** with tag `candidate//`; attach sweep summary and config freeze. * Learnings Back-Prop * Update default training config with **new best hyperparams** (per task/slice). * Persist tuned **augmentation policies** and **loss weights** into recipe library. * If all trials fail gating on a specific slice, emit **data request** back to #8/#12 to mine more examples. * **AWS/Tooling** * **EKS** (Jobs + W\&B agent) or **SageMaker** (parallel training jobs), **S3**, **FSx**, **CloudWatch**, **EventBridge** for triggers. * **W\&B Sweeps** (Bayes/Random/Hyperband/PBT), **PyTorch DDP**, **Ray Tune** (optional backend for advanced schedulers). * **Evidently** (sanity drift checks during trials), **Numba/TVM/TensorRT** (optional latency constraint evaluators). * **Outputs** * `sweep_report.md` + `sweep_summary.json` (leaderboard, Pareto set, importance). * Top-K model artifacts + exports; W\&B artifacts with pinned configs and datasets. * Registry update (winner promoted) + **promotion event** to downstream eval/simulation pipeline. ---