diff --git a/agent/memory_manager.py b/agent/memory_manager.py index 0e4113eff..531342b04 100644 --- a/agent/memory_manager.py +++ b/agent/memory_manager.py @@ -37,6 +37,31 @@ from agent.memory_provider import MemoryProvider logger = logging.getLogger(__name__) +# ----------------------------------------------------------------------- +# Correction detection patterns +# ----------------------------------------------------------------------- + +_CORRECTION_PATTERNS = [ + re.compile(r'\b(?:no|wrong|incorrect|that\'s not right|that is not right)\b', re.IGNORECASE), + re.compile(r'\b(?:actually|nope|not quite|that\'s wrong|that is wrong)\b', re.IGNORECASE), + re.compile(r'\b(?:that\'s not|that is not|that was not|that\'s not what)\b', re.IGNORECASE), + re.compile(r'\bi said|i told you|what i meant|what i said\b', re.IGNORECASE), + re.compile(r'\bcorrection[:\s]|fix that|revise|undo\b', re.IGNORECASE), +] + + +def _detect_correction(user_content: str) -> bool: + """Detect if the user message is a correction of the previous assistant response.""" + if not user_content or len(user_content) < 3: + return False + # Must be short-ish to be a correction (not a new topic) + if len(user_content) > 200: + return False + for pattern in _CORRECTION_PATTERNS: + if pattern.search(user_content): + return True + return False + # --------------------------------------------------------------------------- # Context fencing helpers @@ -211,6 +236,74 @@ class MemoryManager: provider.name, e, ) + def auto_calibrate_feedback( + self, + current_user_message: str, + *, + prev_assistant_response: str = "", + session_id: str = "", + ) -> None: + """Auto-calibrate fact trust based on interaction outcome. + + Called after sync_all(). If the user's current message is a correction + of the previous assistant response, marks prefetched facts as unhelpful. + If no correction detected, marks them as helpful. + + This creates a passive feedback loop: facts that contribute to correct + responses gain trust, facts that lead to corrections lose trust. + """ + is_correction = _detect_correction(current_user_message) + + for provider in self._providers: + try: + fact_ids = provider.get_prefetched_fact_ids() + except Exception: + continue + if not fact_ids: + continue + + for fact_id in fact_ids: + try: + provider.handle_tool_call( + "fact_feedback", + { + "action": "unhelpful" if is_correction else "helpful", + "fact_id": fact_id, + }, + ) + logger.debug( + "Auto-calibrate fact %d: %s (provider=%s)", + fact_id, + "unhelpful" if is_correction else "helpful", + provider.name, + ) + except Exception as e: + logger.debug( + "Auto-calibrate fact %d failed (provider=%s): %s", + fact_id, provider.name, e, + ) + + def get_pruning_candidates(self, threshold: float = 0.15) -> List[Dict[str, Any]]: + """Return facts below the trust threshold that are candidates for pruning. + + This is a read-only query — no facts are deleted. The caller decides + whether to remove them (e.g. during on_session_end or periodic hygiene). + """ + candidates = [] + for provider in self._providers: + try: + result = provider.handle_tool_call( + "fact_store", + {"action": "list", "min_trust": 0.0, "limit": 100}, + ) + data = json.loads(result) + for fact in data.get("facts", []): + if fact.get("trust_score", 0.5) < threshold: + candidates.append(fact) + except Exception: + continue + return candidates + # -- Tools --------------------------------------------------------------- def get_all_tool_schemas(self) -> List[Dict[str, Any]]: diff --git a/agent/memory_provider.py b/agent/memory_provider.py index 54ef1fb10..3862f892a 100644 --- a/agent/memory_provider.py +++ b/agent/memory_provider.py @@ -220,6 +220,15 @@ class MemoryProvider(ABC): should all have ``env_var`` set and this method stays no-op). """ + def get_prefetched_fact_ids(self) -> List[int]: + """Return fact IDs recalled by the last prefetch() call. + + Override this to enable automatic trust calibration: facts used in + successful interactions gain trust, facts that lead to corrections + lose trust. Default returns empty list (no auto-calibration). + """ + return [] + def on_memory_write(self, action: str, target: str, content: str) -> None: """Called when the built-in memory tool writes an entry. diff --git a/plugins/memory/holographic/__init__.py b/plugins/memory/holographic/__init__.py index 23befa899..a43984c4e 100644 --- a/plugins/memory/holographic/__init__.py +++ b/plugins/memory/holographic/__init__.py @@ -119,6 +119,7 @@ class HolographicMemoryProvider(MemoryProvider): self._store = None self._retriever = None self._min_trust = float(self._config.get("min_trust_threshold", 0.3)) + self._last_prefetch_ids: List[int] = [] @property def name(self) -> str: @@ -205,11 +206,14 @@ class HolographicMemoryProvider(MemoryProvider): def prefetch(self, query: str, *, session_id: str = "") -> str: if not self._retriever or not query: + self._last_prefetch_ids = [] return "" try: results = self._retriever.search(query, min_trust=self._min_trust, limit=5) if not results: + self._last_prefetch_ids = [] return "" + self._last_prefetch_ids = [r["fact_id"] for r in results if "fact_id" in r] lines = [] for r in results: trust = r.get("trust_score", r.get("trust", 0)) @@ -217,8 +221,12 @@ class HolographicMemoryProvider(MemoryProvider): return "## Holographic Memory\n" + "\n".join(lines) except Exception as e: logger.debug("Holographic prefetch failed: %s", e) + self._last_prefetch_ids = [] return "" + def get_prefetched_fact_ids(self) -> List[int]: + return list(self._last_prefetch_ids) + def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None: # Holographic memory stores explicit facts via tools, not auto-sync. # The on_session_end hook handles auto-extraction if configured. diff --git a/run_agent.py b/run_agent.py index 349024155..ca1db47a1 100644 --- a/run_agent.py +++ b/run_agent.py @@ -7324,6 +7324,14 @@ class AIAgent: try: _query = original_user_message if isinstance(original_user_message, str) else "" _ext_prefetch_cache = self._memory_manager.prefetch_all(_query) or "" + # Auto-calibrate fact trust: detect if user is correcting + # the previous turn's response. Runs after prefetch so the + # current turn's facts are fresh, and before the tool loop + # so any trust changes affect fact retrieval immediately. + self._memory_manager.auto_calibrate_feedback( + _query, + session_id=getattr(self, 'session_id', ''), + ) except Exception: pass diff --git a/tests/agent/test_fact_calibration.py b/tests/agent/test_fact_calibration.py new file mode 100644 index 000000000..200a3ef1a --- /dev/null +++ b/tests/agent/test_fact_calibration.py @@ -0,0 +1,252 @@ +"""Tests for automatic fact trust calibration (Issue #252).""" + +import json +import pytest + +from agent.memory_manager import MemoryManager, _detect_correction +from plugins.memory.holographic import HolographicMemoryProvider + + +def _make_holographic_provider(db_path=":memory:"): + """Create a holographic provider backed by an in-memory SQLite DB.""" + provider = HolographicMemoryProvider(config={ + "db_path": db_path, + "default_trust": 0.5, + "min_trust_threshold": 0.3, + "hrr_dim": 64, # small for speed + }) + provider.initialize(session_id="test") + return provider + + +class TestDetectCorrection: + """Correction detection pattern matching.""" + + @pytest.mark.parametrize("msg", [ + "No, that's wrong", + "Actually, it's Python 3.12", + "That's not right", + "I said the config is in YAML", + "Correction: the port is 8080", + "Nope, wrong file", + "Not quite what I meant", + "Undo that last change", + "that is not correct", + "what i meant was different", + ]) + def test_correction_detected(self, msg): + assert _detect_correction(msg) is True + + @pytest.mark.parametrize("msg", [ + "", + "Hello", + "What's the weather today?", + "I need you to build a new feature. " * 10, + "yes that's correct", + ]) + def test_not_a_correction(self, msg): + assert _detect_correction(msg) is False + + +class TestAutoCalibrateFeedback: + """Auto-calibration integration.""" + + def test_correction_marks_unhelpful(self): + provider = _make_holographic_provider() + manager = MemoryManager() + manager.add_provider(provider) + + # Store a fact + result = manager.handle_tool_call( + "fact_store", + {"action": "add", "content": "The project uses Flask framework"}, + ) + fact_id = json.loads(result)["fact_id"] + + # Simulate: this fact was prefetched + provider._last_prefetch_ids = [fact_id] + + # User corrects: "No, it uses FastAPI" + manager.auto_calibrate_feedback("No, it uses FastAPI") + + # Check trust dropped + result = manager.handle_tool_call( + "fact_store", + {"action": "list", "min_trust": 0.0}, + ) + facts = json.loads(result)["facts"] + target = next(f for f in facts if f["fact_id"] == fact_id) + assert target["trust_score"] < 0.5 # dropped from default 0.5 + assert target["trust_score"] == pytest.approx(0.4, abs=0.01) # 0.5 - 0.1 + + def test_successful_interaction_gains_trust(self): + provider = _make_holographic_provider() + manager = MemoryManager() + manager.add_provider(provider) + + # Store a fact + result = manager.handle_tool_call( + "fact_store", + {"action": "add", "content": "The project uses Django framework"}, + ) + fact_id = json.loads(result)["fact_id"] + + # Simulate: this fact was prefetched + provider._last_prefetch_ids = [fact_id] + + # User says something normal (not a correction) + manager.auto_calibrate_feedback("What version of Django?") + + # Check trust increased + result = manager.handle_tool_call( + "fact_store", + {"action": "list", "min_trust": 0.0}, + ) + facts = json.loads(result)["facts"] + target = next(f for f in facts if f["fact_id"] == fact_id) + assert target["trust_score"] > 0.5 # rose from default 0.5 + assert target["trust_score"] == pytest.approx(0.55, abs=0.01) # 0.5 + 0.05 + + def test_no_prefetch_no_calibration(self): + provider = _make_holographic_provider() + manager = MemoryManager() + manager.add_provider(provider) + + # Store a fact + result = manager.handle_tool_call( + "fact_store", + {"action": "add", "content": "The database is PostgreSQL"}, + ) + fact_id = json.loads(result)["fact_id"] + + # No prefetched facts + provider._last_prefetch_ids = [] + + # Calibrate — should be no-op + manager.auto_calibrate_feedback("No, it's MySQL") + + # Trust should be unchanged + result = manager.handle_tool_call( + "fact_store", + {"action": "list", "min_trust": 0.0}, + ) + facts = json.loads(result)["facts"] + target = next(f for f in facts if f["fact_id"] == fact_id) + assert target["trust_score"] == 0.5 # unchanged + + def test_multiple_corrections_drives_trust_low(self): + provider = _make_holographic_provider() + manager = MemoryManager() + manager.add_provider(provider) + + # Store a fact + result = manager.handle_tool_call( + "fact_store", + {"action": "add", "content": "The server runs on port 3000"}, + ) + fact_id = json.loads(result)["fact_id"] + provider._last_prefetch_ids = [fact_id] + + # Simulate 5 corrections + for _ in range(5): + manager.auto_calibrate_feedback("Wrong, it's port 8080") + + # Trust should be much lower + result = manager.handle_tool_call( + "fact_store", + {"action": "list", "min_trust": 0.0}, + ) + facts = json.loads(result)["facts"] + target = next(f for f in facts if f["fact_id"] == fact_id) + assert target["trust_score"] < 0.2 # 0.5 - 5*0.1 = 0.0 (clamped) + + def test_trust_floor_at_zero(self): + provider = _make_holographic_provider() + manager = MemoryManager() + manager.add_provider(provider) + + result = manager.handle_tool_call( + "fact_store", + {"action": "add", "content": "Test fact for floor"}, + ) + fact_id = json.loads(result)["fact_id"] + provider._last_prefetch_ids = [fact_id] + + # 10 corrections should clamp at 0.0, not go negative + for _ in range(10): + manager.auto_calibrate_feedback("Wrong!") + + result = manager.handle_tool_call( + "fact_store", + {"action": "list", "min_trust": 0.0}, + ) + facts = json.loads(result)["facts"] + target = next(f for f in facts if f["fact_id"] == fact_id) + assert target["trust_score"] == 0.0 + + def test_trust_ceiling_at_one(self): + provider = _make_holographic_provider() + manager = MemoryManager() + manager.add_provider(provider) + + result = manager.handle_tool_call( + "fact_store", + {"action": "add", "content": "Test fact for ceiling"}, + ) + fact_id = json.loads(result)["fact_id"] + provider._last_prefetch_ids = [fact_id] + + # 20 successful interactions should cap at 1.0 + for _ in range(20): + manager.auto_calibrate_feedback("Thanks, what else?") + + result = manager.handle_tool_call( + "fact_store", + {"action": "list", "min_trust": 0.0}, + ) + facts = json.loads(result)["facts"] + target = next(f for f in facts if f["fact_id"] == fact_id) + assert target["trust_score"] == 1.0 + + def test_get_pruning_candidates(self): + provider = _make_holographic_provider() + manager = MemoryManager() + manager.add_provider(provider) + + # Add a fact and drive its trust below threshold via corrections + result = manager.handle_tool_call( + "fact_store", + {"action": "add", "content": "Bad fact to be pruned"}, + ) + fact_id = json.loads(result)["fact_id"] + provider._last_prefetch_ids = [fact_id] + + for _ in range(5): + manager.auto_calibrate_feedback("Wrong!") + + # Get pruning candidates + candidates = manager.get_pruning_candidates(threshold=0.15) + assert any(c["fact_id"] == fact_id for c in candidates) + + def test_prefetch_tracks_fact_ids(self): + """Verify prefetch populates _last_prefetch_ids.""" + provider = _make_holographic_provider() + + # Add facts + provider.handle_tool_call("fact_store", { + "action": "add", + "content": "Alexander uses Python for development", + }) + provider.handle_tool_call("fact_store", { + "action": "add", + "content": "Alexander prefers dark mode editors", + }) + + # Prefetch should find them and track IDs + result = provider.prefetch("Alexander") + assert "Holographic Memory" in result + assert len(provider._last_prefetch_ids) > 0 + + # Empty query clears IDs + provider.prefetch("") + assert provider._last_prefetch_ids == []