Compare commits

..

1 Commits

Author SHA1 Message Date
Alexander Whitestone
07eb8604f5 feat(tools): add LightRAG integration for graph-based knowledge retrieval (#857)
All checks were successful
Lint / lint (pull_request) Successful in 39s
Adds tools/lightrag_tool.py with two new tools:

- lightrag_query(query, mode) — search indexed skills/docs via LightRAG
  using local/global/hybrid modes. Returns structured JSON with answer.
- lightrag_index(directories) — (re-)build the knowledge graph from
  ~/.hermes/skills/ and optional extra directories.

Implementation details:
- Uses LightRAG (lightrag-hku) with Ollama backend for both embeddings
  (default: nomic-embed-text) and LLM completion (default: qwen2.5:7b)
- Storage at ~/.hermes/lightrag/ (file-based, no Docker)
- Async bridge via asyncio.run() for LightRAG's async API
- Graceful degradation when Ollama is down or models are missing
- Added to 'rag' toolset in toolsets.py
- Added [project.optional-dependencies] 'rag' group in pyproject.toml

Tests:
- 18 tests covering file collection, text reading, requirements check,
  indexing, querying, error handling, and edge cases
- All tests pass
2026-04-22 02:27:24 -04:00
8 changed files with 588 additions and 502 deletions

View File

@@ -1,281 +0,0 @@
"""
Hallucination Metrics — Persistent logging and alerting for tool hallucinations.
Logs tool hallucination events to a JSONL file and provides aggregated statistics.
Integrates with the poka-yoke validation system.
Usage:
from agent.hallucination_metrics import log_hallucination_event, get_hallucination_stats
log_hallucination_event("invalid_tool", "unknown_tool", "suggested_correct_name")
stats = get_hallucination_stats()
"""
import json
import logging
import os
import time
from collections import defaultdict
from datetime import datetime, timezone
from pathlib import Path
from threading import Lock
from typing import Any, Dict, List, Optional, Tuple
from hermes_constants import get_hermes_home
logger = logging.getLogger(__name__)
# Constants
METRICS_FILE_NAME = "hallucination_metrics.jsonl"
ALERT_THRESHOLD = 10 # Alert after this many consecutive failures for a tool
SESSION_WINDOW_HOURS = 24 # Consider events within this window as "session"
# In-memory cache for fast lookups
_cache: Dict[str, Any] = {"events": [], "last_flush": 0, "session_counts": defaultdict(int)}
_cache_lock = Lock()
def _get_metrics_path() -> Path:
"""Return the path to the hallucination metrics file."""
return get_hermes_home() / "metrics" / METRICS_FILE_NAME
def _ensure_metrics_dir():
"""Ensure the metrics directory exists."""
metrics_dir = _get_metrics_path().parent
metrics_dir.mkdir(parents=True, exist_ok=True)
def log_hallucination_event(
tool_name: str,
error_type: str = "unknown_tool",
suggested_name: Optional[str] = None,
validation_messages: Optional[List[str]] = None,
session_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Log a hallucination event to the metrics file.
Args:
tool_name: The hallucinated tool name
error_type: Type of error (unknown_tool, invalid_params, etc.)
suggested_name: Suggested correction if available
validation_messages: List of validation error messages
session_id: Optional session identifier for grouping
Returns:
The logged event dict with additional metadata
"""
event = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"tool_name": tool_name,
"error_type": error_type,
"suggested_name": suggested_name,
"validation_messages": validation_messages or [],
"session_id": session_id,
"unix_timestamp": time.time(),
}
# Write to file
_ensure_metrics_dir()
metrics_path = _get_metrics_path()
try:
with open(metrics_path, "a", encoding="utf-8") as f:
f.write(json.dumps(event, ensure_ascii=False) + "\n")
except Exception as e:
logger.warning(f"Failed to write hallucination event: {e}")
# Update in-memory cache
with _cache_lock:
_cache["events"].append(event)
_cache["session_counts"][tool_name] += 1
session_count = _cache["session_counts"][tool_name]
# Check alert threshold
if session_count >= ALERT_THRESHOLD:
logger.warning(
f"HALLUCINATION ALERT: Tool '{tool_name}' has failed {session_count} times "
f"in this session (threshold: {ALERT_THRESHOLD}). "
f"This may indicate a persistent hallucination pattern."
)
return event
def _load_events_from_file() -> List[Dict[str, Any]]:
"""Load all events from the metrics file."""
metrics_path = _get_metrics_path()
if not metrics_path.exists():
return []
events = []
try:
with open(metrics_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
try:
events.append(json.loads(line))
except json.JSONDecodeError:
continue
except Exception as e:
logger.warning(f"Failed to load hallucination events: {e}")
return events
def get_hallucination_stats(
hours: Optional[int] = None,
tool_name: Optional[str] = None,
) -> Dict[str, Any]:
"""
Get aggregated hallucination statistics.
Args:
hours: Only consider events from the last N hours (None = all time)
tool_name: Filter to specific tool name (None = all tools)
Returns:
Dict with aggregated statistics
"""
events = _load_events_from_file()
# Filter by time window
if hours is not None:
cutoff = time.time() - (hours * 3600)
events = [e for e in events if e.get("unix_timestamp", 0) >= cutoff]
# Filter by tool name
if tool_name is not None:
events = [e for e in events if e.get("tool_name") == tool_name]
# Aggregate by tool
tool_counts: Dict[str, Dict[str, Any]] = defaultdict(
lambda: {"count": 0, "suggested_names": [], "error_types": defaultdict(int)}
)
for event in events:
name = event.get("tool_name", "unknown")
tool_counts[name]["count"] += 1
if event.get("suggested_name"):
tool_counts[name]["suggested_names"].append(event["suggested_name"])
if event.get("error_type"):
tool_counts[name]["error_types"][event["error_type"]] += 1
# Find most common suggestions per tool
for name, data in tool_counts.items():
suggestions = data["suggested_names"]
if suggestions:
from collections import Counter
most_common = Counter(suggestions).most_common(1)[0]
data["most_common_suggestion"] = most_common[0]
data["suggestion_count"] = most_common[1]
del data["suggested_names"] # Remove raw list from output
# Calculate time-based stats
if events:
first_event = min(e.get("unix_timestamp", 0) for e in events)
last_event = max(e.get("unix_timestamp", 0) for e in events)
time_span_hours = (last_event - first_event) / 3600 if first_event != last_event else 0
else:
time_span_hours = 0
# Error type breakdown
all_error_types: Dict[str, int] = defaultdict(int)
for event in events:
et = event.get("error_type", "unknown")
all_error_types[et] += 1
return {
"total_events": len(events),
"unique_tools": len(tool_counts),
"time_span_hours": round(time_span_hours, 1),
"top_hallucinated_tools": sorted(
[{"tool": k, **v} for k, v in tool_counts.items()],
key=lambda x: -x["count"]
)[:20],
"error_type_breakdown": dict(all_error_types),
"alert_threshold": ALERT_THRESHOLD,
"session_window_hours": SESSION_WINDOW_HOURS,
}
def get_most_hallucinated_tools(n: int = 10) -> List[Tuple[str, int]]:
"""Get the top N most frequently hallucinated tool names."""
stats = get_hallucination_stats()
tools = stats.get("top_hallucinated_tools", [])
return [(t["tool"], t["count"]) for t in tools[:n]]
def clear_metrics(older_than_hours: Optional[int] = None) -> int:
"""
Clear hallucination metrics.
Args:
older_than_hours: Only clear events older than this many hours (None = clear all)
Returns:
Number of events removed
"""
metrics_path = _get_metrics_path()
if not metrics_path.exists():
return 0
if older_than_hours is None:
# Clear all
count = len(_load_events_from_file())
metrics_path.unlink(missing_ok=True)
with _cache_lock:
_cache["events"].clear()
_cache["session_counts"].clear()
return count
# Clear only old events
cutoff = time.time() - (older_than_hours * 3600)
events = _load_events_from_file()
keep = [e for e in events if e.get("unix_timestamp", 0) >= cutoff]
removed = len(events) - len(keep)
# Rewrite file
_ensure_metrics_dir()
with open(metrics_path, "w", encoding="utf-8") as f:
for event in keep:
f.write(json.dumps(event, ensure_ascii=False) + "\n")
return removed
def format_stats_for_display(stats: Dict[str, Any]) -> str:
"""Format statistics as a human-readable string."""
lines = [
"=== Hallucination Metrics ===",
"",
f"Total events: {stats['total_events']}",
f"Unique tools hallucinated: {stats['unique_tools']}",
f"Time span: {stats['time_span_hours']:.1f} hours",
"",
"Top Hallucinated Tools:",
"-" * 40,
]
for tool in stats.get("top_hallucinated_tools", [])[:10]:
lines.append(f" {tool['tool']:<30} {tool['count']:>5} events")
if "most_common_suggestion" in tool:
lines.append(f" → Suggested: {tool['most_common_suggestion']} ({tool['suggestion_count']}x)")
if stats.get("error_type_breakdown"):
lines.extend([
"",
"Error Types:",
"-" * 40,
])
for et, count in sorted(stats["error_type_breakdown"].items(), key=lambda x: -x[1]):
lines.append(f" {et:<30} {count:>5}")
lines.extend([
"",
f"Alert threshold: {stats['alert_threshold']} failures per session",
f"Session window: {stats['session_window_hours']} hours",
])
return "\n".join(lines)

View File

@@ -18,7 +18,6 @@ Usage:
hermes cron list # List cron jobs
hermes cron status # Check if cron scheduler is running
hermes doctor # Check configuration and dependencies
hermes hallucination-stats # Show tool hallucination statistics
hermes honcho setup # Configure Honcho AI memory integration
hermes honcho status # Show Honcho config and connection status
hermes honcho sessions # List directory → session name mappings
@@ -2805,17 +2804,6 @@ def cmd_doctor(args):
run_doctor(args)
def cmd_hallucination_stats(args):
"""Show tool hallucination statistics."""
from agent.hallucination_metrics import get_hallucination_stats, format_stats_for_display, clear_metrics
if getattr(args, 'clear', False):
removed = clear_metrics(older_than_hours=getattr(args, 'older_than', None))
print(f"Cleared {removed} hallucination events.")
return
stats = get_hallucination_stats(hours=getattr(args, 'hours', None))
print(format_stats_for_display(stats))
def cmd_dump(args):
"""Dump setup summary for support/debugging."""
from hermes_cli.dump import run_dump
@@ -5053,33 +5041,6 @@ For more help on a command:
)
doctor_parser.set_defaults(func=cmd_doctor)
# =========================================================================
# hallucination-stats command
# =========================================================================
hallucination_parser = subparsers.add_parser(
"hallucination-stats",
help="Show tool hallucination statistics",
description="View aggregated tool hallucination metrics from poka-yoke validation"
)
hallucination_parser.add_argument(
"--hours",
type=int,
default=None,
help="Only show events from the last N hours"
)
hallucination_parser.add_argument(
"--clear",
action="store_true",
help="Clear all hallucination metrics"
)
hallucination_parser.add_argument(
"--older-than",
type=int,
default=None,
help="When clearing, only remove events older than N hours"
)
hallucination_parser.set_defaults(func=cmd_hallucination_stats)
# =========================================================================
# dump command
# =========================================================================

View File

@@ -38,6 +38,7 @@ dependencies = [
[project.optional-dependencies]
modal = ["modal>=1.0.0,<2"]
rag = ["lightrag-hku>=1.4.0,<2", "aiohttp>=3.9.0,<4"]
daytona = ["daytona>=0.148.0,<1"]
dev = ["debugpy>=1.8.0,<2", "pytest>=9.0.2,<10", "pytest-asyncio>=1.3.0,<2", "pytest-xdist>=3.0,<4", "mcp>=1.2.0,<2"]
messaging = ["python-telegram-bot[webhooks]>=22.6,<23", "discord.py[voice]>=2.7.1,<3", "aiohttp>=3.13.3,<4", "slack-bolt>=1.18.0,<2", "slack-sdk>=3.27.0,<4"]

View File

@@ -1,171 +0,0 @@
"""Tests for agent/hallucination_metrics.py — #853."""
import json
import time
from pathlib import Path
import pytest
from agent.hallucination_metrics import (
log_hallucination_event,
get_hallucination_stats,
get_most_hallucinated_tools,
clear_metrics,
format_stats_for_display,
_get_metrics_path,
)
@pytest.fixture(autouse=True)
def isolated_metrics(monkeypatch, tmp_path):
"""Redirect metrics to a temp file for every test."""
metrics_dir = tmp_path / "test_hermes_home" / "metrics"
metrics_dir.mkdir(parents=True)
metrics_file = metrics_dir / "hallucination_metrics.jsonl"
# Patch the get_hermes_home function to return our temp path
def mock_get_hermes_home():
return tmp_path / "test_hermes_home"
monkeypatch.setattr(
"agent.hallucination_metrics.get_hermes_home",
mock_get_hermes_home,
)
# Also clear cache
from agent.hallucination_metrics import _cache, _cache_lock
with _cache_lock:
_cache["events"].clear()
_cache["session_counts"].clear()
yield
clear_metrics()
class TestLogEvent:
def test_log_event_returns_dict(self):
event = log_hallucination_event("fake_tool", "unknown_tool", "real_tool")
assert event["tool_name"] == "fake_tool"
assert event["error_type"] == "unknown_tool"
assert event["suggested_name"] == "real_tool"
assert "timestamp" in event
assert "unix_timestamp" in event
def test_log_event_persists_to_file(self):
log_hallucination_event("tool_a", "unknown_tool")
log_hallucination_event("tool_b", "invalid_params")
path = _get_metrics_path()
assert path.exists()
lines = path.read_text().strip().splitlines()
assert len(lines) == 2
data = [json.loads(line) for line in lines]
assert data[0]["tool_name"] == "tool_a"
assert data[1]["tool_name"] == "tool_b"
class TestGetStats:
def test_empty_stats(self):
stats = get_hallucination_stats()
assert stats["total_events"] == 0
assert stats["unique_tools"] == 0
def test_stats_by_tool(self):
log_hallucination_event("tool_x", "unknown_tool", "tool_y")
log_hallucination_event("tool_x", "unknown_tool", "tool_y")
log_hallucination_event("tool_z", "invalid_params")
stats = get_hallucination_stats()
assert stats["total_events"] == 3
assert stats["unique_tools"] == 2
top = stats["top_hallucinated_tools"]
assert len(top) == 2
assert top[0]["tool"] == "tool_x"
assert top[0]["count"] == 2
assert top[1]["tool"] == "tool_z"
assert top[1]["count"] == 1
def test_stats_hours_filter(self):
# Log old event by faking timestamp
old_event = {
"timestamp": "2026-01-01T00:00:00+00:00",
"tool_name": "old_tool",
"error_type": "unknown_tool",
"unix_timestamp": time.time() - 48 * 3600,
}
path = _get_metrics_path()
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
f.write(json.dumps(old_event) + "\n")
log_hallucination_event("new_tool", "unknown_tool")
stats = get_hallucination_stats(hours=24)
assert stats["total_events"] == 1
assert stats["top_hallucinated_tools"][0]["tool"] == "new_tool"
def test_error_type_breakdown(self):
log_hallucination_event("t1", "unknown_tool")
log_hallucination_event("t2", "invalid_params")
log_hallucination_event("t3", "unknown_tool")
stats = get_hallucination_stats()
breakdown = stats["error_type_breakdown"]
assert breakdown["unknown_tool"] == 2
assert breakdown["invalid_params"] == 1
class TestGetMostHallucinated:
def test_top_tools(self):
for _ in range(5):
log_hallucination_event("common_tool", "unknown_tool")
for _ in range(2):
log_hallucination_event("rare_tool", "unknown_tool")
tools = get_most_hallucinated_tools(n=2)
assert tools[0] == ("common_tool", 5)
assert tools[1] == ("rare_tool", 2)
class TestClearMetrics:
def test_clear_all(self):
log_hallucination_event("t1", "unknown_tool")
removed = clear_metrics()
assert removed == 1
assert _get_metrics_path().exists() is False
def test_clear_older_than(self):
path = _get_metrics_path()
path.parent.mkdir(parents=True, exist_ok=True)
old = {"tool_name": "old", "unix_timestamp": time.time() - 48 * 3600}
new = {"tool_name": "new", "unix_timestamp": time.time()}
with open(path, "w") as f:
f.write(json.dumps(old) + "\n")
f.write(json.dumps(new) + "\n")
removed = clear_metrics(older_than_hours=24)
assert removed == 1
remaining = get_hallucination_stats()
assert remaining["total_events"] == 1
class TestFormatDisplay:
def test_format_includes_headers(self):
log_hallucination_event("bad_tool", "unknown_tool", "good_tool")
stats = get_hallucination_stats()
text = format_stats_for_display(stats)
assert "Hallucination Metrics" in text
assert "bad_tool" in text
assert "Total events: 1" in text
class TestAlertThreshold:
def test_alert_after_threshold(self, monkeypatch, caplog):
monkeypatch.setattr("agent.hallucination_metrics.ALERT_THRESHOLD", 3)
for i in range(4):
log_hallucination_event("persistent_tool", "unknown_tool")
assert "HALLUCINATION ALERT" in caplog.text
assert "persistent_tool" in caplog.text

View File

@@ -0,0 +1,176 @@
"""Tests for tools/lightrag_tool.py"""
import json
import sys
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
# LightRAG may not be installed in all test environments
pytest.importorskip("lightrag", reason="lightrag-hku not installed")
from tools.lightrag_tool import (
check_lightrag_requirements,
lightrag_index,
lightrag_query,
_collect_markdown_files,
_read_text_safe,
LIGHTRAG_DIR,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _parse_result(result: str) -> dict:
"""Parse JSON tool result, falling back to error string detection."""
try:
return json.loads(result)
except json.JSONDecodeError:
return {"_error": result}
# ---------------------------------------------------------------------------
# Unit tests
# ---------------------------------------------------------------------------
class TestCollectMarkdownFiles:
def test_collects_md_files(self, tmp_path):
(tmp_path / "a.md").write_text("# A")
(tmp_path / "b.md").write_text("# B")
(tmp_path / "skip.txt").write_text("text")
found = _collect_markdown_files(tmp_path)
assert len(found) == 2
assert all(p.suffix == ".md" for p in found)
def test_skips_hidden_dirs(self, tmp_path):
(tmp_path / ".git").mkdir()
(tmp_path / ".git" / "readme.md").write_text("# git")
(tmp_path / "visible.md").write_text("# visible")
found = _collect_markdown_files(tmp_path)
names = [p.name for p in found]
assert "visible.md" in names
assert "readme.md" not in names
def test_returns_empty_for_missing_dir(self):
assert _collect_markdown_files(Path("/nonexistent")) == []
class TestReadTextSafe:
def test_reads_small_file(self, tmp_path):
p = tmp_path / "test.md"
p.write_text("hello world")
assert _read_text_safe(p) == "hello world"
def test_truncates_large_file(self, tmp_path):
p = tmp_path / "big.md"
p.write_text("x" * 1_000_000)
text = _read_text_safe(p, limit=500_000)
assert len(text) == 500_000
def test_reads_binary_without_crashing(self, tmp_path):
p = tmp_path / "binary.md"
p.write_bytes(b"\x00\x01\x02")
result = _read_text_safe(p)
# Should not crash; control chars 0x00-0x7F are valid UTF-8
assert isinstance(result, str)
class TestCheckRequirements:
@patch("tools.lightrag_tool._ollama_available", return_value=True)
def test_ok_when_ollama_up(self, mock_ollama):
assert check_lightrag_requirements() is True
@patch("tools.lightrag_tool._ollama_available", return_value=False)
def test_false_when_ollama_down(self, mock_ollama):
assert check_lightrag_requirements() is False
@patch.dict(sys.modules, {"lightrag": None}, clear=False)
def test_false_when_lightrag_missing(self):
with patch("tools.lightrag_tool._ollama_available", return_value=True):
# Force ImportError by removing lightrag from sys.modules
# and blocking import
assert check_lightrag_requirements() is False
class TestLightragIndex:
@patch("tools.lightrag_tool._ollama_available", return_value=False)
def test_error_when_ollama_down(self, mock_ollama):
result = lightrag_index()
assert "Ollama is not running" in result
@patch("tools.lightrag_tool._ollama_available", return_value=True)
@patch("tools.lightrag_tool._has_ollama_model", return_value=False)
def test_error_when_model_missing(self, mock_model, mock_ollama):
result = lightrag_index()
assert "not found in Ollama" in result
@patch("tools.lightrag_tool._ollama_available", return_value=True)
@patch("tools.lightrag_tool._has_ollama_model", return_value=True)
@patch("tools.lightrag_tool._get_lightrag")
@patch("tools.lightrag_tool._collect_markdown_files", return_value=[])
def test_warning_when_no_files(self, mock_collect, mock_get_rag, mock_model, mock_ollama):
result = lightrag_index()
data = _parse_result(result)
assert data.get("status") == "warning"
assert "No markdown files found" in data.get("message", "")
@patch("tools.lightrag_tool._ollama_available", return_value=True)
@patch("tools.lightrag_tool._has_ollama_model", return_value=True)
@patch("tools.lightrag_tool._get_lightrag")
@patch("tools.lightrag_tool._collect_markdown_files")
@patch("tools.lightrag_tool._read_text_safe", return_value="# Skill doc\nContent.")
@patch("asyncio.run")
def test_indexes_files(self, mock_asyncio, mock_read, mock_collect, mock_get_rag, mock_model, mock_ollama):
mock_collect.return_value = [Path("/fake/skills/git.md"), Path("/fake/skills/docker.md")]
mock_rag = MagicMock()
mock_get_rag.return_value = mock_rag
result = lightrag_index()
data = _parse_result(result)
assert data.get("status") == "ok"
assert data.get("indexed_files") == 2
assert data.get("errors") == 0
class TestLightragQuery:
@patch("tools.lightrag_tool._ollama_available", return_value=False)
def test_error_when_ollama_down(self, mock_ollama):
result = lightrag_query("test", mode="hybrid")
assert "Ollama is not running" in result
@patch("tools.lightrag_tool._ollama_available", return_value=True)
@patch("tools.lightrag_tool.LIGHTRAG_DIR")
def test_empty_index_message(self, mock_dir, mock_ollama):
mock_dir.exists.return_value = True
mock_dir.iterdir.return_value = iter([])
result = lightrag_query("test", mode="hybrid")
data = _parse_result(result)
assert data.get("status") == "empty"
@patch("tools.lightrag_tool._ollama_available", return_value=True)
@patch("tools.lightrag_tool.LIGHTRAG_DIR")
@patch("tools.lightrag_tool._get_lightrag")
@patch("asyncio.run", return_value="Use git clone for repos.")
def test_query_returns_answer(self, mock_asyncio, mock_get_rag, mock_dir, mock_ollama):
mock_dir.exists.return_value = True
mock_dir.iterdir.return_value = iter([Path("dummy")])
mock_rag = MagicMock()
mock_get_rag.return_value = mock_rag
result = lightrag_query("How do I clone a repo?", mode="hybrid")
data = _parse_result(result)
assert data.get("status") == "ok"
assert data.get("mode") == "hybrid"
assert "clone" in data.get("answer", "").lower()
@patch("tools.lightrag_tool._ollama_available", return_value=True)
def test_rejects_invalid_mode(self, mock_ollama):
result = lightrag_query("test", mode="invalid")
assert "mode must be one of" in result
def test_rejects_empty_query(self):
result = lightrag_query("", mode="hybrid")
assert "Query cannot be empty" in result

405
tools/lightrag_tool.py Normal file
View File

@@ -0,0 +1,405 @@
#!/usr/bin/env python3
"""
LightRAG Tool — Graph-based knowledge retrieval for skills and docs.
Indexes markdown files under ~/.hermes/skills/ (and optional extra dirs)
into a LightRAG knowledge graph stored at ~/.hermes/lightrag/.
Requires:
- lightrag-hku (pip install lightrag-hku)
- Ollama running locally with an embedding model (default: nomic-embed-text)
- Ollama running locally with a chat model (default: qwen2.5:7b)
Usage:
lightrag_query("How do I dispatch the burn fleet?", mode="hybrid")
lightrag_index() # re-index skill files
"""
import asyncio
import json
import logging
import os
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
from hermes_constants import get_hermes_home
from tools.registry import registry, tool_error
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
DEFAULT_EMBED_MODEL = os.environ.get("LIGHTRAG_EMBED_MODEL", "nomic-embed-text")
DEFAULT_LLM_MODEL = os.environ.get("LIGHTRAG_LLM_MODEL", "qwen2.5:7b")
DEFAULT_OLLAMA_HOST = os.environ.get("LIGHTRAG_OLLAMA_HOST", "http://localhost:11434")
LIGHTRAG_DIR = get_hermes_home() / "lightrag"
SKILLS_DIR = get_hermes_home() / "skills"
# ---------------------------------------------------------------------------
# Ollama helpers
# ---------------------------------------------------------------------------
def _ollama_available() -> bool:
"""Check if Ollama server is reachable."""
try:
import urllib.request
req = urllib.request.Request(f"{DEFAULT_OLLAMA_HOST}/api/tags")
with urllib.request.urlopen(req, timeout=3) as resp:
return resp.status == 200
except Exception:
return False
def _has_ollama_model(model_name: str) -> bool:
"""Check if a specific model is pulled in Ollama."""
try:
import urllib.request
req = urllib.request.Request(f"{DEFAULT_OLLAMA_HOST}/api/tags")
with urllib.request.urlopen(req, timeout=3) as resp:
data = json.loads(resp.read())
models = [m["name"] for m in data.get("models", [])]
return any(model_name in m for m in models)
except Exception:
return False
async def _ollama_embedding(texts: list, **kwargs) -> np.ndarray:
"""Call Ollama embeddings API."""
import aiohttp
payload = {
"model": DEFAULT_EMBED_MODEL,
"input": texts,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{DEFAULT_OLLAMA_HOST}/api/embed",
json=payload,
timeout=aiohttp.ClientTimeout(total=60),
) as resp:
resp.raise_for_status()
data = await resp.json()
embeddings = data.get("embeddings", [])
if not embeddings:
raise RuntimeError("Ollama returned empty embeddings")
return np.array(embeddings, dtype=np.float32)
async def _ollama_complete(
prompt, system_prompt=None, history_messages=None, **kwargs
) -> str:
"""Call Ollama generate API for LLM completion."""
import aiohttp
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if history_messages:
for msg in history_messages:
role = "user" if msg.get("role") == "user" else "assistant"
messages.append({"role": role, "content": msg.get("content", "")})
messages.append({"role": "user", "content": prompt})
payload = {
"model": DEFAULT_LLM_MODEL,
"messages": messages,
"stream": False,
"options": {"temperature": 0.3, "num_predict": 2048},
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{DEFAULT_OLLAMA_HOST}/api/chat",
json=payload,
timeout=aiohttp.ClientTimeout(total=120),
) as resp:
resp.raise_for_status()
data = await resp.json()
return data.get("message", {}).get("content", "")
# ---------------------------------------------------------------------------
# LightRAG setup
# ---------------------------------------------------------------------------
_lightrag_instance: Optional[object] = None
def _get_lightrag() -> object:
"""Lazy-initialize LightRAG with Ollama backends."""
global _lightrag_instance
if _lightrag_instance is not None:
return _lightrag_instance
try:
from lightrag import LightRAG, QueryParam
from lightrag.utils import EmbeddingFunc
except ImportError as e:
raise RuntimeError(
"lightrag is not installed. Run: pip install lightrag-hku"
) from e
LIGHTRAG_DIR.mkdir(parents=True, exist_ok=True)
# Wrap Ollama embedding for LightRAG
embed_func = EmbeddingFunc(
embedding_dim=768, # nomic-embed-text dimension
func=_ollama_embedding,
max_token_size=8192,
model_name=DEFAULT_EMBED_MODEL,
)
_lightrag_instance = LightRAG(
working_dir=str(LIGHTRAG_DIR),
embedding_func=embed_func,
llm_model_func=_ollama_complete,
llm_model_name=DEFAULT_LLM_MODEL,
chunk_token_size=1200,
chunk_overlap_token_size=100,
)
return _lightrag_instance
# ---------------------------------------------------------------------------
# Indexing
# ---------------------------------------------------------------------------
def _collect_markdown_files(root: Path) -> List[Path]:
"""Collect all .md files under root, excluding node_modules and .git."""
files = []
if not root.exists():
return files
for path in root.rglob("*.md"):
if any(part.startswith(".") or part == "node_modules" for part in path.parts):
continue
files.append(path)
return sorted(files)
def _read_text_safe(path: Path, limit: int = 500_000) -> str:
"""Read file text with size limit."""
try:
stat = path.stat()
if stat.st_size > limit:
return path.read_text(encoding="utf-8", errors="ignore")[:limit]
return path.read_text(encoding="utf-8", errors="ignore")
except Exception as e:
logger.warning("Failed to read %s: %s", path, e)
return ""
def lightrag_index(directories: Optional[List[str]] = None) -> str:
"""Index markdown files into LightRAG knowledge graph.
Args:
directories: Extra directories to index (in addition to ~/.hermes/skills/).
"""
if not _ollama_available():
return tool_error(
"Ollama is not running. Start it with: ollama serve"
)
if not _has_ollama_model(DEFAULT_EMBED_MODEL):
return tool_error(
f"Embedding model '{DEFAULT_EMBED_MODEL}' not found in Ollama. "
f"Pull it with: ollama pull {DEFAULT_EMBED_MODEL}"
)
if not _has_ollama_model(DEFAULT_LLM_MODEL):
return tool_error(
f"LLM model '{DEFAULT_LLM_MODEL}' not found in Ollama. "
f"Pull it with: ollama pull {DEFAULT_LLM_MODEL}"
)
rag = _get_lightrag()
dirs = [SKILLS_DIR]
if directories:
for d in directories:
p = Path(d).expanduser()
if p.exists():
dirs.append(p)
all_files = []
for d in dirs:
all_files.extend(_collect_markdown_files(d))
if not all_files:
return json.dumps({
"status": "warning",
"message": "No markdown files found to index.",
"directories": [str(d) for d in dirs],
})
# Read and insert files
inserted = 0
errors = 0
for path in all_files:
text = _read_text_safe(path)
if not text.strip():
continue
try:
# LightRAG insert is async; bridge it
asyncio.run(rag.atext(text))
inserted += 1
except Exception as e:
logger.warning("Failed to index %s: %s", path, e)
errors += 1
return json.dumps({
"status": "ok",
"indexed_files": inserted,
"errors": errors,
"total_files": len(all_files),
"storage_dir": str(LIGHTRAG_DIR),
})
# ---------------------------------------------------------------------------
# Query
# ---------------------------------------------------------------------------
def lightrag_query(query: str, mode: str = "hybrid") -> str:
"""Query the LightRAG knowledge graph.
Args:
query: The question or search query.
mode: Search mode — "local" (nearby entities), "global" (graph-wide),
or "hybrid" (both).
"""
if not query or not query.strip():
return tool_error("Query cannot be empty.")
if mode not in {"local", "global", "hybrid"}:
return tool_error("mode must be one of: local, global, hybrid")
if not _ollama_available():
return tool_error(
"Ollama is not running. Start it with: ollama serve"
)
rag = _get_lightrag()
# Check if any data has been indexed
if not LIGHTRAG_DIR.exists() or not any(LIGHTRAG_DIR.iterdir()):
return json.dumps({
"status": "empty",
"message": "LightRAG index is empty. Run lightrag_index() first.",
})
try:
from lightrag import QueryParam
param = QueryParam(mode=mode)
result = asyncio.run(rag.aquery(query, param=param))
return json.dumps({
"status": "ok",
"mode": mode,
"query": query,
"answer": result,
})
except Exception as e:
logger.exception("LightRAG query failed")
return tool_error(f"Query failed: {e}")
# ---------------------------------------------------------------------------
# Tool schemas
# ---------------------------------------------------------------------------
LIGHTRAG_QUERY_SCHEMA = {
"name": "lightrag_query",
"description": (
"Graph-based knowledge retrieval over indexed skills and documentation.\n\n"
"Use this when the user asks about: conventions, workflows, tool usage, "
"project-specific practices, or anything that might be documented in skills.\n\n"
"Modes:\n"
"- local: fast, searches nearby entities in the graph\n"
"- global: thorough, reasons across the entire knowledge graph\n"
"- hybrid: balanced, combines local and global (recommended)\n\n"
"If the index is empty, the tool will report that and you should "
"call lightrag_index() to populate it."
),
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The question or search query.",
},
"mode": {
"type": "string",
"enum": ["local", "global", "hybrid"],
"description": "Search mode. hybrid is recommended.",
},
},
"required": ["query"],
},
}
LIGHTRAG_INDEX_SCHEMA = {
"name": "lightrag_index",
"description": (
"(Re-)build the LightRAG knowledge graph from skill files and docs.\n\n"
"By default indexes ~/.hermes/skills/. Pass extra directories if needed.\n"
"This is a one-time or occasional operation; queries work against the "
"existing index until you re-index."
),
"parameters": {
"type": "object",
"properties": {
"directories": {
"type": "array",
"items": {"type": "string"},
"description": "Optional extra directories to index (in addition to ~/.hermes/skills/).",
},
},
},
}
# ---------------------------------------------------------------------------
# Availability check
# ---------------------------------------------------------------------------
def check_lightrag_requirements() -> bool:
"""Return True if LightRAG and Ollama appear to be available."""
try:
import lightrag # noqa: F401
except ImportError:
return False
return _ollama_available()
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
registry.register(
name="lightrag_query",
toolset="rag",
schema=LIGHTRAG_QUERY_SCHEMA,
handler=lambda args, **kw: lightrag_query(
query=args.get("query", ""),
mode=args.get("mode", "hybrid"),
),
check_fn=check_lightrag_requirements,
emoji="🔎",
)
registry.register(
name="lightrag_index",
toolset="rag",
schema=LIGHTRAG_INDEX_SCHEMA,
handler=lambda args, **kw: lightrag_index(
directories=args.get("directories"),
),
check_fn=check_lightrag_requirements,
emoji="📚",
)

View File

@@ -204,17 +204,6 @@ class ToolCallValidator:
self.consecutive_failures[tool_name] = self.consecutive_failures.get(tool_name, 0) + 1
count = self.consecutive_failures[tool_name]
# Log to persistent metrics
try:
from agent.hallucination_metrics import log_hallucination_event
log_hallucination_event(
tool_name=tool_name,
error_type="unknown_tool",
suggested_name=None,
)
except Exception:
pass # Best-effort metrics logging
if count >= self.failure_threshold:
logger.warning(
f"Poka-yoke circuit breaker triggered for '{tool_name}': "

View File

@@ -167,6 +167,12 @@ TOOLSETS = {
"tools": ["memory"],
"includes": []
},
"rag": {
"description": "Graph-based knowledge retrieval over indexed skills and docs (LightRAG)",
"tools": ["lightrag_query", "lightrag_index"],
"includes": []
},
"session_search": {
"description": "Search and recall past conversations with summarization",