stream_mixing + ledger observability metricsrecurrent_encoder.py) is the cross-subsystem invariant — gets a "Think hard" pass during implementation. Everything else is standard or trivial.dataset.py, P3 owns distributed/utils.py, P4 owns recurrent_encoder.py + loss.py. Wave-1 phases never touch each other's files; wave-2 phases (P2 + P4) also don't overlap. Safe to parallelise.This plan delivers the 17 functional requirements in FRD v5 — 11 new wandb scalars (7 stream_mixing + 4 ledger, with auto _max siblings on the ledger metrics) across two subsystems, plus the trainer wiring that joins them. The breakdown maximises parallelism while keeping each phase atomic and committable on its own.
Critical sequencing constraint: P4 (ledger false-eviction) needs the SUM-mode extension from P3 (reduce_loss_metrics) to surface its counters correctly cluster-wide. So P3 must complete before P4 starts. Everything else is wider parallelism.
From the FRD (currently local-only, see Linear comment for path):
active/remaining_min, active/remaining_max, active/remaining_fraction_min, active/remaining_fraction_max, active/modalities/<modality> (sparse fan-out), active/steps_since_pick_max, refill/exhaust_events._max siblings via SUM-mode reduce): ledger/evictions/{lru,stale}, ledger/false_evictions/{lru,stale}.debug_log_active_every_n_batches config + free-form log line + 3 tests.tokens_seen / cells_seen, full-wipe eviction metrics (redundant with ledger_wipe_count), num_workers > 0 support beyond what exists.Establishing the load-bearing baseline points from research:
dataset.py:695-744): remaining_picks numpy array already allocated by PR #820 at L702-710. The free-form active-set log line lives at L721-743; batch_count increments at L744. This is where the new stream_mixing metric population goes.dataset.py:755-760): del active[idx] gated by reader.position >= reader._slice_end. This is where refill/exhaust_events increments and where _last_picked_step entries get cleared.dataset.py:653): saved_active = saved_active[:self.num_streams_mixed] — the only path that drops still-active readers; needs the one-shot INFO log added.dataset.py:729-732): existing log line handles NESTED via self.stream_groups[r.stream_id] and DIRECT via path-basename derivation. Extract to a helper.dataset.py:454-464): 9 fields, no change needed — v1 metrics are all per-step snapshots, no persistent state.train.py:842-862): existing DataStatsCollector.aggregate_distributed(...) + extra_metrics.update(wandb_data_stats) pattern. The stream_mixing drain+aggregate slots in right beside this.train.py:815-828, distributed/utils.py:84-138): comment at train.py:813 calls out ledger/... keys explicitly. Existing ledger_wipe_count rides this pipe. Default mode emits dist_mean(value) + dist_max(value); count-weighted mode exists for ratios; no SUM-only mode today — P3 adds it.recurrent_encoder.py:1193 (del ledger[oldest_gid] in _evict_lru_overflow, which already returns evicted count at L1195). STALE at recurrent_encoder.py:660 (del ledger[gid] in _sweep_stale_entries). Full-wipe at L635 (ledger.clear()) — not instrumented for false_evictions (intentionally cleared on full-wipe)._store_state at recurrent_encoder.py:1115 — prev = self.group_state_ledger.get(group_id). When prev is None this is a fresh entry. Owner-only by enforcement at L1108-1114. Single instrumentation site for the false-eviction check.loss.py:914 populates self.last_extra_metrics["ledger_wipe_count"]. New keys go alongside.After all phases land:
_max siblings on the ledger metrics) emit every training step, visible on the rank-0 panel for the healthy smoke run.debug_log_active_every_n_batches config field, factory pass-through, dataset attribute, validation, free-form log line, and 3 obsolete tests all deleted.reduce_loss_metrics gains an opt-in sum_keys kwarg; existing call sites (default-mode and count-weighted-mode) unaffected.StreamMixingDataset.drain_step_metrics() and aggregate_step_metrics() APIs exist; trainer calls them in train.py..venv/bin/pytest prima_foundations/tests/unit_tests/datasets/stream_mixing/test_metrics.py and .../models/pleiades2/test_ledger_false_evictions.py).state_dict / load_state_dict contract unchanged — existing resume tests still pass.data/tokens_seen/<modality> / data/cells_seen/<modality> (needs DataStatsCollector.state_dict infra — out of scope).full_wipe dedicated metric (redundant with ledger_wipe_count or always-large by design).num_workers > 0 support beyond what's already there.stream_mixing dataset instrumentation standardGoal: Land all dataset-side changes for stream_mixing observability — new metric population at the pick site, modality helper extraction, removal of the old debug log knob, resume-reshard INFO log. Pure dataset edits; no trainer wiring yet.
Maps to FRDs: FR1, FR2, FR3, FR6, FR7, FR9.
Changes Required:
prima_foundations/prima_foundations/datasets/stream_mixing/dataset.py — extract _modality_for_reader(self, r) (FR6) from the existing log path at L729-732; add self._last_picked_step: dict[int, int] instance attribute, populated at the pick site after the reader is selected (FR2); add self._exhaust_events_delta: int + self._latest_step_metrics: dict[str, float]; populate _latest_step_metrics at the pick site using the existing remaining_picks array (FR1: remaining_min/max, remaining_fraction_min/max, steps_since_pick_max, modality fan-out via Counter); increment _exhaust_events_delta at the exhaust site (L755-760) and clear the reader's _last_picked_step entry; add drain_step_metrics(self) -> dict[str, float] method that returns + clears (FR3); delete the constructor param + validation + attribute + free-form log line at L721-743 (FR7); emit one-shot INFO log on resume-reshard truncation at L653 if it drops still-active readers (FR9).prima_foundations/prima_foundations/datasets/stream_mixing/factory.py — drop the debug_log_active_every_n_batches=cfg.debug_log_active_every_n_batches kwarg at L209 (FR7).prima_foundations/prima_foundations/config/job_config.py — delete the debug_log_active_every_n_batches: int = 0 field at L2068-2079 (FR7).prima_foundations/tests/unit_tests/datasets/stream_mixing/test_dataset.py — delete the three tests at L643-691 covering the old knob (FR7).Success Criteria:
.venv/bin/pytest prima_foundations/tests/unit_tests/datasets/stream_mixing/ -v passes (existing state-dict + reader tests must stay green). ruff check prima_foundations/datasets/stream_mixing/ + ruff format --check clean. mypy prima_foundations/datasets/stream_mixing/ clean.StreamMixingDataset in a REPL, drain 5 batches, call drain_step_metrics(), verify dict shape matches FRD spec (7 keys for the local view).Goal: Add the cross-rank aggregator that gives each metric its correct distributed reduction, and wire it into the trainer's per-step metric flow.
Maps to FRDs: FR4, FR5.
Changes Required:
prima_foundations/prima_foundations/datasets/stream_mixing/dataset.py — add aggregate_step_metrics(self, local: dict, world_size: int, device: torch.device) -> dict[str, float]. Mirror DataStatsCollector.aggregate_distributed at data_stats.py:170-255: pack scalars into three tensors (MIN-reduce, MAX-reduce, SUM-reduce); use MODALITY_TO_INT from prima_foundations/datasets/processors/types for the sparse modality fan-out's shape; sentinel-fill empty-pool ranks (large for MIN, 0 for MAX/SUM) per data_stats.py:232; single-rank path is identity.prima_foundations/prima_foundations/train.py — inside the per-step metric block around L842-862, after the data_stats_collector.aggregate_distributed(...) block: call local_metrics = self.train_dataset.drain_step_metrics() guarded by hasattr for non-StreamMixingDataset paths; if non-empty, call aggregated = self.train_dataset.aggregate_step_metrics(local_metrics, world_size=parallel_dims.world_size, device=self.device); extra_metrics.update(aggregated).Success Criteria:
.venv/bin/pytest prima_foundations/tests/unit_tests/datasets/stream_mixing/ -v still passes. ruff check prima_foundations/ + mypy prima_foundations/ clean.scripts/test-local) shows the new stream_mixing/* keys in wandb-debug output; the keys map 1-to-1 with FRD spec.Depends on: P1 (needs drain_step_metrics API).
reduce_loss_metrics SUM-mode extension standardGoal: Add an opt-in SUM reduction mode to the shared reduce_loss_metrics utility so counter metrics emit cluster-wide totals instead of per-rank means. Pure shared-utility change; independent of stream_mixing and ledger work.
Maps to FRDs: FR13.
Changes Required:
prima_foundations/prima_foundations/distributed/utils.py — extend reduce_loss_metrics (L84-138) with a new kwarg sum_keys: set[str] | None = None. For each key in metrics:
key in (auxiliary_counts or {}): existing count-weighted branch, unchanged.key in (sum_keys or {}): new branch — multi-rank: reduced[key] = dist_sum(value, mesh, extra_pg), reduced[f"{key}_max"] = dist_max(value, mesh, extra_pg). Single-rank: reduced[key] = value.item().sum_keys ∩ auxiliary_counts == ∅ — a key can't be both.Success Criteria:
tests/unit_tests/distributed/test_utils.py (or new file if none exists): SUM-mode round-trips a known per-rank value to dist_sum + dist_max on world_size=4 (gloo backend). ruff + mypy clean.reduce_loss_metrics from a REPL with sum_keys={"foo"} on a single-rank tensor; verify the value passes through unchanged.Depends on: nothing (Wave 1).
Goal: Detect false ledger evictions via the inverted check at _store_state fresh-write. Single instrumentation site, owner-local by construction, surfaced through the existing last_extra_metrics pipe with SUM-mode reduction (from P3).
Why complex: cross-subsystem invariant; ownership semantics enforced by recurrent_encoder.py:1108-1114; must respect the all-to-all-v paths that orchestrate ledger reads/writes across ranks; eviction count must match the existing local counter returned by _evict_lru_overflow at L1195; full-wipe path must clear _recently_evicted so we don't generate spurious false-eviction signal after intentional wipes.
Maps to FRDs: FR10, FR11, FR12.
Changes Required:
prima_foundations/prima_foundations/models/pleiades2/model/gdm/recurrent_encoder.py:
EvictionMode enum (or reuse LedgerWipeMode if cleaner) with LRU, STALE variants. Add self._recently_evicted: dict[int, EvictionMode] and self._evict_counts: dict[EvictionMode, int] and self._false_evict_counts: dict[EvictionMode, int] to __init__ at L533._evict_lru_overflow): before del ledger[oldest_gid] at L1193, add self._recently_evicted[oldest_gid] = EvictionMode.LRU; increment self._evict_counts[EvictionMode.LRU] per eviction._sweep_stale_entries): in the for gid in doomed: loop body, before del ledger[gid], add self._recently_evicted[gid] = EvictionMode.STALE; increment self._evict_counts[EvictionMode.STALE] per eviction._maybe_wipe_ledger): immediately after self.group_state_ledger.clear(), also do self._recently_evicted.clear(). No counter increment for full-wipe — ledger_wipe_count already tracks events._store_state): right after prev = self.group_state_ledger.get(group_id), add: if prev is None: mode = self._recently_evicted.pop(group_id, None); if mode is not None: self._false_evict_counts[mode] += 1.drain_eviction_metrics(self) -> dict[str, int] that returns the 4 deltas (ledger/evictions/{lru,stale}, ledger/false_evictions/{lru,stale}) and resets the internal counters to 0. Called by loss.py at the same point ledger_wipe_count is collected today.prima_foundations/prima_foundations/models/pleiades2/model/gdm/loss.py: at L914 (where ledger_wipe_count is populated into self.last_extra_metrics), also drain + merge the 4 new counter keys from recurrent_encoder.drain_eviction_metrics(). Add the 4 new keys to a new self.last_extra_sum_keys: set[str] attribute so the trainer can opt them into SUM mode.prima_foundations/prima_foundations/train.py (small edit, no overlap with P2's edit area): at L821 where reduce_loss_metrics(...) is called, pass sum_keys=getattr(self.loss_fn, "last_extra_sum_keys", None) kwarg. Plus a tiny guard so older loss fns without this attribute keep working.Success Criteria:
ledger/false_evictions/* at 0 under healthy config._store_state with the same gid → assert _false_evict_counts[LRU] == 1 and _recently_evicted no longer contains the gid.Depends on: P3 (needs sum_keys kwarg in reduce_loss_metrics).
Goal: Lock in the stream_mixing metric contract with focused unit tests. Lean coverage per the user's testing preference — fewer high-value tests, not exhaustive accessor coverage.
Maps to FRDs: FR14.
Changes Required:
prima_foundations/tests/unit_tests/datasets/stream_mixing/test_metrics.py covering:
test_drain_step_metrics_shape: synthetic 4-stream fixture (NESTED layout); drain 5 batches; assert all 7 stream_mixing keys present with correct types; assert active/modalities/* fan-out has the expected modality names.test_modality_resolution_nested_and_direct: parameterised test for both layouts; _modality_for_reader returns the right string.test_remaining_fraction_bs_independence: two streams with bs=1 and bs=64 but same slice length; after 1 batch, assert their remaining_fraction values are equal (within float tolerance).test_steps_since_pick_max_tracks: drive K=4 readers through 20 steps with deterministic seed; assert steps_since_pick_max grows for unpicked readers, resets on pick, drops on exhaust.test_exhaust_events_delta_resets_on_drain: trigger an exhaust; drain; assert exhaust_events ≥ 1; drain again; assert 0.test_drain_empty_before_first_pick: fresh dataset; drain_step_metrics() returns {}.test_aggregate_step_metrics_world_size_1: identity check on the single-rank path.test_aggregate_step_metrics_packed_reduce: mock the dist.all_reduce calls; assert MIN/MAX/SUM ops are issued on the right packed tensors with the right values.test_aggregate_empty_pool_sentinels: a rank with empty active contributes +inf to MIN and 0 to MAX/SUM; cluster result ignores it.Success Criteria:
Depends on: P2 (needs aggregate_step_metrics to test).
Goal: Lock in the ledger false-eviction join contract and the SUM-mode round-trip for the cluster reduction.
Maps to FRDs: FR15.
Changes Required:
prima_foundations/tests/unit_tests/models/pleiades2/test_ledger_false_evictions.py covering:
test_lru_eviction_then_no_re_encounter: evict a gid via LRU; never re-store; assert _false_evict_counts[LRU] == 0; assert gid remains in _recently_evicted until a subsequent eviction event ages it out or a full-wipe clears it.test_lru_eviction_then_fresh_store: evict gid via LRU; call _store_state(gid, ...) with prev is None; assert _false_evict_counts[LRU] == 1; assert gid not in _recently_evicted.test_stale_eviction_then_fresh_store: mirror of above for STALE mode.test_full_wipe_clears_recently_evicted: evict gids via LRU and STALE; trigger a full-wipe; assert _recently_evicted is empty; assert a subsequent fresh-store of one of the wiped gids does NOT count as false eviction.test_drain_eviction_metrics_resets_counts: trigger a known sequence of events; assert the drained dict has the right key/value pairs; assert internal counters are 0 after drain.test_reduce_loss_metrics_sum_mode_roundtrip (in tests/unit_tests/distributed/test_utils.py — adjacent to P3's work): on a 4-rank gloo backend, each rank passes {"foo": rank+1} with sum_keys={"foo"}; assert reduced["foo"] == 10 (1+2+3+4) and reduced["foo_max"] == 4.Success Criteria:
pytest -xvs prima_foundations/tests/unit_tests/models/pleiades2/test_ledger_false_evictions.py._recently_evicted[oldest_gid] = EvictionMode.LRU line in P4; test_lru_eviction_then_fresh_store must fail. Revert.Depends on: P4 (needs ledger instrumentation), P3 (needs SUM-mode kwarg).
Goal: Three deliberately-misconfigured TOMLs that drive specific metrics out of healthy range. Pure config files; no Python edits.
Maps to FRDs: FR17.
Changes Required:
prima_foundations/prima_foundations/models/pleiades2/train_configs/gdm/2026_05_12_stream_mixing_diagnostics/ with:
diag_imbalanced_bs.toml — extends the healthy smoke; overrides per-stream bs with wildly different values (e.g. one modality bs=1, others bs=64). Drives active/modalities/* spread and steps_since_pick_max for the small-bs reader at the tail.diag_exhaust_spike.toml — extends the healthy smoke; subsamples the dataset down to ~ K × batches_per_step × 5 cells (very short slices). Drives refill/exhaust_events ≈ 1 per step.diag_stale_false_evict.toml — extends the healthy smoke; sets batching_mode = "per_stream", ledger_wipe_mode = "periodic_stale", ledger_stale_threshold = 5 (well below typical slice length). Drives ledger/false_evictions/stale and ledger/evictions/stale non-zero within ≈ 20 steps.config_manager's strict-key check (config/manager.py:181) — copy-paste the healthy smoke and override only the twisted fields.Success Criteria:
python -m prima_foundations.train --config <each-toml> --dry-run (or equivalent existing dry-run path) — all three TOMLs parse and validate without errors. config_manager's strict-key check passes.Depends on: nothing for write; Wave 4 (P8) for actual run verification.
Goal: Run the healthy smoke + 3 diagnostic smokes; verify metrics in wandb; attach panel screenshots to the PR. No code change — this is the empirical verification gate.
Maps to FRDs: FR16 (healthy), FR17 (diagnostic — run side).
Changes Required (operational, no code):
submit-pleiades-2-job with 1node_ddp_per_modality_packing_smoke.toml. Verify in wandb:
stream_mixing/* and ledger/*.ledger/false_evictions/{lru,stale} stay at 0.stream_mixing/active/remaining_fraction_min and _max cluster tightly (spread < 0.2 for most of the run).steps_since_pick_max stays bounded.diag_imbalanced_bs: active/modalities/* wide spread; steps_since_pick_max grows for the small-bs reader.diag_exhaust_spike: refill/exhaust_events non-zero almost every step.diag_stale_false_evict: ledger/false_evictions/stale and ledger/evictions/stale non-zero within first ≈ 20 steps.Success Criteria:
Depends on: P5 + P6 (code lands), P7 (smoke configs exist).
| FR | Description (short) | Phase |
|---|---|---|
| FR1 | emit 7 stream_mixing scalars at pick site | P1 |
| FR2 | track _last_picked_step per reader | P1 |
| FR3 | drain_step_metrics() API | P1 |
| FR4 | aggregate_step_metrics() aggregator | P2 |
| FR5 | trainer drain + aggregate + merge into extra_metrics | P2 |
| FR6 | _modality_for_reader helper extraction | P1 |
| FR7 | remove debug knob + log line + 3 tests | P1 |
| FR8 | per-step overhead < 50 µs | P1 + P2 (verified manually) |
| FR9 | resume-reshard INFO log | P1 |
| FR10 | _recently_evicted ring + populate at 2 eviction sites + clear on full-wipe | P4 |
| FR11 | instrument _store_state fresh-write check | P4 |
| FR12 | surface 4 ledger counters via last_extra_metrics | P4 |
| FR13 | reduce_loss_metrics SUM-mode kwarg | P3 |
| FR14 | stream_mixing unit tests | P5 |
| FR15 | ledger false-eviction + SUM-mode round-trip tests | P6 |
| FR16 | healthy smoke run + wandb verification | P8 |
| FR17 | 3 diagnostic smoke configs + verification runs | P7 (configs) + P8 (runs) |
Every FR maps to ≥ 1 phase; FR8 (cost budget) is enforced manually during P1 and P2 via a microbench in a REPL session — not a separate phase since it's a non-functional acceptance check, not a code deliverable.
Unit-test layout: two new test files (P5, P6) plus a small addition to tests/unit_tests/distributed/test_utils.py for the SUM-mode round-trip (part of P6).
Coverage discipline: per the user's "lean but comprehensive" preference, each test covers one specific contract — no accessor sprawl, no "test that the field exists" trivia. Mutation-sanity checked at the end of each test phase by deliberately breaking one line of the corresponding source phase and confirming the right test fails.
Distributed paths:
aggregate_step_metrics distributed path is tested via mock'd dist.all_reduce (faster than a real gloo backend, sufficient for the packing-and-routing logic).reduce_loss_metrics SUM-mode is tested via a real gloo backend on world_size=4 — the only place the actual cross-rank sum matters end-to-end.Integration verification: P8 is the end-to-end gate. A single-rank local-test (via scripts/test-local) during P2 and P4 catches the "metrics appear in the dict" failure mode; the multi-rank smoke catches "metrics appear in wandb with cluster-correct values."
During P2, run a microbench in a REPL: instantiate the dataset, drain 1000 batches in a tight loop, measure wall time. With and without the new metric population. Delta should be < 50 µs/step at K=4. Record in the PR description. Acceptance: the delta is < 50 µs; if not, profile and revisit before P4.
Wave 1's three phases all edit different files (dataset.py / factory.py / job_config.py for P1; distributed/utils.py for P3; new TOML files for P7). No conflict surface — safe to land independently. Wave 2's P2 and P4 also touch disjoint files (dataset.py + train.py:842-862 for P2; recurrent_encoder.py + loss.py + train.py:821 for P4). The two train.py edits live at different lines (L842 vs L821) so even within train.py there's no overlap.
Each phase = ≥ 1 atomic commit on the worktree branch. Suggested commits: