Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a9316121a4 |
@@ -1396,8 +1396,6 @@ def normalize_anthropic_response(
|
||||
"tool_use": "tool_calls",
|
||||
"max_tokens": "length",
|
||||
"stop_sequence": "stop",
|
||||
"refusal": "content_filter",
|
||||
"model_context_window_exceeded": "length",
|
||||
}
|
||||
finish_reason = stop_reason_map.get(response.stop_reason, "stop")
|
||||
|
||||
@@ -1411,42 +1409,3 @@ def normalize_anthropic_response(
|
||||
),
|
||||
finish_reason,
|
||||
)
|
||||
|
||||
|
||||
def normalize_anthropic_response_v2(
|
||||
response,
|
||||
strip_tool_prefix: bool = False,
|
||||
) -> "NormalizedResponse":
|
||||
"""Normalize Anthropic response to NormalizedResponse.
|
||||
|
||||
Wraps the existing normalize_anthropic_response() and maps its output
|
||||
to the shared transport types. This allows incremental migration
|
||||
without disturbing the legacy call sites.
|
||||
"""
|
||||
from agent.transports.types import NormalizedResponse, build_tool_call
|
||||
|
||||
assistant_msg, finish_reason = normalize_anthropic_response(response, strip_tool_prefix)
|
||||
|
||||
tool_calls = None
|
||||
if assistant_msg.tool_calls:
|
||||
tool_calls = [
|
||||
build_tool_call(
|
||||
id=tc.id,
|
||||
name=tc.function.name,
|
||||
arguments=tc.function.arguments,
|
||||
)
|
||||
for tc in assistant_msg.tool_calls
|
||||
]
|
||||
|
||||
provider_data = {}
|
||||
if getattr(assistant_msg, "reasoning_details", None):
|
||||
provider_data["reasoning_details"] = assistant_msg.reasoning_details
|
||||
|
||||
return NormalizedResponse(
|
||||
content=assistant_msg.content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
reasoning=getattr(assistant_msg, "reasoning", None),
|
||||
usage=None,
|
||||
provider_data=provider_data or None,
|
||||
)
|
||||
|
||||
@@ -1,197 +1,546 @@
|
||||
"""Session compaction with fact extraction.
|
||||
"""Session compaction with structured fact extraction.
|
||||
|
||||
Before compressing conversation context, extracts durable facts
|
||||
(user preferences, corrections, project details) and saves them
|
||||
to the fact store so they survive compression.
|
||||
|
||||
Usage:
|
||||
from agent.session_compactor import extract_and_save_facts
|
||||
facts = extract_and_save_facts(messages)
|
||||
Before compressing conversation context, extract durable facts with enough
|
||||
structure to survive retrieval: source/provenance, temporal anchors,
|
||||
normalized canonical keys, and contradiction groups.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_DEPLOY_METHOD_RE = re.compile(r"\bdeploy(?:ing)?\s+(?:via|through|with)\s+([A-Za-z0-9_./+-]+)", re.IGNORECASE)
|
||||
_WATCHDOG_CAP_RE = re.compile(
|
||||
r"\b(?:the\s+)?([A-Za-z0-9_-]+(?:\s+watchdog)?)\s+(?:caps|limits)\s+dispatches(?:\s+per\s+cycle)?\s+to\s+([0-9]+)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_PROVIDER_RE = re.compile(
|
||||
r"\bprovider\s+(?:is|should\s+stay|should\s+be|needs\s+to\s+be)\s+([A-Za-z0-9._/-]+)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_MODEL_RE = re.compile(
|
||||
r"\bmodel\s+(?:is|should\s+stay|should\s+be|needs\s+to\s+be)\s+([A-Za-z0-9._:/-]+)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_PORT_RE = re.compile(r"\bport\s+(?:is|should\s+be)\s+([0-9]+)", re.IGNORECASE)
|
||||
_PROJECT_USES_RE = re.compile(r"\b(?:the\s+)?project\s+(?:uses|needs|requires)\s+(.+?)(?:[.!?]|$)", re.IGNORECASE)
|
||||
_PREFERENCE_RE = re.compile(r"\bI\s+(?:prefer|like|want|need)\s+(.+?)(?:[.!?]|$)", re.IGNORECASE)
|
||||
_CONSTRAINT_RE = re.compile(r"\b(?:do\s+not|don't)\s+(?:ever\s+|again\s+)?(.+?)(?:[.!?]|$)", re.IGNORECASE)
|
||||
_DECISION_RE = re.compile(r"\b(?:we|the\s+team)\s+(?:decided|agreed|chose)\s+(?:to\s+)?(.+?)(?:[.!?]|$)", re.IGNORECASE)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedFact:
|
||||
"""A fact extracted from conversation."""
|
||||
category: str # "user_pref", "correction", "project", "tool_quirk", "general"
|
||||
entity: str # what the fact is about
|
||||
content: str # the fact itself
|
||||
confidence: float # 0.0-1.0
|
||||
source_turn: int # which message turn it came from
|
||||
"""A durable fact extracted from conversation."""
|
||||
|
||||
category: str
|
||||
entity: str
|
||||
content: str
|
||||
confidence: float
|
||||
source_turn: int
|
||||
timestamp: float = 0.0
|
||||
source_role: str = "user"
|
||||
source_text: str = ""
|
||||
normalized_content: str = ""
|
||||
canonical_key: str = ""
|
||||
relation: str = "general"
|
||||
contradiction_group: str = ""
|
||||
status: str = "active"
|
||||
provenance: str = ""
|
||||
observed_at: str = ""
|
||||
evidence: List[Dict[str, Any]] = field(default_factory=list)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
# Patterns that indicate user preferences
|
||||
_PREFERENCE_PATTERNS = [
|
||||
(r"(?:I|we) (?:prefer|like|want|need) (.+?)(?:\.|$)", "preference"),
|
||||
(r"(?:always|never) (?:use|do|run|deploy) (.+?)(?:\.|$)", "preference"),
|
||||
(r"(?:my|our) (?:default|preferred|usual) (.+?) (?:is|are) (.+?)(?:\.|$)", "preference"),
|
||||
(r"(?:make sure|ensure|remember) (?:to|that) (.+?)(?:\.|$)", "instruction"),
|
||||
(r"(?:don'?t|do not) (?:ever|ever again) (.+?)(?:\.|$)", "constraint"),
|
||||
]
|
||||
|
||||
# Patterns that indicate corrections
|
||||
_CORRECTION_PATTERNS = [
|
||||
(r"(?:actually|no[, ]|wait[, ]|correction[: ]|sorry[, ]) (.+)", "correction"),
|
||||
(r"(?:I meant|what I meant was|the correct) (.+?)(?:\.|$)", "correction"),
|
||||
(r"(?:it'?s|its) (?:not|shouldn'?t be|wrong) (.+?)(?:\.|$)", "correction"),
|
||||
]
|
||||
|
||||
# Patterns that indicate project/tool facts
|
||||
_PROJECT_PATTERNS = [
|
||||
(r"(?:the |our )?(?:project|repo|codebase|code) (?:is|uses|needs|requires) (.+?)(?:\.|$)", "project"),
|
||||
(r"(?:deploy|push|commit) (?:to|on) (.+?)(?:\.|$)", "project"),
|
||||
(r"(?:this|that|the) (?:server|host|machine|VPS) (?:is|runs|has) (.+?)(?:\.|$)", "infrastructure"),
|
||||
(r"(?:model|provider|engine) (?:is|should be|needs to be) (.+?)(?:\.|$)", "config"),
|
||||
]
|
||||
def __post_init__(self) -> None:
|
||||
if not self.timestamp:
|
||||
self.timestamp = time.time()
|
||||
if not self.observed_at:
|
||||
self.observed_at = _iso_from_timestamp(self.timestamp)
|
||||
if not self.normalized_content:
|
||||
self.normalized_content = _normalize_value(self.content)
|
||||
if not self.provenance:
|
||||
self.provenance = f"conversation:{self.source_role}:{self.source_turn}"
|
||||
if not self.canonical_key:
|
||||
self.canonical_key = _canonical_key(self.entity, self.relation, self.normalized_content)
|
||||
if not self.evidence:
|
||||
self.evidence = [
|
||||
{
|
||||
"source_role": self.source_role,
|
||||
"source_turn": self.source_turn,
|
||||
"source_text": self.source_text or self.content,
|
||||
"observed_at": self.observed_at,
|
||||
"provenance": self.provenance,
|
||||
}
|
||||
]
|
||||
self.metadata = dict(self.metadata or {})
|
||||
self.metadata.setdefault("entity", self.entity)
|
||||
self.metadata.setdefault("relation", self.relation)
|
||||
self.metadata.setdefault("value", self.content)
|
||||
self.metadata.setdefault("normalized_value", self.normalized_content)
|
||||
self.metadata.setdefault("provenance", [self.provenance])
|
||||
self.metadata.setdefault("evidence", list(self.evidence))
|
||||
self.metadata.setdefault("observation_count", len(self.evidence))
|
||||
self.metadata.setdefault("duplicate_count", max(0, self.metadata["observation_count"] - 1))
|
||||
if self.contradiction_group:
|
||||
self.metadata.setdefault("contradiction_group", self.contradiction_group)
|
||||
self.metadata.setdefault("status", self.status)
|
||||
|
||||
|
||||
def extract_facts_from_messages(messages: List[Dict[str, Any]]) -> List[ExtractedFact]:
|
||||
"""Extract durable facts from conversation messages.
|
||||
|
||||
Scans user messages for preferences, corrections, project facts,
|
||||
and infrastructure details that should survive compression.
|
||||
Scans conversation turns for preferences, decisions, corrections, and
|
||||
operational state. Raw candidates are normalized into canonical facts so
|
||||
near-duplicates merge and contradictions remain inspectable.
|
||||
"""
|
||||
facts = []
|
||||
seen_contents = set()
|
||||
|
||||
raw_candidates: list[ExtractedFact] = []
|
||||
for turn_idx, msg in enumerate(messages):
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "")
|
||||
|
||||
# Only scan user messages and assistant responses with corrections
|
||||
if role not in ("user", "assistant"):
|
||||
if role not in {"user", "assistant"}:
|
||||
continue
|
||||
if not content or not isinstance(content, str):
|
||||
continue
|
||||
if len(content) < 10:
|
||||
continue
|
||||
|
||||
# Skip tool results and system messages
|
||||
if role == "assistant" and msg.get("tool_calls"):
|
||||
continue
|
||||
if not isinstance(content, str) or len(content.strip()) < 10:
|
||||
continue
|
||||
|
||||
extracted = _extract_from_text(content, turn_idx, role)
|
||||
timestamp, observed_at = _message_time(msg)
|
||||
raw_candidates.extend(
|
||||
_extract_from_text(
|
||||
content.strip(),
|
||||
turn_idx=turn_idx,
|
||||
role=role,
|
||||
timestamp=timestamp,
|
||||
observed_at=observed_at,
|
||||
)
|
||||
)
|
||||
|
||||
# Deduplicate by content
|
||||
for fact in extracted:
|
||||
key = f"{fact.category}:{fact.content[:100]}"
|
||||
if key not in seen_contents:
|
||||
seen_contents.add(key)
|
||||
facts.append(fact)
|
||||
return _normalize_candidates(raw_candidates)
|
||||
|
||||
|
||||
def evaluate_extraction_quality(messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Return before/after metrics for raw vs normalized extraction quality."""
|
||||
|
||||
raw_candidates: list[ExtractedFact] = []
|
||||
for turn_idx, msg in enumerate(messages):
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "")
|
||||
if role not in {"user", "assistant"}:
|
||||
continue
|
||||
if role == "assistant" and msg.get("tool_calls"):
|
||||
continue
|
||||
if not isinstance(content, str) or len(content.strip()) < 10:
|
||||
continue
|
||||
timestamp, observed_at = _message_time(msg)
|
||||
raw_candidates.extend(
|
||||
_extract_from_text(
|
||||
content.strip(),
|
||||
turn_idx=turn_idx,
|
||||
role=role,
|
||||
timestamp=timestamp,
|
||||
observed_at=observed_at,
|
||||
)
|
||||
)
|
||||
|
||||
normalized = _normalize_candidates(raw_candidates)
|
||||
raw_count = len(raw_candidates)
|
||||
normalized_count = len(normalized)
|
||||
contradiction_groups = {
|
||||
fact.contradiction_group
|
||||
for fact in normalized
|
||||
if fact.status == "contradiction" and fact.contradiction_group
|
||||
}
|
||||
duplicate_count = max(0, raw_count - normalized_count)
|
||||
noise_reduction = (duplicate_count / raw_count) if raw_count else 0.0
|
||||
|
||||
return {
|
||||
"raw_candidates": raw_count,
|
||||
"normalized_facts": normalized_count,
|
||||
"duplicates_merged": duplicate_count,
|
||||
"contradiction_groups": len(contradiction_groups),
|
||||
"noise_reduction": round(noise_reduction, 3),
|
||||
}
|
||||
|
||||
|
||||
def _extract_from_text(
|
||||
text: str,
|
||||
*,
|
||||
turn_idx: int,
|
||||
role: str,
|
||||
timestamp: float,
|
||||
observed_at: str,
|
||||
) -> List[ExtractedFact]:
|
||||
"""Extract raw fact candidates from a single text block."""
|
||||
|
||||
facts: list[ExtractedFact] = []
|
||||
if role != "user":
|
||||
return facts
|
||||
|
||||
deploy_match = _DEPLOY_METHOD_RE.search(text)
|
||||
if deploy_match:
|
||||
method = deploy_match.group(1).strip()
|
||||
facts.append(
|
||||
_build_fact(
|
||||
category="project.decision",
|
||||
entity="project",
|
||||
relation="workflow.deploy_method",
|
||||
value=method,
|
||||
content=f"Deploy via {method}",
|
||||
confidence=0.88,
|
||||
source_turn=turn_idx,
|
||||
source_role=role,
|
||||
source_text=text,
|
||||
timestamp=timestamp,
|
||||
observed_at=observed_at,
|
||||
unique_slot=True,
|
||||
)
|
||||
)
|
||||
|
||||
watchdog_match = _WATCHDOG_CAP_RE.search(text)
|
||||
if watchdog_match:
|
||||
watchdog = watchdog_match.group(1).strip()
|
||||
cap = watchdog_match.group(2).strip()
|
||||
facts.append(
|
||||
_build_fact(
|
||||
category="project.operational",
|
||||
entity=_normalize_entity(watchdog),
|
||||
relation="fleet.dispatch_cap",
|
||||
value=cap,
|
||||
content=f"{watchdog} caps dispatches per cycle to {cap}",
|
||||
confidence=0.92,
|
||||
source_turn=turn_idx,
|
||||
source_role=role,
|
||||
source_text=text,
|
||||
timestamp=timestamp,
|
||||
observed_at=observed_at,
|
||||
unique_slot=True,
|
||||
)
|
||||
)
|
||||
|
||||
provider_match = _PROVIDER_RE.search(text)
|
||||
if provider_match:
|
||||
provider = provider_match.group(1).strip()
|
||||
facts.append(
|
||||
_build_fact(
|
||||
category="project.config",
|
||||
entity="project",
|
||||
relation="config.provider",
|
||||
value=provider,
|
||||
content=f"Provider should stay {provider}",
|
||||
confidence=0.91,
|
||||
source_turn=turn_idx,
|
||||
source_role=role,
|
||||
source_text=text,
|
||||
timestamp=timestamp,
|
||||
observed_at=observed_at,
|
||||
unique_slot=True,
|
||||
)
|
||||
)
|
||||
|
||||
model_match = _MODEL_RE.search(text)
|
||||
if model_match:
|
||||
model = model_match.group(1).strip()
|
||||
facts.append(
|
||||
_build_fact(
|
||||
category="project.config",
|
||||
entity="project",
|
||||
relation="config.model",
|
||||
value=model,
|
||||
content=f"Model should stay {model}",
|
||||
confidence=0.9,
|
||||
source_turn=turn_idx,
|
||||
source_role=role,
|
||||
source_text=text,
|
||||
timestamp=timestamp,
|
||||
observed_at=observed_at,
|
||||
unique_slot=True,
|
||||
)
|
||||
)
|
||||
|
||||
port_match = _PORT_RE.search(text)
|
||||
if port_match:
|
||||
port = port_match.group(1).strip()
|
||||
facts.append(
|
||||
_build_fact(
|
||||
category="project.config",
|
||||
entity="project",
|
||||
relation="config.port",
|
||||
value=port,
|
||||
content=f"Port is {port}",
|
||||
confidence=0.9,
|
||||
source_turn=turn_idx,
|
||||
source_role=role,
|
||||
source_text=text,
|
||||
timestamp=timestamp,
|
||||
observed_at=observed_at,
|
||||
unique_slot=True,
|
||||
)
|
||||
)
|
||||
|
||||
project_match = _PROJECT_USES_RE.search(text)
|
||||
if project_match:
|
||||
value = project_match.group(1).strip().rstrip(".")
|
||||
facts.append(
|
||||
_build_fact(
|
||||
category="project.stack",
|
||||
entity="project",
|
||||
relation="project.stack",
|
||||
value=value,
|
||||
content=f"Project uses {value}",
|
||||
confidence=0.74,
|
||||
source_turn=turn_idx,
|
||||
source_role=role,
|
||||
source_text=text,
|
||||
timestamp=timestamp,
|
||||
observed_at=observed_at,
|
||||
unique_slot=False,
|
||||
)
|
||||
)
|
||||
|
||||
preference_match = _PREFERENCE_RE.search(text)
|
||||
if preference_match:
|
||||
value = preference_match.group(1).strip().rstrip(".")
|
||||
facts.append(
|
||||
_build_fact(
|
||||
category="user_pref.preference",
|
||||
entity="user",
|
||||
relation="user.preference",
|
||||
value=value,
|
||||
content=value,
|
||||
confidence=0.72,
|
||||
source_turn=turn_idx,
|
||||
source_role=role,
|
||||
source_text=text,
|
||||
timestamp=timestamp,
|
||||
observed_at=observed_at,
|
||||
unique_slot=False,
|
||||
)
|
||||
)
|
||||
|
||||
constraint_match = _CONSTRAINT_RE.search(text)
|
||||
if constraint_match:
|
||||
value = constraint_match.group(1).strip().rstrip(".")
|
||||
facts.append(
|
||||
_build_fact(
|
||||
category="user_pref.constraint",
|
||||
entity="user",
|
||||
relation="user.constraint",
|
||||
value=value,
|
||||
content=f"Do not {value}",
|
||||
confidence=0.82,
|
||||
source_turn=turn_idx,
|
||||
source_role=role,
|
||||
source_text=text,
|
||||
timestamp=timestamp,
|
||||
observed_at=observed_at,
|
||||
unique_slot=False,
|
||||
)
|
||||
)
|
||||
|
||||
decision_match = _DECISION_RE.search(text)
|
||||
if decision_match:
|
||||
value = decision_match.group(1).strip().rstrip(".")
|
||||
facts.append(
|
||||
_build_fact(
|
||||
category="project.decision",
|
||||
entity="project",
|
||||
relation="project.decision",
|
||||
value=value,
|
||||
content=f"Decision: {value}",
|
||||
confidence=0.79,
|
||||
source_turn=turn_idx,
|
||||
source_role=role,
|
||||
source_text=text,
|
||||
timestamp=timestamp,
|
||||
observed_at=observed_at,
|
||||
unique_slot=False,
|
||||
)
|
||||
)
|
||||
|
||||
return facts
|
||||
|
||||
|
||||
def _extract_from_text(text: str, turn_idx: int, role: str) -> List[ExtractedFact]:
|
||||
"""Extract facts from a single text block."""
|
||||
facts = []
|
||||
timestamp = time.time()
|
||||
def _build_fact(
|
||||
*,
|
||||
category: str,
|
||||
entity: str,
|
||||
relation: str,
|
||||
value: str,
|
||||
content: str,
|
||||
confidence: float,
|
||||
source_turn: int,
|
||||
source_role: str,
|
||||
source_text: str,
|
||||
timestamp: float,
|
||||
observed_at: str,
|
||||
unique_slot: bool,
|
||||
) -> ExtractedFact:
|
||||
normalized_value = _normalize_value(value.rstrip(".!?"))
|
||||
value = value.rstrip(".!?")
|
||||
content = content.rstrip(".!?")
|
||||
provenance = f"conversation:{source_role}:{source_turn}"
|
||||
contradiction_group = relation if unique_slot else ""
|
||||
evidence = [
|
||||
{
|
||||
"source_role": source_role,
|
||||
"source_turn": source_turn,
|
||||
"source_text": source_text,
|
||||
"observed_at": observed_at,
|
||||
"provenance": provenance,
|
||||
}
|
||||
]
|
||||
metadata = {
|
||||
"entity": entity,
|
||||
"relation": relation,
|
||||
"value": value,
|
||||
"normalized_value": normalized_value,
|
||||
"provenance": [provenance],
|
||||
"evidence": list(evidence),
|
||||
"observation_count": 1,
|
||||
"duplicate_count": 0,
|
||||
"status": "active",
|
||||
}
|
||||
if contradiction_group:
|
||||
metadata["contradiction_group"] = contradiction_group
|
||||
return ExtractedFact(
|
||||
category=category,
|
||||
entity=entity,
|
||||
content=content,
|
||||
confidence=confidence,
|
||||
source_turn=source_turn,
|
||||
timestamp=timestamp,
|
||||
source_role=source_role,
|
||||
source_text=source_text,
|
||||
normalized_content=normalized_value,
|
||||
canonical_key=_canonical_key(entity, relation, normalized_value),
|
||||
relation=relation,
|
||||
contradiction_group=contradiction_group,
|
||||
status="active",
|
||||
provenance=provenance,
|
||||
observed_at=observed_at,
|
||||
evidence=evidence,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
# Clean text for pattern matching
|
||||
clean = text.strip()
|
||||
|
||||
# User preference patterns (from user messages)
|
||||
if role == "user":
|
||||
for pattern, subcategory in _PREFERENCE_PATTERNS:
|
||||
for match in re.finditer(pattern, clean, re.IGNORECASE):
|
||||
content = match.group(1).strip() if match.lastindex else match.group(0).strip()
|
||||
if len(content) > 5:
|
||||
facts.append(ExtractedFact(
|
||||
category=f"user_pref.{subcategory}",
|
||||
entity="user",
|
||||
content=content[:200],
|
||||
confidence=0.7,
|
||||
source_turn=turn_idx,
|
||||
timestamp=timestamp,
|
||||
))
|
||||
def _normalize_candidates(candidates: List[ExtractedFact]) -> List[ExtractedFact]:
|
||||
"""Merge duplicates and mark contradictions while preserving evidence."""
|
||||
|
||||
# Correction patterns (from user messages)
|
||||
if role == "user":
|
||||
for pattern, subcategory in _CORRECTION_PATTERNS:
|
||||
for match in re.finditer(pattern, clean, re.IGNORECASE):
|
||||
content = match.group(1).strip() if match.lastindex else match.group(0).strip()
|
||||
if len(content) > 5:
|
||||
facts.append(ExtractedFact(
|
||||
category=f"correction.{subcategory}",
|
||||
entity="user",
|
||||
content=content[:200],
|
||||
confidence=0.8,
|
||||
source_turn=turn_idx,
|
||||
timestamp=timestamp,
|
||||
))
|
||||
by_key: dict[str, ExtractedFact] = {}
|
||||
contradiction_groups: dict[str, list[ExtractedFact]] = {}
|
||||
|
||||
# Project/infrastructure patterns (from both user and assistant)
|
||||
for pattern, subcategory in _PROJECT_PATTERNS:
|
||||
for match in re.finditer(pattern, clean, re.IGNORECASE):
|
||||
content = match.group(1).strip() if match.lastindex else match.group(0).strip()
|
||||
if len(content) > 5:
|
||||
facts.append(ExtractedFact(
|
||||
category=f"project.{subcategory}",
|
||||
entity=subcategory,
|
||||
content=content[:200],
|
||||
confidence=0.6,
|
||||
source_turn=turn_idx,
|
||||
timestamp=timestamp,
|
||||
))
|
||||
for candidate in candidates:
|
||||
existing = by_key.get(candidate.canonical_key)
|
||||
if existing is not None:
|
||||
by_key[candidate.canonical_key] = _merge_fact(existing, candidate)
|
||||
continue
|
||||
|
||||
return facts
|
||||
by_key[candidate.canonical_key] = candidate
|
||||
if candidate.contradiction_group:
|
||||
contradiction_groups.setdefault(candidate.contradiction_group, []).append(candidate)
|
||||
|
||||
for group, facts in contradiction_groups.items():
|
||||
canonical_keys = {fact.canonical_key for fact in facts}
|
||||
if len(canonical_keys) <= 1:
|
||||
continue
|
||||
for fact in facts:
|
||||
fact.status = "contradiction"
|
||||
fact.metadata["status"] = "contradiction"
|
||||
fact.metadata["contradiction_group"] = group
|
||||
fact.metadata["contradiction_keys"] = sorted(canonical_keys - {fact.canonical_key})
|
||||
|
||||
return sorted(by_key.values(), key=lambda fact: (fact.source_turn, fact.timestamp, fact.canonical_key))
|
||||
|
||||
|
||||
def _merge_fact(existing: ExtractedFact, incoming: ExtractedFact) -> ExtractedFact:
|
||||
existing.confidence = max(existing.confidence, incoming.confidence)
|
||||
existing.timestamp = min(existing.timestamp, incoming.timestamp)
|
||||
existing.source_turn = min(existing.source_turn, incoming.source_turn)
|
||||
if not existing.observed_at or (incoming.observed_at and incoming.observed_at < existing.observed_at):
|
||||
existing.observed_at = incoming.observed_at
|
||||
existing.provenance = min(existing.provenance, incoming.provenance)
|
||||
|
||||
provenance = _ordered_unique(existing.metadata.get("provenance", []), incoming.metadata.get("provenance", []))
|
||||
evidence = _merge_evidence(existing.metadata.get("evidence", []), incoming.metadata.get("evidence", []))
|
||||
observation_count = int(existing.metadata.get("observation_count", len(existing.evidence) or 1))
|
||||
observation_count += int(incoming.metadata.get("observation_count", len(incoming.evidence) or 1))
|
||||
|
||||
existing.evidence = evidence
|
||||
existing.metadata["provenance"] = provenance
|
||||
existing.metadata["evidence"] = evidence
|
||||
existing.metadata["observation_count"] = observation_count
|
||||
existing.metadata["duplicate_count"] = max(0, observation_count - 1)
|
||||
existing.metadata["status"] = existing.status
|
||||
return existing
|
||||
|
||||
|
||||
def save_facts_to_store(facts: List[ExtractedFact], fact_store_fn=None) -> int:
|
||||
"""Save extracted facts to the fact store.
|
||||
|
||||
Args:
|
||||
facts: List of extracted facts.
|
||||
fact_store_fn: Optional callable(category, entity, content, trust).
|
||||
If None, uses the holographic fact store if available.
|
||||
|
||||
Returns:
|
||||
Number of facts saved.
|
||||
If a callback is supplied, prefer the structured signature but fall back to
|
||||
the legacy four-argument callback for compatibility.
|
||||
"""
|
||||
saved = 0
|
||||
|
||||
if fact_store_fn:
|
||||
for fact in facts:
|
||||
saved = 0
|
||||
for fact in facts:
|
||||
payload = {
|
||||
"category": _store_category(fact.category),
|
||||
"entity": fact.entity,
|
||||
"content": fact.content,
|
||||
"trust": fact.confidence,
|
||||
"metadata": dict(fact.metadata),
|
||||
"canonical_key": fact.canonical_key,
|
||||
"observed_at": fact.observed_at,
|
||||
"source_role": fact.source_role,
|
||||
"source_turn": fact.source_turn,
|
||||
"contradiction_group": fact.contradiction_group,
|
||||
"status": fact.status,
|
||||
"relation": fact.relation,
|
||||
}
|
||||
|
||||
if fact_store_fn:
|
||||
try:
|
||||
fact_store_fn(
|
||||
category=fact.category,
|
||||
entity=fact.entity,
|
||||
content=fact.content,
|
||||
trust=fact.confidence,
|
||||
)
|
||||
fact_store_fn(**payload)
|
||||
saved += 1
|
||||
except Exception as e:
|
||||
logger.debug("Failed to save fact: %s", e)
|
||||
else:
|
||||
# Try holographic fact store
|
||||
continue
|
||||
except TypeError:
|
||||
try:
|
||||
fact_store_fn(payload["category"], payload["entity"], payload["content"], payload["trust"])
|
||||
saved += 1
|
||||
continue
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to save fact via callback: %s", exc)
|
||||
continue
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to save fact via callback: %s", exc)
|
||||
continue
|
||||
|
||||
try:
|
||||
from fact_store import fact_store as _fs
|
||||
for fact in facts:
|
||||
try:
|
||||
_fs(
|
||||
action="add",
|
||||
content=fact.content,
|
||||
category=fact.category,
|
||||
tags=fact.entity,
|
||||
trust_delta=fact.confidence - 0.5,
|
||||
)
|
||||
saved += 1
|
||||
except Exception as e:
|
||||
logger.debug("Failed to save fact via fact_store: %s", e)
|
||||
|
||||
tags = ",".join(filter(None, [fact.entity, fact.relation, fact.status]))
|
||||
_fs(
|
||||
action="add",
|
||||
content=fact.content,
|
||||
category=_store_category(fact.category),
|
||||
tags=tags,
|
||||
trust_delta=fact.confidence - 0.5,
|
||||
)
|
||||
saved += 1
|
||||
except ImportError:
|
||||
logger.debug("fact_store not available — facts not persisted")
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to save fact via fact_store: %s", exc)
|
||||
|
||||
return saved
|
||||
|
||||
@@ -204,9 +553,10 @@ def extract_and_save_facts(
|
||||
|
||||
Returns (extracted_facts, saved_count).
|
||||
"""
|
||||
|
||||
facts = extract_facts_from_messages(messages)
|
||||
if facts:
|
||||
logger.info("Extracted %d facts from conversation", len(facts))
|
||||
logger.info("Extracted %d normalized facts from conversation", len(facts))
|
||||
saved = save_facts_to_store(facts, fact_store_fn)
|
||||
logger.info("Saved %d/%d facts to store", saved, len(facts))
|
||||
else:
|
||||
@@ -216,16 +566,105 @@ def extract_and_save_facts(
|
||||
|
||||
def format_facts_summary(facts: List[ExtractedFact]) -> str:
|
||||
"""Format extracted facts as a readable summary."""
|
||||
|
||||
if not facts:
|
||||
return "No facts extracted."
|
||||
|
||||
by_category = {}
|
||||
for f in facts:
|
||||
by_category.setdefault(f.category, []).append(f)
|
||||
by_category: dict[str, list[ExtractedFact]] = {}
|
||||
for fact in facts:
|
||||
by_category.setdefault(fact.category, []).append(fact)
|
||||
|
||||
lines = [f"Extracted {len(facts)} facts:", ""]
|
||||
for cat, cat_facts in sorted(by_category.items()):
|
||||
lines.append(f" {cat}:")
|
||||
for f in cat_facts:
|
||||
lines.append(f" - {f.content[:80]}")
|
||||
for category, category_facts in sorted(by_category.items()):
|
||||
lines.append(f" {category}:")
|
||||
for fact in category_facts:
|
||||
suffix = f" [{fact.status}]" if fact.status != "active" else ""
|
||||
lines.append(f" - {fact.content[:80]}{suffix}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _store_category(category: str) -> str:
|
||||
if category.startswith("user_pref"):
|
||||
return "user_pref"
|
||||
if category.startswith("project"):
|
||||
return "project"
|
||||
if category.startswith("tool"):
|
||||
return "tool"
|
||||
return "general"
|
||||
|
||||
|
||||
def _message_time(msg: Dict[str, Any]) -> Tuple[float, str]:
|
||||
for key in ("created_at", "timestamp", "time"):
|
||||
value = msg.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, (int, float)):
|
||||
ts = float(value)
|
||||
return ts, _iso_from_timestamp(ts)
|
||||
if isinstance(value, str):
|
||||
parsed = _parse_time_string(value)
|
||||
if parsed is not None:
|
||||
return parsed, _iso_from_timestamp(parsed) if "T" not in value else value.replace("+00:00", "Z")
|
||||
return time.time(), value
|
||||
now = time.time()
|
||||
return now, _iso_from_timestamp(now)
|
||||
|
||||
|
||||
def _parse_time_string(value: str) -> float | None:
|
||||
text = value.strip()
|
||||
if not text:
|
||||
return None
|
||||
try:
|
||||
return float(text)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
normalized = text[:-1] + "+00:00" if text.endswith("Z") else text
|
||||
return datetime.fromisoformat(normalized).timestamp()
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _iso_from_timestamp(value: float) -> str:
|
||||
return datetime.fromtimestamp(value, tz=timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
def _normalize_value(value: str) -> str:
|
||||
normalized = re.sub(r"[^a-z0-9]+", " ", value.lower())
|
||||
normalized = re.sub(r"\s+", " ", normalized).strip()
|
||||
return normalized
|
||||
|
||||
|
||||
def _normalize_entity(value: str) -> str:
|
||||
return _normalize_value(value).replace(" ", "_") or "entity"
|
||||
|
||||
|
||||
def _canonical_key(entity: str, relation: str, normalized_value: str) -> str:
|
||||
return f"{entity}|{relation}|{normalized_value}"
|
||||
|
||||
|
||||
def _ordered_unique(*groups: List[str]) -> List[str]:
|
||||
seen: set[str] = set()
|
||||
ordered: list[str] = []
|
||||
for group in groups:
|
||||
for item in group:
|
||||
if item and item not in seen:
|
||||
seen.add(item)
|
||||
ordered.append(item)
|
||||
return ordered
|
||||
|
||||
|
||||
def _merge_evidence(existing: List[Dict[str, Any]], incoming: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
seen: set[tuple[str, str, str]] = set()
|
||||
merged: list[dict[str, Any]] = []
|
||||
for item in list(existing) + list(incoming):
|
||||
key = (
|
||||
str(item.get("provenance", "")),
|
||||
str(item.get("observed_at", "")),
|
||||
str(item.get("source_text", "")),
|
||||
)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
merged.append(dict(item))
|
||||
return merged
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
"""Transport layer types and registry for provider response normalization.
|
||||
|
||||
Usage:
|
||||
from agent.transports import get_transport
|
||||
transport = get_transport("anthropic_messages")
|
||||
result = transport.normalize_response(raw_response)
|
||||
"""
|
||||
|
||||
from agent.transports.types import ( # noqa: F401
|
||||
NormalizedResponse,
|
||||
ToolCall,
|
||||
Usage,
|
||||
build_tool_call,
|
||||
map_finish_reason,
|
||||
)
|
||||
|
||||
_REGISTRY: dict = {}
|
||||
|
||||
|
||||
def register_transport(api_mode: str, transport_cls: type) -> None:
|
||||
"""Register a transport class for an api_mode string."""
|
||||
_REGISTRY[api_mode] = transport_cls
|
||||
|
||||
|
||||
def get_transport(api_mode: str):
|
||||
"""Get a transport instance for the given api_mode.
|
||||
|
||||
Returns None if no transport is registered for this api_mode.
|
||||
This allows gradual migration — call sites can check for None
|
||||
and fall back to the legacy code path.
|
||||
"""
|
||||
if not _REGISTRY:
|
||||
_discover_transports()
|
||||
cls = _REGISTRY.get(api_mode)
|
||||
if cls is None:
|
||||
return None
|
||||
return cls()
|
||||
|
||||
|
||||
def _discover_transports() -> None:
|
||||
"""Import all transport modules to trigger auto-registration."""
|
||||
try:
|
||||
import agent.transports.anthropic # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
import agent.transports.codex # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
import agent.transports.chat_completions # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
import agent.transports.bedrock # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
@@ -1,95 +0,0 @@
|
||||
"""Anthropic Messages API transport.
|
||||
|
||||
Delegates to the existing adapter functions in agent/anthropic_adapter.py.
|
||||
This transport owns format conversion and normalization — NOT client lifecycle.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.transports.base import ProviderTransport
|
||||
from agent.transports.types import NormalizedResponse
|
||||
|
||||
|
||||
class AnthropicTransport(ProviderTransport):
|
||||
"""Transport for api_mode='anthropic_messages'."""
|
||||
|
||||
@property
|
||||
def api_mode(self) -> str:
|
||||
return "anthropic_messages"
|
||||
|
||||
def convert_messages(self, messages: List[Dict[str, Any]], **kwargs) -> Any:
|
||||
from agent.anthropic_adapter import convert_messages_to_anthropic
|
||||
|
||||
base_url = kwargs.get("base_url")
|
||||
return convert_messages_to_anthropic(messages, base_url=base_url)
|
||||
|
||||
def convert_tools(self, tools: List[Dict[str, Any]]) -> Any:
|
||||
from agent.anthropic_adapter import convert_tools_to_anthropic
|
||||
|
||||
return convert_tools_to_anthropic(tools)
|
||||
|
||||
def build_kwargs(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
**params,
|
||||
) -> Dict[str, Any]:
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
|
||||
return build_anthropic_kwargs(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
max_tokens=params.get("max_tokens", 16384),
|
||||
reasoning_config=params.get("reasoning_config"),
|
||||
tool_choice=params.get("tool_choice"),
|
||||
is_oauth=params.get("is_oauth", False),
|
||||
preserve_dots=params.get("preserve_dots", False),
|
||||
context_length=params.get("context_length"),
|
||||
base_url=params.get("base_url"),
|
||||
fast_mode=params.get("fast_mode", False),
|
||||
)
|
||||
|
||||
def normalize_response(self, response: Any, **kwargs) -> NormalizedResponse:
|
||||
from agent.anthropic_adapter import normalize_anthropic_response_v2
|
||||
|
||||
strip_tool_prefix = kwargs.get("strip_tool_prefix", False)
|
||||
return normalize_anthropic_response_v2(response, strip_tool_prefix=strip_tool_prefix)
|
||||
|
||||
def validate_response(self, response: Any) -> bool:
|
||||
if response is None:
|
||||
return False
|
||||
content_blocks = getattr(response, "content", None)
|
||||
if not isinstance(content_blocks, list):
|
||||
return False
|
||||
if not content_blocks:
|
||||
return False
|
||||
return True
|
||||
|
||||
def extract_cache_stats(self, response: Any):
|
||||
usage = getattr(response, "usage", None)
|
||||
if usage is None:
|
||||
return None
|
||||
cached = getattr(usage, "cache_read_input_tokens", 0) or 0
|
||||
written = getattr(usage, "cache_creation_input_tokens", 0) or 0
|
||||
if cached or written:
|
||||
return {"cached_tokens": cached, "creation_tokens": written}
|
||||
return None
|
||||
|
||||
_STOP_REASON_MAP = {
|
||||
"end_turn": "stop",
|
||||
"tool_use": "tool_calls",
|
||||
"max_tokens": "length",
|
||||
"stop_sequence": "stop",
|
||||
"refusal": "content_filter",
|
||||
"model_context_window_exceeded": "length",
|
||||
}
|
||||
|
||||
def map_finish_reason(self, raw_reason: str) -> str:
|
||||
return self._STOP_REASON_MAP.get(raw_reason, "stop")
|
||||
|
||||
|
||||
from agent.transports import register_transport # noqa: E402
|
||||
|
||||
register_transport("anthropic_messages", AnthropicTransport)
|
||||
@@ -1,61 +0,0 @@
|
||||
"""Abstract base for provider transports.
|
||||
|
||||
A transport owns the data path for one api_mode:
|
||||
convert_messages → convert_tools → build_kwargs → normalize_response
|
||||
|
||||
It does NOT own: client construction, streaming, credential refresh,
|
||||
prompt caching, interrupt handling, or retry logic. Those stay on AIAgent.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.transports.types import NormalizedResponse
|
||||
|
||||
|
||||
class ProviderTransport(ABC):
|
||||
"""Base class for provider-specific format conversion and normalization."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def api_mode(self) -> str:
|
||||
"""The api_mode string this transport handles."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def convert_messages(self, messages: List[Dict[str, Any]], **kwargs) -> Any:
|
||||
"""Convert OpenAI-format messages to provider-native format."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def convert_tools(self, tools: List[Dict[str, Any]]) -> Any:
|
||||
"""Convert OpenAI-format tool definitions to provider-native format."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def build_kwargs(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
**params,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build the complete provider kwargs dict."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def normalize_response(self, response: Any, **kwargs) -> NormalizedResponse:
|
||||
"""Normalize a raw provider response to the shared NormalizedResponse type."""
|
||||
...
|
||||
|
||||
def validate_response(self, response: Any) -> bool:
|
||||
"""Optional structural validation for raw responses."""
|
||||
return True
|
||||
|
||||
def extract_cache_stats(self, response: Any) -> Optional[Dict[str, int]]:
|
||||
"""Optional cache stats extraction."""
|
||||
return None
|
||||
|
||||
def map_finish_reason(self, raw_reason: str) -> str:
|
||||
"""Optional stop-reason mapping. Defaults to passthrough."""
|
||||
return raw_reason
|
||||
@@ -1,58 +0,0 @@
|
||||
"""Shared types for normalized provider responses."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""A normalized tool call from any provider."""
|
||||
|
||||
id: Optional[str]
|
||||
name: str
|
||||
arguments: str
|
||||
provider_data: Optional[Dict[str, Any]] = field(default=None, repr=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Usage:
|
||||
"""Token usage from an API response."""
|
||||
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
cached_tokens: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class NormalizedResponse:
|
||||
"""Normalized API response from any provider."""
|
||||
|
||||
content: Optional[str]
|
||||
tool_calls: Optional[List[ToolCall]]
|
||||
finish_reason: str
|
||||
reasoning: Optional[str] = None
|
||||
usage: Optional[Usage] = None
|
||||
provider_data: Optional[Dict[str, Any]] = field(default=None, repr=False)
|
||||
|
||||
|
||||
def build_tool_call(
|
||||
id: Optional[str],
|
||||
name: str,
|
||||
arguments: Any,
|
||||
**provider_fields: Any,
|
||||
) -> ToolCall:
|
||||
"""Build a ToolCall, auto-serialising dict arguments."""
|
||||
args_str = json.dumps(arguments) if isinstance(arguments, dict) else str(arguments)
|
||||
provider_data = dict(provider_fields) if provider_fields else None
|
||||
return ToolCall(id=id, name=name, arguments=args_str, provider_data=provider_data)
|
||||
|
||||
|
||||
def map_finish_reason(reason: Optional[str], mapping: Dict[str, str]) -> str:
|
||||
"""Translate a provider-specific stop reason to the normalized set."""
|
||||
if reason is None:
|
||||
return "stop"
|
||||
return mapping.get(reason, "stop")
|
||||
@@ -356,44 +356,57 @@ class HolographicMemoryProvider(MemoryProvider):
|
||||
# -- Auto-extraction (on_session_end) ------------------------------------
|
||||
|
||||
def _auto_extract_facts(self, messages: list) -> None:
|
||||
_PREF_PATTERNS = [
|
||||
re.compile(r'\bI\s+(?:prefer|like|love|use|want|need)\s+(.+)', re.IGNORECASE),
|
||||
re.compile(r'\bmy\s+(?:favorite|preferred|default)\s+\w+\s+is\s+(.+)', re.IGNORECASE),
|
||||
re.compile(r'\bI\s+(?:always|never|usually)\s+(.+)', re.IGNORECASE),
|
||||
]
|
||||
_DECISION_PATTERNS = [
|
||||
re.compile(r'\bwe\s+(?:decided|agreed|chose)\s+(?:to\s+)?(.+)', re.IGNORECASE),
|
||||
re.compile(r'\bthe\s+project\s+(?:uses|needs|requires)\s+(.+)', re.IGNORECASE),
|
||||
]
|
||||
from agent.session_compactor import evaluate_extraction_quality, extract_facts_from_messages
|
||||
|
||||
def _store_category(category: str) -> str:
|
||||
if category.startswith("user_pref"):
|
||||
return "user_pref"
|
||||
if category.startswith("project"):
|
||||
return "project"
|
||||
if category.startswith("tool"):
|
||||
return "tool"
|
||||
return "general"
|
||||
|
||||
facts = extract_facts_from_messages(messages)
|
||||
if not facts:
|
||||
return
|
||||
|
||||
extracted = 0
|
||||
for msg in messages:
|
||||
if msg.get("role") != "user":
|
||||
continue
|
||||
content = msg.get("content", "")
|
||||
if not isinstance(content, str) or len(content) < 10:
|
||||
continue
|
||||
|
||||
for pattern in _PREF_PATTERNS:
|
||||
if pattern.search(content):
|
||||
try:
|
||||
self._store.add_fact(content[:400], category="user_pref")
|
||||
extracted += 1
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
|
||||
for pattern in _DECISION_PATTERNS:
|
||||
if pattern.search(content):
|
||||
try:
|
||||
self._store.add_fact(content[:400], category="project")
|
||||
extracted += 1
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
for fact in facts:
|
||||
try:
|
||||
metadata = dict(fact.metadata)
|
||||
metadata.setdefault("relation", fact.relation)
|
||||
metadata.setdefault("value", fact.content)
|
||||
metadata.setdefault("provenance", [fact.provenance])
|
||||
metadata.setdefault("evidence", list(fact.evidence))
|
||||
metadata.setdefault("observation_count", len(fact.evidence))
|
||||
metadata.setdefault("duplicate_count", max(0, len(fact.evidence) - 1))
|
||||
self._store.add_fact(
|
||||
fact.content,
|
||||
category=_store_category(fact.category),
|
||||
tags=",".join(filter(None, [fact.entity, fact.relation, fact.status])),
|
||||
canonical_key=fact.canonical_key,
|
||||
metadata=metadata,
|
||||
confidence=fact.confidence,
|
||||
source_role=fact.source_role,
|
||||
source_turn=fact.source_turn,
|
||||
observed_at=fact.observed_at,
|
||||
contradiction_group=fact.contradiction_group,
|
||||
status=fact.status,
|
||||
)
|
||||
extracted += 1
|
||||
except Exception as exc:
|
||||
logger.debug("Structured auto-extract failed for %s: %s", fact.canonical_key, exc)
|
||||
|
||||
if extracted:
|
||||
logger.info("Auto-extracted %d facts from conversation", extracted)
|
||||
metrics = evaluate_extraction_quality(messages)
|
||||
logger.info(
|
||||
"Auto-extracted %d structured facts from conversation (raw=%d normalized=%d contradictions=%d)",
|
||||
extracted,
|
||||
metrics["raw_candidates"],
|
||||
metrics["normalized_facts"],
|
||||
metrics["contradiction_groups"],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -3,6 +3,7 @@ SQLite-backed fact store with entity resolution and trust scoring.
|
||||
Single-user Hermes memory store plugin.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import sqlite3
|
||||
import threading
|
||||
@@ -15,16 +16,24 @@ except ImportError:
|
||||
|
||||
_SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS facts (
|
||||
fact_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
content TEXT NOT NULL UNIQUE,
|
||||
category TEXT DEFAULT 'general',
|
||||
tags TEXT DEFAULT '',
|
||||
trust_score REAL DEFAULT 0.5,
|
||||
retrieval_count INTEGER DEFAULT 0,
|
||||
helpful_count INTEGER DEFAULT 0,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
hrr_vector BLOB
|
||||
fact_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
content TEXT NOT NULL UNIQUE,
|
||||
category TEXT DEFAULT 'general',
|
||||
tags TEXT DEFAULT '',
|
||||
trust_score REAL DEFAULT 0.5,
|
||||
retrieval_count INTEGER DEFAULT 0,
|
||||
helpful_count INTEGER DEFAULT 0,
|
||||
canonical_key TEXT DEFAULT '',
|
||||
metadata_json TEXT DEFAULT '{}',
|
||||
confidence REAL DEFAULT 0.5,
|
||||
source_role TEXT DEFAULT '',
|
||||
source_turn INTEGER DEFAULT -1,
|
||||
observed_at TEXT DEFAULT '',
|
||||
contradiction_group TEXT DEFAULT '',
|
||||
status TEXT DEFAULT 'active',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
hrr_vector BLOB
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS entities (
|
||||
@@ -41,9 +50,11 @@ CREATE TABLE IF NOT EXISTS fact_entities (
|
||||
PRIMARY KEY (fact_id, entity_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_facts_trust ON facts(trust_score DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_facts_category ON facts(category);
|
||||
CREATE INDEX IF NOT EXISTS idx_entities_name ON entities(name);
|
||||
CREATE INDEX IF NOT EXISTS idx_facts_trust ON facts(trust_score DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_facts_category ON facts(category);
|
||||
CREATE INDEX IF NOT EXISTS idx_facts_canonical_key ON facts(canonical_key);
|
||||
CREATE INDEX IF NOT EXISTS idx_facts_contradiction_group ON facts(contradiction_group);
|
||||
CREATE INDEX IF NOT EXISTS idx_entities_name ON entities(name);
|
||||
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS facts_fts
|
||||
USING fts5(content, tags, content=facts, content_rowid=fact_id);
|
||||
@@ -129,10 +140,23 @@ class MemoryStore:
|
||||
"""Create tables, indexes, and triggers if they do not exist. Enable WAL mode."""
|
||||
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||
self._conn.executescript(_SCHEMA)
|
||||
# Migrate: add hrr_vector column if missing (safe for existing databases)
|
||||
columns = {row[1] for row in self._conn.execute("PRAGMA table_info(facts)").fetchall()}
|
||||
if "hrr_vector" not in columns:
|
||||
self._conn.execute("ALTER TABLE facts ADD COLUMN hrr_vector BLOB")
|
||||
migrations = {
|
||||
"hrr_vector": "ALTER TABLE facts ADD COLUMN hrr_vector BLOB",
|
||||
"canonical_key": "ALTER TABLE facts ADD COLUMN canonical_key TEXT DEFAULT ''",
|
||||
"metadata_json": "ALTER TABLE facts ADD COLUMN metadata_json TEXT DEFAULT '{}'",
|
||||
"confidence": "ALTER TABLE facts ADD COLUMN confidence REAL DEFAULT 0.5",
|
||||
"source_role": "ALTER TABLE facts ADD COLUMN source_role TEXT DEFAULT ''",
|
||||
"source_turn": "ALTER TABLE facts ADD COLUMN source_turn INTEGER DEFAULT -1",
|
||||
"observed_at": "ALTER TABLE facts ADD COLUMN observed_at TEXT DEFAULT ''",
|
||||
"contradiction_group": "ALTER TABLE facts ADD COLUMN contradiction_group TEXT DEFAULT ''",
|
||||
"status": "ALTER TABLE facts ADD COLUMN status TEXT DEFAULT 'active'",
|
||||
}
|
||||
for column, ddl in migrations.items():
|
||||
if column not in columns:
|
||||
self._conn.execute(ddl)
|
||||
self._conn.execute("CREATE INDEX IF NOT EXISTS idx_facts_canonical_key ON facts(canonical_key)")
|
||||
self._conn.execute("CREATE INDEX IF NOT EXISTS idx_facts_contradiction_group ON facts(contradiction_group)")
|
||||
self._conn.commit()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -144,41 +168,148 @@ class MemoryStore:
|
||||
content: str,
|
||||
category: str = "general",
|
||||
tags: str = "",
|
||||
*,
|
||||
canonical_key: str = "",
|
||||
metadata: dict | None = None,
|
||||
confidence: float | None = None,
|
||||
source_role: str = "",
|
||||
source_turn: int = -1,
|
||||
observed_at: str = "",
|
||||
contradiction_group: str = "",
|
||||
status: str = "active",
|
||||
) -> int:
|
||||
"""Insert a fact and return its fact_id.
|
||||
|
||||
Deduplicates by content (UNIQUE constraint). On duplicate, returns
|
||||
the existing fact_id without modifying the row. Extracts entities from
|
||||
the content and links them to the fact.
|
||||
Exact duplicates are deduplicated by content. Near-duplicates are
|
||||
normalized by canonical_key, with provenance/evidence merged into the
|
||||
existing row. Contradictions sharing the same contradiction_group remain
|
||||
stored as separate rows and are marked inspectably.
|
||||
"""
|
||||
with self._lock:
|
||||
content = content.strip()
|
||||
if not content:
|
||||
raise ValueError("content must not be empty")
|
||||
|
||||
metadata = dict(metadata or {})
|
||||
canonical_key = canonical_key.strip()
|
||||
contradiction_group = contradiction_group.strip()
|
||||
observed_at = observed_at.strip()
|
||||
status = status or "active"
|
||||
trust_score = self.default_trust if confidence is None else _clamp_trust(confidence)
|
||||
metadata_json = json.dumps(metadata, sort_keys=True)
|
||||
|
||||
if canonical_key:
|
||||
existing = self._conn.execute(
|
||||
"SELECT fact_id, metadata_json, trust_score, confidence, observed_at FROM facts WHERE canonical_key = ?",
|
||||
(canonical_key,),
|
||||
).fetchone()
|
||||
if existing is not None:
|
||||
merged_metadata = self._merge_metadata(existing["metadata_json"], metadata)
|
||||
merged_trust = max(float(existing["trust_score"]), trust_score)
|
||||
merged_observed_at = existing["observed_at"] or observed_at
|
||||
if observed_at and merged_observed_at:
|
||||
merged_observed_at = min(merged_observed_at, observed_at)
|
||||
elif observed_at:
|
||||
merged_observed_at = observed_at
|
||||
self._conn.execute(
|
||||
"""
|
||||
UPDATE facts
|
||||
SET metadata_json = ?,
|
||||
trust_score = ?,
|
||||
confidence = ?,
|
||||
observed_at = ?,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE fact_id = ?
|
||||
""",
|
||||
(
|
||||
json.dumps(merged_metadata, sort_keys=True),
|
||||
merged_trust,
|
||||
max(float(existing["confidence"] or 0.0), confidence or trust_score),
|
||||
merged_observed_at,
|
||||
existing["fact_id"],
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
return int(existing["fact_id"])
|
||||
|
||||
contradiction_rows = []
|
||||
if contradiction_group:
|
||||
contradiction_rows = self._conn.execute(
|
||||
"""
|
||||
SELECT fact_id, canonical_key, metadata_json
|
||||
FROM facts
|
||||
WHERE contradiction_group = ?
|
||||
AND canonical_key != ?
|
||||
""",
|
||||
(contradiction_group, canonical_key),
|
||||
).fetchall()
|
||||
if contradiction_rows:
|
||||
status = "contradiction"
|
||||
metadata = dict(metadata)
|
||||
metadata["status"] = "contradiction"
|
||||
metadata["contradiction_group"] = contradiction_group
|
||||
metadata["contradiction_keys"] = sorted(
|
||||
{
|
||||
canonical_key,
|
||||
*[str(row["canonical_key"]) for row in contradiction_rows if row["canonical_key"]],
|
||||
}
|
||||
- {""}
|
||||
)
|
||||
metadata_json = json.dumps(metadata, sort_keys=True)
|
||||
|
||||
try:
|
||||
cur = self._conn.execute(
|
||||
"""
|
||||
INSERT INTO facts (content, category, tags, trust_score)
|
||||
VALUES (?, ?, ?, ?)
|
||||
INSERT INTO facts (
|
||||
content,
|
||||
category,
|
||||
tags,
|
||||
trust_score,
|
||||
canonical_key,
|
||||
metadata_json,
|
||||
confidence,
|
||||
source_role,
|
||||
source_turn,
|
||||
observed_at,
|
||||
contradiction_group,
|
||||
status
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(content, category, tags, self.default_trust),
|
||||
(
|
||||
content,
|
||||
category,
|
||||
tags,
|
||||
trust_score,
|
||||
canonical_key,
|
||||
metadata_json,
|
||||
confidence if confidence is not None else trust_score,
|
||||
source_role,
|
||||
source_turn,
|
||||
observed_at,
|
||||
contradiction_group,
|
||||
status,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
fact_id: int = cur.lastrowid # type: ignore[assignment]
|
||||
except sqlite3.IntegrityError:
|
||||
# Duplicate content — return existing id
|
||||
row = self._conn.execute(
|
||||
"SELECT fact_id FROM facts WHERE content = ?", (content,)
|
||||
).fetchone()
|
||||
return int(row["fact_id"])
|
||||
|
||||
# Entity extraction and linking
|
||||
if contradiction_rows:
|
||||
self._mark_contradictions(
|
||||
contradiction_group=contradiction_group,
|
||||
new_canonical_key=canonical_key,
|
||||
existing_rows=contradiction_rows,
|
||||
)
|
||||
|
||||
for name in self._extract_entities(content):
|
||||
entity_id = self._resolve_entity(name)
|
||||
self._link_fact_entity(fact_id, entity_id)
|
||||
|
||||
# Compute HRR vector after entity linking
|
||||
self._compute_hrr_vector(fact_id, content)
|
||||
self._rebuild_bank(category)
|
||||
|
||||
@@ -211,6 +342,9 @@ class MemoryStore:
|
||||
sql = f"""
|
||||
SELECT f.fact_id, f.content, f.category, f.tags,
|
||||
f.trust_score, f.retrieval_count, f.helpful_count,
|
||||
f.canonical_key, f.metadata_json, f.confidence,
|
||||
f.source_role, f.source_turn, f.observed_at,
|
||||
f.contradiction_group, f.status,
|
||||
f.created_at, f.updated_at
|
||||
FROM facts f
|
||||
JOIN facts_fts fts ON fts.rowid = f.fact_id
|
||||
@@ -336,7 +470,11 @@ class MemoryStore:
|
||||
|
||||
sql = f"""
|
||||
SELECT fact_id, content, category, tags, trust_score,
|
||||
retrieval_count, helpful_count, created_at, updated_at
|
||||
retrieval_count, helpful_count,
|
||||
canonical_key, metadata_json, confidence,
|
||||
source_role, source_turn, observed_at,
|
||||
contradiction_group, status,
|
||||
created_at, updated_at
|
||||
FROM facts
|
||||
WHERE trust_score >= ?
|
||||
{category_clause}
|
||||
@@ -387,6 +525,89 @@ class MemoryStore:
|
||||
"helpful_count": row["helpful_count"] + helpful_increment,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Metadata helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _load_metadata(self, metadata_json: str | None) -> dict:
|
||||
if not metadata_json:
|
||||
return {}
|
||||
try:
|
||||
data = json.loads(metadata_json)
|
||||
return data if isinstance(data, dict) else {}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
def _merge_metadata(self, existing_json: str | None, incoming: dict | None) -> dict:
|
||||
existing = self._load_metadata(existing_json)
|
||||
incoming = dict(incoming or {})
|
||||
merged = dict(existing)
|
||||
merged.update({k: v for k, v in incoming.items() if k not in {"provenance", "evidence", "observation_count", "duplicate_count", "contradiction_keys"}})
|
||||
|
||||
provenance = []
|
||||
seen_provenance: set[str] = set()
|
||||
for item in list(existing.get("provenance", [])) + list(incoming.get("provenance", [])):
|
||||
if item and item not in seen_provenance:
|
||||
seen_provenance.add(item)
|
||||
provenance.append(item)
|
||||
|
||||
evidence = []
|
||||
seen_evidence: set[tuple[str, str, str]] = set()
|
||||
for item in list(existing.get("evidence", [])) + list(incoming.get("evidence", [])):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
key = (
|
||||
str(item.get("provenance", "")),
|
||||
str(item.get("observed_at", "")),
|
||||
str(item.get("source_text", "")),
|
||||
)
|
||||
if key in seen_evidence:
|
||||
continue
|
||||
seen_evidence.add(key)
|
||||
evidence.append(dict(item))
|
||||
|
||||
observation_count = int(existing.get("observation_count", max(1, len(existing.get("evidence", [])) or 1)))
|
||||
observation_count += int(incoming.get("observation_count", max(1, len(incoming.get("evidence", [])) or 1)))
|
||||
|
||||
contradiction_keys = []
|
||||
seen_keys: set[str] = set()
|
||||
for item in list(existing.get("contradiction_keys", [])) + list(incoming.get("contradiction_keys", [])):
|
||||
if item and item not in seen_keys:
|
||||
seen_keys.add(item)
|
||||
contradiction_keys.append(item)
|
||||
|
||||
merged["provenance"] = provenance
|
||||
merged["evidence"] = evidence
|
||||
merged["observation_count"] = observation_count
|
||||
merged["duplicate_count"] = max(0, observation_count - 1)
|
||||
if contradiction_keys:
|
||||
merged["contradiction_keys"] = contradiction_keys
|
||||
return merged
|
||||
|
||||
def _mark_contradictions(self, contradiction_group: str, new_canonical_key: str, existing_rows: list[sqlite3.Row]) -> None:
|
||||
for row in existing_rows:
|
||||
metadata = self._load_metadata(row["metadata_json"])
|
||||
keys = []
|
||||
seen: set[str] = set()
|
||||
for item in list(metadata.get("contradiction_keys", [])) + [new_canonical_key]:
|
||||
if item and item not in seen:
|
||||
seen.add(item)
|
||||
keys.append(item)
|
||||
metadata["status"] = "contradiction"
|
||||
metadata["contradiction_group"] = contradiction_group
|
||||
metadata["contradiction_keys"] = keys
|
||||
self._conn.execute(
|
||||
"""
|
||||
UPDATE facts
|
||||
SET status = 'contradiction',
|
||||
metadata_json = ?,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE fact_id = ?
|
||||
""",
|
||||
(json.dumps(metadata, sort_keys=True), row["fact_id"]),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Entity helpers
|
||||
# ------------------------------------------------------------------
|
||||
@@ -560,8 +781,14 @@ class MemoryStore:
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _row_to_dict(self, row: sqlite3.Row) -> dict:
|
||||
"""Convert a sqlite3.Row to a plain dict."""
|
||||
return dict(row)
|
||||
"""Convert a sqlite3.Row to a plain dict with decoded metadata."""
|
||||
data = dict(row)
|
||||
metadata = self._load_metadata(data.get("metadata_json"))
|
||||
if metadata:
|
||||
data["metadata"] = metadata
|
||||
data.setdefault("relation", metadata.get("relation"))
|
||||
data.pop("metadata_json", None)
|
||||
return data
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the database connection."""
|
||||
|
||||
@@ -1,213 +0,0 @@
|
||||
"""Regression tests: normalize_anthropic_response_v2 vs v1.
|
||||
|
||||
Constructs mock Anthropic responses and asserts that the v2 function
|
||||
(returning NormalizedResponse) produces identical field values to the
|
||||
original v1 function (returning SimpleNamespace + finish_reason).
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.anthropic_adapter import (
|
||||
normalize_anthropic_response,
|
||||
normalize_anthropic_response_v2,
|
||||
)
|
||||
from agent.transports.types import NormalizedResponse
|
||||
|
||||
|
||||
def _text_block(text: str):
|
||||
return SimpleNamespace(type="text", text=text)
|
||||
|
||||
|
||||
def _thinking_block(thinking: str, signature: str = "sig_abc"):
|
||||
return SimpleNamespace(type="thinking", thinking=thinking, signature=signature)
|
||||
|
||||
|
||||
def _tool_use_block(id: str, name: str, input: dict):
|
||||
return SimpleNamespace(type="tool_use", id=id, name=name, input=input)
|
||||
|
||||
|
||||
def _response(content_blocks, stop_reason="end_turn"):
|
||||
return SimpleNamespace(
|
||||
content=content_blocks,
|
||||
stop_reason=stop_reason,
|
||||
usage=SimpleNamespace(input_tokens=10, output_tokens=5),
|
||||
)
|
||||
|
||||
|
||||
class TestTextOnly:
|
||||
def setup_method(self):
|
||||
self.resp = _response([_text_block("Hello world")])
|
||||
self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp)
|
||||
self.v2 = normalize_anthropic_response_v2(self.resp)
|
||||
|
||||
def test_type(self):
|
||||
assert isinstance(self.v2, NormalizedResponse)
|
||||
|
||||
def test_content_matches(self):
|
||||
assert self.v2.content == self.v1_msg.content
|
||||
|
||||
def test_finish_reason_matches(self):
|
||||
assert self.v2.finish_reason == self.v1_finish
|
||||
|
||||
def test_no_tool_calls(self):
|
||||
assert self.v2.tool_calls is None
|
||||
assert self.v1_msg.tool_calls is None
|
||||
|
||||
def test_no_reasoning(self):
|
||||
assert self.v2.reasoning is None
|
||||
assert self.v1_msg.reasoning is None
|
||||
|
||||
|
||||
class TestWithToolCalls:
|
||||
def setup_method(self):
|
||||
self.resp = _response(
|
||||
[
|
||||
_text_block("I'll check that"),
|
||||
_tool_use_block("toolu_abc", "terminal", {"command": "ls"}),
|
||||
_tool_use_block("toolu_def", "read_file", {"path": "/tmp"}),
|
||||
],
|
||||
stop_reason="tool_use",
|
||||
)
|
||||
self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp)
|
||||
self.v2 = normalize_anthropic_response_v2(self.resp)
|
||||
|
||||
def test_finish_reason(self):
|
||||
assert self.v2.finish_reason == "tool_calls"
|
||||
assert self.v1_finish == "tool_calls"
|
||||
|
||||
def test_tool_call_count(self):
|
||||
assert len(self.v2.tool_calls) == 2
|
||||
assert len(self.v1_msg.tool_calls) == 2
|
||||
|
||||
def test_tool_call_ids_match(self):
|
||||
for i in range(2):
|
||||
assert self.v2.tool_calls[i].id == self.v1_msg.tool_calls[i].id
|
||||
|
||||
def test_tool_call_names_match(self):
|
||||
assert self.v2.tool_calls[0].name == "terminal"
|
||||
assert self.v2.tool_calls[1].name == "read_file"
|
||||
for i in range(2):
|
||||
assert self.v2.tool_calls[i].name == self.v1_msg.tool_calls[i].function.name
|
||||
|
||||
def test_tool_call_arguments_match(self):
|
||||
for i in range(2):
|
||||
assert self.v2.tool_calls[i].arguments == self.v1_msg.tool_calls[i].function.arguments
|
||||
|
||||
def test_content_preserved(self):
|
||||
assert self.v2.content == self.v1_msg.content
|
||||
assert "check that" in self.v2.content
|
||||
|
||||
|
||||
class TestWithThinking:
|
||||
def setup_method(self):
|
||||
self.resp = _response([
|
||||
_thinking_block("Let me think about this carefully..."),
|
||||
_text_block("The answer is 42."),
|
||||
])
|
||||
self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp)
|
||||
self.v2 = normalize_anthropic_response_v2(self.resp)
|
||||
|
||||
def test_reasoning_matches(self):
|
||||
assert self.v2.reasoning == self.v1_msg.reasoning
|
||||
assert "think about this" in self.v2.reasoning
|
||||
|
||||
def test_reasoning_details_in_provider_data(self):
|
||||
v1_details = self.v1_msg.reasoning_details
|
||||
v2_details = self.v2.provider_data.get("reasoning_details") if self.v2.provider_data else None
|
||||
assert v1_details is not None
|
||||
assert v2_details is not None
|
||||
assert len(v2_details) == len(v1_details)
|
||||
|
||||
def test_content_excludes_thinking(self):
|
||||
assert self.v2.content == "The answer is 42."
|
||||
|
||||
|
||||
class TestMixed:
|
||||
def setup_method(self):
|
||||
self.resp = _response(
|
||||
[
|
||||
_thinking_block("Planning my approach..."),
|
||||
_text_block("I'll run the command"),
|
||||
_tool_use_block("toolu_xyz", "terminal", {"command": "pwd"}),
|
||||
],
|
||||
stop_reason="tool_use",
|
||||
)
|
||||
self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp)
|
||||
self.v2 = normalize_anthropic_response_v2(self.resp)
|
||||
|
||||
def test_all_fields_present(self):
|
||||
assert self.v2.content is not None
|
||||
assert self.v2.tool_calls is not None
|
||||
assert self.v2.reasoning is not None
|
||||
assert self.v2.finish_reason == "tool_calls"
|
||||
|
||||
def test_content_matches(self):
|
||||
assert self.v2.content == self.v1_msg.content
|
||||
|
||||
def test_reasoning_matches(self):
|
||||
assert self.v2.reasoning == self.v1_msg.reasoning
|
||||
|
||||
def test_tool_call_matches(self):
|
||||
assert self.v2.tool_calls[0].id == self.v1_msg.tool_calls[0].id
|
||||
assert self.v2.tool_calls[0].name == self.v1_msg.tool_calls[0].function.name
|
||||
|
||||
|
||||
class TestStopReasons:
|
||||
@pytest.mark.parametrize("stop_reason,expected", [
|
||||
("end_turn", "stop"),
|
||||
("tool_use", "tool_calls"),
|
||||
("max_tokens", "length"),
|
||||
("stop_sequence", "stop"),
|
||||
("refusal", "content_filter"),
|
||||
("model_context_window_exceeded", "length"),
|
||||
("unknown_future_reason", "stop"),
|
||||
])
|
||||
def test_stop_reason_mapping(self, stop_reason, expected):
|
||||
resp = _response([_text_block("x")], stop_reason=stop_reason)
|
||||
_v1_msg, v1_finish = normalize_anthropic_response(resp)
|
||||
v2 = normalize_anthropic_response_v2(resp)
|
||||
assert v2.finish_reason == v1_finish == expected
|
||||
|
||||
|
||||
class TestStripToolPrefix:
|
||||
def test_prefix_stripped(self):
|
||||
resp = _response(
|
||||
[_tool_use_block("toolu_1", "mcp_terminal", {"cmd": "ls"})],
|
||||
stop_reason="tool_use",
|
||||
)
|
||||
v1_msg, _ = normalize_anthropic_response(resp, strip_tool_prefix=True)
|
||||
v2 = normalize_anthropic_response_v2(resp, strip_tool_prefix=True)
|
||||
assert v1_msg.tool_calls[0].function.name == "terminal"
|
||||
assert v2.tool_calls[0].name == "terminal"
|
||||
|
||||
def test_prefix_kept(self):
|
||||
resp = _response(
|
||||
[_tool_use_block("toolu_1", "mcp_terminal", {"cmd": "ls"})],
|
||||
stop_reason="tool_use",
|
||||
)
|
||||
v1_msg, _ = normalize_anthropic_response(resp, strip_tool_prefix=False)
|
||||
v2 = normalize_anthropic_response_v2(resp, strip_tool_prefix=False)
|
||||
assert v1_msg.tool_calls[0].function.name == "mcp_terminal"
|
||||
assert v2.tool_calls[0].name == "mcp_terminal"
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_empty_content_blocks(self):
|
||||
resp = _response([])
|
||||
v1_msg, _v1_finish = normalize_anthropic_response(resp)
|
||||
v2 = normalize_anthropic_response_v2(resp)
|
||||
assert v2.content == v1_msg.content
|
||||
assert v2.content is None
|
||||
|
||||
def test_no_reasoning_details_means_none_provider_data(self):
|
||||
resp = _response([_text_block("hi")])
|
||||
v2 = normalize_anthropic_response_v2(resp)
|
||||
assert v2.provider_data is None
|
||||
|
||||
def test_v2_returns_dataclass_not_namespace(self):
|
||||
resp = _response([_text_block("hi")])
|
||||
v2 = normalize_anthropic_response_v2(resp)
|
||||
assert isinstance(v2, NormalizedResponse)
|
||||
assert not isinstance(v2, SimpleNamespace)
|
||||
@@ -1,208 +0,0 @@
|
||||
"""Tests for the transport ABC, registry, and AnthropicTransport."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.transports import _REGISTRY, get_transport, register_transport
|
||||
from agent.transports.base import ProviderTransport
|
||||
from agent.transports.types import NormalizedResponse
|
||||
|
||||
|
||||
class TestProviderTransportABC:
|
||||
def test_cannot_instantiate_abc(self):
|
||||
with pytest.raises(TypeError):
|
||||
ProviderTransport()
|
||||
|
||||
def test_concrete_must_implement_all_abstract(self):
|
||||
class Incomplete(ProviderTransport):
|
||||
@property
|
||||
def api_mode(self):
|
||||
return "test"
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
Incomplete()
|
||||
|
||||
def test_minimal_concrete(self):
|
||||
class Minimal(ProviderTransport):
|
||||
@property
|
||||
def api_mode(self):
|
||||
return "test_minimal"
|
||||
|
||||
def convert_messages(self, messages, **kw):
|
||||
return messages
|
||||
|
||||
def convert_tools(self, tools):
|
||||
return tools
|
||||
|
||||
def build_kwargs(self, model, messages, tools=None, **params):
|
||||
return {"model": model, "messages": messages}
|
||||
|
||||
def normalize_response(self, response, **kw):
|
||||
return NormalizedResponse(content="ok", tool_calls=None, finish_reason="stop")
|
||||
|
||||
t = Minimal()
|
||||
assert t.api_mode == "test_minimal"
|
||||
assert t.validate_response(None) is True
|
||||
assert t.extract_cache_stats(None) is None
|
||||
assert t.map_finish_reason("end_turn") == "end_turn"
|
||||
|
||||
|
||||
class TestTransportRegistry:
|
||||
def test_get_unregistered_returns_none(self):
|
||||
assert get_transport("nonexistent_mode") is None
|
||||
|
||||
def test_anthropic_registered_on_import(self):
|
||||
import agent.transports.anthropic # noqa: F401
|
||||
|
||||
t = get_transport("anthropic_messages")
|
||||
assert t is not None
|
||||
assert t.api_mode == "anthropic_messages"
|
||||
|
||||
def test_register_and_get(self):
|
||||
class DummyTransport(ProviderTransport):
|
||||
@property
|
||||
def api_mode(self):
|
||||
return "dummy_test"
|
||||
|
||||
def convert_messages(self, messages, **kw):
|
||||
return messages
|
||||
|
||||
def convert_tools(self, tools):
|
||||
return tools
|
||||
|
||||
def build_kwargs(self, model, messages, tools=None, **params):
|
||||
return {}
|
||||
|
||||
def normalize_response(self, response, **kw):
|
||||
return NormalizedResponse(content=None, tool_calls=None, finish_reason="stop")
|
||||
|
||||
register_transport("dummy_test", DummyTransport)
|
||||
t = get_transport("dummy_test")
|
||||
assert t.api_mode == "dummy_test"
|
||||
_REGISTRY.pop("dummy_test", None)
|
||||
|
||||
|
||||
class TestAnthropicTransport:
|
||||
@pytest.fixture
|
||||
def transport(self):
|
||||
import agent.transports.anthropic # noqa: F401
|
||||
|
||||
return get_transport("anthropic_messages")
|
||||
|
||||
def test_api_mode(self, transport):
|
||||
assert transport.api_mode == "anthropic_messages"
|
||||
|
||||
def test_convert_tools_simple(self, transport):
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"description": "A test",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}]
|
||||
result = transport.convert_tools(tools)
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "test_tool"
|
||||
assert "input_schema" in result[0]
|
||||
|
||||
def test_validate_response_none(self, transport):
|
||||
assert transport.validate_response(None) is False
|
||||
|
||||
def test_validate_response_empty_content(self, transport):
|
||||
r = SimpleNamespace(content=[])
|
||||
assert transport.validate_response(r) is False
|
||||
|
||||
def test_validate_response_valid(self, transport):
|
||||
r = SimpleNamespace(content=[SimpleNamespace(type="text", text="hello")])
|
||||
assert transport.validate_response(r) is True
|
||||
|
||||
def test_map_finish_reason(self, transport):
|
||||
assert transport.map_finish_reason("end_turn") == "stop"
|
||||
assert transport.map_finish_reason("tool_use") == "tool_calls"
|
||||
assert transport.map_finish_reason("max_tokens") == "length"
|
||||
assert transport.map_finish_reason("stop_sequence") == "stop"
|
||||
assert transport.map_finish_reason("refusal") == "content_filter"
|
||||
assert transport.map_finish_reason("model_context_window_exceeded") == "length"
|
||||
assert transport.map_finish_reason("unknown") == "stop"
|
||||
|
||||
def test_extract_cache_stats_none_usage(self, transport):
|
||||
r = SimpleNamespace(usage=None)
|
||||
assert transport.extract_cache_stats(r) is None
|
||||
|
||||
def test_extract_cache_stats_with_cache(self, transport):
|
||||
usage = SimpleNamespace(cache_read_input_tokens=100, cache_creation_input_tokens=50)
|
||||
r = SimpleNamespace(usage=usage)
|
||||
result = transport.extract_cache_stats(r)
|
||||
assert result == {"cached_tokens": 100, "creation_tokens": 50}
|
||||
|
||||
def test_extract_cache_stats_zero(self, transport):
|
||||
usage = SimpleNamespace(cache_read_input_tokens=0, cache_creation_input_tokens=0)
|
||||
r = SimpleNamespace(usage=usage)
|
||||
assert transport.extract_cache_stats(r) is None
|
||||
|
||||
def test_normalize_response_text(self, transport):
|
||||
r = SimpleNamespace(
|
||||
content=[SimpleNamespace(type="text", text="Hello world")],
|
||||
stop_reason="end_turn",
|
||||
usage=SimpleNamespace(input_tokens=10, output_tokens=5),
|
||||
model="claude-sonnet-4-6",
|
||||
)
|
||||
nr = transport.normalize_response(r)
|
||||
assert isinstance(nr, NormalizedResponse)
|
||||
assert nr.content == "Hello world"
|
||||
assert nr.tool_calls is None or nr.tool_calls == []
|
||||
assert nr.finish_reason == "stop"
|
||||
|
||||
def test_normalize_response_tool_calls(self, transport):
|
||||
r = SimpleNamespace(
|
||||
content=[
|
||||
SimpleNamespace(type="tool_use", id="toolu_123", name="terminal", input={"command": "ls"}),
|
||||
],
|
||||
stop_reason="tool_use",
|
||||
usage=SimpleNamespace(input_tokens=10, output_tokens=20),
|
||||
model="claude-sonnet-4-6",
|
||||
)
|
||||
nr = transport.normalize_response(r)
|
||||
assert nr.finish_reason == "tool_calls"
|
||||
assert len(nr.tool_calls) == 1
|
||||
tc = nr.tool_calls[0]
|
||||
assert tc.name == "terminal"
|
||||
assert tc.id == "toolu_123"
|
||||
assert '"command"' in tc.arguments
|
||||
|
||||
def test_normalize_response_thinking(self, transport):
|
||||
r = SimpleNamespace(
|
||||
content=[
|
||||
SimpleNamespace(type="thinking", thinking="Let me think..."),
|
||||
SimpleNamespace(type="text", text="The answer is 42"),
|
||||
],
|
||||
stop_reason="end_turn",
|
||||
usage=SimpleNamespace(input_tokens=10, output_tokens=15),
|
||||
model="claude-sonnet-4-6",
|
||||
)
|
||||
nr = transport.normalize_response(r)
|
||||
assert nr.content == "The answer is 42"
|
||||
assert nr.reasoning == "Let me think..."
|
||||
|
||||
def test_build_kwargs_returns_dict(self, transport):
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="claude-sonnet-4-6",
|
||||
messages=messages,
|
||||
max_tokens=1024,
|
||||
)
|
||||
assert isinstance(kw, dict)
|
||||
assert "model" in kw
|
||||
assert "max_tokens" in kw
|
||||
assert "messages" in kw
|
||||
|
||||
def test_convert_messages_extracts_system(self, transport):
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hi"},
|
||||
]
|
||||
system, msgs = transport.convert_messages(messages)
|
||||
assert system is not None
|
||||
assert len(msgs) >= 1
|
||||
@@ -1,130 +0,0 @@
|
||||
"""Tests for agent/transports/types.py — dataclass construction + helpers."""
|
||||
|
||||
import json
|
||||
|
||||
from agent.transports.types import (
|
||||
NormalizedResponse,
|
||||
ToolCall,
|
||||
Usage,
|
||||
build_tool_call,
|
||||
map_finish_reason,
|
||||
)
|
||||
|
||||
|
||||
class TestToolCall:
|
||||
def test_basic_construction(self):
|
||||
tc = ToolCall(id="call_abc", name="terminal", arguments='{"cmd": "ls"}')
|
||||
assert tc.id == "call_abc"
|
||||
assert tc.name == "terminal"
|
||||
assert tc.arguments == '{"cmd": "ls"}'
|
||||
assert tc.provider_data is None
|
||||
|
||||
def test_none_id(self):
|
||||
tc = ToolCall(id=None, name="read_file", arguments="{}")
|
||||
assert tc.id is None
|
||||
|
||||
def test_provider_data(self):
|
||||
tc = ToolCall(
|
||||
id="call_x",
|
||||
name="t",
|
||||
arguments="{}",
|
||||
provider_data={"call_id": "call_x", "response_item_id": "fc_x"},
|
||||
)
|
||||
assert tc.provider_data["call_id"] == "call_x"
|
||||
assert tc.provider_data["response_item_id"] == "fc_x"
|
||||
|
||||
|
||||
class TestUsage:
|
||||
def test_defaults(self):
|
||||
u = Usage()
|
||||
assert u.prompt_tokens == 0
|
||||
assert u.completion_tokens == 0
|
||||
assert u.total_tokens == 0
|
||||
assert u.cached_tokens == 0
|
||||
|
||||
def test_explicit(self):
|
||||
u = Usage(prompt_tokens=100, completion_tokens=50, total_tokens=150, cached_tokens=80)
|
||||
assert u.total_tokens == 150
|
||||
|
||||
|
||||
class TestNormalizedResponse:
|
||||
def test_text_only(self):
|
||||
r = NormalizedResponse(content="hello", tool_calls=None, finish_reason="stop")
|
||||
assert r.content == "hello"
|
||||
assert r.tool_calls is None
|
||||
assert r.finish_reason == "stop"
|
||||
assert r.reasoning is None
|
||||
assert r.usage is None
|
||||
assert r.provider_data is None
|
||||
|
||||
def test_with_tool_calls(self):
|
||||
tcs = [ToolCall(id="call_1", name="terminal", arguments='{"cmd":"pwd"}')]
|
||||
r = NormalizedResponse(content=None, tool_calls=tcs, finish_reason="tool_calls")
|
||||
assert r.finish_reason == "tool_calls"
|
||||
assert len(r.tool_calls) == 1
|
||||
assert r.tool_calls[0].name == "terminal"
|
||||
|
||||
def test_with_reasoning(self):
|
||||
r = NormalizedResponse(
|
||||
content="answer",
|
||||
tool_calls=None,
|
||||
finish_reason="stop",
|
||||
reasoning="I thought about it",
|
||||
)
|
||||
assert r.reasoning == "I thought about it"
|
||||
|
||||
def test_with_provider_data(self):
|
||||
r = NormalizedResponse(
|
||||
content=None,
|
||||
tool_calls=None,
|
||||
finish_reason="stop",
|
||||
provider_data={"reasoning_details": [{"type": "thinking", "thinking": "hmm"}]},
|
||||
)
|
||||
assert r.provider_data["reasoning_details"][0]["type"] == "thinking"
|
||||
|
||||
|
||||
class TestBuildToolCall:
|
||||
def test_dict_arguments_serialized(self):
|
||||
tc = build_tool_call(id="call_1", name="terminal", arguments={"cmd": "ls"})
|
||||
assert tc.arguments == json.dumps({"cmd": "ls"})
|
||||
assert tc.provider_data is None
|
||||
|
||||
def test_string_arguments_passthrough(self):
|
||||
tc = build_tool_call(id="call_2", name="read_file", arguments='{"path": "/tmp"}')
|
||||
assert tc.arguments == '{"path": "/tmp"}'
|
||||
|
||||
def test_provider_fields(self):
|
||||
tc = build_tool_call(
|
||||
id="call_3",
|
||||
name="terminal",
|
||||
arguments="{}",
|
||||
call_id="call_3",
|
||||
response_item_id="fc_3",
|
||||
)
|
||||
assert tc.provider_data == {"call_id": "call_3", "response_item_id": "fc_3"}
|
||||
|
||||
def test_none_id(self):
|
||||
tc = build_tool_call(id=None, name="t", arguments="{}")
|
||||
assert tc.id is None
|
||||
|
||||
|
||||
class TestMapFinishReason:
|
||||
ANTHROPIC_MAP = {
|
||||
"end_turn": "stop",
|
||||
"tool_use": "tool_calls",
|
||||
"max_tokens": "length",
|
||||
"stop_sequence": "stop",
|
||||
"refusal": "content_filter",
|
||||
}
|
||||
|
||||
def test_known_reason(self):
|
||||
assert map_finish_reason("end_turn", self.ANTHROPIC_MAP) == "stop"
|
||||
assert map_finish_reason("tool_use", self.ANTHROPIC_MAP) == "tool_calls"
|
||||
assert map_finish_reason("max_tokens", self.ANTHROPIC_MAP) == "length"
|
||||
assert map_finish_reason("refusal", self.ANTHROPIC_MAP) == "content_filter"
|
||||
|
||||
def test_unknown_reason_defaults_to_stop(self):
|
||||
assert map_finish_reason("something_new", self.ANTHROPIC_MAP) == "stop"
|
||||
|
||||
def test_none_reason(self):
|
||||
assert map_finish_reason(None, self.ANTHROPIC_MAP) == "stop"
|
||||
63
tests/fixtures/memory_extraction_fragments.json
vendored
Normal file
63
tests/fixtures/memory_extraction_fragments.json
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
{
|
||||
"preferences_and_duplicates": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Deploy via Ansible for production changes.",
|
||||
"created_at": "2026-04-22T10:00:00Z"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "We deploy through Ansible on this repo.",
|
||||
"created_at": "2026-04-22T10:01:00Z"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Gitea-first for repository work.",
|
||||
"created_at": "2026-04-22T10:02:00Z"
|
||||
}
|
||||
],
|
||||
"operational_and_contradictions": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "The BURN watchdog caps dispatches per cycle to 6.",
|
||||
"created_at": "2026-04-22T11:00:00Z"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "The provider should stay openai-codex/gpt-5.4.",
|
||||
"created_at": "2026-04-22T11:01:00Z"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Correction: the provider should stay mimo-v2-pro.",
|
||||
"created_at": "2026-04-22T11:02:00Z"
|
||||
}
|
||||
],
|
||||
"mixed_transcript": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Deploy via Ansible for production changes.",
|
||||
"created_at": "2026-04-22T10:00:00Z"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "We deploy through Ansible on this repo.",
|
||||
"created_at": "2026-04-22T10:01:00Z"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "The BURN watchdog caps dispatches per cycle to 6.",
|
||||
"created_at": "2026-04-22T11:00:00Z"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "The provider should stay openai-codex/gpt-5.4.",
|
||||
"created_at": "2026-04-22T11:01:00Z"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Correction: the provider should stay mimo-v2-pro.",
|
||||
"created_at": "2026-04-22T11:02:00Z"
|
||||
}
|
||||
]
|
||||
}
|
||||
50
tests/plugins/memory/test_holographic_auto_extract.py
Normal file
50
tests/plugins/memory/test_holographic_auto_extract.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Integration tests for holographic auto-extraction with structured fact persistence."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[3]))
|
||||
|
||||
from plugins.memory.holographic import HolographicMemoryProvider
|
||||
|
||||
_FIXTURE_PATH = Path(__file__).resolve().parents[2] / "fixtures" / "memory_extraction_fragments.json"
|
||||
|
||||
|
||||
def _load_fixture(name: str):
|
||||
return json.loads(_FIXTURE_PATH.read_text())[name]
|
||||
|
||||
|
||||
class TestHolographicAutoExtract:
|
||||
def test_auto_extract_persists_structured_metadata_and_normalizes_duplicates(self, tmp_path):
|
||||
provider = HolographicMemoryProvider(
|
||||
config={
|
||||
"db_path": str(tmp_path / "memory_store.db"),
|
||||
"auto_extract": True,
|
||||
"default_trust": 0.5,
|
||||
}
|
||||
)
|
||||
provider.initialize("test-session")
|
||||
|
||||
messages = _load_fixture("mixed_transcript")
|
||||
provider.on_session_end(messages)
|
||||
provider.on_session_end(messages)
|
||||
|
||||
facts = provider._store.list_facts(min_trust=0.0, limit=20)
|
||||
deploy_facts = [f for f in facts if f.get("relation") == "workflow.deploy_method"]
|
||||
provider_facts = [f for f in facts if f.get("contradiction_group") == "config.provider"]
|
||||
|
||||
assert len(deploy_facts) == 1
|
||||
assert deploy_facts[0]["metadata"]["duplicate_count"] >= 3
|
||||
assert deploy_facts[0]["observed_at"] == "2026-04-22T10:00:00Z"
|
||||
assert deploy_facts[0]["metadata"]["provenance"] == [
|
||||
"conversation:user:0",
|
||||
"conversation:user:1",
|
||||
]
|
||||
|
||||
assert len(provider_facts) == 2
|
||||
assert {f["status"] for f in provider_facts} == {"contradiction"}
|
||||
assert {f["metadata"]["value"] for f in provider_facts} == {
|
||||
"openai-codex/gpt-5.4",
|
||||
"mimo-v2-pro",
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for session compaction with fact extraction."""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
@@ -8,12 +8,19 @@ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from agent.session_compactor import (
|
||||
ExtractedFact,
|
||||
extract_facts_from_messages,
|
||||
save_facts_to_store,
|
||||
evaluate_extraction_quality,
|
||||
extract_and_save_facts,
|
||||
extract_facts_from_messages,
|
||||
format_facts_summary,
|
||||
save_facts_to_store,
|
||||
)
|
||||
|
||||
_FIXTURE_PATH = Path(__file__).resolve().parent / "fixtures" / "memory_extraction_fragments.json"
|
||||
|
||||
|
||||
def _load_fixture(name: str):
|
||||
return json.loads(_FIXTURE_PATH.read_text())[name]
|
||||
|
||||
|
||||
class TestFactExtraction:
|
||||
def test_extract_preference(self):
|
||||
@@ -60,14 +67,48 @@ class TestFactExtraction:
|
||||
{"role": "user", "content": "I prefer Python."},
|
||||
]
|
||||
facts = extract_facts_from_messages(messages)
|
||||
# Should deduplicate
|
||||
python_facts = [f for f in facts if "Python" in f.content]
|
||||
assert len(python_facts) == 1
|
||||
|
||||
def test_structured_fact_preserves_provenance_and_temporal_metadata(self):
|
||||
facts = extract_facts_from_messages(_load_fixture("preferences_and_duplicates"))
|
||||
deploy_fact = next(f for f in facts if f.relation == "workflow.deploy_method")
|
||||
assert deploy_fact.source_role == "user"
|
||||
assert deploy_fact.source_turn == 0
|
||||
assert deploy_fact.observed_at == "2026-04-22T10:00:00Z"
|
||||
assert deploy_fact.provenance == "conversation:user:0"
|
||||
assert deploy_fact.canonical_key
|
||||
assert deploy_fact.evidence
|
||||
assert deploy_fact.evidence[0]["source_text"].startswith("Deploy via Ansible")
|
||||
|
||||
def test_near_duplicate_facts_are_normalized_into_one_canonical_fact(self):
|
||||
facts = extract_facts_from_messages(_load_fixture("preferences_and_duplicates"))
|
||||
deploy_facts = [f for f in facts if f.relation == "workflow.deploy_method"]
|
||||
assert len(deploy_facts) == 1
|
||||
assert len(deploy_facts[0].evidence) == 2
|
||||
assert deploy_facts[0].metadata["duplicate_count"] == 1
|
||||
|
||||
def test_contradictory_facts_are_preserved_for_unique_slots(self):
|
||||
facts = extract_facts_from_messages(_load_fixture("operational_and_contradictions"))
|
||||
provider_facts = [f for f in facts if f.contradiction_group == "config.provider"]
|
||||
assert len(provider_facts) == 2
|
||||
assert {f.status for f in provider_facts} == {"contradiction"}
|
||||
assert {f.normalized_content for f in provider_facts} == {
|
||||
"openai codex gpt 5 4",
|
||||
"mimo v2 pro",
|
||||
}
|
||||
|
||||
def test_quality_evaluation_reports_noise_reduction(self):
|
||||
metrics = evaluate_extraction_quality(_load_fixture("mixed_transcript"))
|
||||
assert metrics["raw_candidates"] > metrics["normalized_facts"]
|
||||
assert metrics["noise_reduction"] > 0
|
||||
assert metrics["contradiction_groups"] == 1
|
||||
|
||||
|
||||
class TestSaveFacts:
|
||||
def test_save_with_callback(self):
|
||||
saved = []
|
||||
|
||||
def mock_save(category, entity, content, trust):
|
||||
saved.append({"category": category, "content": content})
|
||||
|
||||
@@ -76,6 +117,38 @@ class TestSaveFacts:
|
||||
assert count == 1
|
||||
assert len(saved) == 1
|
||||
|
||||
def test_save_with_extended_callback_metadata(self):
|
||||
saved = []
|
||||
|
||||
def mock_save(category, entity, content, trust, **kwargs):
|
||||
saved.append({
|
||||
"category": category,
|
||||
"entity": entity,
|
||||
"content": content,
|
||||
"trust": trust,
|
||||
**kwargs,
|
||||
})
|
||||
|
||||
fact = ExtractedFact(
|
||||
"project.operational",
|
||||
"watchdog",
|
||||
"BURN watchdog caps dispatches per cycle to 6",
|
||||
0.9,
|
||||
2,
|
||||
source_role="user",
|
||||
observed_at="2026-04-22T11:00:00Z",
|
||||
provenance="conversation:user:2",
|
||||
canonical_key="project.operational|watchdog|dispatch_cap|6",
|
||||
relation="fleet.dispatch_cap",
|
||||
contradiction_group="fleet.dispatch_cap",
|
||||
metadata={"duplicate_count": 0},
|
||||
)
|
||||
count = save_facts_to_store([fact], fact_store_fn=mock_save)
|
||||
assert count == 1
|
||||
assert saved[0]["canonical_key"] == fact.canonical_key
|
||||
assert saved[0]["observed_at"] == "2026-04-22T11:00:00Z"
|
||||
assert saved[0]["metadata"]["duplicate_count"] == 0
|
||||
|
||||
|
||||
class TestFormatSummary:
|
||||
def test_empty(self):
|
||||
|
||||
Reference in New Issue
Block a user