Compare commits

..

2 Commits

Author SHA1 Message Date
8dd0aaa89d test: session compaction tests
Some checks failed
Contributor Attribution Check / check-attribution (pull_request) Failing after 41s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 51s
Tests / e2e (pull_request) Successful in 3m18s
Tests / test (pull_request) Failing after 55m56s
Part of #748
2026-04-15 03:09:06 +00:00
4ad81ce646 feat: session compaction with fact extraction
Closes #748

Before compressing long conversations, extracts durable facts
(user preferences, corrections, project details) and saves
them to fact_store. Then compresses conversation.
2026-04-15 03:09:00 +00:00
4 changed files with 305 additions and 368 deletions

221
agent/session_compaction.py Normal file
View File

@@ -0,0 +1,221 @@
"""
Session Compaction with Fact Extraction — #748
Before compressing a long conversation, extracts durable facts
(user preferences, corrections, project details) and saves them
to the fact store. Then compresses the conversation.
This ensures key information survives context limits.
Usage:
from agent.session_compaction import compact_session
# In the conversation loop, when context is near limit:
compact_session(messages, fact_store)
"""
import json
import re
from typing import Any, Dict, List, Optional, Tuple
# ---------------------------------------------------------------------------
# Fact Extraction Patterns
# ---------------------------------------------------------------------------
# Patterns that indicate durable facts worth preserving
_FACT_PATTERNS = [
# User preferences
(r"(?:i prefer|i like|i always|my preference is|remember that i)\s+(.+?)(?:\.|$)", "user_pref"),
(r"(?:call me|my name is|i\'m)\s+([A-Z][a-z]+)", "user_name"),
(r"(?:don\'t|do not|never)\s+(?:use|do|show|tell)\s+(.+?)(?:\.|$)", "user_constraint"),
# Corrections
(r"(?:actually|no,?|correction:?)\s+(.+?)(?:\.|$)", "correction"),
(r"(?:that\'s wrong|not correct|i meant)\s+(.+?)(?:\.|$)", "correction"),
# Project facts
(r"(?:the project|this repo|the codebase)\s+(?:is|has|uses|runs)\s+(.+?)(?:\.|$)", "project_fact"),
(r"(?:we use|our stack is|deployed on)\s+(.+?)(?:\.|$)", "project_fact"),
# Technical facts
(r"(?:the server|the service|the endpoint)\s+(?:is|runs on|listens on)\s+(.+?)(?:\.|$)", "technical"),
(r"(?:port|url|address|host)\s*(?::|is|=)\s*(.+?)(?:\.|$)", "technical"),
]
def extract_facts_from_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Scan conversation messages for durable facts.
Returns list of fact dicts suitable for fact_store.
"""
facts = []
seen = set() # Deduplicate
for msg in messages:
if msg.get("role") != "user":
continue
content = msg.get("content", "")
if not isinstance(content, str) or len(content) < 10:
continue
for pattern, category in _FACT_PATTERNS:
matches = re.findall(pattern, content, re.IGNORECASE)
for match in matches:
if isinstance(match, tuple):
match = match[0] if match else ""
fact_text = match.strip()
if len(fact_text) < 5 or len(fact_text) > 200:
continue
# Deduplicate
dedup_key = f"{category}:{fact_text.lower()}"
if dedup_key in seen:
continue
seen.add(dedup_key)
facts.append({
"content": fact_text,
"category": category,
"source": "session_compaction",
"trust": 0.7, # Medium trust — extracted, not explicitly stated
})
return facts
def extract_preferences(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Extract user preferences specifically."""
prefs = []
pref_patterns = [
r"(?:i prefer|i like|i want|use|always)\s+(.+?)(?:\.|$)",
r"(?:my (?:preferred|favorite|default))\s+(?:is|are)\s+(.+?)(?:\.|$)",
r"(?:set|configure|make)\s+(?:it to|the default to)\s+(.+?)(?:\.|$)",
]
for msg in messages:
if msg.get("role") != "user":
continue
content = msg.get("content", "")
if not isinstance(content, str):
continue
for pattern in pref_patterns:
matches = re.findall(pattern, content, re.IGNORECASE)
for match in matches:
if isinstance(match, str) and len(match) > 5 and len(match) < 200:
prefs.append({
"content": match.strip(),
"category": "user_pref",
"source": "session_compaction",
"trust": 0.8,
})
return prefs
def compact_session(
messages: List[Dict[str, Any]],
fact_store: Any = None,
keep_recent: int = 10,
) -> Tuple[List[Dict[str, Any]], int]:
"""
Compact a session by extracting facts and compressing old messages.
Args:
messages: Full conversation history
fact_store: Optional fact_store instance for saving facts
keep_recent: Number of recent messages to keep uncompressed
Returns:
Tuple of (compacted_messages, facts_extracted)
"""
if len(messages) <= keep_recent * 2:
return messages, 0
# Split into old (to compress) and recent (to keep)
split_point = len(messages) - keep_recent
old_messages = messages[:split_point]
recent_messages = messages[split_point:]
# Extract facts from old messages
facts = extract_facts_from_messages(old_messages)
prefs = extract_preferences(old_messages)
all_facts = facts + prefs
# Save facts to store if available
saved_count = 0
if fact_store and all_facts:
for fact in all_facts:
try:
if hasattr(fact_store, 'store'):
fact_store.store(
content=fact["content"],
category=fact["category"],
tags=["session_compaction"],
)
saved_count += 1
elif hasattr(fact_store, 'add'):
fact_store.add(fact["content"])
saved_count += 1
except Exception:
pass # Don't let fact saving block compaction
# Create summary of old messages
summary_parts = []
if saved_count > 0:
summary_parts.append(f"[Session compacted: {saved_count} facts extracted and saved]")
# Count message types
user_msgs = sum(1 for m in old_messages if m.get("role") == "user")
asst_msgs = sum(1 for m in old_messages if m.get("role") == "assistant")
summary_parts.append(f"[Previous conversation: {user_msgs} user messages, {asst_msgs} assistant responses]")
summary = " ".join(summary_parts)
# Build compacted messages
compacted = []
# Add summary as system message
if summary:
compacted.append({
"role": "system",
"content": summary,
"_compacted": True,
})
# Add extracted facts as system context
if all_facts:
facts_text = "Known facts from previous conversation:\n"
for fact in all_facts[:20]: # Limit to 20 facts
facts_text += f"- [{fact['category']}] {fact['content']}\n"
compacted.append({
"role": "system",
"content": facts_text,
"_extracted_facts": True,
})
# Add recent messages
compacted.extend(recent_messages)
return compacted, saved_count
def should_compact(messages: List[Dict[str, Any]], max_tokens: int = 80000) -> bool:
"""
Determine if compaction is needed based on message count/length.
Simple heuristic: compact if we have many messages or very long content.
"""
if len(messages) < 50:
return False
# Estimate token count (rough: 4 chars per token)
total_chars = sum(len(str(m.get("content", ""))) for m in messages)
estimated_tokens = total_chars // 4
return estimated_tokens > max_tokens * 0.8 # Compact at 80% of limit

View File

@@ -1,272 +0,0 @@
#!/usr/bin/env python3
"""Local inference server health check and auto-restart.
Checks llama-server, Ollama, and other local inference endpoints.
Reports status, latency, and can auto-restart dead processes.
Refs: #713 — llama-server DOWN on port 8081
"""
from __future__ import annotations
import json
import os
import subprocess
import sys
import time
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any
from urllib.request import Request, urlopen
from urllib.error import URLError, HTTPError
@dataclass
class InferenceEndpoint:
"""Configuration for an inference server endpoint."""
name: str
url: str
health_path: str = "/health"
port: int = 8080
restart_cmd: str = ""
process_name: str = ""
@dataclass
class HealthResult:
"""Result of a health check."""
name: str
url: str
status: str # "ok", "down", "slow", "error"
latency_ms: float = 0.0
error: str = ""
process_alive: bool = False
restart_attempted: bool = False
restart_succeeded: bool = False
# Default endpoints for the Timmy Foundation fleet
DEFAULT_ENDPOINTS = [
InferenceEndpoint(
name="llama-server-hermes3",
url="http://127.0.0.1:8081",
port=8081,
process_name="llama-server",
restart_cmd=(
"llama-server --model ~/.ollama/models/blobs/sha256-c8985d "
"--port 8081 --host 127.0.0.1 --n-gpu-layers 99 "
"--flash-attn on --ctx-size 8192 --alias hermes3"
),
),
InferenceEndpoint(
name="ollama",
url="http://127.0.0.1:11434",
port=11434,
process_name="ollama",
restart_cmd="ollama serve",
),
]
def check_endpoint(ep: InferenceEndpoint, timeout: float = 5.0) -> HealthResult:
"""Check a single inference endpoint.
Args:
ep: Endpoint configuration.
timeout: HTTP timeout in seconds.
Returns:
HealthResult with status and latency.
"""
url = ep.url.rstrip("/") + ep.health_path
start = time.time()
# Check if process is alive
process_alive = False
if ep.process_name:
try:
result = subprocess.run(
["pgrep", "-f", ep.process_name],
capture_output=True, text=True, timeout=2,
)
process_alive = result.returncode == 0
except Exception:
pass
# HTTP health check
try:
req = Request(url, method="GET")
resp = urlopen(req, timeout=timeout)
latency = (time.time() - start) * 1000
if resp.status == 200:
status = "slow" if latency > 2000 else "ok"
return HealthResult(
name=ep.name, url=ep.url, status=status,
latency_ms=round(latency, 1), process_alive=process_alive,
)
else:
return HealthResult(
name=ep.name, url=ep.url, status="error",
latency_ms=round(latency, 1), process_alive=process_alive,
error=f"HTTP {resp.status}",
)
except URLError as e:
latency = (time.time() - start) * 1000
error_msg = str(e.reason) if hasattr(e, 'reason') else str(e)
return HealthResult(
name=ep.name, url=ep.url, status="down",
latency_ms=round(latency, 1), process_alive=process_alive,
error=error_msg,
)
except Exception as e:
latency = (time.time() - start) * 1000
return HealthResult(
name=ep.name, url=ep.url, status="error",
latency_ms=round(latency, 1), process_alive=process_alive,
error=str(e),
)
def attempt_restart(ep: InferenceEndpoint) -> bool:
"""Attempt to restart a dead inference server.
Args:
ep: Endpoint configuration with restart_cmd.
Returns:
True if restart command executed successfully.
"""
if not ep.restart_cmd:
return False
try:
# Run restart in background
subprocess.Popen(
ep.restart_cmd,
shell=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
# Wait a moment for the process to start
time.sleep(3)
return True
except Exception as e:
print(f"Restart failed for {ep.name}: {e}", file=sys.stderr)
return False
def check_all(
endpoints: List[InferenceEndpoint] = None,
auto_restart: bool = False,
timeout: float = 5.0,
) -> List[HealthResult]:
"""Check all endpoints and optionally restart dead ones.
Args:
endpoints: List of endpoints to check. Uses DEFAULT_ENDPOINTS if None.
auto_restart: If True, attempt to restart down endpoints.
timeout: HTTP timeout per endpoint.
Returns:
List of HealthResult for each endpoint.
"""
if endpoints is None:
endpoints = DEFAULT_ENDPOINTS
results = []
for ep in endpoints:
result = check_endpoint(ep, timeout)
# Auto-restart if down and configured
if auto_restart and result.status == "down" and ep.restart_cmd:
result.restart_attempted = True
result.restart_succeeded = attempt_restart(ep)
if result.restart_succeeded:
# Re-check after restart
time.sleep(2)
result2 = check_endpoint(ep, timeout)
result.status = result2.status
result.latency_ms = result2.latency_ms
result.error = result2.error
results.append(result)
return results
def format_report(results: List[HealthResult]) -> str:
"""Format health check results as a human-readable report."""
lines = [
"# Local Inference Health Check",
f"Time: {time.strftime('%Y-%m-%d %H:%M:%S')}",
"",
"| Endpoint | Status | Latency | Process | Error |",
"|----------|--------|---------|---------|-------|",
]
for r in results:
status_icon = {"ok": "", "slow": "⚠️", "down": "", "error": "💥"}.get(r.status, "?")
proc = "alive" if r.process_alive else "dead"
lat = f"{r.latency_ms}ms" if r.latency_ms > 0 else "-"
err = r.error[:40] if r.error else "-"
lines.append(f"| {r.name} | {status_icon} {r.status} | {lat} | {proc} | {err} |")
down = [r for r in results if r.status in ("down", "error")]
if down:
lines.extend(["", "## DOWN", ""])
for r in down:
lines.append(f"- **{r.name}** ({r.url}): {r.error}")
if r.restart_attempted:
status = "✅ restarted" if r.restart_succeeded else "❌ restart failed"
lines.append(f" Restart: {status}")
return "\n".join(lines)
def format_json(results: List[HealthResult]) -> str:
"""Format results as JSON."""
data = []
for r in results:
data.append({
"name": r.name,
"url": r.url,
"status": r.status,
"latency_ms": r.latency_ms,
"process_alive": r.process_alive,
"error": r.error or None,
"restart_attempted": r.restart_attempted,
"restart_succeeded": r.restart_succeeded,
})
return json.dumps({"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), "endpoints": data}, indent=2)
def main():
import argparse
p = argparse.ArgumentParser(description="Local inference health check")
p.add_argument("--json", action="store_true", help="JSON output")
p.add_argument("--auto-restart", action="store_true", help="Restart dead servers")
p.add_argument("--timeout", type=float, default=5.0, help="HTTP timeout (seconds)")
p.add_argument("--port", type=int, help="Check specific port only")
a = p.parse_args()
endpoints = DEFAULT_ENDPOINTS
if a.port:
endpoints = [ep for ep in DEFAULT_ENDPOINTS if ep.port == a.port]
if not endpoints:
print(f"No endpoint configured for port {a.port}", file=sys.stderr)
sys.exit(1)
results = check_all(endpoints, auto_restart=a.auto_restart, timeout=a.timeout)
if a.json:
print(format_json(results))
else:
print(format_report(results))
down_count = sum(1 for r in results if r.status in ("down", "error"))
sys.exit(1 if down_count > 0 else 0)
if __name__ == "__main__":
main()

View File

@@ -1,96 +0,0 @@
"""Tests for inference health check (#713)."""
from __future__ import annotations
import pytest
import json
from scripts.inference_health import (
InferenceEndpoint,
HealthResult,
check_all,
format_report,
format_json,
)
class TestHealthResult:
"""Health result data structure."""
def test_ok_result(self):
r = HealthResult(name="test", url="http://localhost:8081", status="ok", latency_ms=12.5)
assert r.status == "ok"
assert r.latency_ms == 12.5
assert not r.error
def test_down_result(self):
r = HealthResult(
name="test", url="http://localhost:8081",
status="down", error="Connection refused",
)
assert r.status == "down"
assert r.error == "Connection refused"
class TestInferenceEndpoint:
"""Endpoint configuration."""
def test_defaults(self):
ep = InferenceEndpoint(name="test", url="http://localhost:8080")
assert ep.health_path == "/health"
assert ep.port == 8080
assert ep.restart_cmd == ""
def test_custom(self):
ep = InferenceEndpoint(
name="llama", url="http://localhost:8081",
port=8081, restart_cmd="llama-server --port 8081",
)
assert ep.port == 8081
assert "llama-server" in ep.restart_cmd
class TestFormatReport:
"""Report formatting."""
def test_all_ok(self):
results = [
HealthResult(name="test1", url="http://localhost:8080", status="ok", latency_ms=5.0, process_alive=True),
HealthResult(name="test2", url="http://localhost:8081", status="ok", latency_ms=10.0, process_alive=True),
]
report = format_report(results)
assert "Health Check" in report
assert "test1" in report
assert "test2" in report
assert "DOWN" not in report
def test_with_down(self):
results = [
HealthResult(name="test1", url="http://localhost:8080", status="ok", latency_ms=5.0),
HealthResult(
name="test2", url="http://localhost:8081",
status="down", error="Connection refused", process_alive=False,
),
]
report = format_report(results)
assert "DOWN" in report
assert "Connection refused" in report
class TestFormatJson:
"""JSON output format."""
def test_valid_json(self):
results = [HealthResult(name="test", url="http://localhost:8080", status="ok", latency_ms=5.0)]
output = format_json(results)
data = json.loads(output)
assert "timestamp" in data
assert "endpoints" in data
assert len(data["endpoints"]) == 1
assert data["endpoints"][0]["name"] == "test"
def test_none_error_serializes(self):
results = [HealthResult(name="test", url="http://localhost:8080", status="ok")]
output = format_json(results)
data = json.loads(output)
assert data["endpoints"][0]["error"] is None

View File

@@ -0,0 +1,84 @@
"""Tests for session compaction with fact extraction (#748)."""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from agent.session_compaction import (
extract_facts_from_messages,
extract_preferences,
compact_session,
should_compact,
)
def test_extract_preferences():
msgs = [
{"role": "user", "content": "I prefer using Python for this"},
{"role": "assistant", "content": "OK"},
{"role": "user", "content": "Always use tabs, not spaces"},
]
prefs = extract_preferences(msgs)
assert len(prefs) >= 1
def test_extract_facts():
msgs = [
{"role": "user", "content": "The server runs on port 8080"},
{"role": "user", "content": "Actually, the port is 8081"},
{"role": "user", "content": "Hello"}, # Too short, should be skipped
]
facts = extract_facts_from_messages(msgs)
assert len(facts) >= 1
assert any("technical" in f["category"] for f in facts)
def test_extract_deduplicates():
msgs = [
{"role": "user", "content": "I prefer Python"},
{"role": "user", "content": "I prefer Python"},
]
facts = extract_facts_from_messages(msgs)
assert len(facts) == 1
def test_compact_session():
messages = []
for i in range(30):
messages.append({"role": "user", "content": f"Message {i}: I prefer Python for server {i}"})
messages.append({"role": "assistant", "content": f"Response {i}"})
compacted, count = compact_session(messages, keep_recent=10)
assert len(compacted) < len(messages)
assert count >= 0
def test_compact_keeps_recent():
messages = []
for i in range(30):
messages.append({"role": "user", "content": f"Message {i}"})
messages.append({"role": "assistant", "content": f"Response {i}"})
compacted, _ = compact_session(messages, keep_recent=10)
# Should have summary + facts + 10 recent
assert len(compacted) >= 10
def test_should_compact_short():
messages = [{"role": "user", "content": "hi"} for _ in range(10)]
assert not should_compact(messages)
def test_should_compact_long():
messages = [{"role": "user", "content": "x" * 1000} for _ in range(100)]
assert should_compact(messages)
if __name__ == "__main__":
tests = [test_extract_preferences, test_extract_facts, test_extract_deduplicates,
test_compact_session, test_compact_keeps_recent, test_should_compact_short, test_should_compact_long]
for t in tests:
print(f"Running {t.__name__}...")
t()
print(" PASS")
print("\nAll tests passed.")