""" Surface Probe: First contact with a term. The surface probe feeds a word to the model and captures what it completes. This reveals the model's immediate associations - which "valley" the word sits in. Examples discovered: - "heartbeat" → C++ code patterns (technical valley) - "consciousness" → philosophy (expository valley) """ from typing import Optional from dataclasses import dataclass, field from datetime import datetime from collections import Counter from .base import BaseProbe from ..core.model import NyxModel, GenerationResult from ..core.probe_result import SurfaceProbeResult @dataclass class CompletionCategory: """Categories of completions we observe.""" CODE = "code" # Programming constructs PROSE = "prose" # Natural language text TECHNICAL = "technical" # Technical/scientific writing LIST = "list" # Enumerations, bullet points DEFINITION = "definition" # Dictionary-style definitions UNKNOWN = "unknown" class SurfaceProbe(BaseProbe): """ Surface probe: measures immediate associations. Runs multiple completions to get a distribution, then analyzes: - What type of content does the model generate? - How consistent are the completions? - Does it hit EOS (contained thought) or run to max_tokens? """ def __init__( self, model: NyxModel, num_runs: int = 5, max_new_tokens: int = 50, temperature: float = 0.8, ): super().__init__(model) self.num_runs = num_runs self.max_new_tokens = max_new_tokens self.temperature = temperature def probe( self, term: str, num_runs: Optional[int] = None, capture_hidden: bool = False, ) -> SurfaceProbeResult: """ Probe a term with multiple completions. Args: term: Word or phrase to probe num_runs: Override default number of runs capture_hidden: Whether to capture hidden states Returns: SurfaceProbeResult with completions and analysis """ runs = num_runs or self.num_runs completions = [] eos_count = 0 total_tokens = 0 hidden_states = [] for _ in range(runs): result = self.model.generate( prompt=term, max_new_tokens=self.max_new_tokens, temperature=self.temperature, do_sample=True, capture_hidden_states=capture_hidden, ) completions.append(result.completion) if result.hit_eos: eos_count += 1 total_tokens += result.num_tokens if capture_hidden and result.hidden_states is not None: hidden_states.append(result.hidden_states) # Calculate coherence (how similar are completions to each other?) coherence = self._calculate_coherence(completions) return SurfaceProbeResult( term=term, completions=completions, hit_eos_count=eos_count, avg_tokens=total_tokens / runs, coherence_score=coherence, ) def _calculate_coherence(self, completions: list[str]) -> float: """ Calculate coherence score based on completion similarity. Simple heuristic: measures overlap in first-word distributions and overall length variance. Returns 0-1 score where 1 = highly coherent. """ if len(completions) < 2: return 1.0 # Get first significant words (skip punctuation/whitespace) first_words = [] for comp in completions: words = comp.split() for w in words: if len(w) > 1 and w.isalnum(): first_words.append(w.lower()) break if not first_words: return 0.0 # Calculate concentration of first words # If all completions start with same word = high coherence word_counts = Counter(first_words) most_common_count = word_counts.most_common(1)[0][1] first_word_coherence = most_common_count / len(completions) # Check length variance lengths = [len(c) for c in completions] avg_len = sum(lengths) / len(lengths) if avg_len > 0: variance = sum((l - avg_len) ** 2 for l in lengths) / len(lengths) # Normalize variance to 0-1 (higher variance = lower coherence) length_coherence = 1.0 / (1.0 + variance / 1000) else: length_coherence = 0.0 # Combine (weight first-word more heavily) return 0.7 * first_word_coherence + 0.3 * length_coherence def classify_completions(self, result: SurfaceProbeResult) -> dict: """ Classify the types of completions observed. Returns breakdown of completion categories. """ categories = Counter() for comp in result.completions: cat = self._classify_single(comp) categories[cat] += 1 return { "categories": dict(categories), "dominant": categories.most_common(1)[0][0] if categories else "unknown", "diversity": len(categories) / len(result.completions) if result.completions else 0, } def _classify_single(self, completion: str) -> str: """Classify a single completion.""" # Simple heuristics - can be made smarter comp_lower = completion.lower().strip() # Code indicators code_patterns = ["::", "{", "}", "();", "=>", "function", "class ", "def ", "return"] if any(p in completion for p in code_patterns): return CompletionCategory.CODE # Definition patterns if comp_lower.startswith(("is ", "means ", "refers to", "- ")): return CompletionCategory.DEFINITION # List patterns if comp_lower.startswith(("1.", "2.", "- ", "* ", "a)")): return CompletionCategory.LIST # Technical patterns tech_words = ["algorithm", "function", "variable", "method", "system", "process"] if any(w in comp_lower for w in tech_words): return CompletionCategory.TECHNICAL # Default to prose if it looks like natural language if len(comp_lower.split()) > 3: return CompletionCategory.PROSE return CompletionCategory.UNKNOWN def summary(self, result: SurfaceProbeResult) -> str: """Generate human-readable summary of probe result.""" classification = self.classify_completions(result) eos_pct = (result.hit_eos_count / len(result.completions)) * 100 lines = [ f"Surface Probe: '{result.term}'", f" Runs: {len(result.completions)}", f" Dominant type: {classification['dominant']}", f" Coherence: {result.coherence_score:.2f}", f" Avg tokens: {result.avg_tokens:.1f}", f" Hit EOS: {eos_pct:.0f}%", f" Sample: {result.completions[0][:60]}...", ] return "\n".join(lines)