178 lines
6.0 KiB
Python
178 lines
6.0 KiB
Python
"""Lore Injector - Injects retrieved lore into SkyrimNet prompts."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import structlog
|
|
|
|
from .models import InjectionResult, LoreEntry, NPCProfile
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class LoreInjector:
|
|
"""Injects Oghma lore into SkyrimNet chat messages."""
|
|
|
|
DEFAULT_TEMPLATE = """
|
|
## Relevant Lore Knowledge
|
|
|
|
Based on your background as a {race} {profession} in {location}, you would know:
|
|
|
|
{lore_items}
|
|
|
|
Note: Reference this knowledge naturally when relevant to the conversation. Do not recite it.
|
|
"""
|
|
|
|
def __init__(self, template: str | None = None, position: str = "after_bio"):
|
|
"""
|
|
Initialize injector.
|
|
|
|
Args:
|
|
template: Jinja-style template for injection block
|
|
position: Where to inject - 'after_bio', 'before_conversation', 'system_suffix'
|
|
"""
|
|
self.template = template or self.DEFAULT_TEMPLATE
|
|
self.position = position
|
|
|
|
def inject(
|
|
self,
|
|
messages: list[dict],
|
|
npc_profile: NPCProfile,
|
|
lore_entries: list[LoreEntry],
|
|
query_time_ms: float,
|
|
) -> tuple[list[dict], InjectionResult]:
|
|
"""
|
|
Inject lore into chat messages.
|
|
|
|
Args:
|
|
messages: Original chat messages
|
|
npc_profile: Extracted NPC profile
|
|
lore_entries: Retrieved lore entries
|
|
query_time_ms: Time taken for retrieval
|
|
|
|
Returns:
|
|
Tuple of (modified messages, injection result)
|
|
"""
|
|
if not lore_entries:
|
|
return messages, InjectionResult(
|
|
npc_profile=npc_profile,
|
|
lore_entries=[],
|
|
injection_text="",
|
|
query_time_ms=query_time_ms,
|
|
)
|
|
|
|
# Build injection text
|
|
injection_text = self._build_injection_text(npc_profile, lore_entries)
|
|
|
|
# Clone messages to avoid modifying original
|
|
modified_messages = [dict(msg) for msg in messages]
|
|
|
|
# Find injection point
|
|
injected = False
|
|
for i, msg in enumerate(modified_messages):
|
|
if msg.get("role") == "system":
|
|
content = msg.get("content", "")
|
|
|
|
if self.position == "after_bio":
|
|
# Inject after character bio section
|
|
bio_markers = ["## Background", "## Personality", "## Speech Style"]
|
|
for marker in bio_markers:
|
|
if marker in content:
|
|
# Insert before this section
|
|
idx = content.index(marker)
|
|
modified_messages[i]["content"] = (
|
|
content[:idx] + injection_text + "\n\n" + content[idx:]
|
|
)
|
|
injected = True
|
|
break
|
|
|
|
elif self.position == "system_suffix":
|
|
# Append to end of system message
|
|
modified_messages[i]["content"] = content + "\n\n" + injection_text
|
|
injected = True
|
|
|
|
if injected:
|
|
break
|
|
|
|
# Fallback: prepend to first user message if no system message found
|
|
if not injected and self.position == "before_conversation":
|
|
for i, msg in enumerate(modified_messages):
|
|
if msg.get("role") == "user":
|
|
content = msg.get("content", "")
|
|
modified_messages[i]["content"] = (
|
|
f"[Context for the NPC you're speaking with]\n{injection_text}\n\n"
|
|
f"[Player speaks]\n{content}"
|
|
)
|
|
injected = True
|
|
break
|
|
|
|
if injected:
|
|
logger.info(
|
|
"Injected lore",
|
|
npc_name=npc_profile.name,
|
|
entries_count=len(lore_entries),
|
|
position=self.position,
|
|
)
|
|
else:
|
|
seen_headers: list[str] = []
|
|
system_msg_count = 0
|
|
system_content_chars = 0
|
|
for msg in modified_messages:
|
|
if msg.get("role") == "system":
|
|
system_msg_count += 1
|
|
content = msg.get("content", "")
|
|
system_content_chars += len(content)
|
|
for line in content.splitlines():
|
|
stripped = line.strip()
|
|
if stripped.startswith("## "):
|
|
seen_headers.append(stripped)
|
|
if len(seen_headers) >= 20:
|
|
break
|
|
if len(seen_headers) >= 20:
|
|
break
|
|
|
|
logger.warning(
|
|
"Could not find injection point",
|
|
position=self.position,
|
|
npc_name=npc_profile.name,
|
|
system_messages=system_msg_count,
|
|
system_content_chars=system_content_chars,
|
|
seen_headers=seen_headers,
|
|
)
|
|
|
|
result = InjectionResult(
|
|
npc_profile=npc_profile,
|
|
lore_entries=lore_entries,
|
|
injection_text=injection_text if injected else "",
|
|
query_time_ms=query_time_ms,
|
|
)
|
|
|
|
return modified_messages, result
|
|
|
|
def _build_injection_text(
|
|
self,
|
|
npc_profile: NPCProfile,
|
|
lore_entries: list[LoreEntry],
|
|
) -> str:
|
|
"""Build the injection text block."""
|
|
# Build lore items list
|
|
lore_items = []
|
|
for entry in lore_entries:
|
|
# Truncate very long entries
|
|
content = entry.content
|
|
if len(content) > 300:
|
|
content = content[:297] + "..."
|
|
lore_items.append(f"- **{entry.topic}**: {content}")
|
|
|
|
lore_items_text = "\n".join(lore_items)
|
|
|
|
# Fill template
|
|
injection_text = self.template.format(
|
|
race=npc_profile.race or "person",
|
|
profession=npc_profile.profession or "citizen",
|
|
location=npc_profile.location or "Skyrim",
|
|
lore_items=lore_items_text,
|
|
name=npc_profile.name,
|
|
)
|
|
|
|
return injection_text.strip()
|