stream_mixing + ledger observability metricsdebug_log_active_every_n_batches log line (just identity tuples, no numbers) with 7 first-class stream_mixing wandb scalars covering active-pool composition, remaining-work distribution, stuck-reader detection, and refill churn._max siblings) detecting false evictions: the ledger evicts a group_id, then immediately re-encounters it because the dataset wasn't done draining the cell. Implemented as an inverted check at the ledger's fresh-write site (recurrent_encoder.py:1115) — no cross-subsystem coupling, no dataset-side gid extraction, owner-local by construction.StreamMixingDataset.aggregate_step_metrics() aggregator that mirrors DataStatsCollector.aggregate_distributed: each metric picks its natural reduction (MIN, MAX, SUM). One packed all-reduce per op-kind per log step. Without it the metrics would be rank-0-biased — the very device_per_stream-vs-per_stream sanity-check the ticket exists to enable would fail.last_extra_metrics → reduce_loss_metrics pipe (the same one ledger_wipe_count uses). Adds a new opt-in SUM mode to reduce_loss_metrics so cluster-wide counter totals emit correctly on rank-0 instead of dist_mean.debug_log_active_every_n_batches config field (grep: zero TOMLs set it; default 0). Three diagnostic-smoke TOMLs (imbalanced bs, exhaust spike, low ledger_stale_threshold) verify each metric actually detects its target pathology, not just renders in wandb.data/tokens_seen/<modality> and data/cells_seen/<modality>. DataStatsCollector has no state_dict infra today; standing one up is out of proportion. All v1 metrics are per-step snapshots; no new persistent state.reduce_loss_metrics utility. Revised estimate: ≈ 5 points. The split shows up in the Execution DAG in the Plan document.
After PR #820 (weighted K-active-pool pick, base commit b5fef955a) and TECH-3491 (pick_mode), StreamMixingDataset's per-step cell-batching behaviour is a black box at runtime — the only hook is debug_log_active_every_n_batches, which dumps identity tuples to stderr (no numbers, no quantitative signal, not in wandb). Separately, the GDM recurrent ledger evicts group_id entries in LRU and PERIODIC_STALE modes; today only the full-wipe count (ledger_wipe_count) is observable — incremental evictions are computed locally and discarded, and there's no signal when an evicted cell is still being trained on (silent corruption of recurrent state).
This ticket promotes both surfaces to first-class wandb scalars, scoped to what's high-signal for the failure modes we care about: detecting pathologies under device_per_stream mode, catching stuck readers from the weighted-pick tail, and surfacing false ledger evictions before they corrupt training.
remaining_picks numpy array already allocated by PR #820 at dataset.py:702-710.state_dict contract unchanged.| Persona | Workflow |
|---|---|
| Researcher comparing mixing modes | Opens two wandb runs side-by-side. Reads stream_mixing/active/remaining_fraction_max minus _min as the imbalance-spread; reads active/modalities/* to compare per-modality coverage. device_per_stream should show wider per-rank modality variance (covered by the aggregator's MIN/MAX). |
| ML engineer chasing a slow epoch | Sees stream_mixing/active/steps_since_pick_max trending up. Cross-references with remaining_fraction_min: if both spike, the weighted-pick tail is dragging an almost-drained reader; if steps_since_pick_max is high but remaining_fraction_min isn't, a reader is genuinely stuck mid-drain. |
| Researcher debugging recurrent training corruption | Notices loss drift on cell-types with long sequences. Reads ledger/false_evictions/stale — non-zero indicates the PERIODIC_STALE sweep is evicting cells the dataset still has in flight. Cross-checks ledger_stale_threshold setting against typical slice length. |
_max siblings)Flow: StreamMixingDataset.drain_step_metrics() → trainer calls StreamMixingDataset.aggregate_step_metrics() (new aggregator) → merged into extra_metrics → MetricsProcessor.log(...) emits on rank-0.
| Key | Source | Reduction | Healthy | Pathological |
|---|---|---|---|---|
| stream_mixing/active/remaining_min | remaining_picks.min() on the array already allocated at dataset.py:702-710 |
dist_min |
Decreases steadily toward 0 within an epoch; forecasts next exhaust | Plateaus near 0 across many steps (stuck near-exhaust reader) |
| stream_mixing/active/remaining_max | remaining_picks.max() |
dist_max |
Decreases steadily; forecasts slowest exhaust | Way ahead of _min persistently (proportional draining broken) |
| stream_mixing/active/remaining_fraction_min | min over active readers of (slice_end − position) / (slice_end − slice_start) — bs cancels, modality-independent |
dist_min |
Close to fraction_max (tight cluster) |
Far from fraction_max (drain imbalance, independent of slice size) |
| stream_mixing/active/remaining_fraction_max | max of same | dist_max |
Close to fraction_min |
Spread > 0.3 sustained (proportional draining broken) |
| stream_mixing/active/modalities/<modality> | Counter(_modality_for_reader(r) for r in active) |
dist_sum per modality key |
Stable mix matching the configured stream proportions | Modality count → 0 while config says it should be present (stream starvation) |
| stream_mixing/active/steps_since_pick_max | max(current_step − last_picked_step[r.stream_id] for r in active) |
dist_max |
O(K) — each reader picked ≈ once every K steps; brief spikes during weighted-pick tail | Sustained > 5K and growing (genuine stuck reader); read jointly with remaining_fraction_min to distinguish "stuck near exhaust" from "stuck mid-drain" |
| stream_mixing/refill/exhaust_events | per-step delta of natural exhaustions (gated at dataset.py:755-760) |
dist_sum |
Small steady non-zero (steady-state churn) | Spike (massive refill churn) or 0 across many steps (epoch dragging without progress) |
_max siblings via reduce_loss_metrics)Flow: counters live in RecurrentEncoder, populated at eviction sites and at _store_state fresh-write → returned in last_extra_metrics alongside existing ledger_wipe_count → reduce_loss_metrics (new SUM mode) → emitted on rank-0 as cluster-wide totals + worst-rank.
| Key | Increments when | Healthy | Pathological |
|---|---|---|---|
| ledger/evictions/lru | del ledger[oldest_gid] in _evict_lru_overflow (recurrent_encoder.py:1193) |
Steady non-zero under LRU mode (incremental eviction is the design) | Zero under non-LRU modes; massive spike under LRU = ledger capacity too small |
| ledger/evictions/stale | del ledger[gid] in _sweep_stale_entries (recurrent_encoder.py:660) when gs.last_step < cutoff |
Periodic small bursts at sweep cadence under PERIODIC_STALE | Sustained high under PERIODIC_STALE = ledger_stale_threshold is too low for slice length |
| ledger/false_evictions/lru (headline) | A gid in _recently_evicted[mode=LRU] arrives at _store_state with prev is None (recurrent_encoder.py:1115) |
0 — LRU evictions should only fire on truly-stale entries | Non-zero ⇒ LRU is too aggressive for the current per-rank cell touch-rate |
| ledger/false_evictions/stale (headline) | Same, mode=STALE | 0 — STALE evictions should only fire on cells the dataset has finished with | Non-zero ⇒ ledger_stale_threshold shorter than slice consumption; silent recurrent-state corruption |
| Item | Where | Replacement |
|---|---|---|
debug_log_active_every_n_batches config field |
job_config.py:2068 |
None — metrics fire every step |
| Factory pass-through | factory.py:209 |
— |
| Dataset constructor param + validation + attribute | dataset.py:236, 304, 320 |
— |
| Free-form INFO log line (identity tuples only) | dataset.py:721-743 |
Quantitative wandb scalars cover the use cases |
| Three tests for the old knob | test_dataset.py:643-691 |
New tests in test_metrics.py |
| ID | Level | Requirement |
|---|---|---|
| Stream_mixing instrumentation | ||
| FR1 | MUST | Emit the 7 stream_mixing scalars listed in the metric table at the pick site (dataset.py:702-744). Use remaining_picks.min()/.max() on the existing array; compute remaining_fraction as (slice_end − position) / max(slice_end − slice_start, 1). |
| FR2 | MUST | Track self._last_picked_step: dict[stream_id, int] updated at the pick site (one dict write per step). On drain_step_metrics(), emit max(current_step − _last_picked_step[r.stream_id] for r in active). Cleared on reader exhaust. |
| FR3 | MUST | Add StreamMixingDataset.drain_step_metrics() -> dict[str, float]. Returns the most recent pick's local snapshot scalars plus the accumulated refill/exhaust_events delta, then clears the delta. Returns an empty dict if no pick has occurred since the last drain. |
| FR4 | MUST | Add StreamMixingDataset.aggregate_step_metrics(local: dict, world_size: int, device: torch.device) -> dict[str, float]. Mirrors DataStatsCollector.aggregate_distributed in shape: pack per-metric values into typed tensors, perform one all_reduce per op-kind (MIN / MAX / SUM), unpack. Empty-pool ranks contribute sentinels (large for MIN, 0 for MAX/SUM) per the data_stats.py:232 pattern. Single-rank path skips collectives. |
| FR5 | MUST | Trainer (train.py, near the existing data_stats_collector.aggregate_distributed(...) call around L842) drains local metrics, calls aggregate, merges into extra_metrics. Guarded so non-StreamMixingDataset training paths don't break. |
| FR6 | MUST | Extract the modality-resolution helper from the existing log path (dataset.py:729-732) into a private method: def _modality_for_reader(self, r: StreamReader) -> str handling both NESTED (self.stream_groups[r.stream_id]) and DIRECT (os.path.basename(os.path.dirname(r.path))) layouts. Reused by FR1. |
| FR7 | MUST | Hard-remove debug_log_active_every_n_batches config field, factory pass-through, dataset constructor param + validation + attribute, and the free-form log line at dataset.py:721-743. Remove the three tests at test_dataset.py:643-691. |
| FR8 | MUST | Per-step overhead < 50 µs. Permitted additions: Counter over active readers (size ≤ K), min/median/max on existing remaining_picks array, fractional computation per reader (K subtractions+divisions), one dict update for _last_picked_step, one comparison-and-update for the max-steps-since-pick. All sub-microsecond at K ≤ 8. |
| FR9 | MUST | Resume-reshard truncation at dataset.py:653 (saved_active[:num_streams_mixed]) emits a single INFO log entry at restore: "[stream_mixing] resume reshard dropped %d still-active readers" when the truncated count > 0. Not a per-step metric. |
| Ledger false-eviction detection | ||
| FR10 | MUST | Add RecurrentEncoder._recently_evicted: dict[int, EvictionMode] mapping gid → mode that evicted it. Populated at each eviction site:
recurrent_encoder.py:635) — wipe-and-rebuild is intentional, so the wiped gids aren't tracked as candidates for false-eviction detection. |
| FR11 | MUST | Instrument _store_state at recurrent_encoder.py:1115: where prev = self.group_state_ledger.get(group_id) returns None (fresh-write branch), check self._recently_evicted.pop(group_id, None) and increment self._false_evict_counts[mode] for the popped mode. Single instrumentation site. Owner-only by enforcement at recurrent_encoder.py:1108-1114 — so the eviction-and-re-store join is purely owner-local with zero cross-rank coordination. |
| FR12 | MUST | Surface ledger/evictions/{lru,stale} (per-step delta of total evictions per mode — these were already computed locally but never logged; _evict_lru_overflow already returns the count at L1195) and ledger/false_evictions/{lru,stale} (the new headline metrics) via the recurrent encoder's last_extra_metrics dict (next to existing ledger_wipe_count). No new emission channel. |
| FR13 | MUST | Extend reduce_loss_metrics in prima_foundations/distributed/utils.py:84 with an opt-in SUM mode: new kwarg sum_keys: set[str] | None = None. For keys in sum_keys: emit <key> = dist_sum(value) (cluster total) and <key>_max = dist_max(value) (worst rank). Keys not in sum_keys behave exactly as today (mean / count-weighted modes). Counter metrics (ledger/evictions/*, ledger/false_evictions/*) opt into SUM mode via this kwarg from the call site at train.py:821. |
| Tests + verification | ||
| FR14 | MUST | Unit tests in a new file tests/unit_tests/datasets/stream_mixing/test_metrics.py:
|
| FR15 | MUST | Unit tests in tests/unit_tests/models/pleiades2/test_ledger_false_evictions.py:
|
| FR16 | SHOULD | Healthy smoke run on prima_foundations/models/pleiades2/train_configs/gdm/2026_05_07_per_modality_packing_smoke/1node_ddp_per_modality_packing_smoke.toml (NESTED, K=4): verify all 11 metrics appear in wandb, headline ledger metrics stay at 0, stream_mixing/active/remaining_fraction_min/_max stay close, steps_since_pick_max stays bounded. |
| FR17 | SHOULD | Diagnostic-smoke matrix per below. Three TOMLs under train_configs/gdm/2026_05_12_stream_mixing_diagnostics/, each twisting one knob, each shipping with a wandb-panel screenshot in the PR. Not CI-bound; opt-in for researchers. |
| Original AC | v1 decision | Why |
|---|---|---|
data/tokens_seen/<modality> cumulative across resume |
Dropped from v1. DataStatsCollector's existing per-window data/modality_total_tokens/<modality> is unchanged. |
DataStatsCollector has no state_dict infra today; standing it up and wiring through the checkpoint manager is disproportionate to this ticket's intent. The dataset doesn't know tokens (only cells); tokens are computed downstream in the collator. |
stream_mixing/refill/queue_idx — perm-queue progress |
Dropped. Redundant. | Carries the same information as cumulative refill/exhaust_events (each exhaust → one queue_idx advance). Two metrics for one signal is bloat. |
stream_mixing/active/n_distinct — |active| |
Dropped. Redundant. | Equals sum(stream_mixing/active/modalities/*) — the fan-out already captures it. |
stream_mixing/refill/exhaust_events "since last log", resume-safe cumulative |
Reinterpreted as per-step delta reset on drain. | Cumulative across resume needs persistent state (out of scope per user direction). Per-step delta is trivially resume-neutral; wandb's diff tooling makes cumulative reconstruction easy downstream. |
"Counters survive state_dict / load_state_dict" |
No new persistent state; state_dict contract unchanged. |
Direct consequence of the above two. |
"Cadence configurable via debug_log_active_every_n_batches" |
Knob hard-removed; metrics emit every step. | Per-step emission is consistent with how DataStatsCollector metrics flow today, and the scalars are cheap. Grep shows zero TOMLs set the old knob. |
| "No new cross-rank sync" | Reframed: use existing reduction pipes; one new opt-in mode added to a shared utility. | Initial framing was a misunderstanding — the codebase already has two cross-rank reduction pipes (DataStatsCollector.aggregate_distributed and last_extra_metrics → reduce_loss_metrics). The new aggregate_step_metrics mirrors the first; the SUM-mode extension is opt-in and additive to the second. |
| Original AC didn't mention false ledger evictions | Added in scope during user review. | High-signal cross-subsystem invariant: ledger eviction races against dataset draining could silently corrupt recurrent training. Detected via the inverted check at _store_state — minimal surface for the bug class it catches. |
ledger_wipe_count already tracks the event count; ledger/{false_,}evictions/full_wipe would be either redundant (sum = wipe_count × ledger_size) or always-large by design (under PERIODIC mode every step after a wipe re-encounters wiped gids), so not diagnostic.num_workers > 0 support beyond what's already in place — the dataset requires num_workers=0 for mid-epoch resume (dataset.py:79-82); new metrics inherit that assumption.DataStatsCollector keys (conditional on adding cumulative counterparts, which v1 doesn't).Each diagnostic config is a short multi-modality smoke (1 node, ≤ 200 steps, DDP, NESTED, default model) intended to make one metric break out of its healthy range. Lives under train_configs/gdm/2026_05_12_stream_mixing_diagnostics/, not promoted to production. PR ships wandb-panel screenshots for each.
| Config (filename) | Knob twisted | Expected pathology | Metric(s) that should light up |
|---|---|---|---|
diag_imbalanced_bs.toml |
Per-stream batch sizes deliberately imbalanced — e.g. one modality bs=1, others bs=64. | Proportional draining must keep remaining_fraction aligned across readers despite the bs gulf; the smallest-bs stream gets weight-throttled at the tail and exhibits high steps_since_pick_max; the modality count distribution is lopsided. |
active/modalities/* wide spread; steps_since_pick_max grows for the small-bs reader at the tail; remaining_fraction_min/_max stay close (proves draining is healthy under this stress) |
diag_exhaust_spike.toml |
Drastically reduced per-stream slice length (subsample down to ~ K × batches_per_step × 5 cells). | Readers exhaust repeatedly; refill churn dominates per-step behaviour. | refill/exhaust_events non-zero almost every step |
diag_stale_false_evict.toml |
batching_mode = "per_stream", ledger_wipe_mode = "periodic_stale", ledger_stale_threshold very low (e.g. 5 steps — well below typical slice length). |
PERIODIC_STALE sweep evicts ledger entries for gids the dataset is still mid-drain on. | ledger/false_evictions/stale non-zero within the first ≈ 20 steps; ledger/evictions/stale ticking alongside. |
Why the naive join fails. The ledger keys on group_id: int (extracted by the processor at processors/sc_multimodal/processor.py:172 via element.get_group_id() — downstream of the dataloader). The dataset keys on stream_id: int (positional index into stream_paths) and the path basename string. There is no stream_id → group_id mapping; one stream contains many cells contains many gids.
The inverted idea. Don't try to track "currently active gids" — that's hard and expensive. Instead, track "recently evicted gids" on the ledger side. When a fresh-write happens (the dataset sent us a gid the ledger hadn't seen, or had seen and evicted), check whether we just evicted that gid. If yes → false eviction.
Where the join happens. One single instrumentation site: _store_state at recurrent_encoder.py:1115:
prev = self.group_state_ledger.get(group_id)
+ if prev is None:
+ mode = self._recently_evicted.pop(group_id, None)
+ if mode is not None:
+ self._false_evict_counts[mode] += 1
# ... existing fresh-init logic continues ...
Owner-locality. _store_state is owner-only — enforced at L1108-1114 (raise if non-owner attempts a store). The three eviction sites (LRU at L1193, STALE at L660, full-wipe at L635) all operate on the local shard. So eviction and re-store both happen on the owning rank for any given gid. The intra-rank join is exact, not an approximation; there is no cross-rank gap.
Why _lookup_local at L893 is not the instrumentation site. The miss path at L893 (if group_id not in self.group_state_ledger) fires on every rank for every non-owned gid, since non-owners never hold the entry. Using L893 would over-count by world_size and conflate "non-owner doesn't shard this gid" with "owner evicted it." _store_state is owner-gated by construction, which is exactly what we want.
Memory. _recently_evicted grows on eviction, pops on re-store, clears on full-wipe. Bounded by "currently-evicted but not yet re-encountered" gids — at most unique_gids_owned − len(ledger). Empirically a few hundred ints in steady state.
Distributed semantics. The counters are per-rank-owned-gid quantities. After SUM-mode reduction in reduce_loss_metrics, the wandb keys show cluster-wide totals (the natural reading: "the cluster suffered N false evictions this step"). The paired _max shows the worst single rank — useful for hotspot diagnosis (one rank with a hostile gid distribution).
Mirrors DataStatsCollector.aggregate_distributed at data_stats.py:170. Pseudocode:
def aggregate_step_metrics(local, world_size, device):
if world_size == 1:
return local # identity, no collectives
# Pack scalars by op-kind into typed tensors of fixed shape.
# Modality count is sparse — flatten over the project's known modality enum
# (matches data_stats.py:209) so all ranks have identical shapes.
all_modalities = sorted(MODALITY_TO_INT.keys())
mins = torch.tensor([local["remaining_min"],
local["remaining_fraction_min"]],
dtype=torch.float32, device=device)
maxs = torch.tensor([local["remaining_max"],
local["remaining_fraction_max"],
local["steps_since_pick_max"]],
dtype=torch.float32, device=device)
sums = torch.zeros(len(all_modalities) + 1, # +1 for exhaust_events
dtype=torch.float32, device=device)
for i, m in enumerate(all_modalities):
sums[i] = local["modalities"].get(m, 0)
sums[-1] = local["exhaust_events"]
# Empty-pool ranks contribute sentinels (per data_stats.py:232).
if local_active_empty:
mins.fill_(float("inf"))
maxs.fill_(0.0)
dist.all_reduce(mins, op=dist.ReduceOp.MIN)
dist.all_reduce(maxs, op=dist.ReduceOp.MAX)
dist.all_reduce(sums, op=dist.ReduceOp.SUM)
# Unpack + emit only non-zero modality fan-outs.
...
Three collectives per log step. Each is a small fixed-shape tensor; total bytes << gradient all-reduces happening every step. Negligible.
Per-step additions to the pick site:
remaining_picks.min/max on a size-K array (K ≤ 8): < 1 µs.remaining_fraction per reader (K subtractions + divisions): < 2 µs.Counter(_modality_for_reader(r) for r in active) with K ≤ 8: < 5 µs.self._last_picked_step[stream_id] = current_step + max-over-active: < 1 µs.expected_ledger_wipe_frequency steps.Total << 50 µs budget per step. No tensor ops in the hot path beyond what PR #820 already does. No GPU sync, no CPU↔GPU transfers.
The dataset already documents at dataset.py:79-82 that mid-epoch resume requires num_workers=0. The drain_step_metrics() + aggregate_step_metrics() APIs inherit this — both work because the dataset lives in the main process and the trainer reads its attributes directly. If num_workers > 0 support is ever added, drain semantics will need a queue or batch-dict piggyback design (out of scope for v1).
Two new test files (see FR14, FR15) and one deletion (FR7 removes three obsolete tests).
stream_mixing/active/modalities/<modality> emit zeros for known-but-currently-empty modalities, or stay sparse? Default: sparse (wandb handles missing keys). Revisit if panel readability suffers.ledger/false_evictions/* vs. recurrent_encoder/false_evictions/*. Going with ledger/* to match existing ledger_wipe_count precedent.reduce_loss_metrics emit both <key> (sum) and <key>_max (per-rank max), or only the sum? Going with both for diagnostic consistency with existing modes; the per-rank max is cheap and surfaces hotspots.