diff --git a/scripts/graph_query.py b/scripts/graph_query.py new file mode 100755 index 0000000..adeff80 --- /dev/null +++ b/scripts/graph_query.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +""" +Graph Query Engine — traverse the knowledge graph. + +Usage: + python3 scripts/graph_query.py neighbors [--knowledge-dir knowledge/] + python3 scripts/graph_query.py path [--max-hops 10] + python3 scripts/graph_query.py subgraph [--depth 2] + python3 scripts/graph_query.py stats # Graph statistics + +Outputs JSON to stdout. +""" + +import argparse +import json +import sys +import time +from pathlib import Path +from collections import defaultdict, deque +from typing import Optional + +# --- Graph building --- + +def load_index(knowledge_dir: Path) -> dict: + index_path = knowledge_dir / "index.json" + if not index_path.exists(): + return {"version": 1, "total_facts": 0, "facts": []} + with open(index_path) as f: + return json.load(f) + +def build_adjacency(facts: list[dict]) -> dict: + """Build undirected adjacency list from fact 'related' fields.""" + adj = defaultdict(set) + id_to_fact = {} + for fact in facts: + fid = fact.get("id") + if not fid: + continue + id_to_fact[fid] = fact + for related_id in fact.get("related", []): + adj[fid].add(related_id) + adj[related_id].add(fid) # undirected + return dict(adj), id_to_fact + +# --- Queries --- + +def query_neighbors(fact_id: str, adj: dict, id_to_fact: dict) -> dict: + """Return directly connected facts.""" + neighbors = list(adj.get(fact_id, set())) + return { + "query": "neighbors", + "fact_id": fact_id, + "neighbors": [ + {"id": nid, "fact": id_to_fact.get(nid, {}).get("fact", ""), "category": id_to_fact.get(nid, {}).get("category", "")} + for nid in neighbors if nid in id_to_fact + ], + "count": len(neighbors), + } + +def query_path(from_id: str, to_id: str, adj: dict, max_hops: int = 10) -> dict: + """Find shortest path between two facts using BFS.""" + if from_id not in adj or to_id not in adj: + return {"query": "path", "from": from_id, "to": to_id, "path": None, "error": "Fact not found in graph"} + + if from_id == to_id: + return {"query": "path", "from": from_id, "to": to_id, "path": [from_id], "length": 0} + + queue = deque([(from_id, [from_id])]) + visited = {from_id} + + while queue: + current, path = queue.popleft() + if len(path) > max_hops: + continue + for neighbor in adj.get(current, []): + if neighbor == to_id: + return {"query": "path", "from": from_id, "to": to_id, "path": path + [to_id], "length": len(path)} + if neighbor not in visited: + visited.add(neighbor) + queue.append((neighbor, path + [neighbor])) + + return {"query": "path", "from": from_id, "to": to_id, "path": None, "error": f"No path found within {max_hops} hops"} + +def query_subgraph(fact_id: str, adj: dict, id_to_fact: dict, depth: int = 2) -> dict: + """Extract connected subgraph within N hops.""" + if fact_id not in adj: + return {"query": "subgraph", "fact_id": fact_id, "nodes": [], "edges": [], "error": "Fact not found"} + + visited = set() + queue = deque([(fact_id, 0)]) + subgraph_nodes = set() + subgraph_edges = [] + + while queue: + node, d = queue.popleft() + if node in visited or d > depth: + continue + visited.add(node) + subgraph_nodes.add(node) + for neighbor in adj.get(node, []): + subgraph_edges.append({"source": node, "target": neighbor}) + if neighbor not in visited: + queue.append((neighbor, d + 1)) + + return { + "query": "subgraph", + "fact_id": fact_id, + "depth": depth, + "nodes": [ + {"id": nid, "fact": id_to_fact.get(nid, {}).get("fact", ""), "category": id_to_fact.get(nid, {}).get("category", "")} + for nid in sorted(subgraph_nodes) + ], + "edges": [{"source": e["source"], "target": e["target"]} for e in subgraph_edges], + "node_count": len(subgraph_nodes), + "edge_count": len(subgraph_edges), + } + +def query_stats(adj: dict, id_to_fact: dict) -> dict: + """Graph statistics.""" + return { + "statistics": { + "total_facts": len(id_to_fact), + "total_edges": sum(len(neighbors) for neighbors in adj.values()) // 2, + "connected_components": 0, # TODO: compute if needed + "average_degree": sum(len(neighbors) for neighbors in adj.values()) / len(adj) if adj else 0, + } + } + +# --- CLI --- + +def main(): + parser = argparse.ArgumentParser(description="Graph query engine for knowledge store") + parser.add_argument("command", choices=["neighbors", "path", "subgraph", "stats"]) + parser.add_argument("from_id", nargs="?", help="Starting fact ID") + parser.add_argument("to_id", nargs="?", help="Target fact ID (for path query)") + parser.add_argument("--knowledge-dir", default="knowledge", help="Knowledge directory") + parser.add_argument("--depth", type=int, default=2, help="Depth for subgraph query") + parser.add_argument("--max-hops", type=int, default=10, help="Max hops for path query") + args = parser.parse_args() + + start = time.time() + knowledge_dir = Path(args.knowledge_dir) + index = load_index(knowledge_dir) + facts = index.get("facts", []) + adj, id_to_fact = build_adjacency(facts) + + result = None + if args.command == "neighbors": + if not args.from_id: + print("ERROR: neighbors requires ", file=sys.stderr) + sys.exit(1) + result = query_neighbors(args.from_id, adj, id_to_fact) + elif args.command == "path": + if not args.from_id or not args.to_id: + print("ERROR: path requires ", file=sys.stderr) + sys.exit(1) + result = query_path(args.from_id, args.to_id, adj, max_hops=args.max_hops) + elif args.command == "subgraph": + if not args.from_id: + print("ERROR: subgraph requires ", file=sys.stderr) + sys.exit(1) + result = query_subgraph(args.from_id, adj, id_to_fact, depth=args.depth) + elif args.command == "stats": + result = query_stats(adj, id_to_fact) + + result["elapsed_ms"] = round((time.time() - start) * 1000, 2) + print(json.dumps(result, indent=2)) + +if __name__ == "__main__": + main() diff --git a/scripts/test_graph_query.py b/scripts/test_graph_query.py new file mode 100755 index 0000000..e0e5e9f --- /dev/null +++ b/scripts/test_graph_query.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +""" +Tests for scripts/graph_query.py — Graph Query Engine. + +""" + +import json +import sys +import tempfile +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from graph_query import load_index, build_adjacency, query_neighbors, query_path, query_subgraph, query_stats + + +def make_index(facts: list[dict], tmp_dir: Path) -> Path: + index = { + "version": 1, + "last_updated": "2026-04-13T20:00:00Z", + "total_facts": len(facts), + "facts": facts, + } + path = tmp_dir / "index.json" + with open(path, "w") as f: + json.dump(index, f) + return path + + +def test_neighbors(): + """Neighbor query returns directly connected facts.""" + facts = [ + {"id": "a", "fact": "A", "category": "fact", "related": ["b", "c"]}, + {"id": "b", "fact": "B", "category": "fact", "related": ["a"]}, + {"id": "c", "fact": "C", "category": "fact", "related": ["a"]}, + {"id": "d", "fact": "D", "category": "fact", "related": []}, + ] + adj, id_to_fact = build_adjacency(facts) + result = query_neighbors("a", adj, id_to_fact) + neighbor_ids = {n["id"] for n in result["neighbors"]} + assert neighbor_ids == {"b", "c"}, f"Expected b,c got {neighbor_ids}" + assert result["count"] == 2 + print("PASS: neighbors") + + +def test_path_found(): + """Path query finds shortest path.""" + facts = [ + {"id": "a", "fact": "A", "related": ["b"]}, + {"id": "b", "fact": "B", "related": ["a", "c"]}, + {"id": "c", "fact": "C", "related": ["b", "d"]}, + {"id": "d", "fact": "D", "related": ["c"]}, + ] + adj, id_to_fact = build_adjacency(facts) + result = query_path("a", "d", adj) + assert result["path"] == ["a", "b", "c", "d"], f"Got path {result['path']}" + assert result["length"] == 3 + print("PASS: path_found") + + +def test_path_not_found(): + """Path query returns error when no path exists.""" + facts = [ + {"id": "a", "fact": "A", "related": ["b"]}, + {"id": "b", "fact": "B", "related": ["a"]}, + {"id": "c", "fact": "C", "related": ["d"]}, + {"id": "d", "fact": "D", "related": ["c"]}, + ] + adj, id_to_fact = build_adjacency(facts) + result = query_path("a", "c", adj, max_hops=5) + assert result["path"] is None + assert "error" in result + print("PASS: path_not_found") + + +def test_subgraph_extraction(): + """Subgraph extraction returns nodes within depth.""" + facts = [ + {"id": "a", "fact": "A", "related": ["b", "c"]}, + {"id": "b", "fact": "B", "related": ["a", "d"]}, + {"id": "c", "fact": "C", "related": ["a"]}, + {"id": "d", "fact": "D", "related": ["b", "e"]}, + {"id": "e", "fact": "E", "related": ["d"]}, + ] + adj, id_to_fact = build_adjacency(facts) + result = query_subgraph("a", adj, id_to_fact, depth=1) + node_ids = {n["id"] for n in result["nodes"]} + assert node_ids == {"a", "b", "c"}, f"Got {node_ids}" + assert result["node_count"] == 3 + print("PASS: subgraph_depth1") + + +def test_subgraph_depth2(): + """Depth-2 subgraph includes further nodes.""" + facts = [ + {"id": "a", "fact": "A", "related": ["b"]}, + {"id": "b", "fact": "B", "related": ["a", "c"]}, + {"id": "c", "fact": "C", "related": ["b", "d"]}, + {"id": "d", "fact": "D", "related": ["c"]}, + ] + adj, id_to_fact = build_adjacency(facts) + result = query_subgraph("a", adj, id_to_fact, depth=2) + node_ids = {n["id"] for n in result["nodes"]} + assert node_ids == {"a", "b", "c"}, f"Got {node_ids}" + print("PASS: subgraph_depth2") + + +def test_stats(): + """Statistics query returns graph metrics.""" + facts = [ + {"id": "a", "fact": "A", "related": ["b"]}, + {"id": "b", "fact": "B", "related": ["a", "c"]}, + {"id": "c", "fact": "C", "related": ["b"]}, + ] + adj, id_to_fact = build_adjacency(facts) + result = query_stats(adj, id_to_fact) + assert result["statistics"]["total_facts"] == 3 + assert result["statistics"]["total_edges"] == 2 # undirected double-counted /2 + assert result["statistics"]["average_degree"] > 0 + print("PASS: stats") + + +def test_cli_integration(): + """CLI produces valid JSON with correct query types.""" + with tempfile.TemporaryDirectory() as tmp: + import subprocess as sp + tmp_dir = Path(tmp) + facts = [ + {"id": "x", "fact": "X", "related": ["y"]}, + {"id": "y", "fact": "Y", "related": ["x", "z"]}, + {"id": "z", "fact": "Z", "related": ["y"]}, + ] + index_path = make_index(facts, tmp_dir) + knowledge_dir = index_path.parent + script_path = Path(__file__).resolve().parent / "graph_query.py" + + result = sp.run( + [sys.executable, str(script_path), "neighbors", "x", "--knowledge-dir", str(knowledge_dir)], + capture_output=True, text=True, cwd=str(tmp_dir) + ) + assert result.returncode == 0, f"neighbors failed: {result.stderr}" + out = json.loads(result.stdout) + assert out["query"] == "neighbors" + assert out["fact_id"] == "x" + assert out["count"] == 1 + + result = sp.run( + [sys.executable, str(script_path), "path", "x", "z", "--knowledge-dir", str(knowledge_dir)], + capture_output=True, text=True, cwd=str(tmp_dir) + ) + assert result.returncode == 0, f"path failed: {result.stderr}" + out = json.loads(result.stdout) + assert out["path"] == ["x", "y", "z"] + + print("PASS: cli_integration") + +if __name__ == "__main__": + test_neighbors() + test_path_found() + test_path_not_found() + test_subgraph_extraction() + test_subgraph_depth2() + test_stats() + test_cli_integration() + print("\nAll graph_query tests passed!")