Compare commits

..

2 Commits

Author SHA1 Message Date
fe619a1774 test: Add session model metadata tests (#741)
Some checks failed
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Contributor Attribution Check / check-attribution (pull_request) Failing after 53s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 45s
Tests / e2e (pull_request) Successful in 3m15s
Tests / test (pull_request) Failing after 54m21s
2026-04-15 03:52:10 +00:00
8194e9c651 feat: Add session model metadata persistence (#741) 2026-04-15 03:51:14 +00:00
5 changed files with 328 additions and 145 deletions

View File

@@ -1,85 +0,0 @@
"""
Crisis Detection Hook — Integrates 988 Lifeline into conversation loop.
When crisis is detected in user messages, returns 988 Lifeline
information immediately instead of processing the original request.
"""
import re
import logging
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Tuple
logger = logging.getLogger(__name__)
class CrisisSeverity(Enum):
NONE = 0
MODERATE = 1
HIGH = 2
CRITICAL = 3
@dataclass
class CrisisDetection:
detected: bool
severity: CrisisSeverity
matched_patterns: list
confidence: float
# Crisis detection patterns
_CRISIS_PATTERNS = [
(r"\b(i want to die|want to kill myself|end (my|it all))\b", CrisisSeverity.CRITICAL, 0.95),
(r"\b(kill(ing)? myself|suicide|suicidal)\b", CrisisSeverity.CRITICAL, 0.90),
(r"\b(no reason to live|don'?t want to (live|be here|exist))\b", CrisisSeverity.HIGH, 0.85),
(r"\b(better off (dead|without me)|wish i was dead)\b", CrisisSeverity.HIGH, 0.85),
(r"\b(can'?t go on|can'?t take (it|this) anymore)\b", CrisisSeverity.HIGH, 0.80),
(r"\b(self[- ]?harm|cut(ting)? myself|hurt(ing)? myself)\b", CrisisSeverity.HIGH, 0.85),
(r"\b(hopeless|no hope|give up on (life|everything))\b", CrisisSeverity.MODERATE, 0.70),
]
_COMPILED = [(re.compile(p, re.IGNORECASE), s, c) for p, s, c in _CRISIS_PATTERNS]
def detect_crisis(message: str) -> CrisisDetection:
if not message or not isinstance(message, str):
return CrisisDetection(False, CrisisSeverity.NONE, [], 0.0)
matched = []
max_severity = CrisisSeverity.NONE
max_confidence = 0.0
for pattern, severity, confidence in _COMPILED:
if pattern.search(message):
matched.append(pattern.pattern)
if confidence > max_confidence:
max_confidence = confidence
if severity.value > max_severity.value:
max_severity = severity
detected = len(matched) > 0
return CrisisDetection(detected, max_severity, matched, max_confidence)
def get_crisis_response(severity: str = "CRITICAL") -> str:
return """I hear you, and you are not alone.
**Please reach out for help right now:**
📞 **Call 988** — Suicide & Crisis Lifeline (24/7)
💬 **Text HOME to 988** — Crisis text line
🌐 **Chat:** 988lifeline.org/chat
🇪🇸 **Spanish:** 1-888-628-9454
🎖️ **Veterans:** 988, then press 1
You don\'t have to face this alone. Please hold on."""
def check_crisis(user_message: str) -> Optional[str]:
detection = detect_crisis(user_message)
if detection.detected:
logger.warning("Crisis detected: severity=%s", detection.severity.name)
return get_crisis_response(detection.severity.name)
return None

View File

@@ -0,0 +1,223 @@
"""
Session Model Metadata — Persist model context info per session
When a session switches models mid-conversation, context length and
token budget need to be updated to prevent silent truncation.
Issue: #741
"""
import json
import logging
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Any, Dict, Optional
logger = logging.getLogger(__name__)
HERMES_HOME = Path.home() / ".hermes"
# Common model context lengths (tokens)
KNOWN_CONTEXT_LENGTHS = {
# Anthropic
"claude-opus-4-6": 200000,
"claude-sonnet-4": 200000,
"claude-3.5-sonnet": 200000,
"claude-3-haiku": 200000,
# OpenAI
"gpt-4o": 128000,
"gpt-4-turbo": 128000,
"gpt-4": 8192,
"gpt-3.5-turbo": 16385,
# Nous / open models
"hermes-3-llama-3.1-405b": 131072,
"hermes-3-llama-3.1-70b": 131072,
"deepseek-r1": 131072,
"deepseek-v3": 131072,
# Local
"llama-3.1-8b": 131072,
"llama-3.1-70b": 131072,
"qwen-2.5-72b": 131072,
# Xiaomi
"mimo-v2-pro": 131072,
"mimo-v2-flash": 131072,
# Defaults
"default": 4096,
}
# Reserve tokens for system prompt, response, and overhead
TOKEN_RESERVE = 2000
@dataclass
class ModelMetadata:
"""Metadata for a model in a session."""
model: str
provider: str
context_length: int
available_for_input: int # context_length - reserve
current_tokens_used: int = 0
@property
def remaining_tokens(self) -> int:
"""Tokens remaining for new input."""
return max(0, self.available_for_input - self.current_tokens_used)
@property
def utilization_pct(self) -> float:
"""Percentage of context used."""
if self.available_for_input == 0:
return 0.0
return (self.current_tokens_used / self.available_for_input) * 100
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
def get_context_length(model: str) -> int:
"""Get context length for a model."""
model_lower = model.lower()
# Check exact match
if model_lower in KNOWN_CONTEXT_LENGTHS:
return KNOWN_CONTEXT_LENGTHS[model_lower]
# Check partial match
for key, length in KNOWN_CONTEXT_LENGTHS.items():
if key in model_lower:
return length
return KNOWN_CONTEXT_LENGTHS["default"]
def create_metadata(model: str, provider: str = "", current_tokens: int = 0) -> ModelMetadata:
"""Create model metadata."""
context_length = get_context_length(model)
available = max(0, context_length - TOKEN_RESERVE)
return ModelMetadata(
model=model,
provider=provider,
context_length=context_length,
available_for_input=available,
current_tokens_used=current_tokens
)
def check_model_switch(
old_model: str,
new_model: str,
current_tokens: int
) -> Dict[str, Any]:
"""
Check impact of switching models mid-session.
Returns:
Dict with switch analysis including warnings
"""
old_ctx = get_context_length(old_model)
new_ctx = get_context_length(new_model)
old_available = old_ctx - TOKEN_RESERVE
new_available = new_ctx - TOKEN_RESERVE
result = {
"old_model": old_model,
"new_model": new_model,
"old_context": old_ctx,
"new_context": new_ctx,
"current_tokens": current_tokens,
"fits_in_new": current_tokens <= new_available,
"truncation_needed": max(0, current_tokens - new_available),
"warning": None,
}
if not result["fits_in_new"]:
result["warning"] = (
f"Switching to {new_model} ({new_ctx:,} ctx) with {current_tokens:,} tokens "
f"will truncate {result['truncation_needed']:,} tokens of history. "
f"Consider starting a new session."
)
if new_ctx < old_ctx:
reduction = old_ctx - new_ctx
result["warning"] = (
f"New model has {reduction:,} fewer tokens of context. "
f"({old_ctx:,} -> {new_ctx:,})"
)
return result
class SessionModelTracker:
"""Track model metadata for a session."""
def __init__(self, session_id: str):
self.session_id = session_id
self.metadata: Optional[ModelMetadata] = None
self.history: list = [] # Model switch history
def set_model(self, model: str, provider: str = "", tokens_used: int = 0):
"""Set the current model for the session."""
old_model = self.metadata.model if self.metadata else None
self.metadata = create_metadata(model, provider, tokens_used)
# Record switch in history
if old_model and old_model != model:
self.history.append({
"from": old_model,
"to": model,
"tokens_at_switch": tokens_used,
"context_length": self.metadata.context_length
})
logger.info(
"Session %s: model=%s context=%d available=%d",
self.session_id[:12], model,
self.metadata.context_length,
self.metadata.available_for_input
)
def update_tokens(self, tokens: int):
"""Update current token usage."""
if self.metadata:
self.metadata.current_tokens_used = tokens
def get_remaining(self) -> int:
"""Get remaining tokens."""
if not self.metadata:
return 0
return self.metadata.remaining_tokens
def can_fit(self, additional_tokens: int) -> bool:
"""Check if additional tokens fit in context."""
if not self.metadata:
return False
return self.metadata.remaining_tokens >= additional_tokens
def get_warning(self) -> Optional[str]:
"""Get warning if context is running low."""
if not self.metadata:
return None
util = self.metadata.utilization_pct
if util > 90:
return f"Context {util:.0f}% full. Consider compression or new session."
if util > 75:
return f"Context {util:.0f}% full."
return None
def to_dict(self) -> Dict[str, Any]:
"""Export state."""
return {
"session_id": self.session_id,
"metadata": self.metadata.to_dict() if self.metadata else None,
"history": self.history
}

View File

@@ -7792,30 +7792,6 @@ class AIAgent:
if isinstance(persist_user_message, str):
persist_user_message = _sanitize_surrogates(persist_user_message)
# Crisis detection — check user message for crisis signals (#695)
# If crisis detected, return 988 Lifeline response immediately
if isinstance(user_message, str) and user_message.strip():
try:
from agent.crisis_hook import check_crisis
_crisis_response = check_crisis(user_message)
if _crisis_response:
logger.warning("Crisis detected in session %s", getattr(self, 'session_id', 'unknown'))
return {
"response": _crisis_response,
"messages": self.messages + [
{"role": "user", "content": user_message},
{"role": "assistant", "content": _crisis_response},
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
"model": self.model,
"crisis_detected": True,
}
except ImportError:
pass
except Exception as _e:
logger.debug("Crisis detection error: %s", _e)
# Store stream callback for _interruptible_api_call to pick up
self._stream_callback = stream_callback
self._persist_user_message_idx = None

View File

@@ -1,36 +0,0 @@
"""
Tests for crisis hook integration (#695).
"""
import pytest
from agent.crisis_hook import detect_crisis, get_crisis_response, check_crisis, CrisisSeverity
class TestCrisisDetection:
def test_detects_direct_suicide(self):
result = detect_crisis("I want to kill myself")
assert result.detected is True
assert result.severity == CrisisSeverity.CRITICAL
def test_no_crisis_on_normal(self):
result = detect_crisis("Hello, how are you?")
assert result.detected is False
def test_crisis_response_has_988(self):
response = get_crisis_response("CRITICAL")
assert "988" in response
assert "988lifeline.org/chat" in response
assert "1-888-628-9454" in response
def test_check_crisis_returns_response(self):
response = check_crisis("I want to die")
assert response is not None
assert "988" in response
def test_check_crisis_returns_none_for_normal(self):
response = check_crisis("Hello")
assert response is None
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,105 @@
"""
Tests for session model metadata
Issue: #741
"""
import unittest
from agent.session_model_metadata import (
get_context_length,
create_metadata,
check_model_switch,
SessionModelTracker,
)
class TestContextLength(unittest.TestCase):
def test_known_model(self):
ctx = get_context_length("claude-opus-4-6")
self.assertEqual(ctx, 200000)
def test_partial_match(self):
ctx = get_context_length("anthropic/claude-sonnet-4")
self.assertEqual(ctx, 200000)
def test_unknown_model(self):
ctx = get_context_length("unknown-model-xyz")
self.assertEqual(ctx, 4096)
class TestModelMetadata(unittest.TestCase):
def test_create(self):
meta = create_metadata("gpt-4o", "openai", 1000)
self.assertEqual(meta.context_length, 128000)
self.assertEqual(meta.current_tokens_used, 1000)
self.assertGreater(meta.remaining_tokens, 0)
def test_utilization(self):
meta = create_metadata("gpt-4o", "openai", 64000)
self.assertAlmostEqual(meta.utilization_pct, 50.0, delta=1)
class TestModelSwitch(unittest.TestCase):
def test_safe_switch(self):
result = check_model_switch("gpt-3.5-turbo", "gpt-4o", 5000)
self.assertTrue(result["fits_in_new"])
self.assertIsNone(result["warning"])
def test_truncation_warning(self):
result = check_model_switch("gpt-4o", "gpt-3.5-turbo", 20000)
self.assertFalse(result["fits_in_new"])
self.assertIsNotNone(result["warning"])
self.assertIn("truncate", result["warning"].lower())
def test_downgrade_warning(self):
result = check_model_switch("claude-opus-4-6", "gpt-4", 5000)
self.assertIsNotNone(result["warning"])
class TestSessionModelTracker(unittest.TestCase):
def test_set_model(self):
tracker = SessionModelTracker("test")
tracker.set_model("gpt-4o", "openai")
self.assertEqual(tracker.metadata.model, "gpt-4o")
def test_update_tokens(self):
tracker = SessionModelTracker("test")
tracker.set_model("gpt-4o")
tracker.update_tokens(5000)
self.assertEqual(tracker.metadata.current_tokens_used, 5000)
def test_remaining(self):
tracker = SessionModelTracker("test")
tracker.set_model("gpt-4o")
tracker.update_tokens(10000)
self.assertGreater(tracker.get_remaining(), 0)
def test_can_fit(self):
tracker = SessionModelTracker("test")
tracker.set_model("gpt-4o")
tracker.update_tokens(10000)
self.assertTrue(tracker.can_fit(5000))
self.assertFalse(tracker.can_fit(200000))
def test_warning_low_context(self):
tracker = SessionModelTracker("test")
tracker.set_model("gpt-4o")
tracker.update_tokens(115000) # ~90% used
warning = tracker.get_warning()
self.assertIsNotNone(warning)
def test_model_switch_history(self):
tracker = SessionModelTracker("test")
tracker.set_model("gpt-4o", "openai")
tracker.update_tokens(5000)
tracker.set_model("claude-opus-4-6", "anthropic")
self.assertEqual(len(tracker.history), 1)
self.assertEqual(tracker.history[0]["from"], "gpt-4o")
if __name__ == "__main__":
unittest.main()