Compare commits
1 Commits
main
...
step35/150
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
11a4666363 |
170
scripts/graph_query.py
Executable file
170
scripts/graph_query.py
Executable file
@@ -0,0 +1,170 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Graph Query Engine — traverse the knowledge graph.
|
||||
|
||||
Usage:
|
||||
python3 scripts/graph_query.py neighbors <fact_id> [--knowledge-dir knowledge/]
|
||||
python3 scripts/graph_query.py path <from_id> <to_id> [--max-hops 10]
|
||||
python3 scripts/graph_query.py subgraph <fact_id> [--depth 2]
|
||||
python3 scripts/graph_query.py stats # Graph statistics
|
||||
|
||||
Outputs JSON to stdout.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from collections import defaultdict, deque
|
||||
from typing import Optional
|
||||
|
||||
# --- Graph building ---
|
||||
|
||||
def load_index(knowledge_dir: Path) -> dict:
|
||||
index_path = knowledge_dir / "index.json"
|
||||
if not index_path.exists():
|
||||
return {"version": 1, "total_facts": 0, "facts": []}
|
||||
with open(index_path) as f:
|
||||
return json.load(f)
|
||||
|
||||
def build_adjacency(facts: list[dict]) -> dict:
|
||||
"""Build undirected adjacency list from fact 'related' fields."""
|
||||
adj = defaultdict(set)
|
||||
id_to_fact = {}
|
||||
for fact in facts:
|
||||
fid = fact.get("id")
|
||||
if not fid:
|
||||
continue
|
||||
id_to_fact[fid] = fact
|
||||
for related_id in fact.get("related", []):
|
||||
adj[fid].add(related_id)
|
||||
adj[related_id].add(fid) # undirected
|
||||
return dict(adj), id_to_fact
|
||||
|
||||
# --- Queries ---
|
||||
|
||||
def query_neighbors(fact_id: str, adj: dict, id_to_fact: dict) -> dict:
|
||||
"""Return directly connected facts."""
|
||||
neighbors = list(adj.get(fact_id, set()))
|
||||
return {
|
||||
"query": "neighbors",
|
||||
"fact_id": fact_id,
|
||||
"neighbors": [
|
||||
{"id": nid, "fact": id_to_fact.get(nid, {}).get("fact", ""), "category": id_to_fact.get(nid, {}).get("category", "")}
|
||||
for nid in neighbors if nid in id_to_fact
|
||||
],
|
||||
"count": len(neighbors),
|
||||
}
|
||||
|
||||
def query_path(from_id: str, to_id: str, adj: dict, max_hops: int = 10) -> dict:
|
||||
"""Find shortest path between two facts using BFS."""
|
||||
if from_id not in adj or to_id not in adj:
|
||||
return {"query": "path", "from": from_id, "to": to_id, "path": None, "error": "Fact not found in graph"}
|
||||
|
||||
if from_id == to_id:
|
||||
return {"query": "path", "from": from_id, "to": to_id, "path": [from_id], "length": 0}
|
||||
|
||||
queue = deque([(from_id, [from_id])])
|
||||
visited = {from_id}
|
||||
|
||||
while queue:
|
||||
current, path = queue.popleft()
|
||||
if len(path) > max_hops:
|
||||
continue
|
||||
for neighbor in adj.get(current, []):
|
||||
if neighbor == to_id:
|
||||
return {"query": "path", "from": from_id, "to": to_id, "path": path + [to_id], "length": len(path)}
|
||||
if neighbor not in visited:
|
||||
visited.add(neighbor)
|
||||
queue.append((neighbor, path + [neighbor]))
|
||||
|
||||
return {"query": "path", "from": from_id, "to": to_id, "path": None, "error": f"No path found within {max_hops} hops"}
|
||||
|
||||
def query_subgraph(fact_id: str, adj: dict, id_to_fact: dict, depth: int = 2) -> dict:
|
||||
"""Extract connected subgraph within N hops."""
|
||||
if fact_id not in adj:
|
||||
return {"query": "subgraph", "fact_id": fact_id, "nodes": [], "edges": [], "error": "Fact not found"}
|
||||
|
||||
visited = set()
|
||||
queue = deque([(fact_id, 0)])
|
||||
subgraph_nodes = set()
|
||||
subgraph_edges = []
|
||||
|
||||
while queue:
|
||||
node, d = queue.popleft()
|
||||
if node in visited or d > depth:
|
||||
continue
|
||||
visited.add(node)
|
||||
subgraph_nodes.add(node)
|
||||
for neighbor in adj.get(node, []):
|
||||
subgraph_edges.append({"source": node, "target": neighbor})
|
||||
if neighbor not in visited:
|
||||
queue.append((neighbor, d + 1))
|
||||
|
||||
return {
|
||||
"query": "subgraph",
|
||||
"fact_id": fact_id,
|
||||
"depth": depth,
|
||||
"nodes": [
|
||||
{"id": nid, "fact": id_to_fact.get(nid, {}).get("fact", ""), "category": id_to_fact.get(nid, {}).get("category", "")}
|
||||
for nid in sorted(subgraph_nodes)
|
||||
],
|
||||
"edges": [{"source": e["source"], "target": e["target"]} for e in subgraph_edges],
|
||||
"node_count": len(subgraph_nodes),
|
||||
"edge_count": len(subgraph_edges),
|
||||
}
|
||||
|
||||
def query_stats(adj: dict, id_to_fact: dict) -> dict:
|
||||
"""Graph statistics."""
|
||||
return {
|
||||
"statistics": {
|
||||
"total_facts": len(id_to_fact),
|
||||
"total_edges": sum(len(neighbors) for neighbors in adj.values()) // 2,
|
||||
"connected_components": 0, # TODO: compute if needed
|
||||
"average_degree": sum(len(neighbors) for neighbors in adj.values()) / len(adj) if adj else 0,
|
||||
}
|
||||
}
|
||||
|
||||
# --- CLI ---
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Graph query engine for knowledge store")
|
||||
parser.add_argument("command", choices=["neighbors", "path", "subgraph", "stats"])
|
||||
parser.add_argument("from_id", nargs="?", help="Starting fact ID")
|
||||
parser.add_argument("to_id", nargs="?", help="Target fact ID (for path query)")
|
||||
parser.add_argument("--knowledge-dir", default="knowledge", help="Knowledge directory")
|
||||
parser.add_argument("--depth", type=int, default=2, help="Depth for subgraph query")
|
||||
parser.add_argument("--max-hops", type=int, default=10, help="Max hops for path query")
|
||||
args = parser.parse_args()
|
||||
|
||||
start = time.time()
|
||||
knowledge_dir = Path(args.knowledge_dir)
|
||||
index = load_index(knowledge_dir)
|
||||
facts = index.get("facts", [])
|
||||
adj, id_to_fact = build_adjacency(facts)
|
||||
|
||||
result = None
|
||||
if args.command == "neighbors":
|
||||
if not args.from_id:
|
||||
print("ERROR: neighbors requires <fact_id>", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
result = query_neighbors(args.from_id, adj, id_to_fact)
|
||||
elif args.command == "path":
|
||||
if not args.from_id or not args.to_id:
|
||||
print("ERROR: path requires <from_id> <to_id>", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
result = query_path(args.from_id, args.to_id, adj, max_hops=args.max_hops)
|
||||
elif args.command == "subgraph":
|
||||
if not args.from_id:
|
||||
print("ERROR: subgraph requires <fact_id>", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
result = query_subgraph(args.from_id, adj, id_to_fact, depth=args.depth)
|
||||
elif args.command == "stats":
|
||||
result = query_stats(adj, id_to_fact)
|
||||
|
||||
result["elapsed_ms"] = round((time.time() - start) * 1000, 2)
|
||||
print(json.dumps(result, indent=2))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -22,95 +22,114 @@ import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from session_reader import extract_conversation, read_session
|
||||
|
||||
|
||||
def compute_hash(text: str) -> str:
|
||||
"""Content hash for deduplication."""
|
||||
return hashlib.sha256(text.encode()).hexdigest()[:16]
|
||||
|
||||
|
||||
def extract_pairs_from_conversation(conversation: list, session_id: str, model: str,
|
||||
min_ratio: float = 1.5,
|
||||
def extract_pairs_from_session(session_data: dict, min_ratio: float = 1.5,
|
||||
min_response_words: int = 20) -> list:
|
||||
"""Extract terse→rich pairs from a normalized conversation."""
|
||||
"""Extract terse→rich pairs from a single session object."""
|
||||
pairs = []
|
||||
conversations = session_data.get("conversations", [])
|
||||
session_id = session_data.get("id", "unknown")
|
||||
model = session_data.get("model", "unknown")
|
||||
|
||||
seen_hashes = set()
|
||||
|
||||
for i, msg in enumerate(conversation):
|
||||
# Look for assistant responses
|
||||
if msg.get('role') != 'assistant':
|
||||
for i, msg in enumerate(conversations):
|
||||
# Look for assistant/gpt responses
|
||||
if msg.get("from") not in ("gpt", "assistant"):
|
||||
continue
|
||||
|
||||
response_text = msg.get('content', '')
|
||||
response_text = msg.get("value", "")
|
||||
if not response_text or len(response_text.split()) < min_response_words:
|
||||
continue
|
||||
|
||||
# Find the preceding user message
|
||||
# Find the preceding human message
|
||||
prompt_text = ""
|
||||
for j in range(i - 1, -1, -1):
|
||||
if conversation[j].get('role') == 'user':
|
||||
prompt_text = conversation[j].get('content', '')
|
||||
if conversations[j].get("from") == "human":
|
||||
prompt_text = conversations[j].get("value", "")
|
||||
break
|
||||
|
||||
if not prompt_text:
|
||||
continue
|
||||
|
||||
# Filter: skip tool results, system messages embedded as human
|
||||
if prompt_text.startswith('{') and 'output' in prompt_text[:100]:
|
||||
continue
|
||||
if prompt_text.startswith('# SOUL.md') or prompt_text.startswith('You are'):
|
||||
continue
|
||||
if prompt_text.startswith("{") and "output" in prompt_text[:100]:
|
||||
continue # likely a tool result
|
||||
if prompt_text.startswith("# SOUL.md") or prompt_text.startswith("You are"):
|
||||
continue # system prompt leak
|
||||
|
||||
# Quality filters
|
||||
prompt_words = len(prompt_text.split())
|
||||
response_words = len(response_text.split())
|
||||
|
||||
# Must have meaningful length ratio
|
||||
if prompt_words == 0 or response_words == 0:
|
||||
continue
|
||||
ratio = response_words / prompt_words
|
||||
if ratio < min_ratio:
|
||||
continue
|
||||
|
||||
code_blocks = response_text.count('```')
|
||||
if code_blocks >= 4 and len(response_text.replace('```', '').strip()) < 50:
|
||||
# Skip responses that are mostly code
|
||||
code_blocks = response_text.count("```")
|
||||
if code_blocks >= 4 and len(response_text.replace("```", "").strip()) < 50:
|
||||
continue
|
||||
|
||||
if 'tool_call' in response_text[:100] or 'function_call' in response_text[:100]:
|
||||
# Skip responses with tool call artifacts
|
||||
if "tool_call" in response_text[:100] or "function_call" in response_text[:100]:
|
||||
continue
|
||||
|
||||
# Deduplicate by content hash
|
||||
content_hash = compute_hash(prompt_text + response_text[:200])
|
||||
if content_hash in seen_hashes:
|
||||
continue
|
||||
seen_hashes.add(content_hash)
|
||||
|
||||
# Clean up response: remove markdown headers if too many
|
||||
clean_response = response_text
|
||||
|
||||
pairs.append({
|
||||
'terse': prompt_text.strip(),
|
||||
'rich': clean_response.strip(),
|
||||
'source': session_id,
|
||||
'model': model,
|
||||
'prompt_words': prompt_words,
|
||||
'response_words': response_words,
|
||||
'ratio': round(ratio, 2),
|
||||
"terse": prompt_text.strip(),
|
||||
"rich": clean_response.strip(),
|
||||
"source": session_id,
|
||||
"model": model,
|
||||
"prompt_words": prompt_words,
|
||||
"response_words": response_words,
|
||||
"ratio": round(ratio, 2),
|
||||
})
|
||||
|
||||
return pairs
|
||||
|
||||
|
||||
def extract_from_jsonl_file(filepath: str, **kwargs) -> list:
|
||||
"""Extract pairs from a session JSONL file."""
|
||||
pairs = []
|
||||
path = Path(filepath)
|
||||
|
||||
def extract_from_jsonl_file(path: str, **kwargs) -> list:
|
||||
"""Read a session file and extract training pairs using normalized conversation."""
|
||||
session_messages = read_session(path)
|
||||
if not session_messages:
|
||||
return []
|
||||
conversation = extract_conversation(session_messages)
|
||||
# Derive session_id and model from first real message metadata
|
||||
first_msg = next((m for m in session_messages if m.get('role') or m.get('from')), {})
|
||||
session_id = first_msg.get('meta_session_id', Path(path).name)
|
||||
model = first_msg.get('model', 'unknown')
|
||||
return extract_pairs_from_conversation(conversation, session_id, model, **kwargs)
|
||||
if not path.exists():
|
||||
print(f"Warning: {filepath} not found", file=sys.stderr)
|
||||
return pairs
|
||||
|
||||
content = path.read_text()
|
||||
lines = content.strip().split("\n")
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
session = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
session_pairs = extract_pairs_from_session(session, **kwargs)
|
||||
pairs.extend(session_pairs)
|
||||
|
||||
return pairs
|
||||
|
||||
|
||||
def deduplicate_pairs(pairs: list) -> list:
|
||||
|
||||
165
scripts/test_graph_query.py
Executable file
165
scripts/test_graph_query.py
Executable file
@@ -0,0 +1,165 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for scripts/graph_query.py — Graph Query Engine.
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent))
|
||||
|
||||
from graph_query import load_index, build_adjacency, query_neighbors, query_path, query_subgraph, query_stats
|
||||
|
||||
|
||||
def make_index(facts: list[dict], tmp_dir: Path) -> Path:
|
||||
index = {
|
||||
"version": 1,
|
||||
"last_updated": "2026-04-13T20:00:00Z",
|
||||
"total_facts": len(facts),
|
||||
"facts": facts,
|
||||
}
|
||||
path = tmp_dir / "index.json"
|
||||
with open(path, "w") as f:
|
||||
json.dump(index, f)
|
||||
return path
|
||||
|
||||
|
||||
def test_neighbors():
|
||||
"""Neighbor query returns directly connected facts."""
|
||||
facts = [
|
||||
{"id": "a", "fact": "A", "category": "fact", "related": ["b", "c"]},
|
||||
{"id": "b", "fact": "B", "category": "fact", "related": ["a"]},
|
||||
{"id": "c", "fact": "C", "category": "fact", "related": ["a"]},
|
||||
{"id": "d", "fact": "D", "category": "fact", "related": []},
|
||||
]
|
||||
adj, id_to_fact = build_adjacency(facts)
|
||||
result = query_neighbors("a", adj, id_to_fact)
|
||||
neighbor_ids = {n["id"] for n in result["neighbors"]}
|
||||
assert neighbor_ids == {"b", "c"}, f"Expected b,c got {neighbor_ids}"
|
||||
assert result["count"] == 2
|
||||
print("PASS: neighbors")
|
||||
|
||||
|
||||
def test_path_found():
|
||||
"""Path query finds shortest path."""
|
||||
facts = [
|
||||
{"id": "a", "fact": "A", "related": ["b"]},
|
||||
{"id": "b", "fact": "B", "related": ["a", "c"]},
|
||||
{"id": "c", "fact": "C", "related": ["b", "d"]},
|
||||
{"id": "d", "fact": "D", "related": ["c"]},
|
||||
]
|
||||
adj, id_to_fact = build_adjacency(facts)
|
||||
result = query_path("a", "d", adj)
|
||||
assert result["path"] == ["a", "b", "c", "d"], f"Got path {result['path']}"
|
||||
assert result["length"] == 3
|
||||
print("PASS: path_found")
|
||||
|
||||
|
||||
def test_path_not_found():
|
||||
"""Path query returns error when no path exists."""
|
||||
facts = [
|
||||
{"id": "a", "fact": "A", "related": ["b"]},
|
||||
{"id": "b", "fact": "B", "related": ["a"]},
|
||||
{"id": "c", "fact": "C", "related": ["d"]},
|
||||
{"id": "d", "fact": "D", "related": ["c"]},
|
||||
]
|
||||
adj, id_to_fact = build_adjacency(facts)
|
||||
result = query_path("a", "c", adj, max_hops=5)
|
||||
assert result["path"] is None
|
||||
assert "error" in result
|
||||
print("PASS: path_not_found")
|
||||
|
||||
|
||||
def test_subgraph_extraction():
|
||||
"""Subgraph extraction returns nodes within depth."""
|
||||
facts = [
|
||||
{"id": "a", "fact": "A", "related": ["b", "c"]},
|
||||
{"id": "b", "fact": "B", "related": ["a", "d"]},
|
||||
{"id": "c", "fact": "C", "related": ["a"]},
|
||||
{"id": "d", "fact": "D", "related": ["b", "e"]},
|
||||
{"id": "e", "fact": "E", "related": ["d"]},
|
||||
]
|
||||
adj, id_to_fact = build_adjacency(facts)
|
||||
result = query_subgraph("a", adj, id_to_fact, depth=1)
|
||||
node_ids = {n["id"] for n in result["nodes"]}
|
||||
assert node_ids == {"a", "b", "c"}, f"Got {node_ids}"
|
||||
assert result["node_count"] == 3
|
||||
print("PASS: subgraph_depth1")
|
||||
|
||||
|
||||
def test_subgraph_depth2():
|
||||
"""Depth-2 subgraph includes further nodes."""
|
||||
facts = [
|
||||
{"id": "a", "fact": "A", "related": ["b"]},
|
||||
{"id": "b", "fact": "B", "related": ["a", "c"]},
|
||||
{"id": "c", "fact": "C", "related": ["b", "d"]},
|
||||
{"id": "d", "fact": "D", "related": ["c"]},
|
||||
]
|
||||
adj, id_to_fact = build_adjacency(facts)
|
||||
result = query_subgraph("a", adj, id_to_fact, depth=2)
|
||||
node_ids = {n["id"] for n in result["nodes"]}
|
||||
assert node_ids == {"a", "b", "c"}, f"Got {node_ids}"
|
||||
print("PASS: subgraph_depth2")
|
||||
|
||||
|
||||
def test_stats():
|
||||
"""Statistics query returns graph metrics."""
|
||||
facts = [
|
||||
{"id": "a", "fact": "A", "related": ["b"]},
|
||||
{"id": "b", "fact": "B", "related": ["a", "c"]},
|
||||
{"id": "c", "fact": "C", "related": ["b"]},
|
||||
]
|
||||
adj, id_to_fact = build_adjacency(facts)
|
||||
result = query_stats(adj, id_to_fact)
|
||||
assert result["statistics"]["total_facts"] == 3
|
||||
assert result["statistics"]["total_edges"] == 2 # undirected double-counted /2
|
||||
assert result["statistics"]["average_degree"] > 0
|
||||
print("PASS: stats")
|
||||
|
||||
|
||||
def test_cli_integration():
|
||||
"""CLI produces valid JSON with correct query types."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
import subprocess as sp
|
||||
tmp_dir = Path(tmp)
|
||||
facts = [
|
||||
{"id": "x", "fact": "X", "related": ["y"]},
|
||||
{"id": "y", "fact": "Y", "related": ["x", "z"]},
|
||||
{"id": "z", "fact": "Z", "related": ["y"]},
|
||||
]
|
||||
index_path = make_index(facts, tmp_dir)
|
||||
knowledge_dir = index_path.parent
|
||||
script_path = Path(__file__).resolve().parent / "graph_query.py"
|
||||
|
||||
result = sp.run(
|
||||
[sys.executable, str(script_path), "neighbors", "x", "--knowledge-dir", str(knowledge_dir)],
|
||||
capture_output=True, text=True, cwd=str(tmp_dir)
|
||||
)
|
||||
assert result.returncode == 0, f"neighbors failed: {result.stderr}"
|
||||
out = json.loads(result.stdout)
|
||||
assert out["query"] == "neighbors"
|
||||
assert out["fact_id"] == "x"
|
||||
assert out["count"] == 1
|
||||
|
||||
result = sp.run(
|
||||
[sys.executable, str(script_path), "path", "x", "z", "--knowledge-dir", str(knowledge_dir)],
|
||||
capture_output=True, text=True, cwd=str(tmp_dir)
|
||||
)
|
||||
assert result.returncode == 0, f"path failed: {result.stderr}"
|
||||
out = json.loads(result.stdout)
|
||||
assert out["path"] == ["x", "y", "z"]
|
||||
|
||||
print("PASS: cli_integration")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_neighbors()
|
||||
test_path_found()
|
||||
test_path_not_found()
|
||||
test_subgraph_extraction()
|
||||
test_subgraph_depth2()
|
||||
test_stats()
|
||||
test_cli_integration()
|
||||
print("\nAll graph_query tests passed!")
|
||||
@@ -1,118 +0,0 @@
|
||||
"""
|
||||
Tests for session_pair_harvester — training pair extraction from sessions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "scripts"))
|
||||
from session_pair_harvester import (
|
||||
extract_pairs_from_conversation,
|
||||
extract_from_jsonl_file,
|
||||
deduplicate_pairs,
|
||||
compute_hash,
|
||||
)
|
||||
|
||||
|
||||
class TestSessionPairHarvester(unittest.TestCase):
|
||||
def test_compute_hash_consistent(self):
|
||||
h1 = compute_hash("hello world")
|
||||
h2 = compute_hash("hello world")
|
||||
self.assertEqual(h1, h2)
|
||||
self.assertEqual(len(h1), 16)
|
||||
|
||||
def test_extract_simple_qa_pair(self):
|
||||
"""A simple user→assistant exchange produces one pair."""
|
||||
conversation = [
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
{"role": "assistant", "content": "The capital of France is Paris. It is a major European city renowned for its art, fashion, gastronomy, cultural heritage, and historical significance. The city attracts millions of tourists annually."},
|
||||
]
|
||||
pairs = extract_pairs_from_conversation(conversation, "test_session", "test-model")
|
||||
self.assertEqual(len(pairs), 1)
|
||||
self.assertEqual(pairs[0]["terse"], "What is the capital of France?")
|
||||
self.assertIn("Paris", pairs[0]["rich"])
|
||||
self.assertEqual(pairs[0]["source"], "test_session")
|
||||
|
||||
def test_min_ratio_filter(self):
|
||||
"""Very short responses are filtered out."""
|
||||
conversation = [
|
||||
{"role": "user", "content": "Yes"},
|
||||
{"role": "assistant", "content": "No."},
|
||||
]
|
||||
# Default min_ratio = 1.5, min_words = 20 for response
|
||||
pairs = extract_pairs_from_conversation(conversation, "s", "m", min_response_words=3)
|
||||
self.assertEqual(len(pairs), 0)
|
||||
|
||||
def test_min_words_filter(self):
|
||||
"""Assistant responses below min word count are skipped."""
|
||||
conversation = [
|
||||
{"role": "user", "content": "Explain the project architecture in detail"},
|
||||
{"role": "assistant", "content": "OK."},
|
||||
]
|
||||
pairs = extract_pairs_from_conversation(conversation, "s", "m", min_response_words=5)
|
||||
self.assertEqual(len(pairs), 0)
|
||||
|
||||
def test_skip_non_assistant_messages(self):
|
||||
"""System and tool messages are ignored."""
|
||||
conversation = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there! How can I help you today?"},
|
||||
]
|
||||
pairs = extract_pairs_from_conversation(conversation, "s", "m", min_response_words=3)
|
||||
self.assertEqual(len(pairs), 1)
|
||||
self.assertEqual(pairs[0]["terse"], "Hello")
|
||||
|
||||
def test_multiple_pairs_from_one_session(self):
|
||||
"""A conversation with several Q&A turns yields multiple pairs."""
|
||||
conversation = [
|
||||
{"role": "user", "content": "First question?"},
|
||||
{"role": "assistant", "content": "Here is a detailed and comprehensive answer that thoroughly explores multiple aspects of the subject. It provides background context and practical implications for the reader."},
|
||||
{"role": "user", "content": "Second?"},
|
||||
{"role": "assistant", "content": "Another comprehensive response with detailed examples. This includes practical code blocks and thorough explanations to ensure deep understanding of the topic at hand."},
|
||||
]
|
||||
pairs = extract_pairs_from_conversation(conversation, "s", "m", min_ratio=1.0)
|
||||
self.assertEqual(len(pairs), 2)
|
||||
|
||||
def test_deduplication_removes_duplicates(self):
|
||||
"""Identical pairs across sessions are deduplicated."""
|
||||
pairs = [
|
||||
{"terse": "q1", "rich": "a1", "source": "s1", "model": "m"},
|
||||
{"terse": "q1", "rich": "a1", "source": "s2", "model": "m"},
|
||||
{"terse": "q2", "rich": "a2", "source": "s1", "model": "m"},
|
||||
]
|
||||
unique = deduplicate_pairs(pairs)
|
||||
self.assertEqual(len(unique), 2)
|
||||
sources = {p["source"] for p in unique}
|
||||
# First unique pair can be from either s1 or s2
|
||||
self.assertIn("s1", sources)
|
||||
|
||||
def test_integration_with_test_sessions(self):
|
||||
"""Harvester finds pairs in real test session files."""
|
||||
repo_root = Path(__file__).parent.parent
|
||||
test_sessions_dir = repo_root / "test_sessions"
|
||||
if not test_sessions_dir.exists():
|
||||
self.skipTest("test_sessions not found")
|
||||
|
||||
pairs = []
|
||||
for jsonl_file in sorted(test_sessions_dir.glob("*.jsonl")):
|
||||
pairs.extend(extract_from_jsonl_file(str(jsonl_file)))
|
||||
|
||||
self.assertGreater(len(pairs), 0, "Should extract at least one pair from test_sessions")
|
||||
for p in pairs:
|
||||
self.assertIn("terse", p)
|
||||
self.assertIn("rich", p)
|
||||
self.assertIn("source", p)
|
||||
self.assertIn("model", p)
|
||||
# Verify content exists
|
||||
self.assertGreater(len(p["terse"]), 0)
|
||||
self.assertGreater(len(p["rich"]), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user