Some checks failed
Test / pytest (pull_request) Failing after 8s
Add scripts/entity_extractor.py — LLM-based named entity recognition from session transcripts, READMEs, and issues. Extracts people, projects, tools, concepts, and repos. Outputs to knowledge/entities.json. Includes: - templates/entity-extraction-prompt.md — extraction prompt - tests/test_entity_extractor.py — unit tests for dedup/merge logic - scripts/test_entity_extractor.py — smoke test (mocked pipeline) Accepts --file, --dir, --session, --batch modes. Deduplicates by name+type, merges with existing entities.json. Designed to yield 100+ entities per batch run. Closes #144
269 lines
9.8 KiB
Python
Executable File
269 lines
9.8 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
entity_extractor.py — Extract named entities from text sources.
|
|
|
|
Extracts: people, projects, tools, concepts, repos from session transcripts,
|
|
README files, issue bodies, or any text input.
|
|
|
|
Output: knowledge/entities.json with deduplicated entity list and occurrence counts.
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
SCRIPT_DIR = Path(__file__).parent.absolute()
|
|
sys.path.insert(0, str(SCRIPT_DIR))
|
|
|
|
from session_reader import read_session, messages_to_text
|
|
|
|
# --- Configuration ---
|
|
DEFAULT_API_BASE = os.environ.get("HARVESTER_API_BASE", "https://api.nousresearch.com/v1")
|
|
DEFAULT_API_KEY = os.environ.get("HARVESTER_API_KEY", "")
|
|
DEFAULT_MODEL = os.environ.get("HARVESTER_MODEL", "xiaomi/mimo-v2-pro")
|
|
KNOWLEDGE_DIR = os.environ.get("HARVESTER_KNOWLEDGE_DIR", "knowledge")
|
|
PROMPT_PATH = os.environ.get("ENTITY_PROMPT_PATH", str(SCRIPT_DIR.parent / "templates" / "entity-extraction-prompt.md"))
|
|
|
|
API_KEY_PATHS = [
|
|
os.path.expanduser("~/.config/nous/key"),
|
|
os.path.expanduser("~/.hermes/keymaxxing/active/minimax.key"),
|
|
os.path.expanduser("~/.config/openrouter/key"),
|
|
]
|
|
|
|
def find_api_key() -> str:
|
|
for path in API_KEY_PATHS:
|
|
if os.path.exists(path):
|
|
with open(path) as f:
|
|
key = f.read().strip()
|
|
if key:
|
|
return key
|
|
return ""
|
|
|
|
def load_prompt() -> str:
|
|
path = Path(PROMPT_PATH)
|
|
if not path.exists():
|
|
print(f"ERROR: Entity extraction prompt not found at {path}", file=sys.stderr)
|
|
sys.exit(1)
|
|
return path.read_text(encoding='utf-8')
|
|
|
|
def call_llm(prompt: str, text: str, api_base: str, api_key: str, model: str) -> Optional[list]:
|
|
"""Call LLM API to extract entities."""
|
|
import urllib.request
|
|
|
|
messages = [
|
|
{"role": "system", "content": prompt},
|
|
{"role": "user", "content": f"Extract entities from this text:\n\n{text}"}
|
|
]
|
|
|
|
payload = json.dumps({
|
|
"model": model,
|
|
"messages": messages,
|
|
"temperature": 0.0,
|
|
"max_tokens": 2048
|
|
}).encode('utf-8')
|
|
|
|
req = urllib.request.Request(
|
|
f"{api_base}/chat/completions",
|
|
data=payload,
|
|
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
|
|
method="POST"
|
|
)
|
|
|
|
try:
|
|
with urllib.request.urlopen(req, timeout=60) as resp:
|
|
result = json.loads(resp.read().decode('utf-8'))
|
|
content = result["choices"][0]["message"]["content"]
|
|
return parse_response(content)
|
|
except Exception as e:
|
|
print(f"ERROR: LLM call failed: {e}", file=sys.stderr)
|
|
return None
|
|
|
|
def parse_response(content: str) -> Optional[list]:
|
|
"""Parse LLM JSON response containing entity array."""
|
|
try:
|
|
data = json.loads(content)
|
|
if isinstance(data, list):
|
|
return data
|
|
if isinstance(data, dict) and 'entities' in data:
|
|
return data['entities']
|
|
except json.JSONDecodeError:
|
|
pass
|
|
import re
|
|
match = re.search(r'```(?:json)?\s*(\[.*?\])\s*```', content, re.DOTALL)
|
|
if match:
|
|
try:
|
|
data = json.loads(match.group(1))
|
|
if isinstance(data, list):
|
|
return data
|
|
except json.JSONDecodeError:
|
|
pass
|
|
print(f"WARNING: Could not parse LLM response as entity list", file=sys.stderr)
|
|
return None
|
|
|
|
def load_existing_entities(knowledge_dir: str) -> dict:
|
|
path = Path(knowledge_dir) / "entities.json"
|
|
if not path.exists():
|
|
return {"version": 1, "last_updated": "", "entities": []}
|
|
try:
|
|
with open(path) as f:
|
|
return json.load(f)
|
|
except (json.JSONDecodeError, IOError) as e:
|
|
print(f"WARNING: Could not load entities: {e}", file=sys.stderr)
|
|
return {"version": 1, "last_updated": "", "entities": []}
|
|
|
|
def entity_key(name: str, etype: str) -> tuple:
|
|
return (name.lower().strip(), etype.lower().strip())
|
|
|
|
def merge_entities(new_entities: list, existing: list) -> list:
|
|
"""Merge new entities into existing list, combining counts and sources."""
|
|
existing_by_key = {}
|
|
for e in existing:
|
|
key = entity_key(e.get('name',''), e.get('type',''))
|
|
existing_by_key[key] = e
|
|
|
|
for e in new_entities:
|
|
key = entity_key(e['name'], e['type'])
|
|
if key in existing_by_key:
|
|
existing_e = existing_by_key[key]
|
|
existing_e['count'] = existing_e.get('count', 1) + 1
|
|
# Merge sources
|
|
old_sources = set(existing_e.get('sources', []))
|
|
new_sources = set(e.get('sources', []))
|
|
existing_e['sources'] = sorted(old_sources | new_sources)
|
|
existing_e['last_seen'] = e.get('last_seen', existing_e.get('last_seen'))
|
|
else:
|
|
e['count'] = e.get('count', 1)
|
|
e.setdefault('sources', [])
|
|
e.setdefault('first_seen', datetime.now(timezone.utc).isoformat())
|
|
existing.append(e)
|
|
|
|
return existing
|
|
|
|
def write_entities(index: dict, knowledge_dir: str):
|
|
kdir = Path(knowledge_dir)
|
|
kdir.mkdir(parents=True, exist_ok=True)
|
|
index['last_updated'] = datetime.now(timezone.utc).isoformat()
|
|
path = kdir / "entities.json"
|
|
with open(path, 'w', encoding='utf-8') as f:
|
|
json.dump(index, f, indent=2, ensure_ascii=False)
|
|
|
|
def read_text_from_source(source: str) -> str:
|
|
"""Read text from a file (plain text, markdown, or session JSONL)."""
|
|
path = Path(source)
|
|
if not path.exists():
|
|
raise FileNotFoundError(source)
|
|
if path.suffix == '.jsonl':
|
|
# Session transcript
|
|
from session_reader import read_session, messages_to_text
|
|
messages = read_session(source)
|
|
return messages_to_text(messages)
|
|
else:
|
|
# Plain text / markdown / issue body
|
|
return path.read_text(encoding='utf-8', errors='replace')
|
|
|
|
def extract_from_text(text: str, api_base: str, api_key: str, model: str, source_name: str = "") -> list:
|
|
prompt = load_prompt()
|
|
raw = call_llm(prompt, text, api_base, api_key, model)
|
|
if raw is None:
|
|
return []
|
|
entities = []
|
|
for e in raw:
|
|
if not isinstance(e, dict):
|
|
continue
|
|
name = e.get('name', '').strip()
|
|
etype = e.get('type', '').strip().lower()
|
|
if not name or not etype:
|
|
continue
|
|
entity = {
|
|
'name': name,
|
|
'type': etype,
|
|
'context': e.get('context', '')[:200],
|
|
'last_seen': datetime.now(timezone.utc).isoformat(),
|
|
'sources': [source_name] if source_name else []
|
|
}
|
|
entities.append(entity)
|
|
return entities
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Extract named entities from text sources")
|
|
parser.add_argument('--file', help='Single file to process')
|
|
parser.add_argument('--dir', help='Directory of files to process')
|
|
parser.add_argument('--session', help='Single session JSONL file')
|
|
parser.add_argument('--batch', action='store_true', help='Batch process sessions directory')
|
|
parser.add_argument('--sessions-dir', default=os.path.expanduser('~/.hermes/sessions'),
|
|
help='Sessions directory for batch mode')
|
|
parser.add_argument('--output', default='knowledge', help='Knowledge/output directory')
|
|
parser.add_argument('--api-base', default=DEFAULT_API_BASE)
|
|
parser.add_argument('--api-key', default='', help='API key or set HARVESTER_API_KEY')
|
|
parser.add_argument('--model', default=DEFAULT_MODEL)
|
|
parser.add_argument('--dry-run', action='store_true', help='Preview without writing')
|
|
parser.add_argument('--limit', type=int, default=0, help='Max files/sessions in batch mode')
|
|
args = parser.parse_args()
|
|
|
|
api_key = args.api_key or DEFAULT_API_KEY or find_api_key()
|
|
if not api_key:
|
|
print("ERROR: No API key found", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
knowledge_dir = args.output
|
|
if not os.path.isabs(knowledge_dir):
|
|
knowledge_dir = str(SCRIPT_DIR.parent / knowledge_dir)
|
|
|
|
sources = []
|
|
if args.file:
|
|
sources = [args.file]
|
|
elif args.dir:
|
|
files = sorted(Path(args.dir).rglob("*"))
|
|
sources = [str(f) for f in files if f.is_file() and f.suffix in ('.txt','.md','.json','.jsonl','.yaml','.yml')]
|
|
if args.limit > 0:
|
|
sources = sources[:args.limit]
|
|
elif args.session:
|
|
sources = [args.session]
|
|
elif args.batch:
|
|
sess_dir = Path(args.sessions_dir)
|
|
sources = sorted(sess_dir.glob("*.jsonl"), reverse=True)
|
|
if args.limit > 0:
|
|
sources = sources[:args.limit]
|
|
sources = [str(s) for s in sources]
|
|
else:
|
|
parser.print_help()
|
|
sys.exit(1)
|
|
|
|
print(f"Processing {len(sources)} sources...")
|
|
all_entities = []
|
|
for i, src in enumerate(sources, 1):
|
|
print(f"[{i}/{len(sources)}] {Path(src).name}...", end=" ", flush=True)
|
|
try:
|
|
text = read_text_from_source(src)
|
|
entities = extract_from_text(text, args.api_base, api_key, args.model, source_name=Path(src).name)
|
|
all_entities.extend(entities)
|
|
print(f"→ {len(entities)} entities")
|
|
except Exception as e:
|
|
print(f"ERROR: {e}")
|
|
|
|
# Deduplicate across all sources
|
|
print(f"Total raw entities: {len(all_entities)}")
|
|
existing_index = load_existing_entities(knowledge_dir)
|
|
merged = merge_entities(all_entities, existing_index.get('entities', []))
|
|
print(f"Total unique entities after dedup: {len(merged)}")
|
|
|
|
if not args.dry_run:
|
|
new_index = {"version": 1, "last_updated": "", "entities": merged}
|
|
write_entities(new_index, knowledge_dir)
|
|
print(f"Written to {knowledge_dir}/entities.json")
|
|
|
|
stats = {
|
|
"sources_processed": len(sources),
|
|
"raw_entities": len(all_entities),
|
|
"unique_entities": len(merged)
|
|
}
|
|
print(json.dumps(stats, indent=2))
|
|
|
|
if __name__ == '__main__':
|
|
main()
|