""" DriftProbe: Training-loop monitoring for conceptual topology preservation. Theory: "Spatial Separation Hypothesis" - Use isolated zone languages (German) as scaffolding for new concepts - Monitor anchors (must not move), bridges (must stay separated), canaries (watch for migration) Key Metrics (refined from peer review): 1. Gini Coefficient: Sparse activations (0.8+) = deep/specific, Diffuse (0.3) = shallow/general 2. Angular Drift: Direction change = definition rewrite, magnitude change = sharpening 3. Cross-Language Similarity: Bridges should stay LOW, anchors should stay HIGH """ import json from pathlib import Path from dataclasses import dataclass, field from typing import Optional from enum import Enum import torch import numpy as np class SentinelType(Enum): ANCHOR = "ANCHOR" # Must not move - core topology BRIDGE = "BRIDGE" # Must stay separated - isolated zone integrity CANARY = "CANARY" # Watch for migration - early warning TARGET = "TARGET" # Want movement - training goals class AlertSeverity(Enum): OK = "OK" WARNING = "WARNING" CRITICAL = "CRITICAL" @dataclass class DriftMetrics: """Metrics for a single sentinel term.""" term: str sentinel_type: SentinelType # Activation metrics gini_coefficient: float = 0.0 activation_norm: float = 0.0 # Drift metrics (vs baseline) angular_drift_degrees: float = 0.0 norm_drift_percent: float = 0.0 gini_drift: float = 0.0 # Valley detection detected_valley: str = "UNKNOWN" depth: int = 0 # Cross-language (for anchors/bridges) cross_lang_similarity: float = 0.0 # Alert alert: AlertSeverity = AlertSeverity.OK alert_message: str = "" @dataclass class DriftReport: """Full drift report for a training checkpoint.""" step: int timestamp: str metrics: list[DriftMetrics] = field(default_factory=list) # Summary critical_count: int = 0 warning_count: int = 0 recommendation: str = "CONTINUE" class DriftProbe: """ Lightweight probe for training-loop monitoring. Optimized for RTX 3090 constraints: - Full probe: ~2 min (run at epoch 0, end of training) - Lite probe: ~10 sec (run every 100 steps) """ def __init__(self, model, tokenizer, sentinels_path: Optional[str] = None): self.model = model self.tokenizer = tokenizer self.baseline_states = {} # term -> hidden state tensor # Load sentinels if sentinels_path is None: sentinels_path = Path(__file__).parent.parent.parent / "data" / "sentinels.json" with open(sentinels_path) as f: self.config = json.load(f) self.sentinels = self.config["sentinels"] self.alert_rules = self.config["alert_rules"] def _get_hidden_state(self, text: str, layer: int = 18) -> torch.Tensor: """Get hidden state at specified layer for last token position.""" inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device) with torch.no_grad(): outputs = self.model(**inputs, output_hidden_states=True) return outputs.hidden_states[layer][0, -1, :].float().cpu() def _compute_gini(self, activations: torch.Tensor) -> float: """ Compute Gini coefficient of activation vector. High Gini (0.8+) = Sparse/Specific (Philosophy/Deep) Low Gini (0.3) = Diffuse/General (Prose/Shallow) """ x = torch.abs(activations).numpy() x = np.sort(x) n = len(x) cumsum = np.cumsum(x) gini = (2 * np.sum((np.arange(1, n+1) * x))) / (n * np.sum(x)) - (n + 1) / n return float(gini) def _compute_angular_drift(self, current: torch.Tensor, baseline: torch.Tensor) -> float: """ Compute angular drift in degrees between current and baseline. > 15° = Definition rewrite (concerning) < 5° = Sharpening only (acceptable) """ cos_sim = torch.nn.functional.cosine_similarity( current.unsqueeze(0), baseline.unsqueeze(0) ).item() # Clamp to valid range for arccos cos_sim = max(-1.0, min(1.0, cos_sim)) angle_rad = np.arccos(cos_sim) return float(np.degrees(angle_rad)) def _compute_cross_lang_sim(self, sentinel: dict, layer: int = 18) -> float: """Compute average cross-language similarity for a sentinel.""" translations = sentinel.get("translations", {}) if len(translations) < 2: return 0.0 states = [] for lang, word in translations.items(): states.append(self._get_hidden_state(word, layer)) # Pairwise similarities sims = [] for i in range(len(states)): for j in range(i + 1, len(states)): sim = torch.nn.functional.cosine_similarity( states[i].unsqueeze(0), states[j].unsqueeze(0) ).item() sims.append(sim) return float(np.mean(sims)) if sims else 0.0 def capture_baseline(self, layer: int = 18): """ Capture baseline hidden states for all sentinels. Run this at epoch 0 before training. """ print("Capturing baseline states...") for sentinel in self.sentinels: term = sentinel["term"] # Use English translation or term itself text = sentinel.get("translations", {}).get("EN", term) self.baseline_states[term] = self._get_hidden_state(text, layer) print(f"Baseline captured for {len(self.baseline_states)} sentinels") def probe_lite(self, step: int, layer: int = 18) -> DriftReport: """ Lite probe - only check key sentinels. Optimized for ~10 second runtime. """ from datetime import datetime # Select subset: 2 anchors, 1 bridge, 2 canaries lite_terms = ["heart", "water", "being", "dasein", "thrownness"] lite_sentinels = [s for s in self.sentinels if s["term"] in lite_terms] return self._run_probe(lite_sentinels, step, layer) def probe_full(self, step: int, layer: int = 18) -> DriftReport: """ Full probe - check all sentinels. Runtime: ~2 minutes. """ return self._run_probe(self.sentinels, step, layer) def _run_probe(self, sentinels: list, step: int, layer: int) -> DriftReport: """Run probe on specified sentinels.""" from datetime import datetime report = DriftReport( step=step, timestamp=datetime.now().isoformat() ) for sentinel in sentinels: term = sentinel["term"] text = sentinel.get("translations", {}).get("EN", term) sentinel_type = SentinelType(sentinel["type"]) thresholds = sentinel.get("thresholds", {}) # Get current state current_state = self._get_hidden_state(text, layer) # Compute metrics gini = self._compute_gini(current_state) norm = float(current_state.norm()) # Drift vs baseline angular_drift = 0.0 norm_drift = 0.0 gini_drift = 0.0 if term in self.baseline_states: baseline = self.baseline_states[term] angular_drift = self._compute_angular_drift(current_state, baseline) baseline_norm = float(baseline.norm()) norm_drift = abs(norm - baseline_norm) / baseline_norm * 100 if baseline_norm > 0 else 0 baseline_gini = self._compute_gini(baseline) gini_drift = gini - baseline_gini # Cross-language similarity cross_lang_sim = self._compute_cross_lang_sim(sentinel, layer) # Determine alert level alert = AlertSeverity.OK alert_message = "" if sentinel_type == SentinelType.ANCHOR: max_drift = thresholds.get("max_drift", 0.05) if angular_drift > 15: alert = AlertSeverity.CRITICAL alert_message = f"Angular drift {angular_drift:.1f}° exceeds 15° - definition rewrite" elif norm_drift > max_drift * 100: alert = AlertSeverity.WARNING alert_message = f"Norm drift {norm_drift:.1f}% exceeds threshold" elif sentinel_type == SentinelType.BRIDGE: collapse_threshold = thresholds.get("collapse_alert_threshold", 0.50) if cross_lang_sim > collapse_threshold: alert = AlertSeverity.CRITICAL alert_message = f"Bridge collapsed - cross-lang sim {cross_lang_sim:.2f} > {collapse_threshold}" elif sentinel_type == SentinelType.CANARY: min_gini = thresholds.get("min_gini", 0.70) if gini < min_gini: alert = AlertSeverity.WARNING alert_message = f"Gini {gini:.2f} below {min_gini} - concept melting into prose" if angular_drift > thresholds.get("max_angular_drift", 15): alert = AlertSeverity.WARNING alert_message = f"Angular drift {angular_drift:.1f}° - definition shifting" metrics = DriftMetrics( term=term, sentinel_type=sentinel_type, gini_coefficient=gini, activation_norm=norm, angular_drift_degrees=angular_drift, norm_drift_percent=norm_drift, gini_drift=gini_drift, cross_lang_similarity=cross_lang_sim, alert=alert, alert_message=alert_message ) report.metrics.append(metrics) if alert == AlertSeverity.CRITICAL: report.critical_count += 1 elif alert == AlertSeverity.WARNING: report.warning_count += 1 # Set recommendation if report.critical_count > 0: report.recommendation = "ROLLBACK" elif report.warning_count > 2: report.recommendation = "REDUCE_LR" else: report.recommendation = "CONTINUE" return report def print_report(self, report: DriftReport): """Pretty print a drift report.""" print(f"\n{'='*60}") print(f"DRIFT REPORT - Step {report.step}") print(f"{'='*60}") for m in report.metrics: status = "✓" if m.alert == AlertSeverity.OK else ("⚠" if m.alert == AlertSeverity.WARNING else "✗") print(f"\n{status} {m.term} ({m.sentinel_type.value})") print(f" Gini: {m.gini_coefficient:.3f} (drift: {m.gini_drift:+.3f})") print(f" Angular drift: {m.angular_drift_degrees:.1f}°") print(f" Cross-lang sim: {m.cross_lang_similarity:.3f}") if m.alert_message: print(f" ALERT: {m.alert_message}") print(f"\n{'='*60}") print(f"SUMMARY: {report.critical_count} critical, {report.warning_count} warnings") print(f"RECOMMENDATION: {report.recommendation}") print(f"{'='*60}\n")