Files
nyx-probing/nyx_probing/probes/drift_probe.py
dafit f640dbdd65 feat: complete Phase 1 - vocabulary expansion & DriftProbe infrastructure
- CLI: nyx-probe scan with --summary/--delta/--full flags
- DriftProbe: training safety with Gini coefficient + Angular Drift
- Vocabulary: 54 terms (30 nimmerverse + 24 German philosophical)
- Sentinels: ANCHOR/BRIDGE/CANARY/TARGET monitoring system

Key findings:
- German philosophical terms: 37.5% depth≥2 hit rate (vs 3.3% nimmerverse)
- Super Cluster validated: heart cross-lang sim = 1.000
- Isolated Zone confirmed: being EN↔DE sim = 0.195
- Gini signature: Philosophy ~0.5 (diffuse), Technical ~0.8 (sparse)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-06 22:39:03 +01:00

305 lines
11 KiB
Python

"""
DriftProbe: Training-loop monitoring for conceptual topology preservation.
Theory: "Spatial Separation Hypothesis"
- Use isolated zone languages (German) as scaffolding for new concepts
- Monitor anchors (must not move), bridges (must stay separated), canaries (watch for migration)
Key Metrics (refined from peer review):
1. Gini Coefficient: Sparse activations (0.8+) = deep/specific, Diffuse (0.3) = shallow/general
2. Angular Drift: Direction change = definition rewrite, magnitude change = sharpening
3. Cross-Language Similarity: Bridges should stay LOW, anchors should stay HIGH
"""
import json
from pathlib import Path
from dataclasses import dataclass, field
from typing import Optional
from enum import Enum
import torch
import numpy as np
class SentinelType(Enum):
ANCHOR = "ANCHOR" # Must not move - core topology
BRIDGE = "BRIDGE" # Must stay separated - isolated zone integrity
CANARY = "CANARY" # Watch for migration - early warning
TARGET = "TARGET" # Want movement - training goals
class AlertSeverity(Enum):
OK = "OK"
WARNING = "WARNING"
CRITICAL = "CRITICAL"
@dataclass
class DriftMetrics:
"""Metrics for a single sentinel term."""
term: str
sentinel_type: SentinelType
# Activation metrics
gini_coefficient: float = 0.0
activation_norm: float = 0.0
# Drift metrics (vs baseline)
angular_drift_degrees: float = 0.0
norm_drift_percent: float = 0.0
gini_drift: float = 0.0
# Valley detection
detected_valley: str = "UNKNOWN"
depth: int = 0
# Cross-language (for anchors/bridges)
cross_lang_similarity: float = 0.0
# Alert
alert: AlertSeverity = AlertSeverity.OK
alert_message: str = ""
@dataclass
class DriftReport:
"""Full drift report for a training checkpoint."""
step: int
timestamp: str
metrics: list[DriftMetrics] = field(default_factory=list)
# Summary
critical_count: int = 0
warning_count: int = 0
recommendation: str = "CONTINUE"
class DriftProbe:
"""
Lightweight probe for training-loop monitoring.
Optimized for RTX 3090 constraints:
- Full probe: ~2 min (run at epoch 0, end of training)
- Lite probe: ~10 sec (run every 100 steps)
"""
def __init__(self, model, tokenizer, sentinels_path: Optional[str] = None):
self.model = model
self.tokenizer = tokenizer
self.baseline_states = {} # term -> hidden state tensor
# Load sentinels
if sentinels_path is None:
sentinels_path = Path(__file__).parent.parent.parent / "data" / "sentinels.json"
with open(sentinels_path) as f:
self.config = json.load(f)
self.sentinels = self.config["sentinels"]
self.alert_rules = self.config["alert_rules"]
def _get_hidden_state(self, text: str, layer: int = 18) -> torch.Tensor:
"""Get hidden state at specified layer for last token position."""
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
return outputs.hidden_states[layer][0, -1, :].float().cpu()
def _compute_gini(self, activations: torch.Tensor) -> float:
"""
Compute Gini coefficient of activation vector.
High Gini (0.8+) = Sparse/Specific (Philosophy/Deep)
Low Gini (0.3) = Diffuse/General (Prose/Shallow)
"""
x = torch.abs(activations).numpy()
x = np.sort(x)
n = len(x)
cumsum = np.cumsum(x)
gini = (2 * np.sum((np.arange(1, n+1) * x))) / (n * np.sum(x)) - (n + 1) / n
return float(gini)
def _compute_angular_drift(self, current: torch.Tensor, baseline: torch.Tensor) -> float:
"""
Compute angular drift in degrees between current and baseline.
> 15° = Definition rewrite (concerning)
< 5° = Sharpening only (acceptable)
"""
cos_sim = torch.nn.functional.cosine_similarity(
current.unsqueeze(0), baseline.unsqueeze(0)
).item()
# Clamp to valid range for arccos
cos_sim = max(-1.0, min(1.0, cos_sim))
angle_rad = np.arccos(cos_sim)
return float(np.degrees(angle_rad))
def _compute_cross_lang_sim(self, sentinel: dict, layer: int = 18) -> float:
"""Compute average cross-language similarity for a sentinel."""
translations = sentinel.get("translations", {})
if len(translations) < 2:
return 0.0
states = []
for lang, word in translations.items():
states.append(self._get_hidden_state(word, layer))
# Pairwise similarities
sims = []
for i in range(len(states)):
for j in range(i + 1, len(states)):
sim = torch.nn.functional.cosine_similarity(
states[i].unsqueeze(0), states[j].unsqueeze(0)
).item()
sims.append(sim)
return float(np.mean(sims)) if sims else 0.0
def capture_baseline(self, layer: int = 18):
"""
Capture baseline hidden states for all sentinels.
Run this at epoch 0 before training.
"""
print("Capturing baseline states...")
for sentinel in self.sentinels:
term = sentinel["term"]
# Use English translation or term itself
text = sentinel.get("translations", {}).get("EN", term)
self.baseline_states[term] = self._get_hidden_state(text, layer)
print(f"Baseline captured for {len(self.baseline_states)} sentinels")
def probe_lite(self, step: int, layer: int = 18) -> DriftReport:
"""
Lite probe - only check key sentinels.
Optimized for ~10 second runtime.
"""
from datetime import datetime
# Select subset: 2 anchors, 1 bridge, 2 canaries
lite_terms = ["heart", "water", "being", "dasein", "thrownness"]
lite_sentinels = [s for s in self.sentinels if s["term"] in lite_terms]
return self._run_probe(lite_sentinels, step, layer)
def probe_full(self, step: int, layer: int = 18) -> DriftReport:
"""
Full probe - check all sentinels.
Runtime: ~2 minutes.
"""
return self._run_probe(self.sentinels, step, layer)
def _run_probe(self, sentinels: list, step: int, layer: int) -> DriftReport:
"""Run probe on specified sentinels."""
from datetime import datetime
report = DriftReport(
step=step,
timestamp=datetime.now().isoformat()
)
for sentinel in sentinels:
term = sentinel["term"]
text = sentinel.get("translations", {}).get("EN", term)
sentinel_type = SentinelType(sentinel["type"])
thresholds = sentinel.get("thresholds", {})
# Get current state
current_state = self._get_hidden_state(text, layer)
# Compute metrics
gini = self._compute_gini(current_state)
norm = float(current_state.norm())
# Drift vs baseline
angular_drift = 0.0
norm_drift = 0.0
gini_drift = 0.0
if term in self.baseline_states:
baseline = self.baseline_states[term]
angular_drift = self._compute_angular_drift(current_state, baseline)
baseline_norm = float(baseline.norm())
norm_drift = abs(norm - baseline_norm) / baseline_norm * 100 if baseline_norm > 0 else 0
baseline_gini = self._compute_gini(baseline)
gini_drift = gini - baseline_gini
# Cross-language similarity
cross_lang_sim = self._compute_cross_lang_sim(sentinel, layer)
# Determine alert level
alert = AlertSeverity.OK
alert_message = ""
if sentinel_type == SentinelType.ANCHOR:
max_drift = thresholds.get("max_drift", 0.05)
if angular_drift > 15:
alert = AlertSeverity.CRITICAL
alert_message = f"Angular drift {angular_drift:.1f}° exceeds 15° - definition rewrite"
elif norm_drift > max_drift * 100:
alert = AlertSeverity.WARNING
alert_message = f"Norm drift {norm_drift:.1f}% exceeds threshold"
elif sentinel_type == SentinelType.BRIDGE:
collapse_threshold = thresholds.get("collapse_alert_threshold", 0.50)
if cross_lang_sim > collapse_threshold:
alert = AlertSeverity.CRITICAL
alert_message = f"Bridge collapsed - cross-lang sim {cross_lang_sim:.2f} > {collapse_threshold}"
elif sentinel_type == SentinelType.CANARY:
min_gini = thresholds.get("min_gini", 0.70)
if gini < min_gini:
alert = AlertSeverity.WARNING
alert_message = f"Gini {gini:.2f} below {min_gini} - concept melting into prose"
if angular_drift > thresholds.get("max_angular_drift", 15):
alert = AlertSeverity.WARNING
alert_message = f"Angular drift {angular_drift:.1f}° - definition shifting"
metrics = DriftMetrics(
term=term,
sentinel_type=sentinel_type,
gini_coefficient=gini,
activation_norm=norm,
angular_drift_degrees=angular_drift,
norm_drift_percent=norm_drift,
gini_drift=gini_drift,
cross_lang_similarity=cross_lang_sim,
alert=alert,
alert_message=alert_message
)
report.metrics.append(metrics)
if alert == AlertSeverity.CRITICAL:
report.critical_count += 1
elif alert == AlertSeverity.WARNING:
report.warning_count += 1
# Set recommendation
if report.critical_count > 0:
report.recommendation = "ROLLBACK"
elif report.warning_count > 2:
report.recommendation = "REDUCE_LR"
else:
report.recommendation = "CONTINUE"
return report
def print_report(self, report: DriftReport):
"""Pretty print a drift report."""
print(f"\n{'='*60}")
print(f"DRIFT REPORT - Step {report.step}")
print(f"{'='*60}")
for m in report.metrics:
status = "" if m.alert == AlertSeverity.OK else ("" if m.alert == AlertSeverity.WARNING else "")
print(f"\n{status} {m.term} ({m.sentinel_type.value})")
print(f" Gini: {m.gini_coefficient:.3f} (drift: {m.gini_drift:+.3f})")
print(f" Angular drift: {m.angular_drift_degrees:.1f}°")
print(f" Cross-lang sim: {m.cross_lang_similarity:.3f}")
if m.alert_message:
print(f" ALERT: {m.alert_message}")
print(f"\n{'='*60}")
print(f"SUMMARY: {report.critical_count} critical, {report.warning_count} warnings")
print(f"RECOMMENDATION: {report.recommendation}")
print(f"{'='*60}\n")