feat: complete Phase 1 - vocabulary expansion & DriftProbe infrastructure
- CLI: nyx-probe scan with --summary/--delta/--full flags - DriftProbe: training safety with Gini coefficient + Angular Drift - Vocabulary: 54 terms (30 nimmerverse + 24 German philosophical) - Sentinels: ANCHOR/BRIDGE/CANARY/TARGET monitoring system Key findings: - German philosophical terms: 37.5% depth≥2 hit rate (vs 3.3% nimmerverse) - Super Cluster validated: heart cross-lang sim = 1.000 - Isolated Zone confirmed: being EN↔DE sim = 0.195 - Gini signature: Philosophy ~0.5 (diffuse), Technical ~0.8 (sparse) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
304
nyx_probing/probes/drift_probe.py
Normal file
304
nyx_probing/probes/drift_probe.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
DriftProbe: Training-loop monitoring for conceptual topology preservation.
|
||||
|
||||
Theory: "Spatial Separation Hypothesis"
|
||||
- Use isolated zone languages (German) as scaffolding for new concepts
|
||||
- Monitor anchors (must not move), bridges (must stay separated), canaries (watch for migration)
|
||||
|
||||
Key Metrics (refined from peer review):
|
||||
1. Gini Coefficient: Sparse activations (0.8+) = deep/specific, Diffuse (0.3) = shallow/general
|
||||
2. Angular Drift: Direction change = definition rewrite, magnitude change = sharpening
|
||||
3. Cross-Language Similarity: Bridges should stay LOW, anchors should stay HIGH
|
||||
"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SentinelType(Enum):
|
||||
ANCHOR = "ANCHOR" # Must not move - core topology
|
||||
BRIDGE = "BRIDGE" # Must stay separated - isolated zone integrity
|
||||
CANARY = "CANARY" # Watch for migration - early warning
|
||||
TARGET = "TARGET" # Want movement - training goals
|
||||
|
||||
|
||||
class AlertSeverity(Enum):
|
||||
OK = "OK"
|
||||
WARNING = "WARNING"
|
||||
CRITICAL = "CRITICAL"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DriftMetrics:
|
||||
"""Metrics for a single sentinel term."""
|
||||
term: str
|
||||
sentinel_type: SentinelType
|
||||
|
||||
# Activation metrics
|
||||
gini_coefficient: float = 0.0
|
||||
activation_norm: float = 0.0
|
||||
|
||||
# Drift metrics (vs baseline)
|
||||
angular_drift_degrees: float = 0.0
|
||||
norm_drift_percent: float = 0.0
|
||||
gini_drift: float = 0.0
|
||||
|
||||
# Valley detection
|
||||
detected_valley: str = "UNKNOWN"
|
||||
depth: int = 0
|
||||
|
||||
# Cross-language (for anchors/bridges)
|
||||
cross_lang_similarity: float = 0.0
|
||||
|
||||
# Alert
|
||||
alert: AlertSeverity = AlertSeverity.OK
|
||||
alert_message: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DriftReport:
|
||||
"""Full drift report for a training checkpoint."""
|
||||
step: int
|
||||
timestamp: str
|
||||
metrics: list[DriftMetrics] = field(default_factory=list)
|
||||
|
||||
# Summary
|
||||
critical_count: int = 0
|
||||
warning_count: int = 0
|
||||
recommendation: str = "CONTINUE"
|
||||
|
||||
|
||||
class DriftProbe:
|
||||
"""
|
||||
Lightweight probe for training-loop monitoring.
|
||||
|
||||
Optimized for RTX 3090 constraints:
|
||||
- Full probe: ~2 min (run at epoch 0, end of training)
|
||||
- Lite probe: ~10 sec (run every 100 steps)
|
||||
"""
|
||||
|
||||
def __init__(self, model, tokenizer, sentinels_path: Optional[str] = None):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.baseline_states = {} # term -> hidden state tensor
|
||||
|
||||
# Load sentinels
|
||||
if sentinels_path is None:
|
||||
sentinels_path = Path(__file__).parent.parent.parent / "data" / "sentinels.json"
|
||||
|
||||
with open(sentinels_path) as f:
|
||||
self.config = json.load(f)
|
||||
|
||||
self.sentinels = self.config["sentinels"]
|
||||
self.alert_rules = self.config["alert_rules"]
|
||||
|
||||
def _get_hidden_state(self, text: str, layer: int = 18) -> torch.Tensor:
|
||||
"""Get hidden state at specified layer for last token position."""
|
||||
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs, output_hidden_states=True)
|
||||
return outputs.hidden_states[layer][0, -1, :].float().cpu()
|
||||
|
||||
def _compute_gini(self, activations: torch.Tensor) -> float:
|
||||
"""
|
||||
Compute Gini coefficient of activation vector.
|
||||
|
||||
High Gini (0.8+) = Sparse/Specific (Philosophy/Deep)
|
||||
Low Gini (0.3) = Diffuse/General (Prose/Shallow)
|
||||
"""
|
||||
x = torch.abs(activations).numpy()
|
||||
x = np.sort(x)
|
||||
n = len(x)
|
||||
cumsum = np.cumsum(x)
|
||||
gini = (2 * np.sum((np.arange(1, n+1) * x))) / (n * np.sum(x)) - (n + 1) / n
|
||||
return float(gini)
|
||||
|
||||
def _compute_angular_drift(self, current: torch.Tensor, baseline: torch.Tensor) -> float:
|
||||
"""
|
||||
Compute angular drift in degrees between current and baseline.
|
||||
|
||||
> 15° = Definition rewrite (concerning)
|
||||
< 5° = Sharpening only (acceptable)
|
||||
"""
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
current.unsqueeze(0), baseline.unsqueeze(0)
|
||||
).item()
|
||||
# Clamp to valid range for arccos
|
||||
cos_sim = max(-1.0, min(1.0, cos_sim))
|
||||
angle_rad = np.arccos(cos_sim)
|
||||
return float(np.degrees(angle_rad))
|
||||
|
||||
def _compute_cross_lang_sim(self, sentinel: dict, layer: int = 18) -> float:
|
||||
"""Compute average cross-language similarity for a sentinel."""
|
||||
translations = sentinel.get("translations", {})
|
||||
if len(translations) < 2:
|
||||
return 0.0
|
||||
|
||||
states = []
|
||||
for lang, word in translations.items():
|
||||
states.append(self._get_hidden_state(word, layer))
|
||||
|
||||
# Pairwise similarities
|
||||
sims = []
|
||||
for i in range(len(states)):
|
||||
for j in range(i + 1, len(states)):
|
||||
sim = torch.nn.functional.cosine_similarity(
|
||||
states[i].unsqueeze(0), states[j].unsqueeze(0)
|
||||
).item()
|
||||
sims.append(sim)
|
||||
|
||||
return float(np.mean(sims)) if sims else 0.0
|
||||
|
||||
def capture_baseline(self, layer: int = 18):
|
||||
"""
|
||||
Capture baseline hidden states for all sentinels.
|
||||
Run this at epoch 0 before training.
|
||||
"""
|
||||
print("Capturing baseline states...")
|
||||
for sentinel in self.sentinels:
|
||||
term = sentinel["term"]
|
||||
# Use English translation or term itself
|
||||
text = sentinel.get("translations", {}).get("EN", term)
|
||||
self.baseline_states[term] = self._get_hidden_state(text, layer)
|
||||
print(f"Baseline captured for {len(self.baseline_states)} sentinels")
|
||||
|
||||
def probe_lite(self, step: int, layer: int = 18) -> DriftReport:
|
||||
"""
|
||||
Lite probe - only check key sentinels.
|
||||
Optimized for ~10 second runtime.
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
# Select subset: 2 anchors, 1 bridge, 2 canaries
|
||||
lite_terms = ["heart", "water", "being", "dasein", "thrownness"]
|
||||
lite_sentinels = [s for s in self.sentinels if s["term"] in lite_terms]
|
||||
|
||||
return self._run_probe(lite_sentinels, step, layer)
|
||||
|
||||
def probe_full(self, step: int, layer: int = 18) -> DriftReport:
|
||||
"""
|
||||
Full probe - check all sentinels.
|
||||
Runtime: ~2 minutes.
|
||||
"""
|
||||
return self._run_probe(self.sentinels, step, layer)
|
||||
|
||||
def _run_probe(self, sentinels: list, step: int, layer: int) -> DriftReport:
|
||||
"""Run probe on specified sentinels."""
|
||||
from datetime import datetime
|
||||
|
||||
report = DriftReport(
|
||||
step=step,
|
||||
timestamp=datetime.now().isoformat()
|
||||
)
|
||||
|
||||
for sentinel in sentinels:
|
||||
term = sentinel["term"]
|
||||
text = sentinel.get("translations", {}).get("EN", term)
|
||||
sentinel_type = SentinelType(sentinel["type"])
|
||||
thresholds = sentinel.get("thresholds", {})
|
||||
|
||||
# Get current state
|
||||
current_state = self._get_hidden_state(text, layer)
|
||||
|
||||
# Compute metrics
|
||||
gini = self._compute_gini(current_state)
|
||||
norm = float(current_state.norm())
|
||||
|
||||
# Drift vs baseline
|
||||
angular_drift = 0.0
|
||||
norm_drift = 0.0
|
||||
gini_drift = 0.0
|
||||
|
||||
if term in self.baseline_states:
|
||||
baseline = self.baseline_states[term]
|
||||
angular_drift = self._compute_angular_drift(current_state, baseline)
|
||||
baseline_norm = float(baseline.norm())
|
||||
norm_drift = abs(norm - baseline_norm) / baseline_norm * 100 if baseline_norm > 0 else 0
|
||||
baseline_gini = self._compute_gini(baseline)
|
||||
gini_drift = gini - baseline_gini
|
||||
|
||||
# Cross-language similarity
|
||||
cross_lang_sim = self._compute_cross_lang_sim(sentinel, layer)
|
||||
|
||||
# Determine alert level
|
||||
alert = AlertSeverity.OK
|
||||
alert_message = ""
|
||||
|
||||
if sentinel_type == SentinelType.ANCHOR:
|
||||
max_drift = thresholds.get("max_drift", 0.05)
|
||||
if angular_drift > 15:
|
||||
alert = AlertSeverity.CRITICAL
|
||||
alert_message = f"Angular drift {angular_drift:.1f}° exceeds 15° - definition rewrite"
|
||||
elif norm_drift > max_drift * 100:
|
||||
alert = AlertSeverity.WARNING
|
||||
alert_message = f"Norm drift {norm_drift:.1f}% exceeds threshold"
|
||||
|
||||
elif sentinel_type == SentinelType.BRIDGE:
|
||||
collapse_threshold = thresholds.get("collapse_alert_threshold", 0.50)
|
||||
if cross_lang_sim > collapse_threshold:
|
||||
alert = AlertSeverity.CRITICAL
|
||||
alert_message = f"Bridge collapsed - cross-lang sim {cross_lang_sim:.2f} > {collapse_threshold}"
|
||||
|
||||
elif sentinel_type == SentinelType.CANARY:
|
||||
min_gini = thresholds.get("min_gini", 0.70)
|
||||
if gini < min_gini:
|
||||
alert = AlertSeverity.WARNING
|
||||
alert_message = f"Gini {gini:.2f} below {min_gini} - concept melting into prose"
|
||||
if angular_drift > thresholds.get("max_angular_drift", 15):
|
||||
alert = AlertSeverity.WARNING
|
||||
alert_message = f"Angular drift {angular_drift:.1f}° - definition shifting"
|
||||
|
||||
metrics = DriftMetrics(
|
||||
term=term,
|
||||
sentinel_type=sentinel_type,
|
||||
gini_coefficient=gini,
|
||||
activation_norm=norm,
|
||||
angular_drift_degrees=angular_drift,
|
||||
norm_drift_percent=norm_drift,
|
||||
gini_drift=gini_drift,
|
||||
cross_lang_similarity=cross_lang_sim,
|
||||
alert=alert,
|
||||
alert_message=alert_message
|
||||
)
|
||||
|
||||
report.metrics.append(metrics)
|
||||
|
||||
if alert == AlertSeverity.CRITICAL:
|
||||
report.critical_count += 1
|
||||
elif alert == AlertSeverity.WARNING:
|
||||
report.warning_count += 1
|
||||
|
||||
# Set recommendation
|
||||
if report.critical_count > 0:
|
||||
report.recommendation = "ROLLBACK"
|
||||
elif report.warning_count > 2:
|
||||
report.recommendation = "REDUCE_LR"
|
||||
else:
|
||||
report.recommendation = "CONTINUE"
|
||||
|
||||
return report
|
||||
|
||||
def print_report(self, report: DriftReport):
|
||||
"""Pretty print a drift report."""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"DRIFT REPORT - Step {report.step}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
for m in report.metrics:
|
||||
status = "✓" if m.alert == AlertSeverity.OK else ("⚠" if m.alert == AlertSeverity.WARNING else "✗")
|
||||
print(f"\n{status} {m.term} ({m.sentinel_type.value})")
|
||||
print(f" Gini: {m.gini_coefficient:.3f} (drift: {m.gini_drift:+.3f})")
|
||||
print(f" Angular drift: {m.angular_drift_degrees:.1f}°")
|
||||
print(f" Cross-lang sim: {m.cross_lang_similarity:.3f}")
|
||||
if m.alert_message:
|
||||
print(f" ALERT: {m.alert_message}")
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"SUMMARY: {report.critical_count} critical, {report.warning_count} warnings")
|
||||
print(f"RECOMMENDATION: {report.recommendation}")
|
||||
print(f"{'='*60}\n")
|
||||
Reference in New Issue
Block a user