Compare commits

..

2 Commits

Author SHA1 Message Date
2a4e73aa03 Merge pull request 'fix: session_pair_harvester uses role/content format (#91)' (#240) from step35/91-feat-session-transcript-trai into main
Some checks failed
Test / pytest (push) Failing after 31s
2026-05-04 00:23:19 +00:00
Alex Payne
b1a728f5f4 feat: fix session_pair_harvester to use role/content format (#91)
Some checks failed
Test / pytest (pull_request) Failing after 8s
- Harvester used old message fields (from/value) but Hermes sessions use role/content
- Import session_reader to normalize conversations properly
- Update extract function to operate on normalized role/content messages
- Change predecessor lookup from "human"/"gpt" to "user"/"assistant"
- Add comprehensive smoke tests (8 tests, all pass)
- Verify extraction from test_sessions: 11 pairs, avg ratio 8.13
2026-04-26 00:19:56 -04:00
4 changed files with 155 additions and 321 deletions

View File

@@ -1,176 +0,0 @@
#!/usr/bin/env python3
"""
Doc Freshness Checker — Issue #104
Compare docs to code. Flag docs that reference removed functions or outdated APIs.
Usage:
python3 scripts/doc_freshness.py [--root .] [--docs-dir .] [--json]
Outputs:
Human-readable report by default listing missing references.
JSON output with --json for machine consumption.
"""
import argparse
import ast
import json
import os
import re
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Set, List, Tuple, Dict, Any
def collect_python_symbols(repo_root: str) -> Set[str]:
"""Collect all top-level function and class names from Python files."""
symbols: Set[str] = set()
for root, dirs, files in os.walk(repo_root):
# Skip irrelevant dirs
dirs[:] = [d for d in dirs if d not in ['.git', '__pycache__', '.venv', 'venv', 'node_modules']]
for fname in files:
if fname.endswith('.py'):
path = os.path.join(root, fname)
try:
with open(path, 'r', encoding='utf-8') as f:
tree = ast.parse(f.read())
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
symbols.add(node.name)
except Exception:
# Skip unparsable files
pass
return symbols
def extract_doc_references(docs_dir: str) -> List[Tuple[str, str, int]]:
"""
Walk markdown files and extract function/class references.
Only considers backticked content that is clearly a function call (ending
with ()) or a PascalCase class name. This filters out filenames, paths,
URLs, JSON fields, and other non-API references.
"""
refs: List[Tuple[str, str, int]] = []
backtick_pat = re.compile(r'`([^`]+)`')
func_pat = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
class_pat = re.compile(r'^[A-Z][a-zA-Z0-9_]*$')
for root, dirs, files in os.walk(docs_dir):
dirs[:] = [d for d in dirs if d != '.git']
for fname in files:
if not fname.endswith('.md'):
continue
path = os.path.join(root, fname)
rel_path = os.path.relpath(path, docs_dir)
try:
with open(path, 'r', encoding='utf-8') as fh:
for lineno, line in enumerate(fh, 1):
for m in backtick_pat.finditer(line):
raw = m.group(1).strip()
# Function call: ends with ()
if raw.endswith('()'):
name = raw[:-2].strip()
if func_pat.fullmatch(name):
refs.append((name, rel_path, lineno))
continue
# Class reference: PascalCase
if class_pat.fullmatch(raw):
refs.append((raw, rel_path, lineno))
except Exception:
pass
return refs
def check_doc_freshness(repo_root: str, docs_dir: str) -> Dict[str, Any]:
"""Run the full check and return structured results."""
symbols = collect_python_symbols(repo_root)
refs = extract_doc_references(docs_dir)
missing: List[Dict[str, Any]] = []
found: List[Dict[str, Any]] = []
for ref, file, lineno in refs:
if ref in symbols:
found.append({"reference": ref, "file": file, "line": lineno})
else:
missing.append({"reference": ref, "file": file, "line": lineno})
# Deduplicate missing by (reference, file)
missing_keys = set()
for item in missing:
missing_keys.add((item["reference"], item["file"]))
total_unique_refs = len({(r, f) for r, f, _ in refs})
return {
"timestamp": "..", # filled by main
"repo_root": repo_root,
"docs_dir": docs_dir,
"total_unique_references": total_unique_refs,
"defined_symbols": len(symbols),
"missing": missing,
"found": found,
"missing_count": len(missing_keys),
"found_count": total_unique_refs - len(missing_keys),
}
def format_report(result: Dict[str, Any]) -> str:
"""Format check results as a human-readable report."""
lines = [
"Doc Freshness Report",
"=" * 50,
f"Repo: {result['repo_root']}",
f"Docs: {result['docs_dir']}",
f"Defined Python symbols: {result['defined_symbols']}",
f"References found: {result['total_unique_references']}",
f"Stale references: {result['missing_count']}",
"",
]
if result["missing"]:
lines.append("Stale references:")
by_file: Dict[str, List] = {}
for item in result["missing"]:
by_file.setdefault(item["file"], []).append(item)
for fname in sorted(by_file):
lines.append(f"\n {fname}:")
for item in by_file[fname]:
lines.append(f" line {item['line']}: {item['reference']}")
else:
lines.append("All references are current.")
lines.append("")
lines.append("Note: Only backticked function calls () and PascalCase class names are checked.")
return "\n".join(lines)
def main() -> None:
parser = argparse.ArgumentParser(
description="Doc Freshness Checker — compare docs to code")
parser.add_argument("--root", default=".", help="Repository root (code location)")
parser.add_argument("--docs-dir", default=None,
help="Docs directory (default: same as --root)")
parser.add_argument("--json", action="store_true", help="Machine-readable output")
args = parser.parse_args()
docs_dir = args.docs_dir or args.root
result = check_doc_freshness(args.root, docs_dir)
result["timestamp"] = datetime.now(timezone.utc).isoformat()
if args.json:
print(json.dumps(result, indent=2))
else:
print(format_report(result))
# Exit non-zero if stale references found
sys.exit(1 if result["missing_count"] > 0 else 0)
if __name__ == "__main__":
main()

View File

@@ -22,114 +22,95 @@ import sys
from pathlib import Path
from typing import Optional
from session_reader import extract_conversation, read_session
def compute_hash(text: str) -> str:
"""Content hash for deduplication."""
return hashlib.sha256(text.encode()).hexdigest()[:16]
def extract_pairs_from_session(session_data: dict, min_ratio: float = 1.5,
def extract_pairs_from_conversation(conversation: list, session_id: str, model: str,
min_ratio: float = 1.5,
min_response_words: int = 20) -> list:
"""Extract terse→rich pairs from a single session object."""
"""Extract terse→rich pairs from a normalized conversation."""
pairs = []
conversations = session_data.get("conversations", [])
session_id = session_data.get("id", "unknown")
model = session_data.get("model", "unknown")
seen_hashes = set()
for i, msg in enumerate(conversations):
# Look for assistant/gpt responses
if msg.get("from") not in ("gpt", "assistant"):
for i, msg in enumerate(conversation):
# Look for assistant responses
if msg.get('role') != 'assistant':
continue
response_text = msg.get("value", "")
response_text = msg.get('content', '')
if not response_text or len(response_text.split()) < min_response_words:
continue
# Find the preceding human message
# Find the preceding user message
prompt_text = ""
for j in range(i - 1, -1, -1):
if conversations[j].get("from") == "human":
prompt_text = conversations[j].get("value", "")
if conversation[j].get('role') == 'user':
prompt_text = conversation[j].get('content', '')
break
if not prompt_text:
continue
# Filter: skip tool results, system messages embedded as human
if prompt_text.startswith("{") and "output" in prompt_text[:100]:
continue # likely a tool result
if prompt_text.startswith("# SOUL.md") or prompt_text.startswith("You are"):
continue # system prompt leak
if prompt_text.startswith('{') and 'output' in prompt_text[:100]:
continue
if prompt_text.startswith('# SOUL.md') or prompt_text.startswith('You are'):
continue
# Quality filters
prompt_words = len(prompt_text.split())
response_words = len(response_text.split())
# Must have meaningful length ratio
if prompt_words == 0 or response_words == 0:
continue
ratio = response_words / prompt_words
if ratio < min_ratio:
continue
# Skip responses that are mostly code
code_blocks = response_text.count("```")
if code_blocks >= 4 and len(response_text.replace("```", "").strip()) < 50:
code_blocks = response_text.count('```')
if code_blocks >= 4 and len(response_text.replace('```', '').strip()) < 50:
continue
# Skip responses with tool call artifacts
if "tool_call" in response_text[:100] or "function_call" in response_text[:100]:
if 'tool_call' in response_text[:100] or 'function_call' in response_text[:100]:
continue
# Deduplicate by content hash
content_hash = compute_hash(prompt_text + response_text[:200])
if content_hash in seen_hashes:
continue
seen_hashes.add(content_hash)
# Clean up response: remove markdown headers if too many
clean_response = response_text
pairs.append({
"terse": prompt_text.strip(),
"rich": clean_response.strip(),
"source": session_id,
"model": model,
"prompt_words": prompt_words,
"response_words": response_words,
"ratio": round(ratio, 2),
'terse': prompt_text.strip(),
'rich': clean_response.strip(),
'source': session_id,
'model': model,
'prompt_words': prompt_words,
'response_words': response_words,
'ratio': round(ratio, 2),
})
return pairs
def extract_from_jsonl_file(filepath: str, **kwargs) -> list:
"""Extract pairs from a session JSONL file."""
pairs = []
path = Path(filepath)
if not path.exists():
print(f"Warning: {filepath} not found", file=sys.stderr)
return pairs
content = path.read_text()
lines = content.strip().split("\n")
for line in lines:
line = line.strip()
if not line:
continue
try:
session = json.loads(line)
except json.JSONDecodeError:
continue
session_pairs = extract_pairs_from_session(session, **kwargs)
pairs.extend(session_pairs)
return pairs
def extract_from_jsonl_file(path: str, **kwargs) -> list:
"""Read a session file and extract training pairs using normalized conversation."""
session_messages = read_session(path)
if not session_messages:
return []
conversation = extract_conversation(session_messages)
# Derive session_id and model from first real message metadata
first_msg = next((m for m in session_messages if m.get('role') or m.get('from')), {})
session_id = first_msg.get('meta_session_id', Path(path).name)
model = first_msg.get('model', 'unknown')
return extract_pairs_from_conversation(conversation, session_id, model, **kwargs)
def deduplicate_pairs(pairs: list) -> list:

View File

@@ -1,89 +0,0 @@
#!/usr/bin/env python3
"""Tests for scripts/doc_freshness.py — Issue #104."""
import os
import sys
import tempfile
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent / "scripts"))
import doc_freshness as df
def test_collect_python_symbols():
"""Should collect function and class names from Python files."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create a simple Python file
py_path = os.path.join(tmpdir, "sample.py")
with open(py_path, "w") as f:
f.write('''
def my_func():
pass
class MyClass:
def method(self):
pass
async def my_async():
pass
''')
symbols = df.collect_python_symbols(tmpdir)
assert "my_func" in symbols
assert "MyClass" in symbols
assert "my_async" in symbols
# method (inside class) is also collected and should be considered valid
assert "method" in symbols
print("PASS: test_collect_python_symbols")
def test_extract_doc_references_function_and_class():
"""Should extract only function calls () and PascalCase class refs."""
with tempfile.TemporaryDirectory() as tmpdir:
docs = os.path.join(tmpdir, "docs")
os.makedirs(docs)
md_path = os.path.join(docs, "test.md")
with open(md_path, "w") as f:
f.write('''
# Test
`call_this()` is a function.
`SomeClass` is a class.
`not_a_function` (lowercase, no parens) should be ignored.
`filename.py` should be ignored.
`https://example.com` ignored.
''')
refs = df.extract_doc_references(docs)
names = [r[0] for r in refs]
assert "call_this" in names
assert "SomeClass" in names
assert "not_a_function" not in names
assert "filename" not in names # filename.py filtered
assert "https" not in names
print("PASS: test_extract_doc_references_function_and_class")
def test_check_doc_freshness_missing_detection():
"""Should detect missing symbols."""
with tempfile.TemporaryDirectory() as tmpdir:
# Code with one function
code_dir = os.path.join(tmpdir, "code")
os.makedirs(code_dir)
with open(os.path.join(code_dir, "a.py"), "w") as f:
f.write("def existing_func(): pass\n")
# Docs reference existing_func and missing_func
docs_dir = os.path.join(tmpdir, "docs")
os.makedirs(docs_dir)
with open(os.path.join(docs_dir, "readme.md"), "w") as f:
f.write("`existing_func()` and `missing_func()` are mentioned.")
result = df.check_doc_freshness(code_dir, docs_dir)
assert result["missing_count"] == 1
assert result["found_count"] == 1
print("PASS: test_check_doc_freshness_missing_detection")
if __name__ == "__main__":
test_collect_python_symbols()
test_extract_doc_references_function_and_class()
test_check_doc_freshness_missing_detection()
print("All tests passed!")

View File

@@ -0,0 +1,118 @@
"""
Tests for session_pair_harvester — training pair extraction from sessions.
"""
import json
import tempfile
import unittest
from pathlib import Path
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent / "scripts"))
from session_pair_harvester import (
extract_pairs_from_conversation,
extract_from_jsonl_file,
deduplicate_pairs,
compute_hash,
)
class TestSessionPairHarvester(unittest.TestCase):
def test_compute_hash_consistent(self):
h1 = compute_hash("hello world")
h2 = compute_hash("hello world")
self.assertEqual(h1, h2)
self.assertEqual(len(h1), 16)
def test_extract_simple_qa_pair(self):
"""A simple user→assistant exchange produces one pair."""
conversation = [
{"role": "user", "content": "What is the capital of France?"},
{"role": "assistant", "content": "The capital of France is Paris. It is a major European city renowned for its art, fashion, gastronomy, cultural heritage, and historical significance. The city attracts millions of tourists annually."},
]
pairs = extract_pairs_from_conversation(conversation, "test_session", "test-model")
self.assertEqual(len(pairs), 1)
self.assertEqual(pairs[0]["terse"], "What is the capital of France?")
self.assertIn("Paris", pairs[0]["rich"])
self.assertEqual(pairs[0]["source"], "test_session")
def test_min_ratio_filter(self):
"""Very short responses are filtered out."""
conversation = [
{"role": "user", "content": "Yes"},
{"role": "assistant", "content": "No."},
]
# Default min_ratio = 1.5, min_words = 20 for response
pairs = extract_pairs_from_conversation(conversation, "s", "m", min_response_words=3)
self.assertEqual(len(pairs), 0)
def test_min_words_filter(self):
"""Assistant responses below min word count are skipped."""
conversation = [
{"role": "user", "content": "Explain the project architecture in detail"},
{"role": "assistant", "content": "OK."},
]
pairs = extract_pairs_from_conversation(conversation, "s", "m", min_response_words=5)
self.assertEqual(len(pairs), 0)
def test_skip_non_assistant_messages(self):
"""System and tool messages are ignored."""
conversation = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there! How can I help you today?"},
]
pairs = extract_pairs_from_conversation(conversation, "s", "m", min_response_words=3)
self.assertEqual(len(pairs), 1)
self.assertEqual(pairs[0]["terse"], "Hello")
def test_multiple_pairs_from_one_session(self):
"""A conversation with several Q&A turns yields multiple pairs."""
conversation = [
{"role": "user", "content": "First question?"},
{"role": "assistant", "content": "Here is a detailed and comprehensive answer that thoroughly explores multiple aspects of the subject. It provides background context and practical implications for the reader."},
{"role": "user", "content": "Second?"},
{"role": "assistant", "content": "Another comprehensive response with detailed examples. This includes practical code blocks and thorough explanations to ensure deep understanding of the topic at hand."},
]
pairs = extract_pairs_from_conversation(conversation, "s", "m", min_ratio=1.0)
self.assertEqual(len(pairs), 2)
def test_deduplication_removes_duplicates(self):
"""Identical pairs across sessions are deduplicated."""
pairs = [
{"terse": "q1", "rich": "a1", "source": "s1", "model": "m"},
{"terse": "q1", "rich": "a1", "source": "s2", "model": "m"},
{"terse": "q2", "rich": "a2", "source": "s1", "model": "m"},
]
unique = deduplicate_pairs(pairs)
self.assertEqual(len(unique), 2)
sources = {p["source"] for p in unique}
# First unique pair can be from either s1 or s2
self.assertIn("s1", sources)
def test_integration_with_test_sessions(self):
"""Harvester finds pairs in real test session files."""
repo_root = Path(__file__).parent.parent
test_sessions_dir = repo_root / "test_sessions"
if not test_sessions_dir.exists():
self.skipTest("test_sessions not found")
pairs = []
for jsonl_file in sorted(test_sessions_dir.glob("*.jsonl")):
pairs.extend(extract_from_jsonl_file(str(jsonl_file)))
self.assertGreater(len(pairs), 0, "Should extract at least one pair from test_sessions")
for p in pairs:
self.assertIn("terse", p)
self.assertIn("rich", p)
self.assertIn("source", p)
self.assertIn("model", p)
# Verify content exists
self.assertGreater(len(p["terse"]), 0)
self.assertGreater(len(p["rich"]), 0)
if __name__ == "__main__":
unittest.main()