Compare commits
1 Commits
fix/793-so
...
fix/792-gr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
55c8100b8f |
155
scripts/grounding.py
Executable file
155
scripts/grounding.py
Executable file
@@ -0,0 +1,155 @@
|
||||
#!/usr/bin/env python3
|
||||
# grounding.py - Grounding before generation.
|
||||
# SOUL.md: "When I have verified sources, I must consult them
|
||||
# before I generate from pattern alone. Retrieval is not a feature.
|
||||
# It is the primary mechanism by which I avoid lying."
|
||||
# Part of #792
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
HERMES_HOME = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
MEMORY_DIR = HERMES_HOME / "memory"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroundingResult:
|
||||
query: str
|
||||
sources_found: List[Dict[str, Any]] = field(default_factory=list)
|
||||
grounded: bool = False
|
||||
confidence: float = 0.0
|
||||
source_text: str = ""
|
||||
source_type: str = "" # memory, file, chain, tool_result
|
||||
|
||||
@property
|
||||
def needs_hedging(self):
|
||||
return not self.grounded
|
||||
|
||||
|
||||
class GroundingLayer:
|
||||
def __init__(self, memory_dir=None):
|
||||
self.memory_dir = Path(memory_dir) if memory_dir else MEMORY_DIR
|
||||
|
||||
def ground(self, query, context=None):
|
||||
"""Query local sources before generation."""
|
||||
sources = []
|
||||
|
||||
# 1. Search memory files
|
||||
memory_hits = self._search_memory(query)
|
||||
sources.extend(memory_hits)
|
||||
|
||||
# 2. Search context files if provided
|
||||
if context:
|
||||
context_hits = self._search_context(query, context)
|
||||
sources.extend(context_hits)
|
||||
|
||||
# 3. Build result
|
||||
grounded = len(sources) > 0
|
||||
confidence = min(0.95, 0.3 + len(sources) * 0.2) if grounded else 0.0
|
||||
|
||||
source_text = ""
|
||||
source_type = ""
|
||||
if sources:
|
||||
best = max(sources, key=lambda s: s.get("score", 0))
|
||||
source_text = best.get("text", "")[:200]
|
||||
source_type = best.get("type", "unknown")
|
||||
|
||||
return GroundingResult(
|
||||
query=query, sources_found=sources, grounded=grounded,
|
||||
confidence=confidence, source_text=source_text, source_type=source_type,
|
||||
)
|
||||
|
||||
def _search_memory(self, query):
|
||||
"""Search memory files for relevant content."""
|
||||
results = []
|
||||
if not self.memory_dir.exists():
|
||||
return results
|
||||
|
||||
query_lower = query.lower()
|
||||
query_words = set(query_lower.split())
|
||||
|
||||
for mem_file in self.memory_dir.rglob("*.md"):
|
||||
try:
|
||||
content = mem_file.read_text(encoding="utf-8", errors="replace")
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
content_lower = content.lower()
|
||||
# Simple relevance: count query word matches
|
||||
matches = sum(1 for w in query_words if w in content_lower)
|
||||
if matches > 0:
|
||||
score = matches / max(len(query_words), 1)
|
||||
# Extract relevant snippet
|
||||
lines = content.split("\n")
|
||||
snippet = ""
|
||||
for line in lines:
|
||||
if any(w in line.lower() for w in query_words):
|
||||
snippet = line.strip()[:200]
|
||||
break
|
||||
|
||||
results.append({
|
||||
"text": snippet or content[:200],
|
||||
"source": str(mem_file.relative_to(self.memory_dir)),
|
||||
"type": "memory",
|
||||
"score": round(score, 3),
|
||||
})
|
||||
|
||||
return sorted(results, key=lambda r: -r["score"])[:5]
|
||||
|
||||
def _search_context(self, query, context):
|
||||
"""Search provided context text for relevant content."""
|
||||
results = []
|
||||
if not context:
|
||||
return results
|
||||
|
||||
query_lower = query.lower()
|
||||
query_words = set(query_lower.split())
|
||||
|
||||
for ctx in context:
|
||||
if isinstance(ctx, dict):
|
||||
text = ctx.get("content", "") or ctx.get("text", "")
|
||||
source = ctx.get("source", "context")
|
||||
else:
|
||||
text = str(ctx)
|
||||
source = "context"
|
||||
|
||||
text_lower = text.lower()
|
||||
matches = sum(1 for w in query_words if w in text_lower)
|
||||
if matches > 0:
|
||||
score = matches / max(len(query_words), 1)
|
||||
results.append({
|
||||
"text": text[:200],
|
||||
"source": source,
|
||||
"type": "context",
|
||||
"score": round(score, 3),
|
||||
})
|
||||
|
||||
return sorted(results, key=lambda r: -r["score"])[:5]
|
||||
|
||||
def format_sources(self, result):
|
||||
"""Format grounding result for display."""
|
||||
if not result.grounded:
|
||||
return "No verified sources found. Proceeding from pattern matching."
|
||||
|
||||
lines = ["Based on verified sources:"]
|
||||
for s in result.sources_found[:3]:
|
||||
ref = s.get("source", "unknown")
|
||||
text = s.get("text", "")[:100]
|
||||
lines.append(" - [" + ref + "] " + text)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# Convenience
|
||||
_default_layer = None
|
||||
|
||||
def get_grounding_layer():
|
||||
global _default_layer
|
||||
if _default_layer is None:
|
||||
_default_layer = GroundingLayer()
|
||||
return _default_layer
|
||||
|
||||
def ground(query, **kwargs):
|
||||
return get_grounding_layer().ground(query, **kwargs)
|
||||
67
tests/test_grounding.py
Normal file
67
tests/test_grounding.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Tests for grounding-before-generation - SOUL.md compliance."""
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
|
||||
|
||||
class TestGrounding:
|
||||
def test_ground_with_memory(self, tmp_path):
|
||||
from scripts.grounding import GroundingLayer
|
||||
mem_dir = tmp_path / "memory"
|
||||
mem_dir.mkdir()
|
||||
(mem_dir / "test.md").write_text("Python is a programming language created by Guido.")
|
||||
|
||||
layer = GroundingLayer(memory_dir=mem_dir)
|
||||
result = layer.ground("What is Python?")
|
||||
|
||||
assert result.grounded
|
||||
assert result.confidence > 0
|
||||
assert len(result.sources_found) > 0
|
||||
|
||||
def test_ground_no_sources(self, tmp_path):
|
||||
from scripts.grounding import GroundingLayer
|
||||
mem_dir = tmp_path / "memory"
|
||||
mem_dir.mkdir()
|
||||
|
||||
layer = GroundingLayer(memory_dir=mem_dir)
|
||||
result = layer.ground("What is quantum physics?")
|
||||
|
||||
assert not result.grounded
|
||||
assert result.needs_hedging
|
||||
assert result.confidence == 0.0
|
||||
|
||||
def test_ground_with_context(self):
|
||||
from scripts.grounding import GroundingLayer
|
||||
layer = GroundingLayer(memory_dir=Path("/nonexistent"))
|
||||
|
||||
context = [{"content": "The fleet uses tmux for agent management", "source": "fleet-ops"}]
|
||||
result = layer.ground("How does the fleet work?", context=context)
|
||||
|
||||
assert result.grounded
|
||||
assert result.source_type == "context"
|
||||
|
||||
def test_format_sources_grounded(self):
|
||||
from scripts.grounding import GroundingLayer, GroundingResult
|
||||
layer = GroundingLayer()
|
||||
result = GroundingResult(
|
||||
query="test", grounded=True,
|
||||
sources_found=[{"text": "test info", "source": "test.md", "type": "memory", "score": 0.8}],
|
||||
)
|
||||
formatted = layer.format_sources(result)
|
||||
assert "verified sources" in formatted
|
||||
assert "test.md" in formatted
|
||||
|
||||
def test_format_sources_ungrounded(self):
|
||||
from scripts.grounding import GroundingLayer, GroundingResult
|
||||
layer = GroundingLayer()
|
||||
result = GroundingResult(query="test", grounded=False)
|
||||
formatted = layer.format_sources(result)
|
||||
assert "pattern matching" in formatted
|
||||
|
||||
def test_empty_memory_dir(self, tmp_path):
|
||||
from scripts.grounding import GroundingLayer
|
||||
mem_dir = tmp_path / "empty"
|
||||
mem_dir.mkdir()
|
||||
layer = GroundingLayer(memory_dir=mem_dir)
|
||||
result = layer.ground("anything")
|
||||
assert not result.grounded
|
||||
Reference in New Issue
Block a user