""" Core Model Loader for nyx-probing. Provides access to Qwen2.5-7B-Base with hidden state capture. The model is an "empty vessel" - it completes, not answers. """ from dataclasses import dataclass, field from typing import Optional, List, Tuple import torch from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig @dataclass class GenerationResult: """Result from a generation with hidden states.""" # The generated text (including prompt) text: str # Just the completion (without prompt) completion: str # Token IDs of the full sequence token_ids: List[int] # Token IDs of just the completion completion_token_ids: List[int] # Hidden states from the last layer for each generated token # Shape: (num_generated_tokens, hidden_dim) hidden_states: Optional[torch.Tensor] = None # Token probabilities for each generated token # Shape: (num_generated_tokens,) token_probs: Optional[torch.Tensor] = None # Whether generation ended with EOS hit_eos: bool = False # Number of tokens generated num_tokens: int = 0 class NyxModel: """ Model wrapper for probing Qwen2.5-7B-Base. Key capabilities: - Hidden state capture during generation - Token probability extraction - Proper handling of base model (no chat template) """ def __init__( self, model_name: str = "Qwen/Qwen2.5-7B", device: str = "cuda", dtype: str = "float16", cache_dir: Optional[str] = None, ): self.model_name = model_name self.device = device self.dtype = getattr(torch, dtype) self.cache_dir = cache_dir self._model = None self._tokenizer = None self._loaded = False def load(self) -> "NyxModel": """Load the model and tokenizer.""" if self._loaded: return self print(f"Loading tokenizer: {self.model_name}") self._tokenizer = AutoTokenizer.from_pretrained( self.model_name, cache_dir=self.cache_dir, ) print(f"Loading model to {self.device}...") self._model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=self.dtype, device_map=self.device, cache_dir=self.cache_dir, # Critical for activation capture output_hidden_states=True, ) self._loaded = True print(f"Model loaded. VRAM: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") return self @property def model(self): if not self._loaded: raise RuntimeError("Model not loaded. Call load() first.") return self._model @property def tokenizer(self): if not self._loaded: raise RuntimeError("Model not loaded. Call load() first.") return self._tokenizer def generate( self, prompt: str, max_new_tokens: int = 50, temperature: float = 0.8, do_sample: bool = True, capture_hidden_states: bool = False, capture_probabilities: bool = False, ) -> GenerationResult: """ Generate completion with optional hidden state capture. Args: prompt: Input text to complete max_new_tokens: Maximum tokens to generate temperature: Sampling temperature (0 = greedy) do_sample: Whether to sample (False = greedy) capture_hidden_states: Store hidden states from last layer capture_probabilities: Store token probabilities Returns: GenerationResult with text, tokens, and optionally hidden states """ # Tokenize input inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) prompt_length = inputs.input_ids.shape[1] # Generation config gen_config = GenerationConfig( max_new_tokens=max_new_tokens, temperature=temperature if do_sample else 1.0, do_sample=do_sample, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, output_hidden_states=capture_hidden_states, output_scores=capture_probabilities, return_dict_in_generate=True, ) # Generate with torch.no_grad(): outputs = self.model.generate( **inputs, generation_config=gen_config, ) # Extract sequences full_ids = outputs.sequences[0].tolist() completion_ids = full_ids[prompt_length:] # Decode full_text = self.tokenizer.decode(full_ids) completion_text = self.tokenizer.decode(completion_ids) # Check if hit EOS hit_eos = ( len(completion_ids) > 0 and completion_ids[-1] == self.tokenizer.eos_token_id ) # Build result result = GenerationResult( text=full_text, completion=completion_text, token_ids=full_ids, completion_token_ids=completion_ids, hit_eos=hit_eos, num_tokens=len(completion_ids), ) # Extract hidden states if requested if capture_hidden_states and hasattr(outputs, 'hidden_states'): # hidden_states is tuple of (step, layer, batch, seq, hidden) # We want last layer hidden state for each generated token hidden_list = [] for step_states in outputs.hidden_states: # step_states is tuple of layers # Take last layer, batch 0, last position last_layer = step_states[-1] # (batch, seq, hidden) hidden_list.append(last_layer[0, -1, :]) # (hidden,) result.hidden_states = torch.stack(hidden_list) # (tokens, hidden) # Extract probabilities if requested if capture_probabilities and hasattr(outputs, 'scores'): # scores is tuple of (num_tokens,) each (batch, vocab) probs_list = [] for i, score in enumerate(outputs.scores): # Apply softmax to get probabilities probs = torch.softmax(score[0], dim=-1) # Get probability of the token that was actually chosen chosen_token = completion_ids[i] probs_list.append(probs[chosen_token].item()) result.token_probs = torch.tensor(probs_list) return result def get_token_probabilities( self, prompt: str, continuation: str, ) -> Tuple[List[float], List[str]]: """ Get probability of each token in a specific continuation. Useful for measuring how "expected" a completion is. Args: prompt: The input text continuation: The text that follows Returns: Tuple of (probabilities, token_strings) """ # Tokenize prompt and full sequence prompt_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) full_text = prompt + continuation full_ids = self.tokenizer.encode(full_text, return_tensors="pt").to(self.device) prompt_len = prompt_ids.shape[1] # Forward pass to get logits with torch.no_grad(): outputs = self.model(full_ids) logits = outputs.logits # (batch, seq, vocab) # Get probabilities for continuation tokens probs = [] tokens = [] for i in range(prompt_len, full_ids.shape[1]): # Logits at position i-1 predict token at position i token_logits = logits[0, i - 1, :] token_probs = torch.softmax(token_logits, dim=-1) actual_token = full_ids[0, i].item() prob = token_probs[actual_token].item() probs.append(prob) tokens.append(self.tokenizer.decode([actual_token])) return probs, tokens def tokenize(self, text: str) -> List[str]: """Get individual tokens for text.""" ids = self.tokenizer.encode(text) return [self.tokenizer.decode([id]) for id in ids] def token_count(self, text: str) -> int: """Count tokens in text.""" return len(self.tokenizer.encode(text)) def memory_usage(self) -> dict: """Get current GPU memory usage.""" return { "allocated_gb": torch.cuda.memory_allocated() / 1024**3, "reserved_gb": torch.cuda.memory_reserved() / 1024**3, "max_allocated_gb": torch.cuda.max_memory_allocated() / 1024**3, }