Some checks failed
Contributor Attribution Check / check-attribution (pull_request) Successful in 29s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 33s
Tests / e2e (pull_request) Successful in 3m26s
Tests / test (pull_request) Failing after 1h28m50s
Before compressing conversation context, extract durable facts (user preferences, corrections, project details) and save to fact store so they survive compression. New agent/session_compactor.py: - extract_facts_from_messages(): scans user messages for preferences, corrections, project/infra facts using regex - 3 pattern categories: user_pref (5 patterns), correction (3 patterns), project (4 patterns) - ExtractedFact: category, entity, content, confidence, source_turn - save_facts_to_store(): saves to fact store (callback or auto-detect) - extract_and_save_facts(): one-call extraction + persistence - Deduplication by category+content - Skips tool results, short messages, system messages - format_facts_summary(): human-readable summary Tests: tests/test_session_compactor.py (9 tests) Closes #748
232 lines
8.0 KiB
Python
232 lines
8.0 KiB
Python
"""Session compaction with fact extraction.
|
|
|
|
Before compressing conversation context, extracts durable facts
|
|
(user preferences, corrections, project details) and saves them
|
|
to the fact store so they survive compression.
|
|
|
|
Usage:
|
|
from agent.session_compactor import extract_and_save_facts
|
|
facts = extract_and_save_facts(messages)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class ExtractedFact:
|
|
"""A fact extracted from conversation."""
|
|
category: str # "user_pref", "correction", "project", "tool_quirk", "general"
|
|
entity: str # what the fact is about
|
|
content: str # the fact itself
|
|
confidence: float # 0.0-1.0
|
|
source_turn: int # which message turn it came from
|
|
timestamp: float = 0.0
|
|
|
|
|
|
# Patterns that indicate user preferences
|
|
_PREFERENCE_PATTERNS = [
|
|
(r"(?:I|we) (?:prefer|like|want|need) (.+?)(?:\.|$)", "preference"),
|
|
(r"(?:always|never) (?:use|do|run|deploy) (.+?)(?:\.|$)", "preference"),
|
|
(r"(?:my|our) (?:default|preferred|usual) (.+?) (?:is|are) (.+?)(?:\.|$)", "preference"),
|
|
(r"(?:make sure|ensure|remember) (?:to|that) (.+?)(?:\.|$)", "instruction"),
|
|
(r"(?:don'?t|do not) (?:ever|ever again) (.+?)(?:\.|$)", "constraint"),
|
|
]
|
|
|
|
# Patterns that indicate corrections
|
|
_CORRECTION_PATTERNS = [
|
|
(r"(?:actually|no[, ]|wait[, ]|correction[: ]|sorry[, ]) (.+)", "correction"),
|
|
(r"(?:I meant|what I meant was|the correct) (.+?)(?:\.|$)", "correction"),
|
|
(r"(?:it'?s|its) (?:not|shouldn'?t be|wrong) (.+?)(?:\.|$)", "correction"),
|
|
]
|
|
|
|
# Patterns that indicate project/tool facts
|
|
_PROJECT_PATTERNS = [
|
|
(r"(?:the |our )?(?:project|repo|codebase|code) (?:is|uses|needs|requires) (.+?)(?:\.|$)", "project"),
|
|
(r"(?:deploy|push|commit) (?:to|on) (.+?)(?:\.|$)", "project"),
|
|
(r"(?:this|that|the) (?:server|host|machine|VPS) (?:is|runs|has) (.+?)(?:\.|$)", "infrastructure"),
|
|
(r"(?:model|provider|engine) (?:is|should be|needs to be) (.+?)(?:\.|$)", "config"),
|
|
]
|
|
|
|
|
|
def extract_facts_from_messages(messages: List[Dict[str, Any]]) -> List[ExtractedFact]:
|
|
"""Extract durable facts from conversation messages.
|
|
|
|
Scans user messages for preferences, corrections, project facts,
|
|
and infrastructure details that should survive compression.
|
|
"""
|
|
facts = []
|
|
seen_contents = set()
|
|
|
|
for turn_idx, msg in enumerate(messages):
|
|
role = msg.get("role", "")
|
|
content = msg.get("content", "")
|
|
|
|
# Only scan user messages and assistant responses with corrections
|
|
if role not in ("user", "assistant"):
|
|
continue
|
|
if not content or not isinstance(content, str):
|
|
continue
|
|
if len(content) < 10:
|
|
continue
|
|
|
|
# Skip tool results and system messages
|
|
if role == "assistant" and msg.get("tool_calls"):
|
|
continue
|
|
|
|
extracted = _extract_from_text(content, turn_idx, role)
|
|
|
|
# Deduplicate by content
|
|
for fact in extracted:
|
|
key = f"{fact.category}:{fact.content[:100]}"
|
|
if key not in seen_contents:
|
|
seen_contents.add(key)
|
|
facts.append(fact)
|
|
|
|
return facts
|
|
|
|
|
|
def _extract_from_text(text: str, turn_idx: int, role: str) -> List[ExtractedFact]:
|
|
"""Extract facts from a single text block."""
|
|
facts = []
|
|
timestamp = time.time()
|
|
|
|
# Clean text for pattern matching
|
|
clean = text.strip()
|
|
|
|
# User preference patterns (from user messages)
|
|
if role == "user":
|
|
for pattern, subcategory in _PREFERENCE_PATTERNS:
|
|
for match in re.finditer(pattern, clean, re.IGNORECASE):
|
|
content = match.group(1).strip() if match.lastindex else match.group(0).strip()
|
|
if len(content) > 5:
|
|
facts.append(ExtractedFact(
|
|
category=f"user_pref.{subcategory}",
|
|
entity="user",
|
|
content=content[:200],
|
|
confidence=0.7,
|
|
source_turn=turn_idx,
|
|
timestamp=timestamp,
|
|
))
|
|
|
|
# Correction patterns (from user messages)
|
|
if role == "user":
|
|
for pattern, subcategory in _CORRECTION_PATTERNS:
|
|
for match in re.finditer(pattern, clean, re.IGNORECASE):
|
|
content = match.group(1).strip() if match.lastindex else match.group(0).strip()
|
|
if len(content) > 5:
|
|
facts.append(ExtractedFact(
|
|
category=f"correction.{subcategory}",
|
|
entity="user",
|
|
content=content[:200],
|
|
confidence=0.8,
|
|
source_turn=turn_idx,
|
|
timestamp=timestamp,
|
|
))
|
|
|
|
# Project/infrastructure patterns (from both user and assistant)
|
|
for pattern, subcategory in _PROJECT_PATTERNS:
|
|
for match in re.finditer(pattern, clean, re.IGNORECASE):
|
|
content = match.group(1).strip() if match.lastindex else match.group(0).strip()
|
|
if len(content) > 5:
|
|
facts.append(ExtractedFact(
|
|
category=f"project.{subcategory}",
|
|
entity=subcategory,
|
|
content=content[:200],
|
|
confidence=0.6,
|
|
source_turn=turn_idx,
|
|
timestamp=timestamp,
|
|
))
|
|
|
|
return facts
|
|
|
|
|
|
def save_facts_to_store(facts: List[ExtractedFact], fact_store_fn=None) -> int:
|
|
"""Save extracted facts to the fact store.
|
|
|
|
Args:
|
|
facts: List of extracted facts.
|
|
fact_store_fn: Optional callable(category, entity, content, trust).
|
|
If None, uses the holographic fact store if available.
|
|
|
|
Returns:
|
|
Number of facts saved.
|
|
"""
|
|
saved = 0
|
|
|
|
if fact_store_fn:
|
|
for fact in facts:
|
|
try:
|
|
fact_store_fn(
|
|
category=fact.category,
|
|
entity=fact.entity,
|
|
content=fact.content,
|
|
trust=fact.confidence,
|
|
)
|
|
saved += 1
|
|
except Exception as e:
|
|
logger.debug("Failed to save fact: %s", e)
|
|
else:
|
|
# Try holographic fact store
|
|
try:
|
|
from fact_store import fact_store as _fs
|
|
for fact in facts:
|
|
try:
|
|
_fs(
|
|
action="add",
|
|
content=fact.content,
|
|
category=fact.category,
|
|
tags=fact.entity,
|
|
trust_delta=fact.confidence - 0.5,
|
|
)
|
|
saved += 1
|
|
except Exception as e:
|
|
logger.debug("Failed to save fact via fact_store: %s", e)
|
|
except ImportError:
|
|
logger.debug("fact_store not available — facts not persisted")
|
|
|
|
return saved
|
|
|
|
|
|
def extract_and_save_facts(
|
|
messages: List[Dict[str, Any]],
|
|
fact_store_fn=None,
|
|
) -> Tuple[List[ExtractedFact], int]:
|
|
"""Extract facts from messages and save them.
|
|
|
|
Returns (extracted_facts, saved_count).
|
|
"""
|
|
facts = extract_facts_from_messages(messages)
|
|
if facts:
|
|
logger.info("Extracted %d facts from conversation", len(facts))
|
|
saved = save_facts_to_store(facts, fact_store_fn)
|
|
logger.info("Saved %d/%d facts to store", saved, len(facts))
|
|
else:
|
|
saved = 0
|
|
return facts, saved
|
|
|
|
|
|
def format_facts_summary(facts: List[ExtractedFact]) -> str:
|
|
"""Format extracted facts as a readable summary."""
|
|
if not facts:
|
|
return "No facts extracted."
|
|
|
|
by_category = {}
|
|
for f in facts:
|
|
by_category.setdefault(f.category, []).append(f)
|
|
|
|
lines = [f"Extracted {len(facts)} facts:", ""]
|
|
for cat, cat_facts in sorted(by_category.items()):
|
|
lines.append(f" {cat}:")
|
|
for f in cat_facts:
|
|
lines.append(f" - {f.content[:80]}")
|
|
return "\n".join(lines)
|