Compare commits
1 Commits
step35/144
...
step35/96-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
365ab66e88 |
203
scripts/docstring_generator.py
Normal file
203
scripts/docstring_generator.py
Normal file
@@ -0,0 +1,203 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Docstring Generator — find and add missing docstrings.
|
||||
|
||||
Scans Python files for functions/async functions lacking docstrings.
|
||||
Generates Google-style docstrings from function signature and body.
|
||||
Inserts them in place.
|
||||
|
||||
Usage:
|
||||
python3 docstring_generator.py scripts/ # Fix in place
|
||||
python3 docstring_generator.py --dry-run scripts/ # Preview changes
|
||||
python3 docstring_generator.py --json scripts/ # Machine-readable output
|
||||
python3 docstring_generator.py path/to/file.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
|
||||
# --- Helper: turn snake_case into Title Case phrase ---
|
||||
def name_to_title(name: str) -> str:
|
||||
"""Convert snake_case function name to a Title Case description."""
|
||||
words = name.replace('_', ' ').split()
|
||||
if not words:
|
||||
return ''
|
||||
titled = []
|
||||
for w in words:
|
||||
if len(w) <= 2:
|
||||
titled.append(w.upper())
|
||||
else:
|
||||
titled.append(w[0].upper() + w[1:])
|
||||
return ' '.join(titled)
|
||||
|
||||
|
||||
# --- Helper: extract first meaningful statement from body for summary ---
|
||||
def extract_body_hint(body: list[ast.stmt]) -> Optional[str]:
|
||||
"""Look for an assignment or return that hints at function purpose."""
|
||||
for stmt in body:
|
||||
if isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.Constant):
|
||||
continue # skip existing docstring placeholder
|
||||
# Assignment to a result-like variable?
|
||||
if isinstance(stmt, ast.Assign):
|
||||
for target in stmt.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
var_name = target.id
|
||||
if var_name in ('result', 'msg', 'output', 'retval', 'value', 'response', 'data'):
|
||||
val = ast.unparse(stmt.value).strip()
|
||||
if val:
|
||||
return f"Compute or return {val}"
|
||||
# Return statement
|
||||
if isinstance(stmt, ast.Return) and stmt.value:
|
||||
ret = ast.unparse(stmt.value).strip()
|
||||
if ret:
|
||||
return f"Return {ret}"
|
||||
break
|
||||
return None
|
||||
|
||||
|
||||
# --- Generate a docstring string for a function ---
|
||||
def generate_docstring(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> str:
|
||||
"""Build a Google-style docstring for the given function node."""
|
||||
parts: list[str] = []
|
||||
|
||||
# Summary line
|
||||
summary = name_to_title(func_node.name)
|
||||
body_hint = extract_body_hint(func_node.body)
|
||||
if body_hint:
|
||||
summary = f"{summary}. {body_hint}"
|
||||
parts.append(summary)
|
||||
|
||||
# Args section if there are parameters (excluding self/cls)
|
||||
args = func_node.args.args
|
||||
if args:
|
||||
arg_lines = []
|
||||
for arg in args:
|
||||
if arg.arg in ('self', 'cls'):
|
||||
continue
|
||||
type_ann = ast.unparse(arg.annotation) if arg.annotation else 'Any'
|
||||
arg_lines.append(f"{arg.arg} ({type_ann}): Parameter {arg.arg}")
|
||||
if arg_lines:
|
||||
parts.append("\nArgs:\n " + "\n ".join(arg_lines))
|
||||
|
||||
# Returns section
|
||||
if func_node.returns:
|
||||
ret_type = ast.unparse(func_node.returns)
|
||||
parts.append(f"\nReturns:\n {ret_type}: Return value")
|
||||
elif any(isinstance(s, ast.Return) and s.value is not None for s in ast.walk(func_node)):
|
||||
parts.append("\nReturns:\n Return value")
|
||||
|
||||
return '"""' + '\n'.join(parts) + '\n"""'
|
||||
|
||||
|
||||
# --- Transform source AST ---
|
||||
def process_source(source: str, filename: str) -> Tuple[str, List[str]]:
|
||||
"""Add docstrings to all undocumented functions. Returns (new_source, [func_names])."""
|
||||
try:
|
||||
tree = ast.parse(source)
|
||||
except SyntaxError as e:
|
||||
print(f" WARNING: Could not parse {filename}: {e}", file=sys.stderr)
|
||||
return source, []
|
||||
|
||||
class DocstringInserter(ast.NodeTransformer):
|
||||
def __init__(self):
|
||||
self.modified_funcs: list[str] = []
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
|
||||
return self._process(node)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
|
||||
return self._process(node)
|
||||
|
||||
def _process(self, node):
|
||||
existing_doc = ast.get_docstring(node)
|
||||
if existing_doc is not None:
|
||||
return node
|
||||
docstring_text = generate_docstring(node)
|
||||
doc_node = ast.Expr(value=ast.Constant(value=docstring_text))
|
||||
node.body.insert(0, doc_node)
|
||||
ast.fix_missing_locations(node)
|
||||
self.modified_funcs.append(node.name)
|
||||
return node
|
||||
|
||||
inserter = DocstringInserter()
|
||||
new_tree = inserter.visit(tree)
|
||||
if inserter.modified_funcs:
|
||||
return ast.unparse(new_tree), inserter.modified_funcs
|
||||
return source, []
|
||||
|
||||
|
||||
# --- File discovery ---
|
||||
def iter_python_files(paths: list[str]) -> list[Path]:
|
||||
"""Collect all .py files from provided paths."""
|
||||
files: set[Path] = set()
|
||||
for p in paths:
|
||||
path = Path(p)
|
||||
if not path.exists():
|
||||
print(f"WARNING: Path not found: {p}", file=sys.stderr)
|
||||
continue
|
||||
if path.is_file() and path.suffix == '.py':
|
||||
files.add(path.resolve())
|
||||
elif path.is_dir():
|
||||
for child in path.rglob('*.py'):
|
||||
if '.git' in child.parts or '__pycache__' in child.parts:
|
||||
continue
|
||||
files.add(child.resolve())
|
||||
return sorted(files)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Generate docstrings for functions missing them")
|
||||
parser.add_argument('paths', nargs='+', help='Python files or directories to process')
|
||||
parser.add_argument('--dry-run', action='store_true', help='Show what would change without writing')
|
||||
parser.add_argument('--json', action='store_true', help='Output machine-readable JSON summary')
|
||||
parser.add_argument('-v', '--verbose', action='store_true', help='Print each file processed')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
files = iter_python_files(args.paths)
|
||||
if not files:
|
||||
print("No Python files found to process", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
results = []
|
||||
total_funcs = 0
|
||||
|
||||
for pyfile in files:
|
||||
try:
|
||||
original = pyfile.read_text(encoding='utf-8')
|
||||
except Exception as e:
|
||||
print(f" ERROR reading {pyfile}: {e}", file=sys.stderr)
|
||||
continue
|
||||
|
||||
new_source, modified_funcs = process_source(original, str(pyfile))
|
||||
|
||||
if modified_funcs:
|
||||
total_funcs += len(modified_funcs)
|
||||
rel = os.path.relpath(pyfile)
|
||||
if args.verbose:
|
||||
print(f" {rel}: +{len(modified_funcs)} docstrings")
|
||||
results.append({'file': str(pyfile), 'functions': modified_funcs})
|
||||
if not args.dry_run:
|
||||
pyfile.write_text(new_source, encoding='utf-8')
|
||||
elif args.verbose:
|
||||
print(f" {rel}: no changes")
|
||||
|
||||
if args.json:
|
||||
summary = {'total_files_modified': len(results), 'total_functions': total_funcs, 'files': results}
|
||||
print(json.dumps(summary, indent=2))
|
||||
else:
|
||||
print(f"Generated docstrings for {total_funcs} functions across {len(results)} files")
|
||||
if args.dry_run:
|
||||
print(" (dry run — no files written)")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
@@ -1,268 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
entity_extractor.py — Extract named entities from text sources.
|
||||
|
||||
Extracts: people, projects, tools, concepts, repos from session transcripts,
|
||||
README files, issue bodies, or any text input.
|
||||
|
||||
Output: knowledge/entities.json with deduplicated entity list and occurrence counts.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
SCRIPT_DIR = Path(__file__).parent.absolute()
|
||||
sys.path.insert(0, str(SCRIPT_DIR))
|
||||
|
||||
from session_reader import read_session, messages_to_text
|
||||
|
||||
# --- Configuration ---
|
||||
DEFAULT_API_BASE = os.environ.get("HARVESTER_API_BASE", "https://api.nousresearch.com/v1")
|
||||
DEFAULT_API_KEY = os.environ.get("HARVESTER_API_KEY", "")
|
||||
DEFAULT_MODEL = os.environ.get("HARVESTER_MODEL", "xiaomi/mimo-v2-pro")
|
||||
KNOWLEDGE_DIR = os.environ.get("HARVESTER_KNOWLEDGE_DIR", "knowledge")
|
||||
PROMPT_PATH = os.environ.get("ENTITY_PROMPT_PATH", str(SCRIPT_DIR.parent / "templates" / "entity-extraction-prompt.md"))
|
||||
|
||||
API_KEY_PATHS = [
|
||||
os.path.expanduser("~/.config/nous/key"),
|
||||
os.path.expanduser("~/.hermes/keymaxxing/active/minimax.key"),
|
||||
os.path.expanduser("~/.config/openrouter/key"),
|
||||
]
|
||||
|
||||
def find_api_key() -> str:
|
||||
for path in API_KEY_PATHS:
|
||||
if os.path.exists(path):
|
||||
with open(path) as f:
|
||||
key = f.read().strip()
|
||||
if key:
|
||||
return key
|
||||
return ""
|
||||
|
||||
def load_prompt() -> str:
|
||||
path = Path(PROMPT_PATH)
|
||||
if not path.exists():
|
||||
print(f"ERROR: Entity extraction prompt not found at {path}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
return path.read_text(encoding='utf-8')
|
||||
|
||||
def call_llm(prompt: str, text: str, api_base: str, api_key: str, model: str) -> Optional[list]:
|
||||
"""Call LLM API to extract entities."""
|
||||
import urllib.request
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": f"Extract entities from this text:\n\n{text}"}
|
||||
]
|
||||
|
||||
payload = json.dumps({
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": 0.0,
|
||||
"max_tokens": 2048
|
||||
}).encode('utf-8')
|
||||
|
||||
req = urllib.request.Request(
|
||||
f"{api_base}/chat/completions",
|
||||
data=payload,
|
||||
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
|
||||
method="POST"
|
||||
)
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=60) as resp:
|
||||
result = json.loads(resp.read().decode('utf-8'))
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
return parse_response(content)
|
||||
except Exception as e:
|
||||
print(f"ERROR: LLM call failed: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
def parse_response(content: str) -> Optional[list]:
|
||||
"""Parse LLM JSON response containing entity array."""
|
||||
try:
|
||||
data = json.loads(content)
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
if isinstance(data, dict) and 'entities' in data:
|
||||
return data['entities']
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
import re
|
||||
match = re.search(r'```(?:json)?\s*(\[.*?\])\s*```', content, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
data = json.loads(match.group(1))
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
print(f"WARNING: Could not parse LLM response as entity list", file=sys.stderr)
|
||||
return None
|
||||
|
||||
def load_existing_entities(knowledge_dir: str) -> dict:
|
||||
path = Path(knowledge_dir) / "entities.json"
|
||||
if not path.exists():
|
||||
return {"version": 1, "last_updated": "", "entities": []}
|
||||
try:
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
print(f"WARNING: Could not load entities: {e}", file=sys.stderr)
|
||||
return {"version": 1, "last_updated": "", "entities": []}
|
||||
|
||||
def entity_key(name: str, etype: str) -> tuple:
|
||||
return (name.lower().strip(), etype.lower().strip())
|
||||
|
||||
def merge_entities(new_entities: list, existing: list) -> list:
|
||||
"""Merge new entities into existing list, combining counts and sources."""
|
||||
existing_by_key = {}
|
||||
for e in existing:
|
||||
key = entity_key(e.get('name',''), e.get('type',''))
|
||||
existing_by_key[key] = e
|
||||
|
||||
for e in new_entities:
|
||||
key = entity_key(e['name'], e['type'])
|
||||
if key in existing_by_key:
|
||||
existing_e = existing_by_key[key]
|
||||
existing_e['count'] = existing_e.get('count', 1) + 1
|
||||
# Merge sources
|
||||
old_sources = set(existing_e.get('sources', []))
|
||||
new_sources = set(e.get('sources', []))
|
||||
existing_e['sources'] = sorted(old_sources | new_sources)
|
||||
existing_e['last_seen'] = e.get('last_seen', existing_e.get('last_seen'))
|
||||
else:
|
||||
e['count'] = e.get('count', 1)
|
||||
e.setdefault('sources', [])
|
||||
e.setdefault('first_seen', datetime.now(timezone.utc).isoformat())
|
||||
existing.append(e)
|
||||
|
||||
return existing
|
||||
|
||||
def write_entities(index: dict, knowledge_dir: str):
|
||||
kdir = Path(knowledge_dir)
|
||||
kdir.mkdir(parents=True, exist_ok=True)
|
||||
index['last_updated'] = datetime.now(timezone.utc).isoformat()
|
||||
path = kdir / "entities.json"
|
||||
with open(path, 'w', encoding='utf-8') as f:
|
||||
json.dump(index, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def read_text_from_source(source: str) -> str:
|
||||
"""Read text from a file (plain text, markdown, or session JSONL)."""
|
||||
path = Path(source)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(source)
|
||||
if path.suffix == '.jsonl':
|
||||
# Session transcript
|
||||
from session_reader import read_session, messages_to_text
|
||||
messages = read_session(source)
|
||||
return messages_to_text(messages)
|
||||
else:
|
||||
# Plain text / markdown / issue body
|
||||
return path.read_text(encoding='utf-8', errors='replace')
|
||||
|
||||
def extract_from_text(text: str, api_base: str, api_key: str, model: str, source_name: str = "") -> list:
|
||||
prompt = load_prompt()
|
||||
raw = call_llm(prompt, text, api_base, api_key, model)
|
||||
if raw is None:
|
||||
return []
|
||||
entities = []
|
||||
for e in raw:
|
||||
if not isinstance(e, dict):
|
||||
continue
|
||||
name = e.get('name', '').strip()
|
||||
etype = e.get('type', '').strip().lower()
|
||||
if not name or not etype:
|
||||
continue
|
||||
entity = {
|
||||
'name': name,
|
||||
'type': etype,
|
||||
'context': e.get('context', '')[:200],
|
||||
'last_seen': datetime.now(timezone.utc).isoformat(),
|
||||
'sources': [source_name] if source_name else []
|
||||
}
|
||||
entities.append(entity)
|
||||
return entities
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Extract named entities from text sources")
|
||||
parser.add_argument('--file', help='Single file to process')
|
||||
parser.add_argument('--dir', help='Directory of files to process')
|
||||
parser.add_argument('--session', help='Single session JSONL file')
|
||||
parser.add_argument('--batch', action='store_true', help='Batch process sessions directory')
|
||||
parser.add_argument('--sessions-dir', default=os.path.expanduser('~/.hermes/sessions'),
|
||||
help='Sessions directory for batch mode')
|
||||
parser.add_argument('--output', default='knowledge', help='Knowledge/output directory')
|
||||
parser.add_argument('--api-base', default=DEFAULT_API_BASE)
|
||||
parser.add_argument('--api-key', default='', help='API key or set HARVESTER_API_KEY')
|
||||
parser.add_argument('--model', default=DEFAULT_MODEL)
|
||||
parser.add_argument('--dry-run', action='store_true', help='Preview without writing')
|
||||
parser.add_argument('--limit', type=int, default=0, help='Max files/sessions in batch mode')
|
||||
args = parser.parse_args()
|
||||
|
||||
api_key = args.api_key or DEFAULT_API_KEY or find_api_key()
|
||||
if not api_key:
|
||||
print("ERROR: No API key found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
knowledge_dir = args.output
|
||||
if not os.path.isabs(knowledge_dir):
|
||||
knowledge_dir = str(SCRIPT_DIR.parent / knowledge_dir)
|
||||
|
||||
sources = []
|
||||
if args.file:
|
||||
sources = [args.file]
|
||||
elif args.dir:
|
||||
files = sorted(Path(args.dir).rglob("*"))
|
||||
sources = [str(f) for f in files if f.is_file() and f.suffix in ('.txt','.md','.json','.jsonl','.yaml','.yml')]
|
||||
if args.limit > 0:
|
||||
sources = sources[:args.limit]
|
||||
elif args.session:
|
||||
sources = [args.session]
|
||||
elif args.batch:
|
||||
sess_dir = Path(args.sessions_dir)
|
||||
sources = sorted(sess_dir.glob("*.jsonl"), reverse=True)
|
||||
if args.limit > 0:
|
||||
sources = sources[:args.limit]
|
||||
sources = [str(s) for s in sources]
|
||||
else:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Processing {len(sources)} sources...")
|
||||
all_entities = []
|
||||
for i, src in enumerate(sources, 1):
|
||||
print(f"[{i}/{len(sources)}] {Path(src).name}...", end=" ", flush=True)
|
||||
try:
|
||||
text = read_text_from_source(src)
|
||||
entities = extract_from_text(text, args.api_base, api_key, args.model, source_name=Path(src).name)
|
||||
all_entities.extend(entities)
|
||||
print(f"→ {len(entities)} entities")
|
||||
except Exception as e:
|
||||
print(f"ERROR: {e}")
|
||||
|
||||
# Deduplicate across all sources
|
||||
print(f"Total raw entities: {len(all_entities)}")
|
||||
existing_index = load_existing_entities(knowledge_dir)
|
||||
merged = merge_entities(all_entities, existing_index.get('entities', []))
|
||||
print(f"Total unique entities after dedup: {len(merged)}")
|
||||
|
||||
if not args.dry_run:
|
||||
new_index = {"version": 1, "last_updated": "", "entities": merged}
|
||||
write_entities(new_index, knowledge_dir)
|
||||
print(f"Written to {knowledge_dir}/entities.json")
|
||||
|
||||
stats = {
|
||||
"sources_processed": len(sources),
|
||||
"raw_entities": len(all_entities),
|
||||
"unique_entities": len(merged)
|
||||
}
|
||||
print(json.dumps(stats, indent=2))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,116 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Smoke test for entity_extractor pipeline — verifies:
|
||||
- session/plain text reading
|
||||
- mock LLM entity extraction
|
||||
- deduplication and merging
|
||||
- output file format
|
||||
|
||||
Does NOT call the real LLM.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
SCRIPT_DIR = Path(__file__).parent.absolute()
|
||||
sys.path.insert(0, str(SCRIPT_DIR))
|
||||
|
||||
from session_reader import read_session, messages_to_text
|
||||
import entity_extractor as ee
|
||||
|
||||
def mock_call_llm(prompt: str, text: str, api_base: str, api_key: str, model: str):
|
||||
"""Return a fixed entity list for any input."""
|
||||
return [
|
||||
{"name": "Hermes", "type": "tool", "context": "Hermes agent uses the tools tool."},
|
||||
{"name": "Gitea", "type": "tool", "context": "Gitea is a forge."},
|
||||
{"name": "Timmy_Foundation/hermes-agent", "type": "repo", "context": "Clone the repo at forge..."},
|
||||
]
|
||||
|
||||
def test_read_session_text():
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
||||
f.write('{"role": "user", "content": "Clone repo", "timestamp": "2026-04-13T10:00:00Z"}\n')
|
||||
f.write('{"role": "assistant", "content": "Done", "timestamp": "2026-04-13T10:00:05Z"}\n')
|
||||
path = f.name
|
||||
messages = read_session(path)
|
||||
text = messages_to_text(messages)
|
||||
assert "USER: Clone repo" in text
|
||||
assert "ASSISTANT: Done" in text
|
||||
os.unlink(path)
|
||||
print(" [PASS] session text extraction works")
|
||||
|
||||
def test_entity_deduplication_and_merge():
|
||||
existing = [
|
||||
{"name": "Hermes", "type": "tool", "count": 3, "sources": ["s1.jsonl"]}
|
||||
]
|
||||
new = [
|
||||
{"name": "Hermes", "type": "tool", "sources": ["s2.jsonl"]},
|
||||
{"name": "Gitea", "type": "tool", "sources": ["s2.jsonl"]},
|
||||
]
|
||||
merged = ee.merge_entities(new, existing.copy())
|
||||
# Hermes count becomes 4, sources combined
|
||||
hermes = [e for e in merged if e['name'].lower() == 'hermes'][0]
|
||||
assert hermes['count'] == 4
|
||||
assert set(hermes['sources']) == {'s1.jsonl', 's2.jsonl'}
|
||||
# Gitea new entry
|
||||
gitea = [e for e in merged if e['name'].lower() == 'gitea'][0]
|
||||
assert gitea['count'] == 1
|
||||
print(" [PASS] deduplication & merging works")
|
||||
|
||||
def test_write_and_load_entities():
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
kdir = Path(tmp) / "knowledge"
|
||||
kdir.mkdir()
|
||||
index = {"version": 1, "last_updated": "", "entities": [
|
||||
{"name": "TestTool", "type": "tool", "count": 1, "sources": ["test"]}
|
||||
]}
|
||||
ee.write_entities(index, str(kdir))
|
||||
# load back
|
||||
loaded = ee.load_existing_entities(str(kdir))
|
||||
assert loaded['entities'][0]['name'] == 'TestTool'
|
||||
print(" [PASS] entities persistence works")
|
||||
|
||||
def test_full_pipeline_mocked():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Create two fake session files
|
||||
sess1 = Path(tmpdir) / "s1.jsonl"
|
||||
sess1.write_text('{"role":"user","content":"Use Hermes to clone","timestamp":"..."}\n')
|
||||
sess2 = Path(tmpdir) / "s2.jsonl"
|
||||
sess2.write_text('{"role":"user","content":"Deploy with Gitea","timestamp":"..."}\n')
|
||||
|
||||
knowledge_dir = Path(tmpdir) / "knowledge"
|
||||
knowledge_dir.mkdir()
|
||||
|
||||
# Patch call_llm
|
||||
with patch('entity_extractor.call_llm', side_effect=mock_call_llm):
|
||||
# Simulate processing both sessions via the main logic
|
||||
all_entities = []
|
||||
for src in [str(sess1), str(sess2)]:
|
||||
text = ee.read_text_from_source(src)
|
||||
ents = ee.extract_from_text(text, "http://api", "fake-key", "model", source_name=Path(src).name)
|
||||
all_entities.extend(ents)
|
||||
|
||||
# Merge into empty index
|
||||
merged = ee.merge_entities(all_entities, [])
|
||||
assert len(merged) >= 3, f"Expected >=3 unique entities, got {len(merged)}"
|
||||
|
||||
# Write
|
||||
index = {"version":1, "last_updated":"", "entities": merged}
|
||||
ee.write_entities(index, str(knowledge_dir))
|
||||
|
||||
# Verify file exists
|
||||
out = knowledge_dir / "entities.json"
|
||||
assert out.exists()
|
||||
data = json.loads(out.read_text())
|
||||
assert len(data['entities']) >= 3
|
||||
print(f" [PASS] full pipeline (mocked) produced {len(data['entities'])} entities")
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_read_session_text()
|
||||
test_entity_deduplication_and_merge()
|
||||
test_write_and_load_entities()
|
||||
test_full_pipeline_mocked()
|
||||
print("\nAll smoke tests passed.")
|
||||
@@ -1,42 +0,0 @@
|
||||
# Entity Extraction Prompt
|
||||
|
||||
## System Prompt
|
||||
You are an entity extraction engine. You read text and output ONLY a JSON array of named entities. You do not infer. You extract only what the text explicitly mentions.
|
||||
|
||||
## Task
|
||||
Extract all named entities from the provided text. Categorize each entity into exactly one of these types:
|
||||
- `person` — individual's name (e.g., Alexander, Rockachopa, Allegro)
|
||||
- `project` — software project or component name (e.g., The Nexus, Timmy Home, compounding-intelligence)
|
||||
- `tool` — software tool, command, library, framework (e.g., git, Docker, PyTorch, Hermes)
|
||||
- `concept` — abstract idea, methodology, paradigm (e.g., compounding intelligence, bootstrap, harvester)
|
||||
- `repo` — repository reference in the form `owner/repo` or URL pointing to a repo
|
||||
|
||||
## Rules
|
||||
1. Extract ONLY names that appear explicitly in the text.
|
||||
2. Do NOT infer, assume, or hallucinate.
|
||||
3. Each entity must have: `name` (exact string), `type` (one of the five above), and `context` (short snippet showing usage, 1-2 sentences).
|
||||
4. The same entity mentioned multiple times should appear only ONCE in the output (deduplicate by name+type).
|
||||
5. For `repo` type, match patterns like `owner/repo`, `github.com/owner/repo`, `forge.alexanderwhitestone.com/owner/repo`.
|
||||
6. For `tool` type, include commands (git, pytest), platforms (Linux, macOS), runtimes (Python, Node.js), and CLI utilities.
|
||||
7. For `person` type, look for capitalized full names, or single names used in personal attribution ("asked Alex", "for Alexander").
|
||||
8. For `concept`, include technical terms that represent an idea rather than a concrete thing.
|
||||
|
||||
## Output Format
|
||||
Return ONLY valid JSON, no markdown, no explanation. Array of objects:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"name": "Hermes",
|
||||
"type": "tool",
|
||||
"context": "Hermes agent uses the tools tool to execute commands."
|
||||
},
|
||||
{
|
||||
"name": "Timmy_Foundation/hermes-agent",
|
||||
"type": "repo",
|
||||
"context": "Clone the repo at forge.../Timmy_Foundation/hermes-agent"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
## Text to extract from:
|
||||
{{text}}
|
||||
128
tests/test_docstring_generator.py
Normal file
128
tests/test_docstring_generator.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Tests for docstring_generator module (Issue #96)."""
|
||||
|
||||
import ast
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "scripts"))
|
||||
|
||||
from docstring_generator import (
|
||||
name_to_title,
|
||||
extract_body_hint,
|
||||
generate_docstring,
|
||||
process_source,
|
||||
iter_python_files,
|
||||
)
|
||||
|
||||
|
||||
class TestNameToTitle:
|
||||
def test_snake_to_title(self):
|
||||
assert name_to_title("validate_fact") == "Validate Fact"
|
||||
assert name_to_title("docstring_generator") == "Docstring Generator"
|
||||
assert name_to_title("main") == "Main"
|
||||
assert name_to_title("__init__") == "Init"
|
||||
|
||||
|
||||
class TestExtractBodyHint:
|
||||
def test_assignment_hint(self):
|
||||
body = [ast.parse("result = compute()").body[0]]
|
||||
hint = extract_body_hint(body)
|
||||
assert hint == "Compute or return compute()"
|
||||
|
||||
def test_return_hint(self):
|
||||
body = [ast.parse("return data").body[0]]
|
||||
hint = extract_body_hint(body)
|
||||
assert hint == "Return data"
|
||||
|
||||
def test_no_hint(self):
|
||||
body = [ast.parse("pass").body[0]]
|
||||
assert extract_body_hint(body) is None
|
||||
|
||||
|
||||
class TestGenerateDocstring:
|
||||
def test_simple_function(self):
|
||||
src = "def add(a, b):\n return a + b\n"
|
||||
tree = ast.parse(src)
|
||||
func = tree.body[0]
|
||||
doc = generate_docstring(func)
|
||||
assert 'Add' in doc
|
||||
assert 'a' in doc and 'b' in doc
|
||||
assert 'Args:' in doc
|
||||
assert 'Returns:' in doc
|
||||
|
||||
def test_typed_function(self):
|
||||
src = "def greet(name: str) -> str:\n return f'Hello {name}'\n"
|
||||
tree = ast.parse(src)
|
||||
func = tree.body[0]
|
||||
doc = generate_docstring(func)
|
||||
assert 'name (str)' in doc
|
||||
assert 'str' in doc
|
||||
|
||||
def test_async_function(self):
|
||||
src = "async def fetch():\n pass\n"
|
||||
tree = ast.parse(src)
|
||||
func = tree.body[0]
|
||||
doc = generate_docstring(func)
|
||||
assert 'Fetch' in doc
|
||||
|
||||
def test_self_skipped(self):
|
||||
src = "class C:\n def method(self, x):\n return x\n"
|
||||
tree = ast.parse(src)
|
||||
cls = tree.body[0]
|
||||
method = cls.body[0]
|
||||
doc = generate_docstring(method)
|
||||
# 'self' should not appear in Args section
|
||||
args_start = doc.find('Args:')
|
||||
if args_start >= 0:
|
||||
args_section = doc[args_start:]
|
||||
assert '(self)' not in args_section
|
||||
|
||||
|
||||
class TestProcessSource:
|
||||
def test_adds_docstrings(self):
|
||||
src = "def foo(x):\n return x * 2\n"
|
||||
new_src, funcs = process_source(src, "test.py")
|
||||
assert len(funcs) == 1 and funcs[0] == "foo"
|
||||
assert '"""' in new_src
|
||||
assert 'Foo' in new_src
|
||||
|
||||
def test_preserves_existing_docstrings(self):
|
||||
src = 'def bar():\n """Already documented."""\n return 1\n'
|
||||
new_src, funcs = process_source(src, "test.py")
|
||||
assert len(funcs) == 0
|
||||
assert new_src == src
|
||||
|
||||
def test_multiple_functions(self):
|
||||
src = "def a(): pass\ndef b(): pass\ndef c(): pass\n"
|
||||
new_src, funcs = process_source(src, "test.py")
|
||||
assert len(funcs) == 3
|
||||
assert '"""' in new_src
|
||||
|
||||
def test_dry_run_no_write(self, tmp_path):
|
||||
file = tmp_path / "t.py"
|
||||
file.write_text("def f(): pass\n")
|
||||
original_mtime = file.stat().st_mtime
|
||||
new_src, funcs = process_source(file.read_text(), str(file))
|
||||
assert funcs # detected
|
||||
# When caller handles write, dry-run leaves file unchanged
|
||||
current_mtime = file.stat().st_mtime
|
||||
assert current_mtime == original_mtime
|
||||
|
||||
|
||||
class TestIterPythonFiles:
|
||||
def test_single_file(self, tmp_path):
|
||||
f = tmp_path / "single.py"
|
||||
f.write_text("x = 1")
|
||||
files = iter_python_files([str(f)])
|
||||
assert len(files) == 1
|
||||
assert files[0].name == "single.py"
|
||||
|
||||
def test_directory_recursion(self, tmp_path):
|
||||
(tmp_path / "sub").mkdir()
|
||||
(tmp_path / "sub" / "a.py").write_text("a=1")
|
||||
(tmp_path / "b.py").write_text("b=2")
|
||||
files = iter_python_files([str(tmp_path)])
|
||||
assert len(files) == 2
|
||||
@@ -1,82 +0,0 @@
|
||||
"""
|
||||
Test suite for entity_extractor.py (Issue #144).
|
||||
|
||||
Tests cover:
|
||||
- Text reading from various formats
|
||||
- Entity deduplication logic
|
||||
- Output file structure
|
||||
- Integration: batch processing yields 100+ entities from test_sessions
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# We'll test the pure functions directly; avoid hitting real LLM in unit tests
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "scripts"))
|
||||
|
||||
# The test approach: mock call_llm to return predetermined entities and test
|
||||
# deduplication, merging, and output writing.
|
||||
|
||||
def test_entity_key_normalization():
|
||||
from entity_extractor import entity_key
|
||||
assert entity_key("Hermes", "tool") == entity_key("hermes", "TOOL")
|
||||
assert entity_key("Git", "tool") != entity_key("Git", "project")
|
||||
|
||||
def test_merge_entities_deduplication():
|
||||
from entity_extractor import merge_entities
|
||||
existing = [
|
||||
{"name": "Hermes", "type": "tool", "count": 5, "sources": ["a.jsonl"]}
|
||||
]
|
||||
new = [
|
||||
{"name": "Hermes", "type": "tool", "sources": ["b.jsonl"]},
|
||||
{"name": "Gitea", "type": "tool", "sources": ["b.jsonl"]}
|
||||
]
|
||||
merged = merge_entities(new, existing.copy())
|
||||
# Hermes count should be 5+1=6, sources merged
|
||||
hermes = [e for e in merged if e['name'].lower()=='hermes'][0]
|
||||
assert hermes['count'] == 6
|
||||
assert set(hermes['sources']) == {"a.jsonl", "b.jsonl"}
|
||||
# Gitea added fresh
|
||||
gitea = [e for e in merged if e['name'].lower()=='gitea'][0]
|
||||
assert gitea['count'] == 1
|
||||
|
||||
def test_output_schema():
|
||||
from entity_extractor import write_entities, load_existing_entities
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
kdir = Path(tmp) / "knowledge"
|
||||
kdir.mkdir()
|
||||
index = {"version": 1, "last_updated": "", "entities": [
|
||||
{"name": "Test", "type": "tool", "count": 1, "sources": ["test"]}
|
||||
]}
|
||||
write_entities(index, str(kdir))
|
||||
# Verify file written
|
||||
out = kdir / "entities.json"
|
||||
assert out.exists()
|
||||
data = json.loads(out.read_text())
|
||||
assert "entities" in data
|
||||
assert data["entities"][0]["name"] == "Test"
|
||||
|
||||
def test_batch_yields_many_entities():
|
||||
"""Batch on test_sessions should produce 100+ unique entities with LLM mock."""
|
||||
from entity_extractor import merge_entities, entity_key
|
||||
# Simulate a few sources each returning a diverse entity set
|
||||
mock_sources = [
|
||||
[{"name": "Hermes", "type": "tool", "sources": ["s1"]},
|
||||
{"name": "Gitea", "type": "tool", "sources": ["s1"]},
|
||||
{"name": "Timmy_Foundation/hermes-agent", "type": "repo", "sources": ["s1"]}],
|
||||
[{"name": "Hermes", "type": "tool", "sources": ["s2"]}, # duplicate
|
||||
{"name": "Docker", "type": "tool", "sources": ["s2"]},
|
||||
{"name": "Alexander", "type": "person", "sources": ["s2"]}],
|
||||
]
|
||||
merged = []
|
||||
for batch in mock_sources:
|
||||
merged = merge_entities(batch, merged)
|
||||
# Ensure dedup works across batches
|
||||
names = [e['name'].lower() for e in merged]
|
||||
assert names.count('hermes') == 1
|
||||
assert len(merged) == 4 # Hermes, Gitea, repo, Docker, Alexander
|
||||
|
||||
# The real LLM extraction test would require live API key; skip in CI
|
||||
Reference in New Issue
Block a user