← cs
$ cat projects/Straggler-Aware-Scheduler.md

Straggler-Aware Scheduler for Distributed Training

Persistence-filtered straggler detection and adaptive rate allocation for gradient synchronization.

2024-12-15
PythonPyTorchGlooDistributed Systems

Straggler-Aware Scheduler for Distributed Training

Optimizes collective completion time instead of per-flow fairness for gradient synchronization. 45% iteration time reduction under persistent stragglers, <1% overhead under transient conditions.

The Problem

Distributed training iteration time = max flow time:

T_iter = max_{i ∈ [1,N]} T_flow_i

Traditional congestion control (DCTCP, TIMELY) optimizes per-flow fairness—harmful here since speeding up fast flows does nothing. Naive straggler detection causes oscillation under microbursts.

What I Built

Persistence-Filtered Detection:

streak_i = streak_i + 1 if T_i > 1.2 × T_med else 0
confirmed_straggler = (streak_i ≥ 3)
  • Only trigger after K=3 consecutive slow iterations
  • Filters transient slowdowns (microbursts, CPU spikes)
  • Median-based threshold adapts to background load

Adaptive Rate Reallocation:

r'_i = r_i × (1 + 0.3) if straggler else r_i × (1 - 0.15 × |S|/(N-|S|))
  • Asymmetric: help stragglers aggressively (α=0.3), penalize donors gently (β=0.15)
  • α > 2β directly reduces T_iter

Gradual Recovery:

r_{t+1} = 0.5 × r_t + 0.5 × r_base
  • Exponential moving average prevents bounce-back oscillation
  • Recovery slower than punishment

Cooldown:

  • Wait 5 iterations after reallocation before next adjustment
  • Prevents rapid oscillation

Architecture

Custom Ring All-Reduce:

  • Point-to-point send/recv for per-worker timing hooks
  • N-1 stages: scatter-reduce, then all-gather
  • Enables rate manipulation via delays

Network Model:

  • Simulated via delay_i = d_base / r_i (blocking sleep)
  • No real incast/drops (tests control logic, not transport)

Four Profiles:

  • Uniform: all 10ms
  • Straggler: one 3× persistent
  • Variable: N(d_base, σ²) per iteration
  • Bursty: random 5× with 20% probability

Results

| Profile | Baseline | Ours | Improvement | |---------|----------|------|-------------| | Straggler | 1298ms | 717ms | 44.8% | | Variable | 268ms | 268ms | <1% | | Bursty | 747ms | 750ms | <1% |

Straggler Profile CDF:

  • Median: 1314ms → 708ms (46% improvement)
  • p99: 1365ms → 1001ms (27% improvement)

Ablation (K threshold on bursty):

  • K=1: 878ms, 179 reallocations → +23% regression
  • K=3: 750ms, 5 reallocations → <1% overhead

Implementation

Workload:

  • Small CNN (201K params) on CIFAR-10
  • Communication 70% of iteration time (realistic for large-scale)

Statistical Testing:

  • 5 runs × 250 iterations = 1,240 per config
  • Welch's t-test: p < 0.0001 for straggler improvement

Lessons Learned

Systems:

  • Persistence filtering critical—instant reaction causes oscillation
  • Asymmetric adjustment: help stragglers aggressively, penalize donors gently
  • Cooldown prevents thrashing
  • Median threshold adapts to background automatically

Distributed Training:

  • Max-of-flows metric correct for barriers (not mean/sum)
  • Per-flow fairness wrong for collective operations

Evaluation:

  • Test failure modes (stragglers, bursts) not just normal case
  • Ablation shows why K≥3 works

Limitations & Future Work

Current:

  • Simulated network (sleep not congestion)
  • Single machine (processes not nodes)
  • Relative detection (all-slow undetected)

Future:

  • Real transport (ECN/RTT signals → cwnd)
  • Distributed cluster evaluation
  • Distinguish compute vs. network stragglers
  • Multi-job fairness