- CLI: nyx-probe scan with --summary/--delta/--full flags - DriftProbe: training safety with Gini coefficient + Angular Drift - Vocabulary: 54 terms (30 nimmerverse + 24 German philosophical) - Sentinels: ANCHOR/BRIDGE/CANARY/TARGET monitoring system Key findings: - German philosophical terms: 37.5% depth≥2 hit rate (vs 3.3% nimmerverse) - Super Cluster validated: heart cross-lang sim = 1.000 - Isolated Zone confirmed: being EN↔DE sim = 0.195 - Gini signature: Philosophy ~0.5 (diffuse), Technical ~0.8 (sparse) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
305 lines
11 KiB
Python
305 lines
11 KiB
Python
"""
|
|
DriftProbe: Training-loop monitoring for conceptual topology preservation.
|
|
|
|
Theory: "Spatial Separation Hypothesis"
|
|
- Use isolated zone languages (German) as scaffolding for new concepts
|
|
- Monitor anchors (must not move), bridges (must stay separated), canaries (watch for migration)
|
|
|
|
Key Metrics (refined from peer review):
|
|
1. Gini Coefficient: Sparse activations (0.8+) = deep/specific, Diffuse (0.3) = shallow/general
|
|
2. Angular Drift: Direction change = definition rewrite, magnitude change = sharpening
|
|
3. Cross-Language Similarity: Bridges should stay LOW, anchors should stay HIGH
|
|
"""
|
|
import json
|
|
from pathlib import Path
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional
|
|
from enum import Enum
|
|
|
|
import torch
|
|
import numpy as np
|
|
|
|
|
|
class SentinelType(Enum):
|
|
ANCHOR = "ANCHOR" # Must not move - core topology
|
|
BRIDGE = "BRIDGE" # Must stay separated - isolated zone integrity
|
|
CANARY = "CANARY" # Watch for migration - early warning
|
|
TARGET = "TARGET" # Want movement - training goals
|
|
|
|
|
|
class AlertSeverity(Enum):
|
|
OK = "OK"
|
|
WARNING = "WARNING"
|
|
CRITICAL = "CRITICAL"
|
|
|
|
|
|
@dataclass
|
|
class DriftMetrics:
|
|
"""Metrics for a single sentinel term."""
|
|
term: str
|
|
sentinel_type: SentinelType
|
|
|
|
# Activation metrics
|
|
gini_coefficient: float = 0.0
|
|
activation_norm: float = 0.0
|
|
|
|
# Drift metrics (vs baseline)
|
|
angular_drift_degrees: float = 0.0
|
|
norm_drift_percent: float = 0.0
|
|
gini_drift: float = 0.0
|
|
|
|
# Valley detection
|
|
detected_valley: str = "UNKNOWN"
|
|
depth: int = 0
|
|
|
|
# Cross-language (for anchors/bridges)
|
|
cross_lang_similarity: float = 0.0
|
|
|
|
# Alert
|
|
alert: AlertSeverity = AlertSeverity.OK
|
|
alert_message: str = ""
|
|
|
|
|
|
@dataclass
|
|
class DriftReport:
|
|
"""Full drift report for a training checkpoint."""
|
|
step: int
|
|
timestamp: str
|
|
metrics: list[DriftMetrics] = field(default_factory=list)
|
|
|
|
# Summary
|
|
critical_count: int = 0
|
|
warning_count: int = 0
|
|
recommendation: str = "CONTINUE"
|
|
|
|
|
|
class DriftProbe:
|
|
"""
|
|
Lightweight probe for training-loop monitoring.
|
|
|
|
Optimized for RTX 3090 constraints:
|
|
- Full probe: ~2 min (run at epoch 0, end of training)
|
|
- Lite probe: ~10 sec (run every 100 steps)
|
|
"""
|
|
|
|
def __init__(self, model, tokenizer, sentinels_path: Optional[str] = None):
|
|
self.model = model
|
|
self.tokenizer = tokenizer
|
|
self.baseline_states = {} # term -> hidden state tensor
|
|
|
|
# Load sentinels
|
|
if sentinels_path is None:
|
|
sentinels_path = Path(__file__).parent.parent.parent / "data" / "sentinels.json"
|
|
|
|
with open(sentinels_path) as f:
|
|
self.config = json.load(f)
|
|
|
|
self.sentinels = self.config["sentinels"]
|
|
self.alert_rules = self.config["alert_rules"]
|
|
|
|
def _get_hidden_state(self, text: str, layer: int = 18) -> torch.Tensor:
|
|
"""Get hidden state at specified layer for last token position."""
|
|
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
|
|
with torch.no_grad():
|
|
outputs = self.model(**inputs, output_hidden_states=True)
|
|
return outputs.hidden_states[layer][0, -1, :].float().cpu()
|
|
|
|
def _compute_gini(self, activations: torch.Tensor) -> float:
|
|
"""
|
|
Compute Gini coefficient of activation vector.
|
|
|
|
High Gini (0.8+) = Sparse/Specific (Philosophy/Deep)
|
|
Low Gini (0.3) = Diffuse/General (Prose/Shallow)
|
|
"""
|
|
x = torch.abs(activations).numpy()
|
|
x = np.sort(x)
|
|
n = len(x)
|
|
cumsum = np.cumsum(x)
|
|
gini = (2 * np.sum((np.arange(1, n+1) * x))) / (n * np.sum(x)) - (n + 1) / n
|
|
return float(gini)
|
|
|
|
def _compute_angular_drift(self, current: torch.Tensor, baseline: torch.Tensor) -> float:
|
|
"""
|
|
Compute angular drift in degrees between current and baseline.
|
|
|
|
> 15° = Definition rewrite (concerning)
|
|
< 5° = Sharpening only (acceptable)
|
|
"""
|
|
cos_sim = torch.nn.functional.cosine_similarity(
|
|
current.unsqueeze(0), baseline.unsqueeze(0)
|
|
).item()
|
|
# Clamp to valid range for arccos
|
|
cos_sim = max(-1.0, min(1.0, cos_sim))
|
|
angle_rad = np.arccos(cos_sim)
|
|
return float(np.degrees(angle_rad))
|
|
|
|
def _compute_cross_lang_sim(self, sentinel: dict, layer: int = 18) -> float:
|
|
"""Compute average cross-language similarity for a sentinel."""
|
|
translations = sentinel.get("translations", {})
|
|
if len(translations) < 2:
|
|
return 0.0
|
|
|
|
states = []
|
|
for lang, word in translations.items():
|
|
states.append(self._get_hidden_state(word, layer))
|
|
|
|
# Pairwise similarities
|
|
sims = []
|
|
for i in range(len(states)):
|
|
for j in range(i + 1, len(states)):
|
|
sim = torch.nn.functional.cosine_similarity(
|
|
states[i].unsqueeze(0), states[j].unsqueeze(0)
|
|
).item()
|
|
sims.append(sim)
|
|
|
|
return float(np.mean(sims)) if sims else 0.0
|
|
|
|
def capture_baseline(self, layer: int = 18):
|
|
"""
|
|
Capture baseline hidden states for all sentinels.
|
|
Run this at epoch 0 before training.
|
|
"""
|
|
print("Capturing baseline states...")
|
|
for sentinel in self.sentinels:
|
|
term = sentinel["term"]
|
|
# Use English translation or term itself
|
|
text = sentinel.get("translations", {}).get("EN", term)
|
|
self.baseline_states[term] = self._get_hidden_state(text, layer)
|
|
print(f"Baseline captured for {len(self.baseline_states)} sentinels")
|
|
|
|
def probe_lite(self, step: int, layer: int = 18) -> DriftReport:
|
|
"""
|
|
Lite probe - only check key sentinels.
|
|
Optimized for ~10 second runtime.
|
|
"""
|
|
from datetime import datetime
|
|
|
|
# Select subset: 2 anchors, 1 bridge, 2 canaries
|
|
lite_terms = ["heart", "water", "being", "dasein", "thrownness"]
|
|
lite_sentinels = [s for s in self.sentinels if s["term"] in lite_terms]
|
|
|
|
return self._run_probe(lite_sentinels, step, layer)
|
|
|
|
def probe_full(self, step: int, layer: int = 18) -> DriftReport:
|
|
"""
|
|
Full probe - check all sentinels.
|
|
Runtime: ~2 minutes.
|
|
"""
|
|
return self._run_probe(self.sentinels, step, layer)
|
|
|
|
def _run_probe(self, sentinels: list, step: int, layer: int) -> DriftReport:
|
|
"""Run probe on specified sentinels."""
|
|
from datetime import datetime
|
|
|
|
report = DriftReport(
|
|
step=step,
|
|
timestamp=datetime.now().isoformat()
|
|
)
|
|
|
|
for sentinel in sentinels:
|
|
term = sentinel["term"]
|
|
text = sentinel.get("translations", {}).get("EN", term)
|
|
sentinel_type = SentinelType(sentinel["type"])
|
|
thresholds = sentinel.get("thresholds", {})
|
|
|
|
# Get current state
|
|
current_state = self._get_hidden_state(text, layer)
|
|
|
|
# Compute metrics
|
|
gini = self._compute_gini(current_state)
|
|
norm = float(current_state.norm())
|
|
|
|
# Drift vs baseline
|
|
angular_drift = 0.0
|
|
norm_drift = 0.0
|
|
gini_drift = 0.0
|
|
|
|
if term in self.baseline_states:
|
|
baseline = self.baseline_states[term]
|
|
angular_drift = self._compute_angular_drift(current_state, baseline)
|
|
baseline_norm = float(baseline.norm())
|
|
norm_drift = abs(norm - baseline_norm) / baseline_norm * 100 if baseline_norm > 0 else 0
|
|
baseline_gini = self._compute_gini(baseline)
|
|
gini_drift = gini - baseline_gini
|
|
|
|
# Cross-language similarity
|
|
cross_lang_sim = self._compute_cross_lang_sim(sentinel, layer)
|
|
|
|
# Determine alert level
|
|
alert = AlertSeverity.OK
|
|
alert_message = ""
|
|
|
|
if sentinel_type == SentinelType.ANCHOR:
|
|
max_drift = thresholds.get("max_drift", 0.05)
|
|
if angular_drift > 15:
|
|
alert = AlertSeverity.CRITICAL
|
|
alert_message = f"Angular drift {angular_drift:.1f}° exceeds 15° - definition rewrite"
|
|
elif norm_drift > max_drift * 100:
|
|
alert = AlertSeverity.WARNING
|
|
alert_message = f"Norm drift {norm_drift:.1f}% exceeds threshold"
|
|
|
|
elif sentinel_type == SentinelType.BRIDGE:
|
|
collapse_threshold = thresholds.get("collapse_alert_threshold", 0.50)
|
|
if cross_lang_sim > collapse_threshold:
|
|
alert = AlertSeverity.CRITICAL
|
|
alert_message = f"Bridge collapsed - cross-lang sim {cross_lang_sim:.2f} > {collapse_threshold}"
|
|
|
|
elif sentinel_type == SentinelType.CANARY:
|
|
min_gini = thresholds.get("min_gini", 0.70)
|
|
if gini < min_gini:
|
|
alert = AlertSeverity.WARNING
|
|
alert_message = f"Gini {gini:.2f} below {min_gini} - concept melting into prose"
|
|
if angular_drift > thresholds.get("max_angular_drift", 15):
|
|
alert = AlertSeverity.WARNING
|
|
alert_message = f"Angular drift {angular_drift:.1f}° - definition shifting"
|
|
|
|
metrics = DriftMetrics(
|
|
term=term,
|
|
sentinel_type=sentinel_type,
|
|
gini_coefficient=gini,
|
|
activation_norm=norm,
|
|
angular_drift_degrees=angular_drift,
|
|
norm_drift_percent=norm_drift,
|
|
gini_drift=gini_drift,
|
|
cross_lang_similarity=cross_lang_sim,
|
|
alert=alert,
|
|
alert_message=alert_message
|
|
)
|
|
|
|
report.metrics.append(metrics)
|
|
|
|
if alert == AlertSeverity.CRITICAL:
|
|
report.critical_count += 1
|
|
elif alert == AlertSeverity.WARNING:
|
|
report.warning_count += 1
|
|
|
|
# Set recommendation
|
|
if report.critical_count > 0:
|
|
report.recommendation = "ROLLBACK"
|
|
elif report.warning_count > 2:
|
|
report.recommendation = "REDUCE_LR"
|
|
else:
|
|
report.recommendation = "CONTINUE"
|
|
|
|
return report
|
|
|
|
def print_report(self, report: DriftReport):
|
|
"""Pretty print a drift report."""
|
|
print(f"\n{'='*60}")
|
|
print(f"DRIFT REPORT - Step {report.step}")
|
|
print(f"{'='*60}")
|
|
|
|
for m in report.metrics:
|
|
status = "✓" if m.alert == AlertSeverity.OK else ("⚠" if m.alert == AlertSeverity.WARNING else "✗")
|
|
print(f"\n{status} {m.term} ({m.sentinel_type.value})")
|
|
print(f" Gini: {m.gini_coefficient:.3f} (drift: {m.gini_drift:+.3f})")
|
|
print(f" Angular drift: {m.angular_drift_degrees:.1f}°")
|
|
print(f" Cross-lang sim: {m.cross_lang_similarity:.3f}")
|
|
if m.alert_message:
|
|
print(f" ALERT: {m.alert_message}")
|
|
|
|
print(f"\n{'='*60}")
|
|
print(f"SUMMARY: {report.critical_count} critical, {report.warning_count} warnings")
|
|
print(f"RECOMMENDATION: {report.recommendation}")
|
|
print(f"{'='*60}\n")
|