Model Training & Experimentation¶
¶
13) Distributed Training¶
Trigger
Automatic: new
golden_train/val/test.manifest
from Workflow #11, or new auto-mined slice from #12.Manual: engineer launches a training run from the UI/CLI, selecting dataset spec + model recipe.
Scheduled: nightly/weekly rebuild on rolling data window.
Inputs
Curated dataset manifests +
slices.yaml
(from #11), plus DVC tag/version.Model recipe (backbone, heads, loss weights, augmentations) and training config (
train.yaml
).Pretrained weights (optional) for warm-start or continual learning.
Code container image (ECR) and commit SHA; W&B project/entity; environment secrets.
Hardware profile (e.g.,
p4d.24xlarge
× N nodes) and storage profile (FSx for Lustre).
Core Steps
Packaging & Staging
Build/push Docker image with PyTorch, NCCL, CUDA, AMP, torchvision/OpenMMLab, and your repo.
Stage dataset to FSx for Lustre (for I/O throughput) using DVC or AWS DataSync; keep bulk in S3.
Pre-generate sharded WebDataset/TFRecord (optional) for faster streaming.
Orchestration
Launch SageMaker distributed training job (preferred) or EKS Job with
torchrun
(DDP).Configure NCCL and networking (env:
NCCL_DEBUG=INFO
,NCCL_SOCKET_IFNAME=eth0
,NCCL_ASYNC_ERROR_HANDLING=1
).
Training Runtime
Initialize DDP with
backend=nccl
, settorch.set_float32_matmul_precision("high")
where supported.Data loading: multi-worker
DataLoader
+ persistent workers, pinned memory, async prefetch.Mixed precision (AMP) + gradient scaling; gradient accumulation for large effective batch sizes.
Task balancing for multi-head models (e.g., HydraNet-style): dynamic or curriculum weighting.
Checkpointing: local → FSx (fast) every N steps; periodic sync to S3 (
s3://…/checkpoints/…
).Fault tolerance: resume from last global step on spot interruption; save RNG states/optimizers/scalers.
Evaluation & Gating
After each epoch: run full offline eval on
val
+ key slices; compute mAP/mIoU, per-class AP, trajectory ADE/FDE, lane F1, etc.Produce slice dashboards (per weather/time/road class) and regression checks against last prod model.
Latency/throughput microbenchmarks: TorchScript export + dummy inference to report p50/p95 latency.
Logging & Metadata
Log all metrics/artifacts to Weights & Biases: run config, gradients, losses, images/videos, PR/ROC curves.
Register artifacts: dataset version (DVC hash), code SHA, Docker digest, checkpoints, exported models.
Testing inside the run (hard gates)
Sanity checks: one forward+backward batch before training; fail fast on NaNs or exploding loss.
Determinism smoke (seeded) on small shard; tolerance bands on metrics.
Data drift guard: compare batch feature stats vs. training reference (Evidently profile) and warn/abort on severe drift.
Post-Run Actions
Generate model card (data provenance, metrics, caveats) and attach to W&B run & S3.
If gates pass, push exported model (TorchScript/ONNX) to artifact store and Model Registry (e.g., SageMaker Model Registry or W&B Artifacts promoted alias).
Emit event to orchestration (Airflow/Step Functions) to trigger #14 HPO (if configured) or #15 Eval/Sim (next workflow).
AWS/Tooling
SageMaker Training (DDP), ECR, S3, FSx for Lustre, CloudWatch Logs, IAM, optional EKS.
PyTorch (DDP/torchrun), NCCL, torch.cuda.amp, OpenMMLab/MMDetection (optional), Albumentations.
W&B for tracking/artifacts; DVC for dataset pinning; Evidently for drift profiles.
Outputs
Best checkpoint (
.pt
), plus TorchScript/ONNX exports; quantization-ready or TensorRT plan (optional).eval_report.json
(overall + per-slice metrics, latency), confusion matrices, PR curves.W&B run (summary, artifacts), model card (
model_card.md
), training logs.Stored in S3 (Gold) under
/models/<project>/<semver or run_id>/…
, and linked in W&B Artifacts and registry.
14) Hyper-Parameter Optimization / Sweeps¶
Trigger
Auto: training workflow marks model as “candidate” but below target on one or more slices.
Scheduled: weekly sweeps on prioritized tasks (e.g., night/pedestrian detector).
Manual: engineer launches targeted sweep (e.g., loss weights for lane vs. detection).
Inputs
Same dataset manifests as #13 (or focused slice packs for the weak area).
Baseline config (
train.yaml
) with search space: LR/WD, warmup, aug policy, loss weights, NMS/score thresholds, backbone/neck options, EMA on/off, AMP level, batch size/accum steps.Sweep strategy: Bayesian, Hyperband/ASHA, Random, or Population-Based Training (PBT) for long runs.
Resource budget: max trials, parallelism, GPU hours, early-stop policy, cost cap.
Core Steps
Orchestration & Budgeting
Create W&B Sweep config (YAML) with objective metric (e.g.,
val/mAP_weighted
or multi-objective with constraints: maximizemAP_vehicle
subject tolatency_p95 < X ms
andregression_Δslice < Y
).Choose executor:
Kubernetes Jobs on EKS with the W&B agent (elastic parallelism), or
SageMaker multiple training jobs tagged to the sweep (agent inside container).
Enforce cost guardrails: kill/truncate low-performers via ASHA; set CloudWatch alarms on spend.
Trial Runtime
Each trial inherits DDP setup from #13; parameters injected via env/CLI override.
Log full telemetry to W&B: metrics, hyperparams, system stats (GPU util, memory), eval artifacts.
Early stopping on plateau or rule-based pruning (e.g., after 3 epochs if
val/mAP
< quantile Q).
Validation & Gating (per trial)
Same eval battery as #13, including slice metrics and latency microbenchmarks.
Fairness/coverage checks: ensure no >K% regression on protected or safety-critical slices.
Stability check (optional): re-run best few trials with a different seed on a small shard; require consistent ranking.
Selection & Promotion
W&B sweep dashboard computes leaderboard; export Pareto front for multi-objective cases.
Auto-package Top-K models; re-evaluate on holdout test and simulation smoke set.
Generate HPO report: importance analysis (SHAP/Iceberg plots of hyperparams), budget used, recommended defaults.
Promote winner to Model Registry with tag
candidate/<task>/<date>
; attach sweep summary and config freeze.
Learnings Back-Prop
Update default training config with new best hyperparams (per task/slice).
Persist tuned augmentation policies and loss weights into recipe library.
If all trials fail gating on a specific slice, emit data request back to #8/#12 to mine more examples.
AWS/Tooling
EKS (Jobs + W&B agent) or SageMaker (parallel training jobs), S3, FSx, CloudWatch, EventBridge for triggers.
W&B Sweeps (Bayes/Random/Hyperband/PBT), PyTorch DDP, Ray Tune (optional backend for advanced schedulers).
Evidently (sanity drift checks during trials), Numba/TVM/TensorRT (optional latency constraint evaluators).
Outputs
sweep_report.md
+sweep_summary.json
(leaderboard, Pareto set, importance).Top-K model artifacts + exports; W&B artifacts with pinned configs and datasets.
Registry update (winner promoted) + promotion event to downstream eval/simulation pipeline.