TECH-3501 FRD: stream_mixing + ledger observability metrics

≈2400 words · ≈11 min read · last updated 2026-05-12 · estimate revised 3pt → ≈ 5pt

TL;DR

Scope note. The original ticket was estimated at 3 points (stream_mixing-internal observability only). Through review the user added (a) detection of false ledger evictions — cross-subsystem, requires new ledger-side instrumentation; (b) a new cross-rank aggregator for stream_mixing metrics so the device_per_stream-vs-per_stream sanity-check actually works at cluster scale; (c) a SUM-mode extension to the shared reduce_loss_metrics utility. Revised estimate: ≈ 5 points. The split shows up in the Execution DAG in the Plan document.

Overview

After PR #820 (weighted K-active-pool pick, base commit b5fef955a) and TECH-3491 (pick_mode), StreamMixingDataset's per-step cell-batching behaviour is a black box at runtime — the only hook is debug_log_active_every_n_batches, which dumps identity tuples to stderr (no numbers, no quantitative signal, not in wandb). Separately, the GDM recurrent ledger evicts group_id entries in LRU and PERIODIC_STALE modes; today only the full-wipe count (ledger_wipe_count) is observable — incremental evictions are computed locally and discarded, and there's no signal when an evicted cell is still being trained on (silent corruption of recurrent state).

This ticket promotes both surfaces to first-class wandb scalars, scoped to what's high-signal for the failure modes we care about: detecting pathologies under device_per_stream mode, catching stuck readers from the weighted-pick tail, and surfacing false ledger evictions before they corrupt training.

Goals

  1. G1 — Active-pool visibility. Per-modality breakdown, absolute and fractional remaining-work distributions across active readers, stuck-reader detection.
  2. G2 — Refill visibility. Per-step exhaust rate.
  3. G3 — False-ledger-eviction detection. When the ledger evicts a gid that's still in flight, detect on re-encounter and surface as a metric. Owner-local (no cross-rank sync beyond what already exists in the loss-metric reduction pipe).
  4. G4 — Free observability win. Surface the already-computed-but-unlogged LRU and PERIODIC_STALE eviction counts.
  5. G5 — Cluster-accurate aggregation. Each metric uses its natural reduction (MIN, MAX, SUM) across DP ranks. No rank-0-biased reporting on metrics whose job is diagnosing distributed behaviour.
  6. G6 — Cheap. Per-step overhead < 50 µs; reuse the remaining_picks numpy array already allocated by PR #820 at dataset.py:702-710.
  7. G7 — Resume-neutral. No new persistent state. state_dict contract unchanged.

Use Cases

PersonaWorkflow
Researcher comparing mixing modes Opens two wandb runs side-by-side. Reads stream_mixing/active/remaining_fraction_max minus _min as the imbalance-spread; reads active/modalities/* to compare per-modality coverage. device_per_stream should show wider per-rank modality variance (covered by the aggregator's MIN/MAX).
ML engineer chasing a slow epoch Sees stream_mixing/active/steps_since_pick_max trending up. Cross-references with remaining_fraction_min: if both spike, the weighted-pick tail is dragging an almost-drained reader; if steps_since_pick_max is high but remaining_fraction_min isn't, a reader is genuinely stuck mid-drain.
Researcher debugging recurrent training corruption Notices loss drift on cell-types with long sequences. Reads ledger/false_evictions/stale — non-zero indicates the PERIODIC_STALE sweep is evicting cells the dataset still has in flight. Cross-checks ledger_stale_threshold setting against typical slice length.

The metric set (11 wandb scalars, +4 paired _max siblings)

Stream_mixing — 7 scalars

Flow: StreamMixingDataset.drain_step_metrics() → trainer calls StreamMixingDataset.aggregate_step_metrics() (new aggregator) → merged into extra_metricsMetricsProcessor.log(...) emits on rank-0.

KeySourceReductionHealthyPathological
stream_mixing/active/remaining_min remaining_picks.min() on the array already allocated at dataset.py:702-710 dist_min Decreases steadily toward 0 within an epoch; forecasts next exhaust Plateaus near 0 across many steps (stuck near-exhaust reader)
stream_mixing/active/remaining_max remaining_picks.max() dist_max Decreases steadily; forecasts slowest exhaust Way ahead of _min persistently (proportional draining broken)
stream_mixing/active/remaining_fraction_min min over active readers of (slice_end − position) / (slice_end − slice_start) — bs cancels, modality-independent dist_min Close to fraction_max (tight cluster) Far from fraction_max (drain imbalance, independent of slice size)
stream_mixing/active/remaining_fraction_max max of same dist_max Close to fraction_min Spread > 0.3 sustained (proportional draining broken)
stream_mixing/active/modalities/<modality> Counter(_modality_for_reader(r) for r in active) dist_sum per modality key Stable mix matching the configured stream proportions Modality count → 0 while config says it should be present (stream starvation)
stream_mixing/active/steps_since_pick_max max(current_step − last_picked_step[r.stream_id] for r in active) dist_max O(K) — each reader picked ≈ once every K steps; brief spikes during weighted-pick tail Sustained > 5K and growing (genuine stuck reader); read jointly with remaining_fraction_min to distinguish "stuck near exhaust" from "stuck mid-drain"
stream_mixing/refill/exhaust_events per-step delta of natural exhaustions (gated at dataset.py:755-760) dist_sum Small steady non-zero (steady-state churn) Spike (massive refill churn) or 0 across many steps (epoch dragging without progress)

Ledger — 4 scalars (+ _max siblings via reduce_loss_metrics)

Flow: counters live in RecurrentEncoder, populated at eviction sites and at _store_state fresh-write → returned in last_extra_metrics alongside existing ledger_wipe_countreduce_loss_metrics (new SUM mode) → emitted on rank-0 as cluster-wide totals + worst-rank.

KeyIncrements whenHealthyPathological
ledger/evictions/lru del ledger[oldest_gid] in _evict_lru_overflow (recurrent_encoder.py:1193) Steady non-zero under LRU mode (incremental eviction is the design) Zero under non-LRU modes; massive spike under LRU = ledger capacity too small
ledger/evictions/stale del ledger[gid] in _sweep_stale_entries (recurrent_encoder.py:660) when gs.last_step < cutoff Periodic small bursts at sweep cadence under PERIODIC_STALE Sustained high under PERIODIC_STALE = ledger_stale_threshold is too low for slice length
ledger/false_evictions/lru (headline) A gid in _recently_evicted[mode=LRU] arrives at _store_state with prev is None (recurrent_encoder.py:1115) 0 — LRU evictions should only fire on truly-stale entries Non-zero ⇒ LRU is too aggressive for the current per-rank cell touch-rate
ledger/false_evictions/stale (headline) Same, mode=STALE 0 — STALE evictions should only fire on cells the dataset has finished with Non-zero ⇒ ledger_stale_threshold shorter than slice consumption; silent recurrent-state corruption

Removed surface

ItemWhereReplacement
debug_log_active_every_n_batches config field job_config.py:2068 None — metrics fire every step
Factory pass-through factory.py:209
Dataset constructor param + validation + attribute dataset.py:236, 304, 320
Free-form INFO log line (identity tuples only) dataset.py:721-743 Quantitative wandb scalars cover the use cases
Three tests for the old knob test_dataset.py:643-691 New tests in test_metrics.py

Functional Requirements

IDLevelRequirement
Stream_mixing instrumentation
FR1 MUST Emit the 7 stream_mixing scalars listed in the metric table at the pick site (dataset.py:702-744). Use remaining_picks.min()/.max() on the existing array; compute remaining_fraction as (slice_end − position) / max(slice_end − slice_start, 1).
FR2 MUST Track self._last_picked_step: dict[stream_id, int] updated at the pick site (one dict write per step). On drain_step_metrics(), emit max(current_step − _last_picked_step[r.stream_id] for r in active). Cleared on reader exhaust.
FR3 MUST Add StreamMixingDataset.drain_step_metrics() -> dict[str, float]. Returns the most recent pick's local snapshot scalars plus the accumulated refill/exhaust_events delta, then clears the delta. Returns an empty dict if no pick has occurred since the last drain.
FR4 MUST Add StreamMixingDataset.aggregate_step_metrics(local: dict, world_size: int, device: torch.device) -> dict[str, float]. Mirrors DataStatsCollector.aggregate_distributed in shape: pack per-metric values into typed tensors, perform one all_reduce per op-kind (MIN / MAX / SUM), unpack. Empty-pool ranks contribute sentinels (large for MIN, 0 for MAX/SUM) per the data_stats.py:232 pattern. Single-rank path skips collectives.
FR5 MUST Trainer (train.py, near the existing data_stats_collector.aggregate_distributed(...) call around L842) drains local metrics, calls aggregate, merges into extra_metrics. Guarded so non-StreamMixingDataset training paths don't break.
FR6 MUST Extract the modality-resolution helper from the existing log path (dataset.py:729-732) into a private method: def _modality_for_reader(self, r: StreamReader) -> str handling both NESTED (self.stream_groups[r.stream_id]) and DIRECT (os.path.basename(os.path.dirname(r.path))) layouts. Reused by FR1.
FR7 MUST Hard-remove debug_log_active_every_n_batches config field, factory pass-through, dataset constructor param + validation + attribute, and the free-form log line at dataset.py:721-743. Remove the three tests at test_dataset.py:643-691.
FR8 MUST Per-step overhead < 50 µs. Permitted additions: Counter over active readers (size ≤ K), min/median/max on existing remaining_picks array, fractional computation per reader (K subtractions+divisions), one dict update for _last_picked_step, one comparison-and-update for the max-steps-since-pick. All sub-microsecond at K ≤ 8.
FR9 MUST Resume-reshard truncation at dataset.py:653 (saved_active[:num_streams_mixed]) emits a single INFO log entry at restore: "[stream_mixing] resume reshard dropped %d still-active readers" when the truncated count > 0. Not a per-step metric.
Ledger false-eviction detection
FR10 MUST Add RecurrentEncoder._recently_evicted: dict[int, EvictionMode] mapping gid → mode that evicted it. Populated at each eviction site:
  • recurrent_encoder.py:1193 — LRU (each del ledger[oldest_gid] in _evict_lru_overflow).
  • recurrent_encoder.py:660 — STALE (each del ledger[gid] in _sweep_stale_entries).
Cleared on full-wipe (recurrent_encoder.py:635) — wipe-and-rebuild is intentional, so the wiped gids aren't tracked as candidates for false-eviction detection.
FR11 MUST Instrument _store_state at recurrent_encoder.py:1115: where prev = self.group_state_ledger.get(group_id) returns None (fresh-write branch), check self._recently_evicted.pop(group_id, None) and increment self._false_evict_counts[mode] for the popped mode. Single instrumentation site. Owner-only by enforcement at recurrent_encoder.py:1108-1114 — so the eviction-and-re-store join is purely owner-local with zero cross-rank coordination.
FR12 MUST Surface ledger/evictions/{lru,stale} (per-step delta of total evictions per mode — these were already computed locally but never logged; _evict_lru_overflow already returns the count at L1195) and ledger/false_evictions/{lru,stale} (the new headline metrics) via the recurrent encoder's last_extra_metrics dict (next to existing ledger_wipe_count). No new emission channel.
FR13 MUST Extend reduce_loss_metrics in prima_foundations/distributed/utils.py:84 with an opt-in SUM mode: new kwarg sum_keys: set[str] | None = None. For keys in sum_keys: emit <key> = dist_sum(value) (cluster total) and <key>_max = dist_max(value) (worst rank). Keys not in sum_keys behave exactly as today (mean / count-weighted modes). Counter metrics (ledger/evictions/*, ledger/false_evictions/*) opt into SUM mode via this kwarg from the call site at train.py:821.
Tests + verification
FR14 MUST Unit tests in a new file tests/unit_tests/datasets/stream_mixing/test_metrics.py:
  • Per-step metric shape: all 7 keys present, correct types, expected values on a synthetic 4-stream fixture.
  • Modality resolution for NESTED + DIRECT layouts.
  • remaining_fraction is bs-independent (assertion on two streams with different bs).
  • steps_since_pick_max increments per step for unpicked readers, resets on pick, drops on exhaust.
  • exhaust_events reset semantics on drain; empty-dict behaviour before first pick.
  • aggregate_step_metrics on world_size=1 path = identity; multi-rank path mocked via packed tensors.
  • Empty-pool rank's MIN/MAX sentinels don't bias the cluster reduction.
FR15 MUST Unit tests in tests/unit_tests/models/pleiades2/test_ledger_false_evictions.py:
  • LRU eviction of a gid that's never re-encountered → false_evictions/lru stays 0.
  • LRU eviction of a gid followed by _store_state with prev is Nonefalse_evictions/lru = 1; gid popped from _recently_evicted so subsequent re-stores don't double-count.
  • STALE eviction → false-eviction → counter increments for the right mode.
  • Full-wipe clears _recently_evicted (subsequent re-stores after wipe don't count as false evictions).
  • reduce_loss_metrics SUM mode round-trips a known per-rank value to cluster-wide sum + worst-rank max on world_size = 4 (gloo backend).
FR16 SHOULD Healthy smoke run on prima_foundations/models/pleiades2/train_configs/gdm/2026_05_07_per_modality_packing_smoke/1node_ddp_per_modality_packing_smoke.toml (NESTED, K=4): verify all 11 metrics appear in wandb, headline ledger metrics stay at 0, stream_mixing/active/remaining_fraction_min/_max stay close, steps_since_pick_max stays bounded.
FR17 SHOULD Diagnostic-smoke matrix per below. Three TOMLs under train_configs/gdm/2026_05_12_stream_mixing_diagnostics/, each twisting one knob, each shipping with a wandb-panel screenshot in the PR. Not CI-bound; opt-in for researchers.

Scope decisions (deltas from the ticket text)

Original ACv1 decisionWhy
data/tokens_seen/<modality> cumulative across resume Dropped from v1. DataStatsCollector's existing per-window data/modality_total_tokens/<modality> is unchanged. DataStatsCollector has no state_dict infra today; standing it up and wiring through the checkpoint manager is disproportionate to this ticket's intent. The dataset doesn't know tokens (only cells); tokens are computed downstream in the collator.
stream_mixing/refill/queue_idx — perm-queue progress Dropped. Redundant. Carries the same information as cumulative refill/exhaust_events (each exhaust → one queue_idx advance). Two metrics for one signal is bloat.
stream_mixing/active/n_distinct|active| Dropped. Redundant. Equals sum(stream_mixing/active/modalities/*) — the fan-out already captures it.
stream_mixing/refill/exhaust_events "since last log", resume-safe cumulative Reinterpreted as per-step delta reset on drain. Cumulative across resume needs persistent state (out of scope per user direction). Per-step delta is trivially resume-neutral; wandb's diff tooling makes cumulative reconstruction easy downstream.
"Counters survive state_dict / load_state_dict" No new persistent state; state_dict contract unchanged. Direct consequence of the above two.
"Cadence configurable via debug_log_active_every_n_batches" Knob hard-removed; metrics emit every step. Per-step emission is consistent with how DataStatsCollector metrics flow today, and the scalars are cheap. Grep shows zero TOMLs set the old knob.
"No new cross-rank sync" Reframed: use existing reduction pipes; one new opt-in mode added to a shared utility. Initial framing was a misunderstanding — the codebase already has two cross-rank reduction pipes (DataStatsCollector.aggregate_distributed and last_extra_metrics → reduce_loss_metrics). The new aggregate_step_metrics mirrors the first; the SUM-mode extension is opt-in and additive to the second.
Original AC didn't mention false ledger evictions Added in scope during user review. High-signal cross-subsystem invariant: ledger eviction races against dataset draining could silently corrupt recurrent training. Detected via the inverted check at _store_state — minimal surface for the bug class it catches.

Non-Goals

Diagnostic Smoke Matrix (FR17)

Each diagnostic config is a short multi-modality smoke (1 node, ≤ 200 steps, DDP, NESTED, default model) intended to make one metric break out of its healthy range. Lives under train_configs/gdm/2026_05_12_stream_mixing_diagnostics/, not promoted to production. PR ships wandb-panel screenshots for each.

Config (filename)Knob twistedExpected pathologyMetric(s) that should light up
diag_imbalanced_bs.toml Per-stream batch sizes deliberately imbalanced — e.g. one modality bs=1, others bs=64. Proportional draining must keep remaining_fraction aligned across readers despite the bs gulf; the smallest-bs stream gets weight-throttled at the tail and exhibits high steps_since_pick_max; the modality count distribution is lopsided. active/modalities/* wide spread; steps_since_pick_max grows for the small-bs reader at the tail; remaining_fraction_min/_max stay close (proves draining is healthy under this stress)
diag_exhaust_spike.toml Drastically reduced per-stream slice length (subsample down to ~ K × batches_per_step × 5 cells). Readers exhaust repeatedly; refill churn dominates per-step behaviour. refill/exhaust_events non-zero almost every step
diag_stale_false_evict.toml batching_mode = "per_stream", ledger_wipe_mode = "periodic_stale", ledger_stale_threshold very low (e.g. 5 steps — well below typical slice length). PERIODIC_STALE sweep evicts ledger entries for gids the dataset is still mid-drain on. ledger/false_evictions/stale non-zero within the first ≈ 20 steps; ledger/evictions/stale ticking alongside.

Technical Considerations

Ledger integration shape (FR10–FR13) — the inverted design

Why the naive join fails. The ledger keys on group_id: int (extracted by the processor at processors/sc_multimodal/processor.py:172 via element.get_group_id()downstream of the dataloader). The dataset keys on stream_id: int (positional index into stream_paths) and the path basename string. There is no stream_id → group_id mapping; one stream contains many cells contains many gids.

The inverted idea. Don't try to track "currently active gids" — that's hard and expensive. Instead, track "recently evicted gids" on the ledger side. When a fresh-write happens (the dataset sent us a gid the ledger hadn't seen, or had seen and evicted), check whether we just evicted that gid. If yes → false eviction.

Where the join happens. One single instrumentation site: _store_state at recurrent_encoder.py:1115:

prev = self.group_state_ledger.get(group_id)
+ if prev is None:
+     mode = self._recently_evicted.pop(group_id, None)
+     if mode is not None:
+         self._false_evict_counts[mode] += 1
# ... existing fresh-init logic continues ...

Owner-locality. _store_state is owner-only — enforced at L1108-1114 (raise if non-owner attempts a store). The three eviction sites (LRU at L1193, STALE at L660, full-wipe at L635) all operate on the local shard. So eviction and re-store both happen on the owning rank for any given gid. The intra-rank join is exact, not an approximation; there is no cross-rank gap.

Why _lookup_local at L893 is not the instrumentation site. The miss path at L893 (if group_id not in self.group_state_ledger) fires on every rank for every non-owned gid, since non-owners never hold the entry. Using L893 would over-count by world_size and conflate "non-owner doesn't shard this gid" with "owner evicted it." _store_state is owner-gated by construction, which is exactly what we want.

Memory. _recently_evicted grows on eviction, pops on re-store, clears on full-wipe. Bounded by "currently-evicted but not yet re-encountered" gids — at most unique_gids_owned − len(ledger). Empirically a few hundred ints in steady state.

Distributed semantics. The counters are per-rank-owned-gid quantities. After SUM-mode reduction in reduce_loss_metrics, the wandb keys show cluster-wide totals (the natural reading: "the cluster suffered N false evictions this step"). The paired _max shows the worst single rank — useful for hotspot diagnosis (one rank with a hostile gid distribution).

Stream_mixing aggregator (FR4) — packed all-reduce shape

Mirrors DataStatsCollector.aggregate_distributed at data_stats.py:170. Pseudocode:

def aggregate_step_metrics(local, world_size, device):
    if world_size == 1:
        return local                              # identity, no collectives

    # Pack scalars by op-kind into typed tensors of fixed shape.
    # Modality count is sparse — flatten over the project's known modality enum
    # (matches data_stats.py:209) so all ranks have identical shapes.
    all_modalities = sorted(MODALITY_TO_INT.keys())

    mins = torch.tensor([local["remaining_min"],
                         local["remaining_fraction_min"]],
                        dtype=torch.float32, device=device)
    maxs = torch.tensor([local["remaining_max"],
                         local["remaining_fraction_max"],
                         local["steps_since_pick_max"]],
                        dtype=torch.float32, device=device)
    sums = torch.zeros(len(all_modalities) + 1,  # +1 for exhaust_events
                       dtype=torch.float32, device=device)
    for i, m in enumerate(all_modalities):
        sums[i] = local["modalities"].get(m, 0)
    sums[-1] = local["exhaust_events"]

    # Empty-pool ranks contribute sentinels (per data_stats.py:232).
    if local_active_empty:
        mins.fill_(float("inf"))
        maxs.fill_(0.0)

    dist.all_reduce(mins, op=dist.ReduceOp.MIN)
    dist.all_reduce(maxs, op=dist.ReduceOp.MAX)
    dist.all_reduce(sums, op=dist.ReduceOp.SUM)
    # Unpack + emit only non-zero modality fan-outs.
    ...

Three collectives per log step. Each is a small fixed-shape tensor; total bytes << gradient all-reduces happening every step. Negligible.

Cost analysis (FR8)

Per-step additions to the pick site:

Total << 50 µs budget per step. No tensor ops in the hot path beyond what PR #820 already does. No GPU sync, no CPU↔GPU transfers.

Worker-process assumption

The dataset already documents at dataset.py:79-82 that mid-epoch resume requires num_workers=0. The drain_step_metrics() + aggregate_step_metrics() APIs inherit this — both work because the dataset lives in the main process and the trainer reads its attributes directly. If num_workers > 0 support is ever added, drain semantics will need a queue or batch-dict piggyback design (out of scope for v1).

Test footprint

Two new test files (see FR14, FR15) and one deletion (FR7 removes three obsolete tests).

Open Questions (deferred to implementation, not blocking spec)