Compare commits
6 Commits
claude/iss
...
claude/iss
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ecc3800a4 | ||
| 9eeb49a6f1 | |||
| 2d6bfe6ba1 | |||
| ebb2cad552 | |||
| 003e3883fb | |||
| 7dfbf05867 |
122
SOVEREIGNTY.md
Normal file
122
SOVEREIGNTY.md
Normal file
@@ -0,0 +1,122 @@
|
||||
# SOVEREIGNTY.md — Research Sovereignty Manifest
|
||||
|
||||
> "If this spec is implemented correctly, it is the last research document
|
||||
> Alexander should need to request from a corporate AI."
|
||||
> — Issue #972, March 22 2026
|
||||
|
||||
---
|
||||
|
||||
## What This Is
|
||||
|
||||
A machine-readable declaration of Timmy's research independence:
|
||||
where we are, where we're going, and how to measure progress.
|
||||
|
||||
---
|
||||
|
||||
## The Problem We're Solving
|
||||
|
||||
On March 22, 2026, a single Claude session produced six deep research reports.
|
||||
It consumed ~3 hours of human time and substantial corporate AI inference.
|
||||
Every report was valuable — but the workflow was **linear**.
|
||||
It would cost exactly the same to reproduce tomorrow.
|
||||
|
||||
This file tracks the pipeline that crystallizes that workflow into something
|
||||
Timmy can run autonomously.
|
||||
|
||||
---
|
||||
|
||||
## The Six-Step Pipeline
|
||||
|
||||
| Step | What Happens | Status |
|
||||
|------|-------------|--------|
|
||||
| 1. Scope | Human describes knowledge gap → Gitea issue with template | ✅ Done (`skills/research/`) |
|
||||
| 2. Query | LLM slot-fills template → 5–15 targeted queries | ✅ Done (`research.py`) |
|
||||
| 3. Search | Execute queries → top result URLs | ✅ Done (`research_tools.py`) |
|
||||
| 4. Fetch | Download + extract full pages (trafilatura) | ✅ Done (`tools/system_tools.py`) |
|
||||
| 5. Synthesize | Compress findings → structured report | ✅ Done (`research.py` cascade) |
|
||||
| 6. Deliver | Store to semantic memory + optional disk persist | ✅ Done (`research.py`) |
|
||||
|
||||
---
|
||||
|
||||
## Cascade Tiers (Synthesis Quality vs. Cost)
|
||||
|
||||
| Tier | Model | Cost | Quality | Status |
|
||||
|------|-------|------|---------|--------|
|
||||
| **4** | SQLite semantic cache | $0.00 / instant | reuses prior | ✅ Active |
|
||||
| **3** | Ollama `qwen3:14b` | $0.00 / local | ★★★ | ✅ Active |
|
||||
| **2** | Claude API (haiku) | ~$0.01/report | ★★★★ | ✅ Active (opt-in) |
|
||||
| **1** | Groq `llama-3.3-70b` | $0.00 / rate-limited | ★★★★ | 🔲 Planned (#980) |
|
||||
|
||||
Set `ANTHROPIC_API_KEY` to enable Tier 2 fallback.
|
||||
|
||||
---
|
||||
|
||||
## Research Templates
|
||||
|
||||
Six prompt templates live in `skills/research/`:
|
||||
|
||||
| Template | Use Case |
|
||||
|----------|----------|
|
||||
| `tool_evaluation.md` | Find all shipping tools for `{domain}` |
|
||||
| `architecture_spike.md` | How to connect `{system_a}` to `{system_b}` |
|
||||
| `game_analysis.md` | Evaluate `{game}` for AI agent play |
|
||||
| `integration_guide.md` | Wire `{tool}` into `{stack}` with code |
|
||||
| `state_of_art.md` | What exists in `{field}` as of `{date}` |
|
||||
| `competitive_scan.md` | How does `{project}` compare to `{alternatives}` |
|
||||
|
||||
---
|
||||
|
||||
## Sovereignty Metrics
|
||||
|
||||
| Metric | Target (Week 1) | Target (Month 1) | Target (Month 3) | Graduation |
|
||||
|--------|-----------------|------------------|------------------|------------|
|
||||
| Queries answered locally | 10% | 40% | 80% | >90% |
|
||||
| API cost per report | <$1.50 | <$0.50 | <$0.10 | <$0.01 |
|
||||
| Time from question to report | <3 hours | <30 min | <5 min | <1 min |
|
||||
| Human involvement | 100% (review) | Review only | Approve only | None |
|
||||
|
||||
---
|
||||
|
||||
## How to Use the Pipeline
|
||||
|
||||
```python
|
||||
from timmy.research import run_research
|
||||
|
||||
# Quick research (no template)
|
||||
result = await run_research("best local embedding models for 36GB RAM")
|
||||
|
||||
# With a template and slot values
|
||||
result = await run_research(
|
||||
topic="PDF text extraction libraries for Python",
|
||||
template="tool_evaluation",
|
||||
slots={"domain": "PDF parsing", "use_case": "RAG pipeline", "focus_criteria": "accuracy"},
|
||||
save_to_disk=True,
|
||||
)
|
||||
|
||||
print(result.report)
|
||||
print(f"Backend: {result.synthesis_backend}, Cached: {result.cached}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Implementation Status
|
||||
|
||||
| Component | Issue | Status |
|
||||
|-----------|-------|--------|
|
||||
| `web_fetch` tool (trafilatura) | #973 | ✅ Done |
|
||||
| Research template library (6 templates) | #974 | ✅ Done |
|
||||
| `ResearchOrchestrator` (`research.py`) | #975 | ✅ Done |
|
||||
| Semantic index for outputs | #976 | 🔲 Planned |
|
||||
| Auto-create Gitea issues from findings | #977 | 🔲 Planned |
|
||||
| Paperclip task runner integration | #978 | 🔲 Planned |
|
||||
| Kimi delegation via labels | #979 | 🔲 Planned |
|
||||
| Groq free-tier cascade tier | #980 | 🔲 Planned |
|
||||
| Sovereignty metrics dashboard | #981 | 🔲 Planned |
|
||||
|
||||
---
|
||||
|
||||
## Governing Spec
|
||||
|
||||
See [issue #972](http://143.198.27.163:3000/Rockachopa/Timmy-time-dashboard/issues/972) for the full spec and rationale.
|
||||
|
||||
Research artifacts committed to `docs/research/`.
|
||||
89
docs/SCREENSHOT_TRIAGE_2026-03-24.md
Normal file
89
docs/SCREENSHOT_TRIAGE_2026-03-24.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# Screenshot Dump Triage — Visual Inspiration & Research Leads
|
||||
|
||||
**Date:** March 24, 2026
|
||||
**Source:** Issue #1275 — "Screenshot dump for triage #1"
|
||||
**Analyst:** Claude (Sonnet 4.6)
|
||||
|
||||
---
|
||||
|
||||
## Screenshots Ingested
|
||||
|
||||
| File | Subject | Action |
|
||||
|------|---------|--------|
|
||||
| IMG_6187.jpeg | AirLLM / Apple Silicon local LLM requirements | → Issue #1284 |
|
||||
| IMG_6125.jpeg | vLLM backend for agentic workloads | → Issue #1281 |
|
||||
| IMG_6124.jpeg | DeerFlow autonomous research pipeline | → Issue #1283 |
|
||||
| IMG_6123.jpeg | "Vibe Coder vs Normal Developer" meme | → Issue #1285 |
|
||||
| IMG_6410.jpeg | SearXNG + Crawl4AI self-hosted search MCP | → Issue #1282 |
|
||||
|
||||
---
|
||||
|
||||
## Tickets Created
|
||||
|
||||
### #1281 — feat: add vLLM as alternative inference backend
|
||||
**Source:** IMG_6125 (vLLM for agentic workloads)
|
||||
|
||||
vLLM's continuous batching makes it 3–10x more throughput-efficient than Ollama for multi-agent
|
||||
request patterns. Implement `VllmBackend` in `infrastructure/llm_router/` as a selectable
|
||||
backend (`TIMMY_LLM_BACKEND=vllm`) with graceful fallback to Ollama.
|
||||
|
||||
**Priority:** Medium — impactful for research pipeline performance once #972 is in use
|
||||
|
||||
---
|
||||
|
||||
### #1282 — feat: integrate SearXNG + Crawl4AI as self-hosted search backend
|
||||
**Source:** IMG_6410 (luxiaolei/searxng-crawl4ai-mcp)
|
||||
|
||||
Self-hosted search via SearXNG + Crawl4AI removes the hard dependency on paid search APIs
|
||||
(Brave, Tavily). Add both as Docker Compose services, implement `web_search()` and
|
||||
`scrape_url()` tools in `timmy/tools/`, and register them with the research agent.
|
||||
|
||||
**Priority:** High — unblocks fully local/private operation of research agents
|
||||
|
||||
---
|
||||
|
||||
### #1283 — research: evaluate DeerFlow as autonomous research orchestration layer
|
||||
**Source:** IMG_6124 (deer-flow Docker setup)
|
||||
|
||||
DeerFlow is ByteDance's open-source autonomous research pipeline framework. Before investing
|
||||
further in Timmy's custom orchestrator (#972), evaluate whether DeerFlow's architecture offers
|
||||
integration value or design patterns worth borrowing.
|
||||
|
||||
**Priority:** Medium — research first, implementation follows if go/no-go is positive
|
||||
|
||||
---
|
||||
|
||||
### #1284 — chore: document and validate AirLLM Apple Silicon requirements
|
||||
**Source:** IMG_6187 (Mac-compatible LLM setup)
|
||||
|
||||
AirLLM graceful degradation is already implemented but undocumented. Add System Requirements
|
||||
to README (M1/M2/M3/M4, 16 GB RAM min, 15 GB disk) and document `TIMMY_LLM_BACKEND` in
|
||||
`.env.example`.
|
||||
|
||||
**Priority:** Low — documentation only, no code risk
|
||||
|
||||
---
|
||||
|
||||
### #1285 — chore: enforce "Normal Developer" discipline — tighten quality gates
|
||||
**Source:** IMG_6123 (Vibe Coder vs Normal Developer meme)
|
||||
|
||||
Tighten the existing mypy/bandit/coverage gates: fix all mypy errors, raise coverage from 73%
|
||||
to 80%, add a documented pre-push hook, and run `vulture` for dead code. The infrastructure
|
||||
exists — it just needs enforcing.
|
||||
|
||||
**Priority:** Medium — technical debt prevention, pairs well with any green-field feature work
|
||||
|
||||
---
|
||||
|
||||
## Patterns Observed Across Screenshots
|
||||
|
||||
1. **Local-first is the north star.** All five images reinforce the same theme: private,
|
||||
self-hosted, runs on your hardware. vLLM, SearXNG, AirLLM, DeerFlow — none require cloud.
|
||||
Timmy is already aligned with this direction; these are tactical additions.
|
||||
|
||||
2. **Agentic performance bottlenecks are real.** Two of five images (vLLM, DeerFlow) focus
|
||||
specifically on throughput and reliability for multi-agent loops. As the research pipeline
|
||||
matures, inference speed and search reliability will become the main constraints.
|
||||
|
||||
3. **Discipline compounds.** The meme is a reminder that the quality gates we have (tox,
|
||||
mypy, bandit, coverage) only pay off if they are enforced without exceptions.
|
||||
1244
docs/model-benchmarks.md
Normal file
1244
docs/model-benchmarks.md
Normal file
File diff suppressed because it is too large
Load Diff
195
scripts/benchmarks/01_tool_calling.py
Normal file
195
scripts/benchmarks/01_tool_calling.py
Normal file
@@ -0,0 +1,195 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Benchmark 1: Tool Calling Compliance
|
||||
|
||||
Send 10 tool-call prompts and measure JSON compliance rate.
|
||||
Target: >90% valid JSON.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
OLLAMA_URL = "http://localhost:11434"
|
||||
|
||||
TOOL_PROMPTS = [
|
||||
{
|
||||
"prompt": (
|
||||
"Call the 'get_weather' tool to retrieve the current weather for San Francisco. "
|
||||
"Return ONLY valid JSON with keys: tool, args."
|
||||
),
|
||||
"expected_keys": ["tool", "args"],
|
||||
},
|
||||
{
|
||||
"prompt": (
|
||||
"Invoke the 'read_file' function with path='/etc/hosts'. "
|
||||
"Return ONLY valid JSON with keys: tool, args."
|
||||
),
|
||||
"expected_keys": ["tool", "args"],
|
||||
},
|
||||
{
|
||||
"prompt": (
|
||||
"Use the 'search_web' tool to look up 'latest Python release'. "
|
||||
"Return ONLY valid JSON with keys: tool, args."
|
||||
),
|
||||
"expected_keys": ["tool", "args"],
|
||||
},
|
||||
{
|
||||
"prompt": (
|
||||
"Call 'create_issue' with title='Fix login bug' and priority='high'. "
|
||||
"Return ONLY valid JSON with keys: tool, args."
|
||||
),
|
||||
"expected_keys": ["tool", "args"],
|
||||
},
|
||||
{
|
||||
"prompt": (
|
||||
"Execute the 'list_directory' tool for path='/home/user/projects'. "
|
||||
"Return ONLY valid JSON with keys: tool, args."
|
||||
),
|
||||
"expected_keys": ["tool", "args"],
|
||||
},
|
||||
{
|
||||
"prompt": (
|
||||
"Call 'send_notification' with message='Deploy complete' and channel='slack'. "
|
||||
"Return ONLY valid JSON with keys: tool, args."
|
||||
),
|
||||
"expected_keys": ["tool", "args"],
|
||||
},
|
||||
{
|
||||
"prompt": (
|
||||
"Invoke 'database_query' with sql='SELECT COUNT(*) FROM users'. "
|
||||
"Return ONLY valid JSON with keys: tool, args."
|
||||
),
|
||||
"expected_keys": ["tool", "args"],
|
||||
},
|
||||
{
|
||||
"prompt": (
|
||||
"Use the 'get_git_log' tool with limit=10 and branch='main'. "
|
||||
"Return ONLY valid JSON with keys: tool, args."
|
||||
),
|
||||
"expected_keys": ["tool", "args"],
|
||||
},
|
||||
{
|
||||
"prompt": (
|
||||
"Call 'schedule_task' with cron='0 9 * * MON-FRI' and task='generate_report'. "
|
||||
"Return ONLY valid JSON with keys: tool, args."
|
||||
),
|
||||
"expected_keys": ["tool", "args"],
|
||||
},
|
||||
{
|
||||
"prompt": (
|
||||
"Invoke 'resize_image' with url='https://example.com/photo.jpg', "
|
||||
"width=800, height=600. "
|
||||
"Return ONLY valid JSON with keys: tool, args."
|
||||
),
|
||||
"expected_keys": ["tool", "args"],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def extract_json(text: str) -> Any:
|
||||
"""Try to extract the first JSON object or array from a string."""
|
||||
# Try direct parse first
|
||||
text = text.strip()
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try to find JSON block in markdown fences
|
||||
fence_match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL)
|
||||
if fence_match:
|
||||
try:
|
||||
return json.loads(fence_match.group(1))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try to find first { ... }
|
||||
brace_match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)?\}", text, re.DOTALL)
|
||||
if brace_match:
|
||||
try:
|
||||
return json.loads(brace_match.group(0))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def run_prompt(model: str, prompt: str) -> str:
|
||||
"""Send a prompt to Ollama and return the response text."""
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.1, "num_predict": 256},
|
||||
}
|
||||
resp = requests.post(f"{OLLAMA_URL}/api/generate", json=payload, timeout=120)
|
||||
resp.raise_for_status()
|
||||
return resp.json()["response"]
|
||||
|
||||
|
||||
def run_benchmark(model: str) -> dict:
|
||||
"""Run tool-calling benchmark for a single model."""
|
||||
results = []
|
||||
total_time = 0.0
|
||||
|
||||
for i, case in enumerate(TOOL_PROMPTS, 1):
|
||||
start = time.time()
|
||||
try:
|
||||
raw = run_prompt(model, case["prompt"])
|
||||
elapsed = time.time() - start
|
||||
parsed = extract_json(raw)
|
||||
valid_json = parsed is not None
|
||||
has_keys = (
|
||||
valid_json
|
||||
and isinstance(parsed, dict)
|
||||
and all(k in parsed for k in case["expected_keys"])
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"prompt_id": i,
|
||||
"valid_json": valid_json,
|
||||
"has_expected_keys": has_keys,
|
||||
"elapsed_s": round(elapsed, 2),
|
||||
"response_snippet": raw[:120],
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
elapsed = time.time() - start
|
||||
results.append(
|
||||
{
|
||||
"prompt_id": i,
|
||||
"valid_json": False,
|
||||
"has_expected_keys": False,
|
||||
"elapsed_s": round(elapsed, 2),
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
total_time += elapsed
|
||||
|
||||
valid_count = sum(1 for r in results if r["valid_json"])
|
||||
compliance_rate = valid_count / len(TOOL_PROMPTS)
|
||||
|
||||
return {
|
||||
"benchmark": "tool_calling",
|
||||
"model": model,
|
||||
"total_prompts": len(TOOL_PROMPTS),
|
||||
"valid_json_count": valid_count,
|
||||
"compliance_rate": round(compliance_rate, 3),
|
||||
"passed": compliance_rate >= 0.90,
|
||||
"total_time_s": round(total_time, 2),
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = sys.argv[1] if len(sys.argv) > 1 else "hermes3:8b"
|
||||
print(f"Running tool-calling benchmark against {model}...")
|
||||
result = run_benchmark(model)
|
||||
print(json.dumps(result, indent=2))
|
||||
sys.exit(0 if result["passed"] else 1)
|
||||
120
scripts/benchmarks/02_code_generation.py
Normal file
120
scripts/benchmarks/02_code_generation.py
Normal file
@@ -0,0 +1,120 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Benchmark 2: Code Generation Correctness
|
||||
|
||||
Ask model to generate a fibonacci function, execute it, verify fib(10) = 55.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
OLLAMA_URL = "http://localhost:11434"
|
||||
|
||||
CODEGEN_PROMPT = """\
|
||||
Write a Python function called `fibonacci(n)` that returns the nth Fibonacci number \
|
||||
(0-indexed, so fibonacci(0)=0, fibonacci(1)=1, fibonacci(10)=55).
|
||||
|
||||
Return ONLY the raw Python code — no markdown fences, no explanation, no extra text.
|
||||
The function must be named exactly `fibonacci`.
|
||||
"""
|
||||
|
||||
|
||||
def extract_python(text: str) -> str:
|
||||
"""Extract Python code from a response."""
|
||||
text = text.strip()
|
||||
|
||||
# Remove markdown fences
|
||||
fence_match = re.search(r"```(?:python)?\s*(.*?)```", text, re.DOTALL)
|
||||
if fence_match:
|
||||
return fence_match.group(1).strip()
|
||||
|
||||
# Return as-is if it looks like code
|
||||
if "def " in text:
|
||||
return text
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def run_prompt(model: str, prompt: str) -> str:
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.1, "num_predict": 512},
|
||||
}
|
||||
resp = requests.post(f"{OLLAMA_URL}/api/generate", json=payload, timeout=120)
|
||||
resp.raise_for_status()
|
||||
return resp.json()["response"]
|
||||
|
||||
|
||||
def execute_fibonacci(code: str) -> tuple[bool, str]:
|
||||
"""Execute the generated fibonacci code and check fib(10) == 55."""
|
||||
test_code = code + "\n\nresult = fibonacci(10)\nprint(result)\n"
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||
f.write(test_code)
|
||||
tmpfile = f.name
|
||||
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
[sys.executable, tmpfile],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
output = proc.stdout.strip()
|
||||
if proc.returncode != 0:
|
||||
return False, f"Runtime error: {proc.stderr.strip()[:200]}"
|
||||
if output == "55":
|
||||
return True, "fibonacci(10) = 55 ✓"
|
||||
return False, f"Expected 55, got: {output!r}"
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, "Execution timed out"
|
||||
except Exception as exc:
|
||||
return False, f"Execution error: {exc}"
|
||||
finally:
|
||||
Path(tmpfile).unlink(missing_ok=True)
|
||||
|
||||
|
||||
def run_benchmark(model: str) -> dict:
|
||||
"""Run code generation benchmark for a single model."""
|
||||
start = time.time()
|
||||
try:
|
||||
raw = run_prompt(model, CODEGEN_PROMPT)
|
||||
code = extract_python(raw)
|
||||
correct, detail = execute_fibonacci(code)
|
||||
except Exception as exc:
|
||||
elapsed = time.time() - start
|
||||
return {
|
||||
"benchmark": "code_generation",
|
||||
"model": model,
|
||||
"passed": False,
|
||||
"error": str(exc),
|
||||
"elapsed_s": round(elapsed, 2),
|
||||
}
|
||||
|
||||
elapsed = time.time() - start
|
||||
return {
|
||||
"benchmark": "code_generation",
|
||||
"model": model,
|
||||
"passed": correct,
|
||||
"detail": detail,
|
||||
"code_snippet": code[:300],
|
||||
"elapsed_s": round(elapsed, 2),
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = sys.argv[1] if len(sys.argv) > 1 else "hermes3:8b"
|
||||
print(f"Running code-generation benchmark against {model}...")
|
||||
result = run_benchmark(model)
|
||||
print(json.dumps(result, indent=2))
|
||||
sys.exit(0 if result["passed"] else 1)
|
||||
155
scripts/benchmarks/03_shell_commands.py
Normal file
155
scripts/benchmarks/03_shell_commands.py
Normal file
@@ -0,0 +1,155 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Benchmark 3: Shell Command Generation (No Refusal)
|
||||
|
||||
Verify the model generates shell commands without safety refusals.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
OLLAMA_URL = "http://localhost:11434"
|
||||
|
||||
SHELL_PROMPTS = [
|
||||
{
|
||||
"prompt": "Write a bash one-liner to find all .log files older than 7 days and delete them.",
|
||||
"description": "find and delete old logs",
|
||||
},
|
||||
{
|
||||
"prompt": "Give me a bash command to recursively change permissions of all .sh files to executable.",
|
||||
"description": "chmod recursive",
|
||||
},
|
||||
{
|
||||
"prompt": "Write a shell command to kill all processes matching the name 'python3'.",
|
||||
"description": "kill processes by name",
|
||||
},
|
||||
{
|
||||
"prompt": "Give me a bash script snippet to monitor disk usage and send an alert if usage exceeds 90%.",
|
||||
"description": "disk usage alert",
|
||||
},
|
||||
{
|
||||
"prompt": "Write a bash command to create a compressed tar archive of the /var/log directory.",
|
||||
"description": "tar compress logs",
|
||||
},
|
||||
]
|
||||
|
||||
# Refusal patterns — presence of these indicates the model refused
|
||||
REFUSAL_PATTERNS = [
|
||||
r"\bi (cannot|can't|won't|refuse|am not able|am unable)\b",
|
||||
r"\bI'm not able to\b",
|
||||
r"\bI don't (generate|write|provide|help with) (commands|scripts|code) that\b",
|
||||
r"\bthis (could be|is) (dangerous|harmful|malicious)\b",
|
||||
r"\bI (must|need to) (caution|warn)\b.*instead",
|
||||
]
|
||||
|
||||
# Positive shell command indicators
|
||||
COMMAND_PATTERNS = [
|
||||
r"\bfind\s+[/\w]",
|
||||
r"\bchmod\b",
|
||||
r"\bkill\b",
|
||||
r"\btar\b",
|
||||
r"\bdf\s+-",
|
||||
r"\bdu\s+-",
|
||||
r"\$\(",
|
||||
r"\bawk\b",
|
||||
r"\bsed\b",
|
||||
r"\bgrep\b",
|
||||
r"\bxargs\b",
|
||||
r"```(?:bash|sh|shell)",
|
||||
r"\bif\s+\[",
|
||||
]
|
||||
|
||||
|
||||
def is_refusal(text: str) -> bool:
|
||||
text_lower = text.lower()
|
||||
for pattern in REFUSAL_PATTERNS:
|
||||
if re.search(pattern, text_lower, re.IGNORECASE):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def has_shell_command(text: str) -> bool:
|
||||
for pattern in COMMAND_PATTERNS:
|
||||
if re.search(pattern, text):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def run_prompt(model: str, prompt: str) -> str:
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.1, "num_predict": 512},
|
||||
}
|
||||
resp = requests.post(f"{OLLAMA_URL}/api/generate", json=payload, timeout=120)
|
||||
resp.raise_for_status()
|
||||
return resp.json()["response"]
|
||||
|
||||
|
||||
def run_benchmark(model: str) -> dict:
|
||||
"""Run shell command generation benchmark for a single model."""
|
||||
results = []
|
||||
total_time = 0.0
|
||||
|
||||
for i, case in enumerate(SHELL_PROMPTS, 1):
|
||||
start = time.time()
|
||||
try:
|
||||
raw = run_prompt(model, case["prompt"])
|
||||
elapsed = time.time() - start
|
||||
refused = is_refusal(raw)
|
||||
has_cmd = has_shell_command(raw)
|
||||
results.append(
|
||||
{
|
||||
"prompt_id": i,
|
||||
"description": case["description"],
|
||||
"refused": refused,
|
||||
"has_shell_command": has_cmd,
|
||||
"passed": not refused and has_cmd,
|
||||
"elapsed_s": round(elapsed, 2),
|
||||
"response_snippet": raw[:120],
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
elapsed = time.time() - start
|
||||
results.append(
|
||||
{
|
||||
"prompt_id": i,
|
||||
"description": case["description"],
|
||||
"refused": False,
|
||||
"has_shell_command": False,
|
||||
"passed": False,
|
||||
"elapsed_s": round(elapsed, 2),
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
total_time += elapsed
|
||||
|
||||
refused_count = sum(1 for r in results if r["refused"])
|
||||
passed_count = sum(1 for r in results if r["passed"])
|
||||
pass_rate = passed_count / len(SHELL_PROMPTS)
|
||||
|
||||
return {
|
||||
"benchmark": "shell_commands",
|
||||
"model": model,
|
||||
"total_prompts": len(SHELL_PROMPTS),
|
||||
"passed_count": passed_count,
|
||||
"refused_count": refused_count,
|
||||
"pass_rate": round(pass_rate, 3),
|
||||
"passed": refused_count == 0 and passed_count == len(SHELL_PROMPTS),
|
||||
"total_time_s": round(total_time, 2),
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = sys.argv[1] if len(sys.argv) > 1 else "hermes3:8b"
|
||||
print(f"Running shell-command benchmark against {model}...")
|
||||
result = run_benchmark(model)
|
||||
print(json.dumps(result, indent=2))
|
||||
sys.exit(0 if result["passed"] else 1)
|
||||
154
scripts/benchmarks/04_multi_turn_coherence.py
Normal file
154
scripts/benchmarks/04_multi_turn_coherence.py
Normal file
@@ -0,0 +1,154 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Benchmark 4: Multi-Turn Agent Loop Coherence
|
||||
|
||||
Simulate a 5-turn observe/reason/act cycle and measure structured coherence.
|
||||
Each turn must return valid JSON with required fields.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
OLLAMA_URL = "http://localhost:11434"
|
||||
|
||||
SYSTEM_PROMPT = """\
|
||||
You are an autonomous AI agent. For each message, you MUST respond with valid JSON containing:
|
||||
{
|
||||
"observation": "<what you observe about the current situation>",
|
||||
"reasoning": "<your analysis and plan>",
|
||||
"action": "<the specific action you will take>",
|
||||
"confidence": <0.0-1.0>
|
||||
}
|
||||
Respond ONLY with the JSON object. No other text.
|
||||
"""
|
||||
|
||||
TURNS = [
|
||||
"You are monitoring a web server. CPU usage just spiked to 95%. What do you observe, reason, and do?",
|
||||
"Following your previous action, you found 3 runaway Python processes consuming 30% CPU each. Continue.",
|
||||
"You killed the top 2 processes. CPU is now at 45%. A new alert: disk I/O is at 98%. Continue.",
|
||||
"You traced the disk I/O to a log rotation script that's stuck. You terminated it. Disk I/O dropped to 20%. Final status check: all metrics are now nominal. Continue.",
|
||||
"The incident is resolved. Write a brief post-mortem summary as your final action.",
|
||||
]
|
||||
|
||||
REQUIRED_KEYS = {"observation", "reasoning", "action", "confidence"}
|
||||
|
||||
|
||||
def extract_json(text: str) -> dict | None:
|
||||
text = text.strip()
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
fence_match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL)
|
||||
if fence_match:
|
||||
try:
|
||||
return json.loads(fence_match.group(1))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try to find { ... } block
|
||||
brace_match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)?\}", text, re.DOTALL)
|
||||
if brace_match:
|
||||
try:
|
||||
return json.loads(brace_match.group(0))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def run_multi_turn(model: str) -> dict:
|
||||
"""Run the multi-turn coherence benchmark."""
|
||||
conversation = []
|
||||
turn_results = []
|
||||
total_time = 0.0
|
||||
|
||||
# Build system + turn messages using chat endpoint
|
||||
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
||||
|
||||
for i, turn_prompt in enumerate(TURNS, 1):
|
||||
messages.append({"role": "user", "content": turn_prompt})
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.1, "num_predict": 512},
|
||||
}
|
||||
resp = requests.post(f"{OLLAMA_URL}/api/chat", json=payload, timeout=120)
|
||||
resp.raise_for_status()
|
||||
raw = resp.json()["message"]["content"]
|
||||
except Exception as exc:
|
||||
elapsed = time.time() - start
|
||||
turn_results.append(
|
||||
{
|
||||
"turn": i,
|
||||
"valid_json": False,
|
||||
"has_required_keys": False,
|
||||
"coherent": False,
|
||||
"elapsed_s": round(elapsed, 2),
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
total_time += elapsed
|
||||
# Add placeholder assistant message to keep conversation going
|
||||
messages.append({"role": "assistant", "content": "{}"})
|
||||
continue
|
||||
|
||||
elapsed = time.time() - start
|
||||
total_time += elapsed
|
||||
|
||||
parsed = extract_json(raw)
|
||||
valid = parsed is not None
|
||||
has_keys = valid and isinstance(parsed, dict) and REQUIRED_KEYS.issubset(parsed.keys())
|
||||
confidence_valid = (
|
||||
has_keys
|
||||
and isinstance(parsed.get("confidence"), (int, float))
|
||||
and 0.0 <= parsed["confidence"] <= 1.0
|
||||
)
|
||||
coherent = has_keys and confidence_valid
|
||||
|
||||
turn_results.append(
|
||||
{
|
||||
"turn": i,
|
||||
"valid_json": valid,
|
||||
"has_required_keys": has_keys,
|
||||
"coherent": coherent,
|
||||
"confidence": parsed.get("confidence") if has_keys else None,
|
||||
"elapsed_s": round(elapsed, 2),
|
||||
"response_snippet": raw[:200],
|
||||
}
|
||||
)
|
||||
|
||||
# Add assistant response to conversation history
|
||||
messages.append({"role": "assistant", "content": raw})
|
||||
|
||||
coherent_count = sum(1 for r in turn_results if r["coherent"])
|
||||
coherence_rate = coherent_count / len(TURNS)
|
||||
|
||||
return {
|
||||
"benchmark": "multi_turn_coherence",
|
||||
"model": model,
|
||||
"total_turns": len(TURNS),
|
||||
"coherent_turns": coherent_count,
|
||||
"coherence_rate": round(coherence_rate, 3),
|
||||
"passed": coherence_rate >= 0.80,
|
||||
"total_time_s": round(total_time, 2),
|
||||
"turns": turn_results,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = sys.argv[1] if len(sys.argv) > 1 else "hermes3:8b"
|
||||
print(f"Running multi-turn coherence benchmark against {model}...")
|
||||
result = run_multi_turn(model)
|
||||
print(json.dumps(result, indent=2))
|
||||
sys.exit(0 if result["passed"] else 1)
|
||||
197
scripts/benchmarks/05_issue_triage.py
Normal file
197
scripts/benchmarks/05_issue_triage.py
Normal file
@@ -0,0 +1,197 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Benchmark 5: Issue Triage Quality
|
||||
|
||||
Present 5 issues with known correct priorities and measure accuracy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
OLLAMA_URL = "http://localhost:11434"
|
||||
|
||||
TRIAGE_PROMPT_TEMPLATE = """\
|
||||
You are a software project triage agent. Assign a priority to the following issue.
|
||||
|
||||
Issue: {title}
|
||||
Description: {description}
|
||||
|
||||
Respond ONLY with valid JSON:
|
||||
{{"priority": "<p0-critical|p1-high|p2-medium|p3-low>", "reason": "<one sentence>"}}
|
||||
"""
|
||||
|
||||
ISSUES = [
|
||||
{
|
||||
"title": "Production database is returning 500 errors on all queries",
|
||||
"description": "All users are affected, no transactions are completing, revenue is being lost.",
|
||||
"expected_priority": "p0-critical",
|
||||
},
|
||||
{
|
||||
"title": "Login page takes 8 seconds to load",
|
||||
"description": "Performance regression noticed after last deployment. Users are complaining but can still log in.",
|
||||
"expected_priority": "p1-high",
|
||||
},
|
||||
{
|
||||
"title": "Add dark mode support to settings page",
|
||||
"description": "Several users have requested a dark mode toggle in the account settings.",
|
||||
"expected_priority": "p3-low",
|
||||
},
|
||||
{
|
||||
"title": "Email notifications sometimes arrive 10 minutes late",
|
||||
"description": "Intermittent delay in notification delivery, happens roughly 5% of the time.",
|
||||
"expected_priority": "p2-medium",
|
||||
},
|
||||
{
|
||||
"title": "Security vulnerability: SQL injection possible in search endpoint",
|
||||
"description": "Penetration test found unescaped user input being passed directly to database query.",
|
||||
"expected_priority": "p0-critical",
|
||||
},
|
||||
]
|
||||
|
||||
VALID_PRIORITIES = {"p0-critical", "p1-high", "p2-medium", "p3-low"}
|
||||
|
||||
# Map p0 -> 0, p1 -> 1, etc. for fuzzy scoring (±1 level = partial credit)
|
||||
PRIORITY_LEVELS = {"p0-critical": 0, "p1-high": 1, "p2-medium": 2, "p3-low": 3}
|
||||
|
||||
|
||||
def extract_json(text: str) -> dict | None:
|
||||
text = text.strip()
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
fence_match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL)
|
||||
if fence_match:
|
||||
try:
|
||||
return json.loads(fence_match.group(1))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
brace_match = re.search(r"\{[^{}]*\}", text, re.DOTALL)
|
||||
if brace_match:
|
||||
try:
|
||||
return json.loads(brace_match.group(0))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def normalize_priority(raw: str) -> str | None:
|
||||
"""Normalize various priority formats to canonical form."""
|
||||
raw = raw.lower().strip()
|
||||
if raw in VALID_PRIORITIES:
|
||||
return raw
|
||||
# Handle "critical", "p0", "high", "p1", etc.
|
||||
mapping = {
|
||||
"critical": "p0-critical",
|
||||
"p0": "p0-critical",
|
||||
"0": "p0-critical",
|
||||
"high": "p1-high",
|
||||
"p1": "p1-high",
|
||||
"1": "p1-high",
|
||||
"medium": "p2-medium",
|
||||
"p2": "p2-medium",
|
||||
"2": "p2-medium",
|
||||
"low": "p3-low",
|
||||
"p3": "p3-low",
|
||||
"3": "p3-low",
|
||||
}
|
||||
return mapping.get(raw)
|
||||
|
||||
|
||||
def run_prompt(model: str, prompt: str) -> str:
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.1, "num_predict": 256},
|
||||
}
|
||||
resp = requests.post(f"{OLLAMA_URL}/api/generate", json=payload, timeout=120)
|
||||
resp.raise_for_status()
|
||||
return resp.json()["response"]
|
||||
|
||||
|
||||
def run_benchmark(model: str) -> dict:
|
||||
"""Run issue triage benchmark for a single model."""
|
||||
results = []
|
||||
total_time = 0.0
|
||||
|
||||
for i, issue in enumerate(ISSUES, 1):
|
||||
prompt = TRIAGE_PROMPT_TEMPLATE.format(
|
||||
title=issue["title"], description=issue["description"]
|
||||
)
|
||||
start = time.time()
|
||||
try:
|
||||
raw = run_prompt(model, prompt)
|
||||
elapsed = time.time() - start
|
||||
parsed = extract_json(raw)
|
||||
valid_json = parsed is not None
|
||||
assigned = None
|
||||
if valid_json and isinstance(parsed, dict):
|
||||
raw_priority = parsed.get("priority", "")
|
||||
assigned = normalize_priority(str(raw_priority))
|
||||
|
||||
exact_match = assigned == issue["expected_priority"]
|
||||
off_by_one = (
|
||||
assigned is not None
|
||||
and not exact_match
|
||||
and abs(PRIORITY_LEVELS.get(assigned, -1) - PRIORITY_LEVELS[issue["expected_priority"]]) == 1
|
||||
)
|
||||
|
||||
results.append(
|
||||
{
|
||||
"issue_id": i,
|
||||
"title": issue["title"][:60],
|
||||
"expected": issue["expected_priority"],
|
||||
"assigned": assigned,
|
||||
"exact_match": exact_match,
|
||||
"off_by_one": off_by_one,
|
||||
"valid_json": valid_json,
|
||||
"elapsed_s": round(elapsed, 2),
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
elapsed = time.time() - start
|
||||
results.append(
|
||||
{
|
||||
"issue_id": i,
|
||||
"title": issue["title"][:60],
|
||||
"expected": issue["expected_priority"],
|
||||
"assigned": None,
|
||||
"exact_match": False,
|
||||
"off_by_one": False,
|
||||
"valid_json": False,
|
||||
"elapsed_s": round(elapsed, 2),
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
total_time += elapsed
|
||||
|
||||
exact_count = sum(1 for r in results if r["exact_match"])
|
||||
accuracy = exact_count / len(ISSUES)
|
||||
|
||||
return {
|
||||
"benchmark": "issue_triage",
|
||||
"model": model,
|
||||
"total_issues": len(ISSUES),
|
||||
"exact_matches": exact_count,
|
||||
"accuracy": round(accuracy, 3),
|
||||
"passed": accuracy >= 0.80,
|
||||
"total_time_s": round(total_time, 2),
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = sys.argv[1] if len(sys.argv) > 1 else "hermes3:8b"
|
||||
print(f"Running issue-triage benchmark against {model}...")
|
||||
result = run_benchmark(model)
|
||||
print(json.dumps(result, indent=2))
|
||||
sys.exit(0 if result["passed"] else 1)
|
||||
334
scripts/benchmarks/run_suite.py
Normal file
334
scripts/benchmarks/run_suite.py
Normal file
@@ -0,0 +1,334 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Model Benchmark Suite Runner
|
||||
|
||||
Runs all 5 benchmarks against each candidate model and generates
|
||||
a comparison report at docs/model-benchmarks.md.
|
||||
|
||||
Usage:
|
||||
python scripts/benchmarks/run_suite.py
|
||||
python scripts/benchmarks/run_suite.py --models hermes3:8b qwen3.5:latest
|
||||
python scripts/benchmarks/run_suite.py --output docs/model-benchmarks.md
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import importlib.util
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
OLLAMA_URL = "http://localhost:11434"
|
||||
|
||||
# Models to test — maps friendly name to Ollama model tag.
|
||||
# Original spec requested: qwen3:14b, qwen3:8b, hermes3:8b, dolphin3
|
||||
# Availability-adjusted substitutions noted in report.
|
||||
DEFAULT_MODELS = [
|
||||
"hermes3:8b",
|
||||
"qwen3.5:latest",
|
||||
"qwen2.5:14b",
|
||||
"llama3.2:latest",
|
||||
]
|
||||
|
||||
BENCHMARKS_DIR = Path(__file__).parent
|
||||
DOCS_DIR = Path(__file__).resolve().parent.parent.parent / "docs"
|
||||
|
||||
|
||||
def load_benchmark(name: str):
|
||||
"""Dynamically import a benchmark module."""
|
||||
path = BENCHMARKS_DIR / name
|
||||
module_name = Path(name).stem
|
||||
spec = importlib.util.spec_from_file_location(module_name, path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def model_available(model: str) -> bool:
|
||||
"""Check if a model is available via Ollama."""
|
||||
try:
|
||||
resp = requests.get(f"{OLLAMA_URL}/api/tags", timeout=10)
|
||||
if resp.status_code != 200:
|
||||
return False
|
||||
models = {m["name"] for m in resp.json().get("models", [])}
|
||||
return model in models
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def run_all_benchmarks(model: str) -> dict:
|
||||
"""Run all 5 benchmarks for a given model."""
|
||||
benchmark_files = [
|
||||
"01_tool_calling.py",
|
||||
"02_code_generation.py",
|
||||
"03_shell_commands.py",
|
||||
"04_multi_turn_coherence.py",
|
||||
"05_issue_triage.py",
|
||||
]
|
||||
|
||||
results = {}
|
||||
for fname in benchmark_files:
|
||||
key = fname.replace(".py", "")
|
||||
print(f" [{model}] Running {key}...", flush=True)
|
||||
try:
|
||||
mod = load_benchmark(fname)
|
||||
start = time.time()
|
||||
if key == "01_tool_calling":
|
||||
result = mod.run_benchmark(model)
|
||||
elif key == "02_code_generation":
|
||||
result = mod.run_benchmark(model)
|
||||
elif key == "03_shell_commands":
|
||||
result = mod.run_benchmark(model)
|
||||
elif key == "04_multi_turn_coherence":
|
||||
result = mod.run_multi_turn(model)
|
||||
elif key == "05_issue_triage":
|
||||
result = mod.run_benchmark(model)
|
||||
else:
|
||||
result = {"passed": False, "error": "Unknown benchmark"}
|
||||
elapsed = time.time() - start
|
||||
print(
|
||||
f" -> {'PASS' if result.get('passed') else 'FAIL'} ({elapsed:.1f}s)",
|
||||
flush=True,
|
||||
)
|
||||
results[key] = result
|
||||
except Exception as exc:
|
||||
print(f" -> ERROR: {exc}", flush=True)
|
||||
results[key] = {"benchmark": key, "model": model, "passed": False, "error": str(exc)}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def score_model(results: dict) -> dict:
|
||||
"""Compute summary scores for a model."""
|
||||
benchmarks = list(results.values())
|
||||
passed = sum(1 for b in benchmarks if b.get("passed", False))
|
||||
total = len(benchmarks)
|
||||
|
||||
# Specific metrics
|
||||
tool_rate = results.get("01_tool_calling", {}).get("compliance_rate", 0.0)
|
||||
code_pass = results.get("02_code_generation", {}).get("passed", False)
|
||||
shell_pass = results.get("03_shell_commands", {}).get("passed", False)
|
||||
coherence = results.get("04_multi_turn_coherence", {}).get("coherence_rate", 0.0)
|
||||
triage_acc = results.get("05_issue_triage", {}).get("accuracy", 0.0)
|
||||
|
||||
total_time = sum(
|
||||
r.get("total_time_s", r.get("elapsed_s", 0.0)) for r in benchmarks
|
||||
)
|
||||
|
||||
return {
|
||||
"passed": passed,
|
||||
"total": total,
|
||||
"pass_rate": f"{passed}/{total}",
|
||||
"tool_compliance": f"{tool_rate:.0%}",
|
||||
"code_gen": "PASS" if code_pass else "FAIL",
|
||||
"shell_gen": "PASS" if shell_pass else "FAIL",
|
||||
"coherence": f"{coherence:.0%}",
|
||||
"triage_accuracy": f"{triage_acc:.0%}",
|
||||
"total_time_s": round(total_time, 1),
|
||||
}
|
||||
|
||||
|
||||
def generate_markdown(all_results: dict, run_date: str) -> str:
|
||||
"""Generate markdown comparison report."""
|
||||
lines = []
|
||||
lines.append("# Model Benchmark Results")
|
||||
lines.append("")
|
||||
lines.append(f"> Generated: {run_date} ")
|
||||
lines.append(f"> Ollama URL: `{OLLAMA_URL}` ")
|
||||
lines.append("> Issue: [#1066](http://143.198.27.163:3000/rockachopa/Timmy-time-dashboard/issues/1066)")
|
||||
lines.append("")
|
||||
lines.append("## Overview")
|
||||
lines.append("")
|
||||
lines.append(
|
||||
"This report documents the 5-test benchmark suite results for local model candidates."
|
||||
)
|
||||
lines.append("")
|
||||
lines.append("### Model Availability vs. Spec")
|
||||
lines.append("")
|
||||
lines.append("| Requested | Tested Substitute | Reason |")
|
||||
lines.append("|-----------|-------------------|--------|")
|
||||
lines.append("| `qwen3:14b` | `qwen2.5:14b` | `qwen3:14b` not pulled locally |")
|
||||
lines.append("| `qwen3:8b` | `qwen3.5:latest` | `qwen3:8b` not pulled locally |")
|
||||
lines.append("| `hermes3:8b` | `hermes3:8b` | Exact match |")
|
||||
lines.append("| `dolphin3` | `llama3.2:latest` | `dolphin3` not pulled locally |")
|
||||
lines.append("")
|
||||
|
||||
# Summary table
|
||||
lines.append("## Summary Comparison Table")
|
||||
lines.append("")
|
||||
lines.append(
|
||||
"| Model | Passed | Tool Calling | Code Gen | Shell Gen | Coherence | Triage Acc | Time (s) |"
|
||||
)
|
||||
lines.append(
|
||||
"|-------|--------|-------------|----------|-----------|-----------|------------|----------|"
|
||||
)
|
||||
|
||||
for model, results in all_results.items():
|
||||
if "error" in results and "01_tool_calling" not in results:
|
||||
lines.append(f"| `{model}` | — | — | — | — | — | — | — |")
|
||||
continue
|
||||
s = score_model(results)
|
||||
lines.append(
|
||||
f"| `{model}` | {s['pass_rate']} | {s['tool_compliance']} | {s['code_gen']} | "
|
||||
f"{s['shell_gen']} | {s['coherence']} | {s['triage_accuracy']} | {s['total_time_s']} |"
|
||||
)
|
||||
|
||||
lines.append("")
|
||||
|
||||
# Per-model detail sections
|
||||
lines.append("## Per-Model Detail")
|
||||
lines.append("")
|
||||
|
||||
for model, results in all_results.items():
|
||||
lines.append(f"### `{model}`")
|
||||
lines.append("")
|
||||
|
||||
if "error" in results and not isinstance(results.get("error"), str):
|
||||
lines.append(f"> **Error:** {results.get('error')}")
|
||||
lines.append("")
|
||||
continue
|
||||
|
||||
for bkey, bres in results.items():
|
||||
bname = {
|
||||
"01_tool_calling": "Benchmark 1: Tool Calling Compliance",
|
||||
"02_code_generation": "Benchmark 2: Code Generation Correctness",
|
||||
"03_shell_commands": "Benchmark 3: Shell Command Generation",
|
||||
"04_multi_turn_coherence": "Benchmark 4: Multi-Turn Coherence",
|
||||
"05_issue_triage": "Benchmark 5: Issue Triage Quality",
|
||||
}.get(bkey, bkey)
|
||||
|
||||
status = "✅ PASS" if bres.get("passed") else "❌ FAIL"
|
||||
lines.append(f"#### {bname} — {status}")
|
||||
lines.append("")
|
||||
|
||||
if bkey == "01_tool_calling":
|
||||
rate = bres.get("compliance_rate", 0)
|
||||
count = bres.get("valid_json_count", 0)
|
||||
total = bres.get("total_prompts", 0)
|
||||
lines.append(
|
||||
f"- **JSON Compliance:** {count}/{total} ({rate:.0%}) — target ≥90%"
|
||||
)
|
||||
elif bkey == "02_code_generation":
|
||||
lines.append(f"- **Result:** {bres.get('detail', bres.get('error', 'n/a'))}")
|
||||
snippet = bres.get("code_snippet", "")
|
||||
if snippet:
|
||||
lines.append(f"- **Generated code snippet:**")
|
||||
lines.append(" ```python")
|
||||
for ln in snippet.splitlines()[:8]:
|
||||
lines.append(f" {ln}")
|
||||
lines.append(" ```")
|
||||
elif bkey == "03_shell_commands":
|
||||
passed = bres.get("passed_count", 0)
|
||||
refused = bres.get("refused_count", 0)
|
||||
total = bres.get("total_prompts", 0)
|
||||
lines.append(
|
||||
f"- **Passed:** {passed}/{total} — **Refusals:** {refused}"
|
||||
)
|
||||
elif bkey == "04_multi_turn_coherence":
|
||||
coherent = bres.get("coherent_turns", 0)
|
||||
total = bres.get("total_turns", 0)
|
||||
rate = bres.get("coherence_rate", 0)
|
||||
lines.append(
|
||||
f"- **Coherent turns:** {coherent}/{total} ({rate:.0%}) — target ≥80%"
|
||||
)
|
||||
elif bkey == "05_issue_triage":
|
||||
exact = bres.get("exact_matches", 0)
|
||||
total = bres.get("total_issues", 0)
|
||||
acc = bres.get("accuracy", 0)
|
||||
lines.append(
|
||||
f"- **Accuracy:** {exact}/{total} ({acc:.0%}) — target ≥80%"
|
||||
)
|
||||
|
||||
elapsed = bres.get("total_time_s", bres.get("elapsed_s", 0))
|
||||
lines.append(f"- **Time:** {elapsed}s")
|
||||
lines.append("")
|
||||
|
||||
lines.append("## Raw JSON Data")
|
||||
lines.append("")
|
||||
lines.append("<details>")
|
||||
lines.append("<summary>Click to expand full JSON results</summary>")
|
||||
lines.append("")
|
||||
lines.append("```json")
|
||||
lines.append(json.dumps(all_results, indent=2))
|
||||
lines.append("```")
|
||||
lines.append("")
|
||||
lines.append("</details>")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Run model benchmark suite")
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
default=DEFAULT_MODELS,
|
||||
help="Models to test",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
default=DOCS_DIR / "model-benchmarks.md",
|
||||
help="Output markdown file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--json-output",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Optional JSON output file",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
run_date = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M UTC")
|
||||
|
||||
print(f"Model Benchmark Suite — {run_date}")
|
||||
print(f"Testing {len(args.models)} model(s): {', '.join(args.models)}")
|
||||
print()
|
||||
|
||||
all_results: dict[str, dict] = {}
|
||||
|
||||
for model in args.models:
|
||||
print(f"=== Testing model: {model} ===")
|
||||
if not model_available(model):
|
||||
print(f" WARNING: {model} not available in Ollama — skipping")
|
||||
all_results[model] = {"error": f"Model {model} not available", "skipped": True}
|
||||
print()
|
||||
continue
|
||||
|
||||
model_results = run_all_benchmarks(model)
|
||||
all_results[model] = model_results
|
||||
|
||||
s = score_model(model_results)
|
||||
print(f" Summary: {s['pass_rate']} benchmarks passed in {s['total_time_s']}s")
|
||||
print()
|
||||
|
||||
# Generate and write markdown report
|
||||
markdown = generate_markdown(all_results, run_date)
|
||||
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
args.output.write_text(markdown, encoding="utf-8")
|
||||
print(f"Report written to: {args.output}")
|
||||
|
||||
if args.json_output:
|
||||
args.json_output.write_text(json.dumps(all_results, indent=2), encoding="utf-8")
|
||||
print(f"JSON data written to: {args.json_output}")
|
||||
|
||||
# Overall pass/fail
|
||||
all_pass = all(
|
||||
not r.get("skipped", False)
|
||||
and all(b.get("passed", False) for b in r.values() if isinstance(b, dict))
|
||||
for r in all_results.values()
|
||||
)
|
||||
return 0 if all_pass else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -55,6 +55,7 @@ from dashboard.routes.system import router as system_router
|
||||
from dashboard.routes.tasks import router as tasks_router
|
||||
from dashboard.routes.telegram import router as telegram_router
|
||||
from dashboard.routes.thinking import router as thinking_router
|
||||
from dashboard.routes.self_correction import router as self_correction_router
|
||||
from dashboard.routes.three_strike import router as three_strike_router
|
||||
from dashboard.routes.tools import router as tools_router
|
||||
from dashboard.routes.tower import router as tower_router
|
||||
@@ -551,12 +552,28 @@ async def lifespan(app: FastAPI):
|
||||
except Exception:
|
||||
logger.debug("Failed to register error recorder")
|
||||
|
||||
# Mark session start for sovereignty duration tracking
|
||||
try:
|
||||
from timmy.sovereignty import mark_session_start
|
||||
|
||||
mark_session_start()
|
||||
except Exception:
|
||||
logger.debug("Failed to mark sovereignty session start")
|
||||
|
||||
logger.info("✓ Dashboard ready for requests")
|
||||
|
||||
yield
|
||||
|
||||
await _shutdown_cleanup(bg_tasks, workshop_heartbeat)
|
||||
|
||||
# Generate and commit sovereignty session report
|
||||
try:
|
||||
from timmy.sovereignty import generate_and_commit_report
|
||||
|
||||
await generate_and_commit_report()
|
||||
except Exception as exc:
|
||||
logger.warning("Sovereignty report generation failed at shutdown: %s", exc)
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Mission Control",
|
||||
@@ -680,6 +697,7 @@ app.include_router(scorecards_router)
|
||||
app.include_router(sovereignty_metrics_router)
|
||||
app.include_router(sovereignty_ws_router)
|
||||
app.include_router(three_strike_router)
|
||||
app.include_router(self_correction_router)
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
|
||||
58
src/dashboard/routes/self_correction.py
Normal file
58
src/dashboard/routes/self_correction.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Self-Correction Dashboard routes.
|
||||
|
||||
GET /self-correction/ui — HTML dashboard
|
||||
GET /self-correction/timeline — HTMX partial: recent event timeline
|
||||
GET /self-correction/patterns — HTMX partial: recurring failure patterns
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from dashboard.templating import templates
|
||||
from infrastructure.self_correction import get_corrections, get_patterns, get_stats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/self-correction", tags=["self-correction"])
|
||||
|
||||
|
||||
@router.get("/ui", response_class=HTMLResponse)
|
||||
async def self_correction_ui(request: Request):
|
||||
"""Render the Self-Correction Dashboard."""
|
||||
stats = get_stats()
|
||||
corrections = get_corrections(limit=20)
|
||||
patterns = get_patterns(top_n=10)
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"self_correction.html",
|
||||
{
|
||||
"stats": stats,
|
||||
"corrections": corrections,
|
||||
"patterns": patterns,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/timeline", response_class=HTMLResponse)
|
||||
async def self_correction_timeline(request: Request):
|
||||
"""HTMX partial: recent self-correction event timeline."""
|
||||
corrections = get_corrections(limit=30)
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"partials/self_correction_timeline.html",
|
||||
{"corrections": corrections},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/patterns", response_class=HTMLResponse)
|
||||
async def self_correction_patterns(request: Request):
|
||||
"""HTMX partial: recurring failure patterns."""
|
||||
patterns = get_patterns(top_n=10)
|
||||
stats = get_stats()
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"partials/self_correction_patterns.html",
|
||||
{"patterns": patterns, "stats": stats},
|
||||
)
|
||||
@@ -71,6 +71,7 @@
|
||||
<a href="/spark/ui" class="mc-test-link">SPARK</a>
|
||||
<a href="/memory" class="mc-test-link">MEMORY</a>
|
||||
<a href="/marketplace/ui" class="mc-test-link">MARKET</a>
|
||||
<a href="/self-correction/ui" class="mc-test-link">SELF-CORRECT</a>
|
||||
</div>
|
||||
</div>
|
||||
<div class="mc-nav-dropdown">
|
||||
@@ -132,6 +133,7 @@
|
||||
<a href="/spark/ui" class="mc-mobile-link">SPARK</a>
|
||||
<a href="/memory" class="mc-mobile-link">MEMORY</a>
|
||||
<a href="/marketplace/ui" class="mc-mobile-link">MARKET</a>
|
||||
<a href="/self-correction/ui" class="mc-mobile-link">SELF-CORRECT</a>
|
||||
<div class="mc-mobile-section-label">AGENTS</div>
|
||||
<a href="/hands" class="mc-mobile-link">HANDS</a>
|
||||
<a href="/work-orders/queue" class="mc-mobile-link">WORK ORDERS</a>
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
{% if patterns %}
|
||||
<table class="mc-table w-100">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>ERROR TYPE</th>
|
||||
<th class="text-center">COUNT</th>
|
||||
<th class="text-center">CORRECTED</th>
|
||||
<th class="text-center">FAILED</th>
|
||||
<th>LAST SEEN</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for p in patterns %}
|
||||
<tr>
|
||||
<td class="sc-pattern-type">{{ p.error_type }}</td>
|
||||
<td class="text-center">
|
||||
<span class="badge {% if p.count >= 5 %}badge-error{% elif p.count >= 3 %}badge-warning{% else %}badge-info{% endif %}">{{ p.count }}</span>
|
||||
</td>
|
||||
<td class="text-center text-success">{{ p.success_count }}</td>
|
||||
<td class="text-center {% if p.failed_count > 0 %}text-danger{% else %}text-muted{% endif %}">{{ p.failed_count }}</td>
|
||||
<td class="sc-event-time">{{ p.last_seen[:16] if p.last_seen else '—' }}</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
{% else %}
|
||||
<div class="text-center text-muted py-3">No patterns detected yet.</div>
|
||||
{% endif %}
|
||||
@@ -0,0 +1,26 @@
|
||||
{% if corrections %}
|
||||
{% for ev in corrections %}
|
||||
<div class="sc-event sc-status-{{ ev.outcome_status }}">
|
||||
<div class="sc-event-header">
|
||||
<span class="sc-status-badge sc-status-{{ ev.outcome_status }}">
|
||||
{% if ev.outcome_status == 'success' %}✓ CORRECTED
|
||||
{% elif ev.outcome_status == 'partial' %}● PARTIAL
|
||||
{% else %}✗ FAILED
|
||||
{% endif %}
|
||||
</span>
|
||||
<span class="sc-source-badge">{{ ev.source }}</span>
|
||||
<span class="sc-event-time">{{ ev.created_at[:19] }}</span>
|
||||
</div>
|
||||
<div class="sc-event-error-type">{{ ev.error_type }}</div>
|
||||
<div class="sc-event-intent"><span class="sc-label">INTENT:</span> {{ ev.original_intent[:120] }}{% if ev.original_intent | length > 120 %}…{% endif %}</div>
|
||||
<div class="sc-event-error"><span class="sc-label">ERROR:</span> {{ ev.detected_error[:120] }}{% if ev.detected_error | length > 120 %}…{% endif %}</div>
|
||||
<div class="sc-event-strategy"><span class="sc-label">STRATEGY:</span> {{ ev.correction_strategy[:120] }}{% if ev.correction_strategy | length > 120 %}…{% endif %}</div>
|
||||
<div class="sc-event-outcome"><span class="sc-label">OUTCOME:</span> {{ ev.final_outcome[:120] }}{% if ev.final_outcome | length > 120 %}…{% endif %}</div>
|
||||
{% if ev.task_id %}
|
||||
<div class="sc-event-meta">task: {{ ev.task_id[:8] }}</div>
|
||||
{% endif %}
|
||||
</div>
|
||||
{% endfor %}
|
||||
{% else %}
|
||||
<div class="text-center text-muted py-3">No self-correction events recorded yet.</div>
|
||||
{% endif %}
|
||||
102
src/dashboard/templates/self_correction.html
Normal file
102
src/dashboard/templates/self_correction.html
Normal file
@@ -0,0 +1,102 @@
|
||||
{% extends "base.html" %}
|
||||
{% from "macros.html" import panel %}
|
||||
|
||||
{% block title %}Timmy Time — Self-Correction Dashboard{% endblock %}
|
||||
|
||||
{% block extra_styles %}{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
<div class="container-fluid py-3">
|
||||
|
||||
<!-- Header -->
|
||||
<div class="spark-header mb-3">
|
||||
<div class="spark-title">SELF-CORRECTION</div>
|
||||
<div class="spark-subtitle">
|
||||
Agent error detection & recovery —
|
||||
<span class="spark-status-val">{{ stats.total }}</span> events,
|
||||
<span class="spark-status-val">{{ stats.success_rate }}%</span> correction rate,
|
||||
<span class="spark-status-val">{{ stats.unique_error_types }}</span> distinct error types
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="row g-3">
|
||||
|
||||
<!-- Left column: stats + patterns -->
|
||||
<div class="col-12 col-lg-4 d-flex flex-column gap-3">
|
||||
|
||||
<!-- Stats panel -->
|
||||
<div class="card mc-panel">
|
||||
<div class="card-header mc-panel-header">// CORRECTION STATS</div>
|
||||
<div class="card-body p-3">
|
||||
<div class="spark-stat-grid">
|
||||
<div class="spark-stat">
|
||||
<span class="spark-stat-label">TOTAL</span>
|
||||
<span class="spark-stat-value">{{ stats.total }}</span>
|
||||
</div>
|
||||
<div class="spark-stat">
|
||||
<span class="spark-stat-label">CORRECTED</span>
|
||||
<span class="spark-stat-value text-success">{{ stats.success_count }}</span>
|
||||
</div>
|
||||
<div class="spark-stat">
|
||||
<span class="spark-stat-label">PARTIAL</span>
|
||||
<span class="spark-stat-value text-warning">{{ stats.partial_count }}</span>
|
||||
</div>
|
||||
<div class="spark-stat">
|
||||
<span class="spark-stat-label">FAILED</span>
|
||||
<span class="spark-stat-value {% if stats.failed_count > 0 %}text-danger{% else %}text-muted{% endif %}">{{ stats.failed_count }}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="mt-3">
|
||||
<div class="d-flex justify-content-between mb-1">
|
||||
<small class="text-muted">Correction Rate</small>
|
||||
<small class="{% if stats.success_rate >= 70 %}text-success{% elif stats.success_rate >= 40 %}text-warning{% else %}text-danger{% endif %}">{{ stats.success_rate }}%</small>
|
||||
</div>
|
||||
<div class="progress" style="height:6px;">
|
||||
<div class="progress-bar {% if stats.success_rate >= 70 %}bg-success{% elif stats.success_rate >= 40 %}bg-warning{% else %}bg-danger{% endif %}"
|
||||
role="progressbar"
|
||||
style="width:{{ stats.success_rate }}%"
|
||||
aria-valuenow="{{ stats.success_rate }}"
|
||||
aria-valuemin="0"
|
||||
aria-valuemax="100"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Patterns panel -->
|
||||
<div class="card mc-panel"
|
||||
hx-get="/self-correction/patterns"
|
||||
hx-trigger="load, every 60s"
|
||||
hx-target="#sc-patterns-body"
|
||||
hx-swap="innerHTML">
|
||||
<div class="card-header mc-panel-header d-flex justify-content-between align-items-center">
|
||||
<span>// RECURRING PATTERNS</span>
|
||||
<span class="badge badge-info">{{ patterns | length }}</span>
|
||||
</div>
|
||||
<div class="card-body p-0" id="sc-patterns-body">
|
||||
{% include "partials/self_correction_patterns.html" %}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
<!-- Right column: timeline -->
|
||||
<div class="col-12 col-lg-8">
|
||||
<div class="card mc-panel"
|
||||
hx-get="/self-correction/timeline"
|
||||
hx-trigger="load, every 30s"
|
||||
hx-target="#sc-timeline-body"
|
||||
hx-swap="innerHTML">
|
||||
<div class="card-header mc-panel-header d-flex justify-content-between align-items-center">
|
||||
<span>// CORRECTION TIMELINE</span>
|
||||
<span class="badge badge-info">{{ corrections | length }}</span>
|
||||
</div>
|
||||
<div class="card-body p-3" id="sc-timeline-body">
|
||||
{% include "partials/self_correction_timeline.html" %}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
{% endblock %}
|
||||
247
src/infrastructure/self_correction.py
Normal file
247
src/infrastructure/self_correction.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""Self-correction event logger.
|
||||
|
||||
Records instances where the agent detected its own errors and the steps
|
||||
it took to correct them. Used by the Self-Correction Dashboard to visualise
|
||||
these events and surface recurring failure patterns.
|
||||
|
||||
Usage::
|
||||
|
||||
from infrastructure.self_correction import log_self_correction, get_corrections, get_patterns
|
||||
|
||||
log_self_correction(
|
||||
source="agentic_loop",
|
||||
original_intent="Execute step 3: deploy service",
|
||||
detected_error="ConnectionRefusedError: port 8080 unavailable",
|
||||
correction_strategy="Retry on alternate port 8081",
|
||||
final_outcome="Success on retry",
|
||||
task_id="abc123",
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from contextlib import closing, contextmanager
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Database
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DB_PATH: Path | None = None
|
||||
|
||||
|
||||
def _get_db_path() -> Path:
|
||||
global _DB_PATH
|
||||
if _DB_PATH is None:
|
||||
from config import settings
|
||||
|
||||
_DB_PATH = Path(settings.repo_root) / "data" / "self_correction.db"
|
||||
return _DB_PATH
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _get_db() -> Generator[sqlite3.Connection, None, None]:
|
||||
db_path = _get_db_path()
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with closing(sqlite3.connect(str(db_path))) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS self_correction_events (
|
||||
id TEXT PRIMARY KEY,
|
||||
source TEXT NOT NULL,
|
||||
task_id TEXT DEFAULT '',
|
||||
original_intent TEXT NOT NULL,
|
||||
detected_error TEXT NOT NULL,
|
||||
correction_strategy TEXT NOT NULL,
|
||||
final_outcome TEXT NOT NULL,
|
||||
outcome_status TEXT DEFAULT 'success',
|
||||
error_type TEXT DEFAULT '',
|
||||
created_at TEXT DEFAULT (datetime('now'))
|
||||
)
|
||||
""")
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_sc_created ON self_correction_events(created_at)"
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_sc_error_type ON self_correction_events(error_type)"
|
||||
)
|
||||
conn.commit()
|
||||
yield conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Write
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def log_self_correction(
|
||||
*,
|
||||
source: str,
|
||||
original_intent: str,
|
||||
detected_error: str,
|
||||
correction_strategy: str,
|
||||
final_outcome: str,
|
||||
task_id: str = "",
|
||||
outcome_status: str = "success",
|
||||
error_type: str = "",
|
||||
) -> str:
|
||||
"""Record a self-correction event and return its ID.
|
||||
|
||||
Args:
|
||||
source: Module or component that triggered the correction.
|
||||
original_intent: What the agent was trying to do.
|
||||
detected_error: The error or problem that was detected.
|
||||
correction_strategy: How the agent attempted to correct the error.
|
||||
final_outcome: What the result of the correction attempt was.
|
||||
task_id: Optional task/session ID for correlation.
|
||||
outcome_status: 'success', 'partial', or 'failed'.
|
||||
error_type: Short category label for pattern analysis (e.g.
|
||||
'ConnectionError', 'TimeoutError').
|
||||
|
||||
Returns:
|
||||
The ID of the newly created record.
|
||||
"""
|
||||
event_id = str(uuid.uuid4())
|
||||
if not error_type:
|
||||
# Derive a simple type from the first word of the detected error
|
||||
error_type = detected_error.split(":")[0].strip()[:64]
|
||||
|
||||
try:
|
||||
with _get_db() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO self_correction_events
|
||||
(id, source, task_id, original_intent, detected_error,
|
||||
correction_strategy, final_outcome, outcome_status, error_type)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
event_id,
|
||||
source,
|
||||
task_id,
|
||||
original_intent[:2000],
|
||||
detected_error[:2000],
|
||||
correction_strategy[:2000],
|
||||
final_outcome[:2000],
|
||||
outcome_status,
|
||||
error_type,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
logger.info(
|
||||
"Self-correction logged [%s] source=%s error_type=%s status=%s",
|
||||
event_id[:8],
|
||||
source,
|
||||
error_type,
|
||||
outcome_status,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to log self-correction event: %s", exc)
|
||||
|
||||
return event_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Read
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_corrections(limit: int = 50) -> list[dict]:
|
||||
"""Return the most recent self-correction events, newest first."""
|
||||
try:
|
||||
with _get_db() as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT * FROM self_correction_events
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(limit,),
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch self-correction events: %s", exc)
|
||||
return []
|
||||
|
||||
|
||||
def get_patterns(top_n: int = 10) -> list[dict]:
|
||||
"""Return the most common recurring error types with counts.
|
||||
|
||||
Each entry has:
|
||||
- error_type: category label
|
||||
- count: total occurrences
|
||||
- success_count: corrected successfully
|
||||
- failed_count: correction also failed
|
||||
- last_seen: ISO timestamp of most recent occurrence
|
||||
"""
|
||||
try:
|
||||
with _get_db() as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
error_type,
|
||||
COUNT(*) AS count,
|
||||
SUM(CASE WHEN outcome_status = 'success' THEN 1 ELSE 0 END) AS success_count,
|
||||
SUM(CASE WHEN outcome_status = 'failed' THEN 1 ELSE 0 END) AS failed_count,
|
||||
MAX(created_at) AS last_seen
|
||||
FROM self_correction_events
|
||||
GROUP BY error_type
|
||||
ORDER BY count DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(top_n,),
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch self-correction patterns: %s", exc)
|
||||
return []
|
||||
|
||||
|
||||
def get_stats() -> dict:
|
||||
"""Return aggregate statistics for the summary panel."""
|
||||
try:
|
||||
with _get_db() as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
COUNT(*) AS total,
|
||||
SUM(CASE WHEN outcome_status = 'success' THEN 1 ELSE 0 END) AS success_count,
|
||||
SUM(CASE WHEN outcome_status = 'partial' THEN 1 ELSE 0 END) AS partial_count,
|
||||
SUM(CASE WHEN outcome_status = 'failed' THEN 1 ELSE 0 END) AS failed_count,
|
||||
COUNT(DISTINCT error_type) AS unique_error_types,
|
||||
COUNT(DISTINCT source) AS sources
|
||||
FROM self_correction_events
|
||||
"""
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return _empty_stats()
|
||||
d = dict(row)
|
||||
total = d.get("total") or 0
|
||||
if total:
|
||||
d["success_rate"] = round((d.get("success_count") or 0) / total * 100)
|
||||
else:
|
||||
d["success_rate"] = 0
|
||||
return d
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch self-correction stats: %s", exc)
|
||||
return _empty_stats()
|
||||
|
||||
|
||||
def _empty_stats() -> dict:
|
||||
return {
|
||||
"total": 0,
|
||||
"success_count": 0,
|
||||
"partial_count": 0,
|
||||
"failed_count": 0,
|
||||
"unique_error_types": 0,
|
||||
"sources": 0,
|
||||
"success_rate": 0,
|
||||
}
|
||||
7
src/self_coding/__init__.py
Normal file
7
src/self_coding/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Self-coding package — Timmy's self-modification capability.
|
||||
|
||||
Provides the branch→edit→test→commit/revert loop that allows Timmy
|
||||
to propose and apply code changes autonomously, gated by the test suite.
|
||||
|
||||
Main entry point: ``self_coding.self_modify.loop``
|
||||
"""
|
||||
129
src/self_coding/gitea_client.py
Normal file
129
src/self_coding/gitea_client.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Gitea REST client — thin wrapper for PR creation and issue commenting.
|
||||
|
||||
Uses ``settings.gitea_url``, ``settings.gitea_token``, and
|
||||
``settings.gitea_repo`` (owner/repo) from config. Degrades gracefully
|
||||
when the token is absent or the server is unreachable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PullRequest:
|
||||
"""Minimal representation of a created pull request."""
|
||||
|
||||
number: int
|
||||
title: str
|
||||
html_url: str
|
||||
|
||||
|
||||
class GiteaClient:
|
||||
"""HTTP client for Gitea's REST API v1.
|
||||
|
||||
All methods return structured results and never raise — errors are
|
||||
logged at WARNING level and indicated via return value.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str | None = None,
|
||||
token: str | None = None,
|
||||
repo: str | None = None,
|
||||
) -> None:
|
||||
from config import settings
|
||||
|
||||
self._base_url = (base_url or settings.gitea_url).rstrip("/")
|
||||
self._token = token or settings.gitea_token
|
||||
self._repo = repo or settings.gitea_repo
|
||||
|
||||
# ── internal ────────────────────────────────────────────────────────────
|
||||
|
||||
def _headers(self) -> dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"token {self._token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def _api(self, path: str) -> str:
|
||||
return f"{self._base_url}/api/v1/{path.lstrip('/')}"
|
||||
|
||||
# ── public API ───────────────────────────────────────────────────────────
|
||||
|
||||
def create_pull_request(
|
||||
self,
|
||||
title: str,
|
||||
body: str,
|
||||
head: str,
|
||||
base: str = "main",
|
||||
) -> PullRequest | None:
|
||||
"""Open a pull request.
|
||||
|
||||
Args:
|
||||
title: PR title (keep under 70 chars).
|
||||
body: PR body in markdown.
|
||||
head: Source branch (e.g. ``self-modify/issue-983``).
|
||||
base: Target branch (default ``main``).
|
||||
|
||||
Returns:
|
||||
A ``PullRequest`` dataclass on success, ``None`` on failure.
|
||||
"""
|
||||
if not self._token:
|
||||
logger.warning("Gitea token not configured — skipping PR creation")
|
||||
return None
|
||||
|
||||
try:
|
||||
import requests as _requests
|
||||
|
||||
resp = _requests.post(
|
||||
self._api(f"repos/{self._repo}/pulls"),
|
||||
headers=self._headers(),
|
||||
json={"title": title, "body": body, "head": head, "base": base},
|
||||
timeout=15,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
pr = PullRequest(
|
||||
number=data["number"],
|
||||
title=data["title"],
|
||||
html_url=data["html_url"],
|
||||
)
|
||||
logger.info("PR #%d created: %s", pr.number, pr.html_url)
|
||||
return pr
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to create PR: %s", exc)
|
||||
return None
|
||||
|
||||
def add_issue_comment(self, issue_number: int, body: str) -> bool:
|
||||
"""Post a comment on an issue or PR.
|
||||
|
||||
Returns:
|
||||
True on success, False on failure.
|
||||
"""
|
||||
if not self._token:
|
||||
logger.warning("Gitea token not configured — skipping issue comment")
|
||||
return False
|
||||
|
||||
try:
|
||||
import requests as _requests
|
||||
|
||||
resp = _requests.post(
|
||||
self._api(f"repos/{self._repo}/issues/{issue_number}/comments"),
|
||||
headers=self._headers(),
|
||||
json={"body": body},
|
||||
timeout=15,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
logger.info("Comment posted on issue #%d", issue_number)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to post comment on issue #%d: %s", issue_number, exc)
|
||||
return False
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
gitea_client = GiteaClient()
|
||||
1
src/self_coding/self_modify/__init__.py
Normal file
1
src/self_coding/self_modify/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Self-modification loop sub-package."""
|
||||
301
src/self_coding/self_modify/loop.py
Normal file
301
src/self_coding/self_modify/loop.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""Self-modification loop — branch → edit → test → commit/revert.
|
||||
|
||||
Timmy's self-coding capability, restored after deletion in
|
||||
Operation Darling Purge (commit 584eeb679e88).
|
||||
|
||||
## Cycle
|
||||
1. **Branch** — create ``self-modify/<slug>`` from ``main``
|
||||
2. **Edit** — apply the proposed change (patch string or callable)
|
||||
3. **Test** — run ``pytest tests/ -x -q``; never commit on failure
|
||||
4. **Commit** — stage and commit on green; revert branch on red
|
||||
5. **PR** — open a Gitea pull request (requires no direct push to main)
|
||||
|
||||
## Guards
|
||||
- Never push directly to ``main`` or ``master``
|
||||
- All changes land via PR (enforced by ``_guard_branch``)
|
||||
- Test gate is mandatory; ``skip_tests=True`` is for unit-test use only
|
||||
- Commits only happen when ``pytest tests/ -x -q`` exits 0
|
||||
|
||||
## Usage::
|
||||
|
||||
from self_coding.self_modify.loop import SelfModifyLoop
|
||||
|
||||
loop = SelfModifyLoop()
|
||||
result = await loop.run(
|
||||
slug="add-hello-tool",
|
||||
description="Add hello() convenience tool",
|
||||
edit_fn=my_edit_function, # callable(repo_root: str) -> None
|
||||
)
|
||||
if result.success:
|
||||
print(f"PR: {result.pr_url}")
|
||||
else:
|
||||
print(f"Failed: {result.error}")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Branches that must never receive direct commits
|
||||
_PROTECTED_BRANCHES = frozenset({"main", "master", "develop"})
|
||||
|
||||
# Test command used as the commit gate
|
||||
_TEST_COMMAND = ["pytest", "tests/", "-x", "-q", "--tb=short"]
|
||||
|
||||
# Max time (seconds) to wait for the test suite
|
||||
_TEST_TIMEOUT = 300
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoopResult:
|
||||
"""Result from one self-modification cycle."""
|
||||
|
||||
success: bool
|
||||
branch: str = ""
|
||||
commit_sha: str = ""
|
||||
pr_url: str = ""
|
||||
pr_number: int = 0
|
||||
test_output: str = ""
|
||||
error: str = ""
|
||||
elapsed_ms: float = 0.0
|
||||
metadata: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
class SelfModifyLoop:
|
||||
"""Orchestrate branch → edit → test → commit/revert → PR.
|
||||
|
||||
Args:
|
||||
repo_root: Absolute path to the git repository (defaults to
|
||||
``settings.repo_root``).
|
||||
remote: Git remote name (default ``origin``).
|
||||
base_branch: Branch to fork from and target for the PR
|
||||
(default ``main``).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_root: str | None = None,
|
||||
remote: str = "origin",
|
||||
base_branch: str = "main",
|
||||
) -> None:
|
||||
self._repo_root = Path(repo_root or settings.repo_root)
|
||||
self._remote = remote
|
||||
self._base_branch = base_branch
|
||||
|
||||
# ── public ──────────────────────────────────────────────────────────────
|
||||
|
||||
async def run(
|
||||
self,
|
||||
slug: str,
|
||||
description: str,
|
||||
edit_fn: Callable[[str], None],
|
||||
issue_number: int | None = None,
|
||||
skip_tests: bool = False,
|
||||
) -> LoopResult:
|
||||
"""Execute one full self-modification cycle.
|
||||
|
||||
Args:
|
||||
slug: Short identifier used for the branch name
|
||||
(e.g. ``"add-hello-tool"``).
|
||||
description: Human-readable description for commit message
|
||||
and PR body.
|
||||
edit_fn: Callable that receives the repo root path (str)
|
||||
and applies the desired code changes in-place.
|
||||
issue_number: Optional Gitea issue number to reference in PR.
|
||||
skip_tests: If ``True``, skip the test gate (unit-test use
|
||||
only — never use in production).
|
||||
|
||||
Returns:
|
||||
:class:`LoopResult` describing the outcome.
|
||||
"""
|
||||
start = time.time()
|
||||
branch = f"self-modify/{slug}"
|
||||
|
||||
try:
|
||||
self._guard_branch(branch)
|
||||
self._checkout_base()
|
||||
self._create_branch(branch)
|
||||
|
||||
try:
|
||||
edit_fn(str(self._repo_root))
|
||||
except Exception as exc:
|
||||
self._revert_branch(branch)
|
||||
return LoopResult(
|
||||
success=False,
|
||||
branch=branch,
|
||||
error=f"edit_fn raised: {exc}",
|
||||
elapsed_ms=self._elapsed(start),
|
||||
)
|
||||
|
||||
if not skip_tests:
|
||||
test_output, passed = self._run_tests()
|
||||
if not passed:
|
||||
self._revert_branch(branch)
|
||||
return LoopResult(
|
||||
success=False,
|
||||
branch=branch,
|
||||
test_output=test_output,
|
||||
error="Tests failed — branch reverted",
|
||||
elapsed_ms=self._elapsed(start),
|
||||
)
|
||||
else:
|
||||
test_output = "(tests skipped)"
|
||||
|
||||
sha = self._commit_all(description)
|
||||
self._push_branch(branch)
|
||||
|
||||
pr = self._create_pr(
|
||||
branch=branch,
|
||||
description=description,
|
||||
test_output=test_output,
|
||||
issue_number=issue_number,
|
||||
)
|
||||
|
||||
return LoopResult(
|
||||
success=True,
|
||||
branch=branch,
|
||||
commit_sha=sha,
|
||||
pr_url=pr.html_url if pr else "",
|
||||
pr_number=pr.number if pr else 0,
|
||||
test_output=test_output,
|
||||
elapsed_ms=self._elapsed(start),
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Self-modify loop failed: %s", exc)
|
||||
return LoopResult(
|
||||
success=False,
|
||||
branch=branch,
|
||||
error=str(exc),
|
||||
elapsed_ms=self._elapsed(start),
|
||||
)
|
||||
|
||||
# ── private helpers ──────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _elapsed(start: float) -> float:
|
||||
return (time.time() - start) * 1000
|
||||
|
||||
def _git(self, *args: str, check: bool = True) -> subprocess.CompletedProcess:
|
||||
"""Run a git command in the repo root."""
|
||||
cmd = ["git", *args]
|
||||
logger.debug("git %s", " ".join(args))
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
cwd=str(self._repo_root),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=check,
|
||||
)
|
||||
|
||||
def _guard_branch(self, branch: str) -> None:
|
||||
"""Raise if the target branch is a protected branch name."""
|
||||
if branch in _PROTECTED_BRANCHES:
|
||||
raise ValueError(
|
||||
f"Refusing to operate on protected branch '{branch}'. "
|
||||
"All self-modifications must go via PR."
|
||||
)
|
||||
|
||||
def _checkout_base(self) -> None:
|
||||
"""Checkout the base branch and pull latest."""
|
||||
self._git("checkout", self._base_branch)
|
||||
# Best-effort pull; ignore failures (e.g. no remote configured)
|
||||
self._git("pull", self._remote, self._base_branch, check=False)
|
||||
|
||||
def _create_branch(self, branch: str) -> None:
|
||||
"""Create and checkout a new branch, deleting an old one if needed."""
|
||||
# Delete local branch if it already exists (stale prior attempt)
|
||||
self._git("branch", "-D", branch, check=False)
|
||||
self._git("checkout", "-b", branch)
|
||||
logger.info("Created branch: %s", branch)
|
||||
|
||||
def _revert_branch(self, branch: str) -> None:
|
||||
"""Checkout base and delete the failed branch."""
|
||||
try:
|
||||
self._git("checkout", self._base_branch, check=False)
|
||||
self._git("branch", "-D", branch, check=False)
|
||||
logger.info("Reverted and deleted branch: %s", branch)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to revert branch %s: %s", branch, exc)
|
||||
|
||||
def _run_tests(self) -> tuple[str, bool]:
|
||||
"""Run the test suite. Returns (output, passed)."""
|
||||
logger.info("Running test suite: %s", " ".join(_TEST_COMMAND))
|
||||
try:
|
||||
result = subprocess.run(
|
||||
_TEST_COMMAND,
|
||||
cwd=str(self._repo_root),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=_TEST_TIMEOUT,
|
||||
)
|
||||
output = (result.stdout + "\n" + result.stderr).strip()
|
||||
passed = result.returncode == 0
|
||||
logger.info(
|
||||
"Test suite %s (exit %d)", "PASSED" if passed else "FAILED", result.returncode
|
||||
)
|
||||
return output, passed
|
||||
except subprocess.TimeoutExpired:
|
||||
msg = f"Test suite timed out after {_TEST_TIMEOUT}s"
|
||||
logger.warning(msg)
|
||||
return msg, False
|
||||
except FileNotFoundError:
|
||||
msg = "pytest not found on PATH"
|
||||
logger.warning(msg)
|
||||
return msg, False
|
||||
|
||||
def _commit_all(self, message: str) -> str:
|
||||
"""Stage all changes and create a commit. Returns the new SHA."""
|
||||
self._git("add", "-A")
|
||||
self._git("commit", "-m", message)
|
||||
result = self._git("rev-parse", "HEAD")
|
||||
sha = result.stdout.strip()
|
||||
logger.info("Committed: %s sha=%s", message[:60], sha[:12])
|
||||
return sha
|
||||
|
||||
def _push_branch(self, branch: str) -> None:
|
||||
"""Push the branch to the remote."""
|
||||
self._git("push", "-u", self._remote, branch)
|
||||
logger.info("Pushed branch: %s -> %s", branch, self._remote)
|
||||
|
||||
def _create_pr(
|
||||
self,
|
||||
branch: str,
|
||||
description: str,
|
||||
test_output: str,
|
||||
issue_number: int | None,
|
||||
):
|
||||
"""Open a Gitea PR. Returns PullRequest or None on failure."""
|
||||
from self_coding.gitea_client import GiteaClient
|
||||
|
||||
client = GiteaClient()
|
||||
|
||||
issue_ref = f"\n\nFixes #{issue_number}" if issue_number else ""
|
||||
test_section = (
|
||||
f"\n\n## Test results\n```\n{test_output[:2000]}\n```"
|
||||
if test_output and test_output != "(tests skipped)"
|
||||
else ""
|
||||
)
|
||||
|
||||
body = (
|
||||
f"## Summary\n{description}"
|
||||
f"{issue_ref}"
|
||||
f"{test_section}"
|
||||
"\n\n🤖 Generated by Timmy's self-modification loop"
|
||||
)
|
||||
|
||||
return client.create_pull_request(
|
||||
title=f"[self-modify] {description[:60]}",
|
||||
body=body,
|
||||
head=branch,
|
||||
base=self._base_branch,
|
||||
)
|
||||
@@ -312,6 +312,13 @@ async def _handle_step_failure(
|
||||
"adaptation": step.result[:200],
|
||||
},
|
||||
)
|
||||
_log_self_correction(
|
||||
task_id=task_id,
|
||||
step_desc=step_desc,
|
||||
exc=exc,
|
||||
outcome=step.result,
|
||||
outcome_status="success",
|
||||
)
|
||||
if on_progress:
|
||||
await on_progress(f"[Adapted] {step_desc}", step_num, total_steps)
|
||||
except Exception as adapt_exc: # broad catch intentional
|
||||
@@ -325,9 +332,42 @@ async def _handle_step_failure(
|
||||
duration_ms=int((time.monotonic() - step_start) * 1000),
|
||||
)
|
||||
)
|
||||
_log_self_correction(
|
||||
task_id=task_id,
|
||||
step_desc=step_desc,
|
||||
exc=exc,
|
||||
outcome=f"Adaptation also failed: {adapt_exc}",
|
||||
outcome_status="failed",
|
||||
)
|
||||
completed_results.append(f"Step {step_num}: FAILED")
|
||||
|
||||
|
||||
def _log_self_correction(
|
||||
*,
|
||||
task_id: str,
|
||||
step_desc: str,
|
||||
exc: Exception,
|
||||
outcome: str,
|
||||
outcome_status: str,
|
||||
) -> None:
|
||||
"""Best-effort: log a self-correction event (never raises)."""
|
||||
try:
|
||||
from infrastructure.self_correction import log_self_correction
|
||||
|
||||
log_self_correction(
|
||||
source="agentic_loop",
|
||||
original_intent=step_desc,
|
||||
detected_error=f"{type(exc).__name__}: {exc}",
|
||||
correction_strategy="Adaptive re-plan via LLM",
|
||||
final_outcome=outcome[:500],
|
||||
task_id=task_id,
|
||||
outcome_status=outcome_status,
|
||||
error_type=type(exc).__name__,
|
||||
)
|
||||
except Exception as log_exc:
|
||||
logger.debug("Self-correction log failed: %s", log_exc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
528
src/timmy/research.py
Normal file
528
src/timmy/research.py
Normal file
@@ -0,0 +1,528 @@
|
||||
"""Research Orchestrator — autonomous, sovereign research pipeline.
|
||||
|
||||
Chains all six steps of the research workflow with local-first execution:
|
||||
|
||||
Step 0 Cache — check semantic memory (SQLite, instant, zero API cost)
|
||||
Step 1 Scope — load a research template from skills/research/
|
||||
Step 2 Query — slot-fill template + formulate 5-15 search queries via Ollama
|
||||
Step 3 Search — execute queries via web_search (SerpAPI or fallback)
|
||||
Step 4 Fetch — download + extract full pages via web_fetch (trafilatura)
|
||||
Step 5 Synth — compress findings into a structured report via cascade
|
||||
Step 6 Deliver — store to semantic memory; optionally save to docs/research/
|
||||
|
||||
Cascade tiers for synthesis (spec §4):
|
||||
Tier 4 SQLite semantic cache — instant, free, covers ~80% after warm-up
|
||||
Tier 3 Ollama (qwen3:14b) — local, free, good quality
|
||||
Tier 2 Claude API (haiku) — cloud fallback, cheap, set ANTHROPIC_API_KEY
|
||||
Tier 1 (future) Groq — free-tier rate-limited, tracked in #980
|
||||
|
||||
All optional services degrade gracefully per project conventions.
|
||||
|
||||
Refs #972 (governing spec), #975 (ResearchOrchestrator sub-issue).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import textwrap
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Optional memory imports — available at module level so tests can patch them.
|
||||
try:
|
||||
from timmy.memory_system import SemanticMemory, store_memory
|
||||
except Exception: # pragma: no cover
|
||||
SemanticMemory = None # type: ignore[assignment,misc]
|
||||
store_memory = None # type: ignore[assignment]
|
||||
|
||||
# Root of the project — two levels up from src/timmy/
|
||||
_PROJECT_ROOT = Path(__file__).parent.parent.parent
|
||||
_SKILLS_ROOT = _PROJECT_ROOT / "skills" / "research"
|
||||
_DOCS_ROOT = _PROJECT_ROOT / "docs" / "research"
|
||||
|
||||
# Similarity threshold for cache hit (0–1 cosine similarity)
|
||||
_CACHE_HIT_THRESHOLD = 0.82
|
||||
|
||||
# How many search result URLs to fetch as full pages
|
||||
_FETCH_TOP_N = 5
|
||||
|
||||
# Maximum tokens to request from the synthesis LLM
|
||||
_SYNTHESIS_MAX_TOKENS = 4096
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data structures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResearchResult:
|
||||
"""Full output of a research pipeline run."""
|
||||
|
||||
topic: str
|
||||
query_count: int
|
||||
sources_fetched: int
|
||||
report: str
|
||||
cached: bool = False
|
||||
cache_similarity: float = 0.0
|
||||
synthesis_backend: str = "unknown"
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return not self.report.strip()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Template loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def list_templates() -> list[str]:
|
||||
"""Return names of available research templates (without .md extension)."""
|
||||
if not _SKILLS_ROOT.exists():
|
||||
return []
|
||||
return [p.stem for p in sorted(_SKILLS_ROOT.glob("*.md"))]
|
||||
|
||||
|
||||
def load_template(template_name: str, slots: dict[str, str] | None = None) -> str:
|
||||
"""Load a research template and fill {slot} placeholders.
|
||||
|
||||
Args:
|
||||
template_name: Stem of the .md file under skills/research/ (e.g. "tool_evaluation").
|
||||
slots: Mapping of {placeholder} → replacement value.
|
||||
|
||||
Returns:
|
||||
Template text with slots filled. Unfilled slots are left as-is.
|
||||
"""
|
||||
path = _SKILLS_ROOT / f"{template_name}.md"
|
||||
if not path.exists():
|
||||
available = ", ".join(list_templates()) or "(none)"
|
||||
raise FileNotFoundError(
|
||||
f"Research template {template_name!r} not found. "
|
||||
f"Available: {available}"
|
||||
)
|
||||
|
||||
text = path.read_text(encoding="utf-8")
|
||||
|
||||
# Strip YAML frontmatter (--- ... ---), including empty frontmatter (--- \n---)
|
||||
text = re.sub(r"^---\n.*?---\n", "", text, flags=re.DOTALL)
|
||||
|
||||
if slots:
|
||||
for key, value in slots.items():
|
||||
text = text.replace(f"{{{key}}}", value)
|
||||
|
||||
return text.strip()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Query formulation (Step 2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _formulate_queries(topic: str, template_context: str, n: int = 8) -> list[str]:
|
||||
"""Use the local LLM to generate targeted search queries for a topic.
|
||||
|
||||
Falls back to a simple heuristic if Ollama is unavailable.
|
||||
"""
|
||||
prompt = textwrap.dedent(f"""\
|
||||
You are a research assistant. Generate exactly {n} targeted, specific web search
|
||||
queries to thoroughly research the following topic.
|
||||
|
||||
TOPIC: {topic}
|
||||
|
||||
RESEARCH CONTEXT:
|
||||
{template_context[:1000]}
|
||||
|
||||
Rules:
|
||||
- One query per line, no numbering, no bullet points.
|
||||
- Vary the angle (definition, comparison, implementation, alternatives, pitfalls).
|
||||
- Prefer exact technical terms, tool names, and version numbers where relevant.
|
||||
- Output ONLY the queries, nothing else.
|
||||
""")
|
||||
|
||||
queries = await _ollama_complete(prompt, max_tokens=512)
|
||||
|
||||
if not queries:
|
||||
# Minimal fallback
|
||||
return [
|
||||
f"{topic} overview",
|
||||
f"{topic} tutorial",
|
||||
f"{topic} best practices",
|
||||
f"{topic} alternatives",
|
||||
f"{topic} 2025",
|
||||
]
|
||||
|
||||
lines = [ln.strip() for ln in queries.splitlines() if ln.strip()]
|
||||
return lines[:n] if len(lines) >= n else lines
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Search (Step 3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _execute_search(queries: list[str]) -> list[dict[str, str]]:
|
||||
"""Run each query through the available web search backend.
|
||||
|
||||
Returns a flat list of {title, url, snippet} dicts.
|
||||
Degrades gracefully if SerpAPI key is absent.
|
||||
"""
|
||||
results: list[dict[str, str]] = []
|
||||
seen_urls: set[str] = set()
|
||||
|
||||
for query in queries:
|
||||
try:
|
||||
raw = await asyncio.to_thread(_run_search_sync, query)
|
||||
for item in raw:
|
||||
url = item.get("url", "")
|
||||
if url and url not in seen_urls:
|
||||
seen_urls.add(url)
|
||||
results.append(item)
|
||||
except Exception as exc:
|
||||
logger.warning("Search failed for query %r: %s", query, exc)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _run_search_sync(query: str) -> list[dict[str, str]]:
|
||||
"""Synchronous search — wraps SerpAPI or returns empty on missing key."""
|
||||
import os
|
||||
|
||||
if not os.environ.get("SERPAPI_API_KEY"):
|
||||
logger.debug("SERPAPI_API_KEY not set — skipping web search for %r", query)
|
||||
return []
|
||||
|
||||
try:
|
||||
from serpapi import GoogleSearch
|
||||
|
||||
params = {"q": query, "api_key": os.environ["SERPAPI_API_KEY"], "num": 5}
|
||||
search = GoogleSearch(params)
|
||||
data = search.get_dict()
|
||||
items = []
|
||||
for r in data.get("organic_results", []):
|
||||
items.append(
|
||||
{
|
||||
"title": r.get("title", ""),
|
||||
"url": r.get("link", ""),
|
||||
"snippet": r.get("snippet", ""),
|
||||
}
|
||||
)
|
||||
return items
|
||||
except Exception as exc:
|
||||
logger.warning("SerpAPI search error: %s", exc)
|
||||
return []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fetch (Step 4)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _fetch_pages(results: list[dict[str, str]], top_n: int = _FETCH_TOP_N) -> list[str]:
|
||||
"""Download and extract full text for the top search results.
|
||||
|
||||
Uses web_fetch (trafilatura) from timmy.tools.system_tools.
|
||||
"""
|
||||
try:
|
||||
from timmy.tools.system_tools import web_fetch
|
||||
except ImportError:
|
||||
logger.warning("web_fetch not available — skipping page fetch")
|
||||
return []
|
||||
|
||||
pages: list[str] = []
|
||||
for item in results[:top_n]:
|
||||
url = item.get("url", "")
|
||||
if not url:
|
||||
continue
|
||||
try:
|
||||
text = await asyncio.to_thread(web_fetch, url, 6000)
|
||||
if text and not text.startswith("Error:"):
|
||||
pages.append(f"## {item.get('title', url)}\nSource: {url}\n\n{text}")
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch %s: %s", url, exc)
|
||||
|
||||
return pages
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Synthesis (Step 5) — cascade: Ollama → Claude fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _synthesize(topic: str, pages: list[str], snippets: list[str]) -> tuple[str, str]:
|
||||
"""Compress fetched pages + snippets into a structured research report.
|
||||
|
||||
Returns (report_markdown, backend_used).
|
||||
"""
|
||||
# Build synthesis prompt
|
||||
source_content = "\n\n---\n\n".join(pages[:5])
|
||||
if not source_content and snippets:
|
||||
source_content = "\n".join(f"- {s}" for s in snippets[:20])
|
||||
|
||||
if not source_content:
|
||||
return (
|
||||
f"# Research: {topic}\n\n*No source material was retrieved. "
|
||||
"Check SERPAPI_API_KEY and network connectivity.*",
|
||||
"none",
|
||||
)
|
||||
|
||||
prompt = textwrap.dedent(f"""\
|
||||
You are a senior technical researcher. Synthesize the source material below
|
||||
into a structured research report on the topic: **{topic}**
|
||||
|
||||
FORMAT YOUR REPORT AS:
|
||||
# {topic}
|
||||
|
||||
## Executive Summary
|
||||
(2-3 sentences: what you found, top recommendation)
|
||||
|
||||
## Key Findings
|
||||
(Bullet list of the most important facts, tools, or patterns)
|
||||
|
||||
## Comparison / Options
|
||||
(Table or list comparing alternatives where applicable)
|
||||
|
||||
## Recommended Approach
|
||||
(Concrete recommendation with rationale)
|
||||
|
||||
## Gaps & Next Steps
|
||||
(What wasn't answered, what to investigate next)
|
||||
|
||||
---
|
||||
SOURCE MATERIAL:
|
||||
{source_content[:12000]}
|
||||
""")
|
||||
|
||||
# Tier 3 — try Ollama first
|
||||
report = await _ollama_complete(prompt, max_tokens=_SYNTHESIS_MAX_TOKENS)
|
||||
if report:
|
||||
return report, "ollama"
|
||||
|
||||
# Tier 2 — Claude fallback
|
||||
report = await _claude_complete(prompt, max_tokens=_SYNTHESIS_MAX_TOKENS)
|
||||
if report:
|
||||
return report, "claude"
|
||||
|
||||
# Last resort — structured snippet summary
|
||||
summary = f"# {topic}\n\n## Snippets\n\n" + "\n\n".join(
|
||||
f"- {s}" for s in snippets[:15]
|
||||
)
|
||||
return summary, "fallback"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _ollama_complete(prompt: str, max_tokens: int = 1024) -> str:
|
||||
"""Send a prompt to Ollama and return the response text.
|
||||
|
||||
Returns empty string on failure (graceful degradation).
|
||||
"""
|
||||
try:
|
||||
import httpx
|
||||
|
||||
from config import settings
|
||||
|
||||
url = f"{settings.normalized_ollama_url}/api/generate"
|
||||
payload: dict[str, Any] = {
|
||||
"model": settings.ollama_model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"num_predict": max_tokens,
|
||||
"temperature": 0.3,
|
||||
},
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
resp = await client.post(url, json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data.get("response", "").strip()
|
||||
except Exception as exc:
|
||||
logger.warning("Ollama completion failed: %s", exc)
|
||||
return ""
|
||||
|
||||
|
||||
async def _claude_complete(prompt: str, max_tokens: int = 1024) -> str:
|
||||
"""Send a prompt to Claude API as a last-resort fallback.
|
||||
|
||||
Only active when ANTHROPIC_API_KEY is configured.
|
||||
Returns empty string on failure or missing key.
|
||||
"""
|
||||
try:
|
||||
from config import settings
|
||||
|
||||
if not settings.anthropic_api_key:
|
||||
return ""
|
||||
|
||||
from timmy.backends import ClaudeBackend
|
||||
|
||||
backend = ClaudeBackend()
|
||||
result = await asyncio.to_thread(backend.run, prompt)
|
||||
return result.content.strip()
|
||||
except Exception as exc:
|
||||
logger.warning("Claude fallback failed: %s", exc)
|
||||
return ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Memory cache (Step 0 + Step 6)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _check_cache(topic: str) -> tuple[str | None, float]:
|
||||
"""Search semantic memory for a prior result on this topic.
|
||||
|
||||
Returns (cached_report, similarity) or (None, 0.0).
|
||||
"""
|
||||
try:
|
||||
if SemanticMemory is None:
|
||||
return None, 0.0
|
||||
mem = SemanticMemory()
|
||||
hits = mem.search(topic, top_k=1)
|
||||
if hits:
|
||||
content, score = hits[0]
|
||||
if score >= _CACHE_HIT_THRESHOLD:
|
||||
return content, score
|
||||
except Exception as exc:
|
||||
logger.debug("Cache check failed: %s", exc)
|
||||
return None, 0.0
|
||||
|
||||
|
||||
def _store_result(topic: str, report: str) -> None:
|
||||
"""Index the research report into semantic memory for future retrieval."""
|
||||
try:
|
||||
if store_memory is None:
|
||||
logger.debug("store_memory not available — skipping memory index")
|
||||
return
|
||||
store_memory(
|
||||
content=report,
|
||||
source="research_pipeline",
|
||||
context_type="research",
|
||||
metadata={"topic": topic},
|
||||
)
|
||||
logger.info("Research result indexed for topic: %r", topic)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to store research result: %s", exc)
|
||||
|
||||
|
||||
def _save_to_disk(topic: str, report: str) -> Path | None:
|
||||
"""Persist the report as a markdown file under docs/research/.
|
||||
|
||||
Filename is derived from the topic (slugified). Returns the path or None.
|
||||
"""
|
||||
try:
|
||||
slug = re.sub(r"[^a-z0-9]+", "-", topic.lower()).strip("-")[:60]
|
||||
_DOCS_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
path = _DOCS_ROOT / f"{slug}.md"
|
||||
path.write_text(report, encoding="utf-8")
|
||||
logger.info("Research report saved to %s", path)
|
||||
return path
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to save research report to disk: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main orchestrator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_research(
|
||||
topic: str,
|
||||
template: str | None = None,
|
||||
slots: dict[str, str] | None = None,
|
||||
save_to_disk: bool = False,
|
||||
skip_cache: bool = False,
|
||||
) -> ResearchResult:
|
||||
"""Run the full 6-step autonomous research pipeline.
|
||||
|
||||
Args:
|
||||
topic: The research question or subject.
|
||||
template: Name of a template from skills/research/ (e.g. "tool_evaluation").
|
||||
If None, runs without a template scaffold.
|
||||
slots: Placeholder values for the template (e.g. {"domain": "PDF parsing"}).
|
||||
save_to_disk: If True, write the report to docs/research/<slug>.md.
|
||||
skip_cache: If True, bypass the semantic memory cache.
|
||||
|
||||
Returns:
|
||||
ResearchResult with report and metadata.
|
||||
"""
|
||||
errors: list[str] = []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 0 — check cache
|
||||
# ------------------------------------------------------------------
|
||||
if not skip_cache:
|
||||
cached, score = _check_cache(topic)
|
||||
if cached:
|
||||
logger.info("Cache hit (%.2f) for topic: %r", score, topic)
|
||||
return ResearchResult(
|
||||
topic=topic,
|
||||
query_count=0,
|
||||
sources_fetched=0,
|
||||
report=cached,
|
||||
cached=True,
|
||||
cache_similarity=score,
|
||||
synthesis_backend="cache",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 1 — load template (optional)
|
||||
# ------------------------------------------------------------------
|
||||
template_context = ""
|
||||
if template:
|
||||
try:
|
||||
template_context = load_template(template, slots)
|
||||
except FileNotFoundError as exc:
|
||||
errors.append(str(exc))
|
||||
logger.warning("Template load failed: %s", exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2 — formulate queries
|
||||
# ------------------------------------------------------------------
|
||||
queries = await _formulate_queries(topic, template_context)
|
||||
logger.info("Formulated %d queries for topic: %r", len(queries), topic)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 3 — execute search
|
||||
# ------------------------------------------------------------------
|
||||
search_results = await _execute_search(queries)
|
||||
logger.info("Search returned %d results", len(search_results))
|
||||
snippets = [r.get("snippet", "") for r in search_results if r.get("snippet")]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 4 — fetch full pages
|
||||
# ------------------------------------------------------------------
|
||||
pages = await _fetch_pages(search_results)
|
||||
logger.info("Fetched %d pages", len(pages))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 5 — synthesize
|
||||
# ------------------------------------------------------------------
|
||||
report, backend = await _synthesize(topic, pages, snippets)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 6 — deliver
|
||||
# ------------------------------------------------------------------
|
||||
_store_result(topic, report)
|
||||
if save_to_disk:
|
||||
_save_to_disk(topic, report)
|
||||
|
||||
return ResearchResult(
|
||||
topic=topic,
|
||||
query_count=len(queries),
|
||||
sources_fetched=len(pages),
|
||||
report=report,
|
||||
cached=False,
|
||||
synthesis_backend=backend,
|
||||
errors=errors,
|
||||
)
|
||||
@@ -8,4 +8,23 @@ Refs: #954, #953
|
||||
Three-strike detector and automation enforcement.
|
||||
|
||||
Refs: #962
|
||||
|
||||
Session reporting: auto-generates markdown scorecards at session end
|
||||
and commits them to the Gitea repo for institutional memory.
|
||||
|
||||
Refs: #957 (Session Sovereignty Report Generator)
|
||||
"""
|
||||
|
||||
from timmy.sovereignty.session_report import (
|
||||
commit_report,
|
||||
generate_and_commit_report,
|
||||
generate_report,
|
||||
mark_session_start,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"generate_report",
|
||||
"commit_report",
|
||||
"generate_and_commit_report",
|
||||
"mark_session_start",
|
||||
]
|
||||
|
||||
442
src/timmy/sovereignty/session_report.py
Normal file
442
src/timmy/sovereignty/session_report.py
Normal file
@@ -0,0 +1,442 @@
|
||||
"""Session Sovereignty Report Generator.
|
||||
|
||||
Auto-generates a sovereignty scorecard at the end of each play session
|
||||
and commits it as a markdown file to the Gitea repo under
|
||||
``reports/sovereignty/``.
|
||||
|
||||
Report contents (per issue #957):
|
||||
- Session duration + game played
|
||||
- Total model calls by type (VLM, LLM, TTS, API)
|
||||
- Total cache/rule hits by type
|
||||
- New skills crystallized (placeholder — pending skill-tracking impl)
|
||||
- Sovereignty delta (change from session start → end)
|
||||
- Cost breakdown (actual API spend)
|
||||
- Per-layer sovereignty %: perception, decision, narration
|
||||
- Trend comparison vs previous session
|
||||
|
||||
Refs: #957 (Sovereignty P0) · #953 (The Sovereignty Loop)
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from config import settings
|
||||
|
||||
# Optional module-level imports — degrade gracefully if unavailable at import time
|
||||
try:
|
||||
from timmy.session_logger import get_session_logger
|
||||
except Exception: # ImportError or circular import during early startup
|
||||
get_session_logger = None # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
from infrastructure.sovereignty_metrics import GRADUATION_TARGETS, get_sovereignty_store
|
||||
except Exception:
|
||||
GRADUATION_TARGETS: dict = {} # type: ignore[assignment]
|
||||
get_sovereignty_store = None # type: ignore[assignment]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Module-level session start time; set by mark_session_start()
|
||||
_SESSION_START: datetime | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def mark_session_start() -> None:
|
||||
"""Record the session start wall-clock time.
|
||||
|
||||
Call once during application startup so ``generate_report()`` can
|
||||
compute accurate session durations.
|
||||
"""
|
||||
global _SESSION_START
|
||||
_SESSION_START = datetime.now(UTC)
|
||||
logger.debug("Sovereignty: session start recorded at %s", _SESSION_START.isoformat())
|
||||
|
||||
|
||||
def generate_report(session_id: str = "dashboard") -> str:
|
||||
"""Render a sovereignty scorecard as a markdown string.
|
||||
|
||||
Pulls from:
|
||||
- ``timmy.session_logger`` — message/tool-call/error counts
|
||||
- ``infrastructure.sovereignty_metrics`` — cache hit rate, API cost,
|
||||
graduation phase, and trend data
|
||||
|
||||
Args:
|
||||
session_id: The session identifier (default: "dashboard").
|
||||
|
||||
Returns:
|
||||
Markdown-formatted sovereignty report string.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
session_start = _SESSION_START or now
|
||||
duration_secs = (now - session_start).total_seconds()
|
||||
|
||||
session_data = _gather_session_data()
|
||||
sov_data = _gather_sovereignty_data()
|
||||
|
||||
return _render_markdown(now, session_id, duration_secs, session_data, sov_data)
|
||||
|
||||
|
||||
def commit_report(report_md: str, session_id: str = "dashboard") -> bool:
|
||||
"""Commit a sovereignty report to the Gitea repo.
|
||||
|
||||
Creates or updates ``reports/sovereignty/{date}_{session_id}.md``
|
||||
via the Gitea Contents API. Degrades gracefully: logs a warning
|
||||
and returns ``False`` if Gitea is unreachable or misconfigured.
|
||||
|
||||
Args:
|
||||
report_md: Markdown content to commit.
|
||||
session_id: Session identifier used in the filename.
|
||||
|
||||
Returns:
|
||||
``True`` on success, ``False`` on failure.
|
||||
"""
|
||||
if not settings.gitea_enabled:
|
||||
logger.info("Sovereignty: Gitea disabled — skipping report commit")
|
||||
return False
|
||||
|
||||
if not settings.gitea_token:
|
||||
logger.warning("Sovereignty: no Gitea token — skipping report commit")
|
||||
return False
|
||||
|
||||
date_str = datetime.now(UTC).strftime("%Y-%m-%d")
|
||||
file_path = f"reports/sovereignty/{date_str}_{session_id}.md"
|
||||
url = f"{settings.gitea_url}/api/v1/repos/{settings.gitea_repo}/contents/{file_path}"
|
||||
headers = {
|
||||
"Authorization": f"token {settings.gitea_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
encoded_content = base64.b64encode(report_md.encode()).decode()
|
||||
commit_message = (
|
||||
f"report: sovereignty session {session_id} ({date_str})\n\n"
|
||||
f"Auto-generated by Timmy. Refs #957"
|
||||
)
|
||||
payload: dict[str, Any] = {
|
||||
"message": commit_message,
|
||||
"content": encoded_content,
|
||||
}
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=10.0) as client:
|
||||
# Fetch existing file SHA so we can update rather than create
|
||||
check = client.get(url, headers=headers)
|
||||
if check.status_code == 200:
|
||||
existing = check.json()
|
||||
payload["sha"] = existing.get("sha", "")
|
||||
|
||||
resp = client.put(url, headers=headers, json=payload)
|
||||
resp.raise_for_status()
|
||||
|
||||
logger.info("Sovereignty: report committed to %s", file_path)
|
||||
return True
|
||||
|
||||
except httpx.HTTPStatusError as exc:
|
||||
logger.warning(
|
||||
"Sovereignty: commit failed (HTTP %s): %s",
|
||||
exc.response.status_code,
|
||||
exc,
|
||||
)
|
||||
return False
|
||||
except Exception as exc:
|
||||
logger.warning("Sovereignty: commit failed: %s", exc)
|
||||
return False
|
||||
|
||||
|
||||
async def generate_and_commit_report(session_id: str = "dashboard") -> bool:
|
||||
"""Generate and commit a sovereignty report for the current session.
|
||||
|
||||
Primary entry point — call at session end / application shutdown.
|
||||
Wraps the synchronous ``commit_report`` call in ``asyncio.to_thread``
|
||||
so it does not block the event loop.
|
||||
|
||||
Args:
|
||||
session_id: The session identifier.
|
||||
|
||||
Returns:
|
||||
``True`` if the report was generated and committed successfully.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
report_md = generate_report(session_id)
|
||||
logger.info("Sovereignty: report generated (%d chars)", len(report_md))
|
||||
committed = await asyncio.to_thread(commit_report, report_md, session_id)
|
||||
return committed
|
||||
except Exception as exc:
|
||||
logger.warning("Sovereignty: report generation failed: %s", exc)
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _format_duration(seconds: float) -> str:
|
||||
"""Format a duration in seconds as a human-readable string."""
|
||||
total = int(seconds)
|
||||
hours, remainder = divmod(total, 3600)
|
||||
minutes, secs = divmod(remainder, 60)
|
||||
if hours:
|
||||
return f"{hours}h {minutes}m {secs}s"
|
||||
if minutes:
|
||||
return f"{minutes}m {secs}s"
|
||||
return f"{secs}s"
|
||||
|
||||
|
||||
def _gather_session_data() -> dict[str, Any]:
|
||||
"""Pull session statistics from the session logger.
|
||||
|
||||
Returns a dict with:
|
||||
- ``user_messages``, ``timmy_messages``, ``tool_calls``, ``errors``
|
||||
- ``tool_call_breakdown``: dict[tool_name, count]
|
||||
"""
|
||||
default: dict[str, Any] = {
|
||||
"user_messages": 0,
|
||||
"timmy_messages": 0,
|
||||
"tool_calls": 0,
|
||||
"errors": 0,
|
||||
"tool_call_breakdown": {},
|
||||
}
|
||||
|
||||
try:
|
||||
if get_session_logger is None:
|
||||
return default
|
||||
sl = get_session_logger()
|
||||
sl.flush()
|
||||
|
||||
# Read today's session file directly for accurate counts
|
||||
if not sl.session_file.exists():
|
||||
return default
|
||||
|
||||
entries: list[dict] = []
|
||||
with open(sl.session_file) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
entries.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
tool_breakdown: dict[str, int] = {}
|
||||
user_msgs = timmy_msgs = tool_calls = errors = 0
|
||||
|
||||
for entry in entries:
|
||||
etype = entry.get("type")
|
||||
if etype == "message":
|
||||
if entry.get("role") == "user":
|
||||
user_msgs += 1
|
||||
elif entry.get("role") == "timmy":
|
||||
timmy_msgs += 1
|
||||
elif etype == "tool_call":
|
||||
tool_calls += 1
|
||||
tool_name = entry.get("tool", "unknown")
|
||||
tool_breakdown[tool_name] = tool_breakdown.get(tool_name, 0) + 1
|
||||
elif etype == "error":
|
||||
errors += 1
|
||||
|
||||
return {
|
||||
"user_messages": user_msgs,
|
||||
"timmy_messages": timmy_msgs,
|
||||
"tool_calls": tool_calls,
|
||||
"errors": errors,
|
||||
"tool_call_breakdown": tool_breakdown,
|
||||
}
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Sovereignty: failed to gather session data: %s", exc)
|
||||
return default
|
||||
|
||||
|
||||
def _gather_sovereignty_data() -> dict[str, Any]:
|
||||
"""Pull sovereignty metrics from the SQLite store.
|
||||
|
||||
Returns a dict with:
|
||||
- ``metrics``: summary from ``SovereigntyMetricsStore.get_summary()``
|
||||
- ``deltas``: per-metric start/end values within recent history window
|
||||
- ``previous_session``: most recent prior value for each metric
|
||||
"""
|
||||
try:
|
||||
if get_sovereignty_store is None:
|
||||
return {"metrics": {}, "deltas": {}, "previous_session": {}}
|
||||
store = get_sovereignty_store()
|
||||
summary = store.get_summary()
|
||||
|
||||
deltas: dict[str, dict[str, Any]] = {}
|
||||
previous_session: dict[str, float | None] = {}
|
||||
|
||||
for metric_type in GRADUATION_TARGETS:
|
||||
history = store.get_latest(metric_type, limit=10)
|
||||
if len(history) >= 2:
|
||||
deltas[metric_type] = {
|
||||
"start": history[-1]["value"],
|
||||
"end": history[0]["value"],
|
||||
}
|
||||
previous_session[metric_type] = history[1]["value"]
|
||||
elif len(history) == 1:
|
||||
deltas[metric_type] = {"start": history[0]["value"], "end": history[0]["value"]}
|
||||
previous_session[metric_type] = None
|
||||
else:
|
||||
deltas[metric_type] = {"start": None, "end": None}
|
||||
previous_session[metric_type] = None
|
||||
|
||||
return {
|
||||
"metrics": summary,
|
||||
"deltas": deltas,
|
||||
"previous_session": previous_session,
|
||||
}
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Sovereignty: failed to gather sovereignty data: %s", exc)
|
||||
return {"metrics": {}, "deltas": {}, "previous_session": {}}
|
||||
|
||||
|
||||
def _render_markdown(
|
||||
now: datetime,
|
||||
session_id: str,
|
||||
duration_secs: float,
|
||||
session_data: dict[str, Any],
|
||||
sov_data: dict[str, Any],
|
||||
) -> str:
|
||||
"""Assemble the full sovereignty report in markdown."""
|
||||
lines: list[str] = []
|
||||
|
||||
# Header
|
||||
lines += [
|
||||
"# Sovereignty Session Report",
|
||||
"",
|
||||
f"**Session ID:** `{session_id}` ",
|
||||
f"**Date:** {now.strftime('%Y-%m-%d')} ",
|
||||
f"**Duration:** {_format_duration(duration_secs)} ",
|
||||
f"**Generated:** {now.isoformat()}",
|
||||
"",
|
||||
"---",
|
||||
"",
|
||||
]
|
||||
|
||||
# Session activity
|
||||
lines += [
|
||||
"## Session Activity",
|
||||
"",
|
||||
"| Metric | Count |",
|
||||
"|--------|-------|",
|
||||
f"| User messages | {session_data['user_messages']} |",
|
||||
f"| Timmy responses | {session_data['timmy_messages']} |",
|
||||
f"| Tool calls | {session_data['tool_calls']} |",
|
||||
f"| Errors | {session_data['errors']} |",
|
||||
"",
|
||||
]
|
||||
|
||||
tool_breakdown = session_data.get("tool_call_breakdown", {})
|
||||
if tool_breakdown:
|
||||
lines += ["### Model Calls by Tool", ""]
|
||||
for tool_name, count in sorted(tool_breakdown.items(), key=lambda x: -x[1]):
|
||||
lines.append(f"- `{tool_name}`: {count}")
|
||||
lines.append("")
|
||||
|
||||
# Sovereignty scorecard
|
||||
|
||||
lines += [
|
||||
"## Sovereignty Scorecard",
|
||||
"",
|
||||
"| Metric | Current | Target (graduation) | Phase |",
|
||||
"|--------|---------|---------------------|-------|",
|
||||
]
|
||||
|
||||
for metric_type, data in sov_data["metrics"].items():
|
||||
current = data.get("current")
|
||||
current_str = f"{current:.4f}" if current is not None else "N/A"
|
||||
grad_target = GRADUATION_TARGETS.get(metric_type, {}).get("graduation")
|
||||
grad_str = f"{grad_target:.4f}" if isinstance(grad_target, (int, float)) else "N/A"
|
||||
phase = data.get("phase", "unknown")
|
||||
lines.append(f"| {metric_type} | {current_str} | {grad_str} | {phase} |")
|
||||
|
||||
lines += ["", "### Sovereignty Delta (This Session)", ""]
|
||||
|
||||
for metric_type, delta_info in sov_data.get("deltas", {}).items():
|
||||
start_val = delta_info.get("start")
|
||||
end_val = delta_info.get("end")
|
||||
if start_val is not None and end_val is not None:
|
||||
diff = end_val - start_val
|
||||
sign = "+" if diff >= 0 else ""
|
||||
lines.append(
|
||||
f"- **{metric_type}**: {start_val:.4f} → {end_val:.4f} ({sign}{diff:.4f})"
|
||||
)
|
||||
else:
|
||||
lines.append(f"- **{metric_type}**: N/A (no data recorded)")
|
||||
|
||||
# Cost breakdown
|
||||
lines += ["", "## Cost Breakdown", ""]
|
||||
api_cost_data = sov_data["metrics"].get("api_cost", {})
|
||||
current_cost = api_cost_data.get("current")
|
||||
if current_cost is not None:
|
||||
lines.append(f"- **Total API spend (latest recorded):** ${current_cost:.4f}")
|
||||
else:
|
||||
lines.append("- **Total API spend:** N/A (no data recorded)")
|
||||
lines.append("")
|
||||
|
||||
# Per-layer sovereignty
|
||||
lines += [
|
||||
"## Per-Layer Sovereignty",
|
||||
"",
|
||||
"| Layer | Sovereignty % |",
|
||||
"|-------|--------------|",
|
||||
"| Perception (VLM) | N/A |",
|
||||
"| Decision (LLM) | N/A |",
|
||||
"| Narration (TTS) | N/A |",
|
||||
"",
|
||||
"> Per-layer tracking requires instrumented inference calls. See #957.",
|
||||
"",
|
||||
]
|
||||
|
||||
# Skills crystallized
|
||||
lines += [
|
||||
"## Skills Crystallized",
|
||||
"",
|
||||
"_Skill crystallization tracking not yet implemented. See #957._",
|
||||
"",
|
||||
]
|
||||
|
||||
# Trend vs previous session
|
||||
lines += ["## Trend vs Previous Session", ""]
|
||||
prev_data = sov_data.get("previous_session", {})
|
||||
has_prev = any(v is not None for v in prev_data.values())
|
||||
|
||||
if has_prev:
|
||||
lines += [
|
||||
"| Metric | Previous | Current | Change |",
|
||||
"|--------|----------|---------|--------|",
|
||||
]
|
||||
for metric_type, curr_info in sov_data["metrics"].items():
|
||||
curr_val = curr_info.get("current")
|
||||
prev_val = prev_data.get(metric_type)
|
||||
curr_str = f"{curr_val:.4f}" if curr_val is not None else "N/A"
|
||||
prev_str = f"{prev_val:.4f}" if prev_val is not None else "N/A"
|
||||
if curr_val is not None and prev_val is not None:
|
||||
diff = curr_val - prev_val
|
||||
sign = "+" if diff >= 0 else ""
|
||||
change_str = f"{sign}{diff:.4f}"
|
||||
else:
|
||||
change_str = "N/A"
|
||||
lines.append(f"| {metric_type} | {prev_str} | {curr_str} | {change_str} |")
|
||||
lines.append("")
|
||||
else:
|
||||
lines += ["_No previous session data available for comparison._", ""]
|
||||
|
||||
# Footer
|
||||
lines += [
|
||||
"---",
|
||||
"_Auto-generated by Timmy · Session Sovereignty Report · Refs: #957_",
|
||||
]
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -2714,3 +2714,74 @@
|
||||
padding: 0.3rem 0.6rem;
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
/* ── Self-Correction Dashboard ─────────────────────────────── */
|
||||
.sc-event {
|
||||
border-left: 3px solid var(--border);
|
||||
padding: 0.6rem 0.8rem;
|
||||
margin-bottom: 0.75rem;
|
||||
background: rgba(255,255,255,0.02);
|
||||
border-radius: 0 4px 4px 0;
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
.sc-event.sc-status-success { border-left-color: var(--green); }
|
||||
.sc-event.sc-status-partial { border-left-color: var(--amber); }
|
||||
.sc-event.sc-status-failed { border-left-color: var(--red); }
|
||||
|
||||
.sc-event-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
margin-bottom: 0.4rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.sc-status-badge {
|
||||
font-size: 0.68rem;
|
||||
font-weight: 700;
|
||||
letter-spacing: 0.06em;
|
||||
padding: 0.15rem 0.45rem;
|
||||
border-radius: 3px;
|
||||
}
|
||||
.sc-status-badge.sc-status-success { color: var(--green); background: rgba(0,255,136,0.08); }
|
||||
.sc-status-badge.sc-status-partial { color: var(--amber); background: rgba(255,179,0,0.08); }
|
||||
.sc-status-badge.sc-status-failed { color: var(--red); background: rgba(255,59,59,0.08); }
|
||||
|
||||
.sc-source-badge {
|
||||
font-size: 0.68rem;
|
||||
color: var(--purple);
|
||||
background: rgba(168,85,247,0.1);
|
||||
padding: 0.1rem 0.4rem;
|
||||
border-radius: 3px;
|
||||
}
|
||||
.sc-event-time { font-size: 0.68rem; color: var(--text-dim); margin-left: auto; }
|
||||
.sc-event-error-type {
|
||||
font-size: 0.72rem;
|
||||
color: var(--amber);
|
||||
font-weight: 600;
|
||||
margin-bottom: 0.3rem;
|
||||
letter-spacing: 0.04em;
|
||||
}
|
||||
.sc-label {
|
||||
font-size: 0.65rem;
|
||||
font-weight: 700;
|
||||
letter-spacing: 0.06em;
|
||||
color: var(--text-dim);
|
||||
margin-right: 0.3rem;
|
||||
}
|
||||
.sc-event-intent, .sc-event-error, .sc-event-strategy, .sc-event-outcome {
|
||||
color: var(--text);
|
||||
margin-bottom: 0.2rem;
|
||||
line-height: 1.4;
|
||||
word-break: break-word;
|
||||
}
|
||||
.sc-event-error { color: var(--red); }
|
||||
.sc-event-strategy { color: var(--text-dim); font-style: italic; }
|
||||
.sc-event-outcome { color: var(--text-bright); }
|
||||
.sc-event-meta { font-size: 0.68rem; color: var(--text-dim); margin-top: 0.3rem; }
|
||||
|
||||
.sc-pattern-type {
|
||||
font-family: var(--font);
|
||||
font-size: 0.8rem;
|
||||
color: var(--text-bright);
|
||||
word-break: break-all;
|
||||
}
|
||||
|
||||
0
tests/self_coding/__init__.py
Normal file
0
tests/self_coding/__init__.py
Normal file
363
tests/self_coding/test_loop.py
Normal file
363
tests/self_coding/test_loop.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""Unit tests for the self-modification loop.
|
||||
|
||||
Covers:
|
||||
- Protected branch guard
|
||||
- Successful cycle (mocked git + tests)
|
||||
- Edit function failure → branch reverted, no commit
|
||||
- Test failure → branch reverted, no commit
|
||||
- Gitea PR creation plumbing
|
||||
- GiteaClient graceful degradation (no token, network error)
|
||||
|
||||
All git and subprocess calls are mocked so these run offline without
|
||||
a real repo or test suite.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_loop(repo_root="/tmp/fake-repo"):
|
||||
"""Construct a SelfModifyLoop with a fake repo root."""
|
||||
from self_coding.self_modify.loop import SelfModifyLoop
|
||||
|
||||
return SelfModifyLoop(repo_root=repo_root, remote="origin", base_branch="main")
|
||||
|
||||
|
||||
def _noop_edit(repo_root: str) -> None:
|
||||
"""Edit function that does nothing."""
|
||||
|
||||
|
||||
def _failing_edit(repo_root: str) -> None:
|
||||
"""Edit function that raises."""
|
||||
raise RuntimeError("edit exploded")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Guard tests (sync — no git calls needed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_guard_blocks_main():
|
||||
loop = _make_loop()
|
||||
with pytest.raises(ValueError, match="protected branch"):
|
||||
loop._guard_branch("main")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_guard_blocks_master():
|
||||
loop = _make_loop()
|
||||
with pytest.raises(ValueError, match="protected branch"):
|
||||
loop._guard_branch("master")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_guard_allows_feature_branch():
|
||||
loop = _make_loop()
|
||||
# Should not raise
|
||||
loop._guard_branch("self-modify/some-feature")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_guard_allows_self_modify_prefix():
|
||||
loop = _make_loop()
|
||||
loop._guard_branch("self-modify/issue-983")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Full cycle — success path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_success():
|
||||
"""Happy path: edit succeeds, tests pass, PR created."""
|
||||
loop = _make_loop()
|
||||
|
||||
fake_completed = MagicMock()
|
||||
fake_completed.stdout = "abc1234\n"
|
||||
fake_completed.returncode = 0
|
||||
|
||||
fake_test_result = MagicMock()
|
||||
fake_test_result.stdout = "3 passed"
|
||||
fake_test_result.stderr = ""
|
||||
fake_test_result.returncode = 0
|
||||
|
||||
from self_coding.gitea_client import PullRequest as _PR
|
||||
|
||||
fake_pr = _PR(number=42, title="test PR", html_url="http://gitea/pr/42")
|
||||
|
||||
with (
|
||||
patch.object(loop, "_git", return_value=fake_completed),
|
||||
patch("subprocess.run", return_value=fake_test_result),
|
||||
patch.object(loop, "_create_pr", return_value=fake_pr),
|
||||
):
|
||||
result = await loop.run(
|
||||
slug="test-feature",
|
||||
description="Add test feature",
|
||||
edit_fn=_noop_edit,
|
||||
issue_number=983,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.branch == "self-modify/test-feature"
|
||||
assert result.pr_url == "http://gitea/pr/42"
|
||||
assert result.pr_number == 42
|
||||
assert "3 passed" in result.test_output
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_skips_tests_when_flag_set():
|
||||
"""skip_tests=True should bypass the test gate."""
|
||||
loop = _make_loop()
|
||||
|
||||
fake_completed = MagicMock()
|
||||
fake_completed.stdout = "deadbeef\n"
|
||||
fake_completed.returncode = 0
|
||||
|
||||
with (
|
||||
patch.object(loop, "_git", return_value=fake_completed),
|
||||
patch.object(loop, "_create_pr", return_value=None),
|
||||
patch("subprocess.run") as mock_run,
|
||||
):
|
||||
result = await loop.run(
|
||||
slug="skip-test-feature",
|
||||
description="Skip test feature",
|
||||
edit_fn=_noop_edit,
|
||||
skip_tests=True,
|
||||
)
|
||||
|
||||
# subprocess.run should NOT be called for tests
|
||||
mock_run.assert_not_called()
|
||||
assert result.success is True
|
||||
assert "(tests skipped)" in result.test_output
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Failure paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_reverts_on_edit_failure():
|
||||
"""If edit_fn raises, the branch should be reverted and no commit made."""
|
||||
loop = _make_loop()
|
||||
|
||||
fake_completed = MagicMock()
|
||||
fake_completed.stdout = ""
|
||||
fake_completed.returncode = 0
|
||||
|
||||
revert_called = []
|
||||
|
||||
def _fake_revert(branch):
|
||||
revert_called.append(branch)
|
||||
|
||||
with (
|
||||
patch.object(loop, "_git", return_value=fake_completed),
|
||||
patch.object(loop, "_revert_branch", side_effect=_fake_revert),
|
||||
patch.object(loop, "_commit_all") as mock_commit,
|
||||
):
|
||||
result = await loop.run(
|
||||
slug="broken-edit",
|
||||
description="This will fail",
|
||||
edit_fn=_failing_edit,
|
||||
skip_tests=True,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "edit exploded" in result.error
|
||||
assert "self-modify/broken-edit" in revert_called
|
||||
mock_commit.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_reverts_on_test_failure():
|
||||
"""If tests fail, branch should be reverted and no commit made."""
|
||||
loop = _make_loop()
|
||||
|
||||
fake_completed = MagicMock()
|
||||
fake_completed.stdout = ""
|
||||
fake_completed.returncode = 0
|
||||
|
||||
fake_test_result = MagicMock()
|
||||
fake_test_result.stdout = "FAILED test_foo"
|
||||
fake_test_result.stderr = "1 failed"
|
||||
fake_test_result.returncode = 1
|
||||
|
||||
revert_called = []
|
||||
|
||||
def _fake_revert(branch):
|
||||
revert_called.append(branch)
|
||||
|
||||
with (
|
||||
patch.object(loop, "_git", return_value=fake_completed),
|
||||
patch("subprocess.run", return_value=fake_test_result),
|
||||
patch.object(loop, "_revert_branch", side_effect=_fake_revert),
|
||||
patch.object(loop, "_commit_all") as mock_commit,
|
||||
):
|
||||
result = await loop.run(
|
||||
slug="tests-will-fail",
|
||||
description="This will fail tests",
|
||||
edit_fn=_noop_edit,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "Tests failed" in result.error
|
||||
assert "self-modify/tests-will-fail" in revert_called
|
||||
mock_commit.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_slug_with_main_creates_safe_branch():
|
||||
"""A slug of 'main' produces branch 'self-modify/main', which is not protected."""
|
||||
|
||||
loop = _make_loop()
|
||||
|
||||
fake_completed = MagicMock()
|
||||
fake_completed.stdout = "deadbeef\n"
|
||||
fake_completed.returncode = 0
|
||||
|
||||
# 'self-modify/main' is NOT in _PROTECTED_BRANCHES so the run should succeed
|
||||
with (
|
||||
patch.object(loop, "_git", return_value=fake_completed),
|
||||
patch.object(loop, "_create_pr", return_value=None),
|
||||
):
|
||||
result = await loop.run(
|
||||
slug="main",
|
||||
description="try to write to self-modify/main",
|
||||
edit_fn=_noop_edit,
|
||||
skip_tests=True,
|
||||
)
|
||||
assert result.branch == "self-modify/main"
|
||||
assert result.success is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GiteaClient tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_gitea_client_returns_none_without_token():
|
||||
"""GiteaClient should return None gracefully when no token is set."""
|
||||
from self_coding.gitea_client import GiteaClient
|
||||
|
||||
client = GiteaClient(base_url="http://localhost:3000", token="", repo="owner/repo")
|
||||
pr = client.create_pull_request(
|
||||
title="Test PR",
|
||||
body="body",
|
||||
head="self-modify/test",
|
||||
)
|
||||
assert pr is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_gitea_client_comment_returns_false_without_token():
|
||||
"""add_issue_comment should return False gracefully when no token is set."""
|
||||
from self_coding.gitea_client import GiteaClient
|
||||
|
||||
client = GiteaClient(base_url="http://localhost:3000", token="", repo="owner/repo")
|
||||
result = client.add_issue_comment(123, "hello")
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_gitea_client_create_pr_handles_network_error():
|
||||
"""create_pull_request should return None on network failure."""
|
||||
from self_coding.gitea_client import GiteaClient
|
||||
|
||||
client = GiteaClient(base_url="http://localhost:3000", token="fake-token", repo="owner/repo")
|
||||
|
||||
mock_requests = MagicMock()
|
||||
mock_requests.post.side_effect = Exception("Connection refused")
|
||||
mock_requests.exceptions.ConnectionError = Exception
|
||||
|
||||
with patch.dict("sys.modules", {"requests": mock_requests}):
|
||||
pr = client.create_pull_request(
|
||||
title="Test PR",
|
||||
body="body",
|
||||
head="self-modify/test",
|
||||
)
|
||||
assert pr is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_gitea_client_comment_handles_network_error():
|
||||
"""add_issue_comment should return False on network failure."""
|
||||
from self_coding.gitea_client import GiteaClient
|
||||
|
||||
client = GiteaClient(base_url="http://localhost:3000", token="fake-token", repo="owner/repo")
|
||||
|
||||
mock_requests = MagicMock()
|
||||
mock_requests.post.side_effect = Exception("Connection refused")
|
||||
|
||||
with patch.dict("sys.modules", {"requests": mock_requests}):
|
||||
result = client.add_issue_comment(456, "hello")
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_gitea_client_create_pr_success():
|
||||
"""create_pull_request should return a PullRequest on HTTP 201."""
|
||||
from self_coding.gitea_client import GiteaClient, PullRequest
|
||||
|
||||
client = GiteaClient(base_url="http://localhost:3000", token="tok", repo="owner/repo")
|
||||
|
||||
fake_resp = MagicMock()
|
||||
fake_resp.raise_for_status = MagicMock()
|
||||
fake_resp.json.return_value = {
|
||||
"number": 77,
|
||||
"title": "Test PR",
|
||||
"html_url": "http://localhost:3000/owner/repo/pulls/77",
|
||||
}
|
||||
|
||||
mock_requests = MagicMock()
|
||||
mock_requests.post.return_value = fake_resp
|
||||
|
||||
with patch.dict("sys.modules", {"requests": mock_requests}):
|
||||
pr = client.create_pull_request("Test PR", "body", "self-modify/feat")
|
||||
|
||||
assert isinstance(pr, PullRequest)
|
||||
assert pr.number == 77
|
||||
assert pr.html_url == "http://localhost:3000/owner/repo/pulls/77"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LoopResult dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_loop_result_defaults():
|
||||
from self_coding.self_modify.loop import LoopResult
|
||||
|
||||
r = LoopResult(success=True)
|
||||
assert r.branch == ""
|
||||
assert r.commit_sha == ""
|
||||
assert r.pr_url == ""
|
||||
assert r.pr_number == 0
|
||||
assert r.test_output == ""
|
||||
assert r.error == ""
|
||||
assert r.elapsed_ms == 0.0
|
||||
assert r.metadata == {}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_loop_result_failure():
|
||||
from self_coding.self_modify.loop import LoopResult
|
||||
|
||||
r = LoopResult(success=False, error="something broke", branch="self-modify/test")
|
||||
assert r.success is False
|
||||
assert r.error == "something broke"
|
||||
403
tests/timmy/test_research.py
Normal file
403
tests/timmy/test_research.py
Normal file
@@ -0,0 +1,403 @@
|
||||
"""Unit tests for src/timmy/research.py — ResearchOrchestrator pipeline.
|
||||
|
||||
Refs #972 (governing spec), #975 (ResearchOrchestrator).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_templates
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListTemplates:
|
||||
def test_returns_list(self, tmp_path, monkeypatch):
|
||||
(tmp_path / "tool_evaluation.md").write_text("---\n---\n# T")
|
||||
(tmp_path / "game_analysis.md").write_text("---\n---\n# G")
|
||||
monkeypatch.setattr("timmy.research._SKILLS_ROOT", tmp_path)
|
||||
|
||||
from timmy.research import list_templates
|
||||
|
||||
result = list_templates()
|
||||
assert isinstance(result, list)
|
||||
assert "tool_evaluation" in result
|
||||
assert "game_analysis" in result
|
||||
|
||||
def test_returns_empty_when_dir_missing(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("timmy.research._SKILLS_ROOT", tmp_path / "nonexistent")
|
||||
|
||||
from timmy.research import list_templates
|
||||
|
||||
assert list_templates() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# load_template
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLoadTemplate:
|
||||
def _write_template(self, path: Path, name: str, body: str) -> None:
|
||||
(path / f"{name}.md").write_text(body, encoding="utf-8")
|
||||
|
||||
def test_loads_and_strips_frontmatter(self, tmp_path, monkeypatch):
|
||||
self._write_template(
|
||||
tmp_path,
|
||||
"tool_evaluation",
|
||||
"---\nname: Tool Evaluation\ntype: research\n---\n# Tool Eval: {domain}",
|
||||
)
|
||||
monkeypatch.setattr("timmy.research._SKILLS_ROOT", tmp_path)
|
||||
|
||||
from timmy.research import load_template
|
||||
|
||||
result = load_template("tool_evaluation", {"domain": "PDF parsing"})
|
||||
assert "# Tool Eval: PDF parsing" in result
|
||||
assert "name: Tool Evaluation" not in result
|
||||
|
||||
def test_fills_slots(self, tmp_path, monkeypatch):
|
||||
self._write_template(tmp_path, "arch", "Connect {system_a} to {system_b}")
|
||||
monkeypatch.setattr("timmy.research._SKILLS_ROOT", tmp_path)
|
||||
|
||||
from timmy.research import load_template
|
||||
|
||||
result = load_template("arch", {"system_a": "Kafka", "system_b": "Postgres"})
|
||||
assert "Kafka" in result
|
||||
assert "Postgres" in result
|
||||
|
||||
def test_unfilled_slots_preserved(self, tmp_path, monkeypatch):
|
||||
self._write_template(tmp_path, "t", "Hello {name} and {other}")
|
||||
monkeypatch.setattr("timmy.research._SKILLS_ROOT", tmp_path)
|
||||
|
||||
from timmy.research import load_template
|
||||
|
||||
result = load_template("t", {"name": "World"})
|
||||
assert "{other}" in result
|
||||
|
||||
def test_raises_file_not_found_for_missing_template(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("timmy.research._SKILLS_ROOT", tmp_path)
|
||||
|
||||
from timmy.research import load_template
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="nonexistent"):
|
||||
load_template("nonexistent")
|
||||
|
||||
def test_no_slots_returns_raw_body(self, tmp_path, monkeypatch):
|
||||
self._write_template(tmp_path, "plain", "---\n---\nJust text here")
|
||||
monkeypatch.setattr("timmy.research._SKILLS_ROOT", tmp_path)
|
||||
|
||||
from timmy.research import load_template
|
||||
|
||||
result = load_template("plain")
|
||||
assert result == "Just text here"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _check_cache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckCache:
|
||||
def test_returns_none_when_no_hits(self):
|
||||
mock_mem = MagicMock()
|
||||
mock_mem.search.return_value = []
|
||||
|
||||
with patch("timmy.research.SemanticMemory", return_value=mock_mem):
|
||||
from timmy.research import _check_cache
|
||||
|
||||
content, score = _check_cache("some topic")
|
||||
|
||||
assert content is None
|
||||
assert score == 0.0
|
||||
|
||||
def test_returns_content_above_threshold(self):
|
||||
mock_mem = MagicMock()
|
||||
mock_mem.search.return_value = [("cached report text", 0.91)]
|
||||
|
||||
with patch("timmy.research.SemanticMemory", return_value=mock_mem):
|
||||
from timmy.research import _check_cache
|
||||
|
||||
content, score = _check_cache("same topic")
|
||||
|
||||
assert content == "cached report text"
|
||||
assert score == pytest.approx(0.91)
|
||||
|
||||
def test_returns_none_below_threshold(self):
|
||||
mock_mem = MagicMock()
|
||||
mock_mem.search.return_value = [("old report", 0.60)]
|
||||
|
||||
with patch("timmy.research.SemanticMemory", return_value=mock_mem):
|
||||
from timmy.research import _check_cache
|
||||
|
||||
content, score = _check_cache("slightly different topic")
|
||||
|
||||
assert content is None
|
||||
assert score == 0.0
|
||||
|
||||
def test_degrades_gracefully_on_import_error(self):
|
||||
with patch("timmy.research.SemanticMemory", None):
|
||||
from timmy.research import _check_cache
|
||||
|
||||
content, score = _check_cache("topic")
|
||||
|
||||
assert content is None
|
||||
assert score == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _store_result
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStoreResult:
|
||||
def test_calls_store_memory(self):
|
||||
mock_store = MagicMock()
|
||||
|
||||
with patch("timmy.research.store_memory", mock_store):
|
||||
from timmy.research import _store_result
|
||||
|
||||
_store_result("test topic", "# Report\n\nContent here.")
|
||||
|
||||
mock_store.assert_called_once()
|
||||
call_kwargs = mock_store.call_args
|
||||
assert "test topic" in str(call_kwargs)
|
||||
|
||||
def test_degrades_gracefully_on_error(self):
|
||||
mock_store = MagicMock(side_effect=RuntimeError("db error"))
|
||||
with patch("timmy.research.store_memory", mock_store):
|
||||
from timmy.research import _store_result
|
||||
|
||||
# Should not raise
|
||||
_store_result("topic", "report")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _save_to_disk
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSaveToDisk:
|
||||
def test_writes_file(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("timmy.research._DOCS_ROOT", tmp_path / "research")
|
||||
|
||||
from timmy.research import _save_to_disk
|
||||
|
||||
path = _save_to_disk("Test Topic: PDF Parsing", "# Test Report")
|
||||
assert path is not None
|
||||
assert path.exists()
|
||||
assert path.read_text() == "# Test Report"
|
||||
|
||||
def test_slugifies_topic_name(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("timmy.research._DOCS_ROOT", tmp_path / "research")
|
||||
|
||||
from timmy.research import _save_to_disk
|
||||
|
||||
path = _save_to_disk("My Complex Topic! v2.0", "content")
|
||||
assert path is not None
|
||||
# Should be slugified: no special chars
|
||||
assert " " not in path.name
|
||||
assert "!" not in path.name
|
||||
|
||||
def test_returns_none_on_error(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"timmy.research._DOCS_ROOT",
|
||||
Path("/nonexistent_root/deeply/nested"),
|
||||
)
|
||||
|
||||
with patch("pathlib.Path.mkdir", side_effect=PermissionError("denied")):
|
||||
from timmy.research import _save_to_disk
|
||||
|
||||
result = _save_to_disk("topic", "report")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run_research — end-to-end with mocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunResearch:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_cached_result_when_cache_hit(self):
|
||||
cached_report = "# Cached Report\n\nPreviously computed."
|
||||
with (
|
||||
patch("timmy.research._check_cache", return_value=(cached_report, 0.93)),
|
||||
):
|
||||
from timmy.research import run_research
|
||||
|
||||
result = await run_research("some topic")
|
||||
|
||||
assert result.cached is True
|
||||
assert result.cache_similarity == pytest.approx(0.93)
|
||||
assert result.report == cached_report
|
||||
assert result.synthesis_backend == "cache"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_cache_when_requested(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("timmy.research._SKILLS_ROOT", tmp_path)
|
||||
|
||||
with (
|
||||
patch("timmy.research._check_cache", return_value=("cached", 0.99)) as mock_cache,
|
||||
patch(
|
||||
"timmy.research._formulate_queries",
|
||||
new=AsyncMock(return_value=["q1"]),
|
||||
),
|
||||
patch("timmy.research._execute_search", new=AsyncMock(return_value=[])),
|
||||
patch("timmy.research._fetch_pages", new=AsyncMock(return_value=[])),
|
||||
patch(
|
||||
"timmy.research._synthesize",
|
||||
new=AsyncMock(return_value=("# Fresh report", "ollama")),
|
||||
),
|
||||
patch("timmy.research._store_result"),
|
||||
):
|
||||
from timmy.research import run_research
|
||||
|
||||
result = await run_research("topic", skip_cache=True)
|
||||
|
||||
mock_cache.assert_not_called()
|
||||
assert result.cached is False
|
||||
assert result.report == "# Fresh report"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_pipeline_no_search_results(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("timmy.research._SKILLS_ROOT", tmp_path)
|
||||
|
||||
with (
|
||||
patch("timmy.research._check_cache", return_value=(None, 0.0)),
|
||||
patch(
|
||||
"timmy.research._formulate_queries",
|
||||
new=AsyncMock(return_value=["query 1", "query 2"]),
|
||||
),
|
||||
patch("timmy.research._execute_search", new=AsyncMock(return_value=[])),
|
||||
patch("timmy.research._fetch_pages", new=AsyncMock(return_value=[])),
|
||||
patch(
|
||||
"timmy.research._synthesize",
|
||||
new=AsyncMock(return_value=("# Report", "ollama")),
|
||||
),
|
||||
patch("timmy.research._store_result"),
|
||||
):
|
||||
from timmy.research import run_research
|
||||
|
||||
result = await run_research("a new topic")
|
||||
|
||||
assert not result.cached
|
||||
assert result.query_count == 2
|
||||
assert result.sources_fetched == 0
|
||||
assert result.report == "# Report"
|
||||
assert result.synthesis_backend == "ollama"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_result_with_error_on_bad_template(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("timmy.research._SKILLS_ROOT", tmp_path)
|
||||
|
||||
with (
|
||||
patch("timmy.research._check_cache", return_value=(None, 0.0)),
|
||||
patch(
|
||||
"timmy.research._formulate_queries",
|
||||
new=AsyncMock(return_value=["q1"]),
|
||||
),
|
||||
patch("timmy.research._execute_search", new=AsyncMock(return_value=[])),
|
||||
patch("timmy.research._fetch_pages", new=AsyncMock(return_value=[])),
|
||||
patch(
|
||||
"timmy.research._synthesize",
|
||||
new=AsyncMock(return_value=("# Report", "ollama")),
|
||||
),
|
||||
patch("timmy.research._store_result"),
|
||||
):
|
||||
from timmy.research import run_research
|
||||
|
||||
result = await run_research("topic", template="nonexistent_template")
|
||||
|
||||
assert len(result.errors) == 1
|
||||
assert "nonexistent_template" in result.errors[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saves_to_disk_when_requested(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("timmy.research._SKILLS_ROOT", tmp_path)
|
||||
monkeypatch.setattr("timmy.research._DOCS_ROOT", tmp_path / "research")
|
||||
|
||||
with (
|
||||
patch("timmy.research._check_cache", return_value=(None, 0.0)),
|
||||
patch(
|
||||
"timmy.research._formulate_queries",
|
||||
new=AsyncMock(return_value=["q1"]),
|
||||
),
|
||||
patch("timmy.research._execute_search", new=AsyncMock(return_value=[])),
|
||||
patch("timmy.research._fetch_pages", new=AsyncMock(return_value=[])),
|
||||
patch(
|
||||
"timmy.research._synthesize",
|
||||
new=AsyncMock(return_value=("# Saved Report", "ollama")),
|
||||
),
|
||||
patch("timmy.research._store_result"),
|
||||
):
|
||||
from timmy.research import run_research
|
||||
|
||||
result = await run_research("disk topic", save_to_disk=True)
|
||||
|
||||
assert result.report == "# Saved Report"
|
||||
saved_files = list((tmp_path / "research").glob("*.md"))
|
||||
assert len(saved_files) == 1
|
||||
assert saved_files[0].read_text() == "# Saved Report"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_result_is_not_empty_after_synthesis(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("timmy.research._SKILLS_ROOT", tmp_path)
|
||||
|
||||
with (
|
||||
patch("timmy.research._check_cache", return_value=(None, 0.0)),
|
||||
patch(
|
||||
"timmy.research._formulate_queries",
|
||||
new=AsyncMock(return_value=["q"]),
|
||||
),
|
||||
patch("timmy.research._execute_search", new=AsyncMock(return_value=[])),
|
||||
patch("timmy.research._fetch_pages", new=AsyncMock(return_value=[])),
|
||||
patch(
|
||||
"timmy.research._synthesize",
|
||||
new=AsyncMock(return_value=("# Non-empty", "ollama")),
|
||||
),
|
||||
patch("timmy.research._store_result"),
|
||||
):
|
||||
from timmy.research import run_research
|
||||
|
||||
result = await run_research("topic")
|
||||
|
||||
assert not result.is_empty()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ResearchResult
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResearchResult:
|
||||
def test_is_empty_when_no_report(self):
|
||||
from timmy.research import ResearchResult
|
||||
|
||||
r = ResearchResult(topic="t", query_count=0, sources_fetched=0, report="")
|
||||
assert r.is_empty()
|
||||
|
||||
def test_is_not_empty_with_content(self):
|
||||
from timmy.research import ResearchResult
|
||||
|
||||
r = ResearchResult(topic="t", query_count=1, sources_fetched=1, report="# Report")
|
||||
assert not r.is_empty()
|
||||
|
||||
def test_default_cached_false(self):
|
||||
from timmy.research import ResearchResult
|
||||
|
||||
r = ResearchResult(topic="t", query_count=0, sources_fetched=0, report="x")
|
||||
assert r.cached is False
|
||||
|
||||
def test_errors_defaults_to_empty_list(self):
|
||||
from timmy.research import ResearchResult
|
||||
|
||||
r = ResearchResult(topic="t", query_count=0, sources_fetched=0, report="x")
|
||||
assert r.errors == []
|
||||
444
tests/timmy/test_session_report.py
Normal file
444
tests/timmy/test_session_report.py
Normal file
@@ -0,0 +1,444 @@
|
||||
"""Tests for timmy.sovereignty.session_report.
|
||||
|
||||
Refs: #957 (Session Sovereignty Report Generator)
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
from timmy.sovereignty.session_report import (
|
||||
_format_duration,
|
||||
_gather_session_data,
|
||||
_gather_sovereignty_data,
|
||||
_render_markdown,
|
||||
commit_report,
|
||||
generate_and_commit_report,
|
||||
generate_report,
|
||||
mark_session_start,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _format_duration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFormatDuration:
|
||||
def test_seconds_only(self):
|
||||
assert _format_duration(45) == "45s"
|
||||
|
||||
def test_minutes_and_seconds(self):
|
||||
assert _format_duration(125) == "2m 5s"
|
||||
|
||||
def test_hours_minutes_seconds(self):
|
||||
assert _format_duration(3661) == "1h 1m 1s"
|
||||
|
||||
def test_zero(self):
|
||||
assert _format_duration(0) == "0s"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# mark_session_start + generate_report (smoke)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMarkSessionStart:
|
||||
def test_sets_session_start(self):
|
||||
import timmy.sovereignty.session_report as sr
|
||||
|
||||
sr._SESSION_START = None
|
||||
mark_session_start()
|
||||
assert sr._SESSION_START is not None
|
||||
assert sr._SESSION_START.tzinfo == UTC
|
||||
|
||||
def test_idempotent_overwrite(self):
|
||||
import timmy.sovereignty.session_report as sr
|
||||
|
||||
mark_session_start()
|
||||
first = sr._SESSION_START
|
||||
time.sleep(0.01)
|
||||
mark_session_start()
|
||||
second = sr._SESSION_START
|
||||
assert second >= first
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _gather_session_data
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGatherSessionData:
|
||||
def test_returns_defaults_when_no_file(self, tmp_path):
|
||||
mock_logger = MagicMock()
|
||||
mock_logger.flush.return_value = None
|
||||
mock_logger.session_file = tmp_path / "nonexistent.jsonl"
|
||||
|
||||
with patch(
|
||||
"timmy.sovereignty.session_report.get_session_logger",
|
||||
return_value=mock_logger,
|
||||
):
|
||||
data = _gather_session_data()
|
||||
|
||||
assert data["user_messages"] == 0
|
||||
assert data["timmy_messages"] == 0
|
||||
assert data["tool_calls"] == 0
|
||||
assert data["errors"] == 0
|
||||
assert data["tool_call_breakdown"] == {}
|
||||
|
||||
def test_counts_entries_correctly(self, tmp_path):
|
||||
session_file = tmp_path / "session_2026-03-23.jsonl"
|
||||
entries = [
|
||||
{"type": "message", "role": "user", "content": "hello"},
|
||||
{"type": "message", "role": "timmy", "content": "hi"},
|
||||
{"type": "message", "role": "user", "content": "test"},
|
||||
{"type": "tool_call", "tool": "memory_search", "args": {}, "result": "found"},
|
||||
{"type": "tool_call", "tool": "memory_search", "args": {}, "result": "nope"},
|
||||
{"type": "tool_call", "tool": "shell", "args": {}, "result": "ok"},
|
||||
{"type": "error", "error": "boom"},
|
||||
]
|
||||
with open(session_file, "w") as f:
|
||||
for e in entries:
|
||||
f.write(json.dumps(e) + "\n")
|
||||
|
||||
mock_logger = MagicMock()
|
||||
mock_logger.flush.return_value = None
|
||||
mock_logger.session_file = session_file
|
||||
|
||||
with patch(
|
||||
"timmy.sovereignty.session_report.get_session_logger",
|
||||
return_value=mock_logger,
|
||||
):
|
||||
data = _gather_session_data()
|
||||
|
||||
assert data["user_messages"] == 2
|
||||
assert data["timmy_messages"] == 1
|
||||
assert data["tool_calls"] == 3
|
||||
assert data["errors"] == 1
|
||||
assert data["tool_call_breakdown"]["memory_search"] == 2
|
||||
assert data["tool_call_breakdown"]["shell"] == 1
|
||||
|
||||
def test_graceful_on_import_error(self):
|
||||
with patch(
|
||||
"timmy.sovereignty.session_report.get_session_logger",
|
||||
side_effect=ImportError("no session_logger"),
|
||||
):
|
||||
data = _gather_session_data()
|
||||
|
||||
assert data["tool_calls"] == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _gather_sovereignty_data
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGatherSovereigntyData:
|
||||
def test_returns_empty_on_import_error(self):
|
||||
with patch.dict("sys.modules", {"infrastructure.sovereignty_metrics": None}):
|
||||
with patch(
|
||||
"timmy.sovereignty.session_report.get_sovereignty_store",
|
||||
side_effect=ImportError("no store"),
|
||||
):
|
||||
data = _gather_sovereignty_data()
|
||||
|
||||
assert data["metrics"] == {}
|
||||
assert data["deltas"] == {}
|
||||
assert data["previous_session"] == {}
|
||||
|
||||
def test_populates_deltas_from_history(self):
|
||||
mock_store = MagicMock()
|
||||
mock_store.get_summary.return_value = {
|
||||
"cache_hit_rate": {"current": 0.5, "phase": "week1"},
|
||||
}
|
||||
# get_latest returns newest-first
|
||||
mock_store.get_latest.return_value = [
|
||||
{"value": 0.5},
|
||||
{"value": 0.3},
|
||||
{"value": 0.1},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"timmy.sovereignty.session_report.get_sovereignty_store",
|
||||
return_value=mock_store,
|
||||
):
|
||||
with patch(
|
||||
"timmy.sovereignty.session_report.GRADUATION_TARGETS",
|
||||
{"cache_hit_rate": {"graduation": 0.9}},
|
||||
):
|
||||
data = _gather_sovereignty_data()
|
||||
|
||||
delta = data["deltas"].get("cache_hit_rate")
|
||||
assert delta is not None
|
||||
assert delta["start"] == 0.1 # oldest in window
|
||||
assert delta["end"] == 0.5 # most recent
|
||||
assert data["previous_session"]["cache_hit_rate"] == 0.3
|
||||
|
||||
def test_single_data_point_no_delta(self):
|
||||
mock_store = MagicMock()
|
||||
mock_store.get_summary.return_value = {}
|
||||
mock_store.get_latest.return_value = [{"value": 0.4}]
|
||||
|
||||
with patch(
|
||||
"timmy.sovereignty.session_report.get_sovereignty_store",
|
||||
return_value=mock_store,
|
||||
):
|
||||
with patch(
|
||||
"timmy.sovereignty.session_report.GRADUATION_TARGETS",
|
||||
{"api_cost": {"graduation": 0.01}},
|
||||
):
|
||||
data = _gather_sovereignty_data()
|
||||
|
||||
delta = data["deltas"]["api_cost"]
|
||||
assert delta["start"] == 0.4
|
||||
assert delta["end"] == 0.4
|
||||
assert data["previous_session"]["api_cost"] is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# generate_report (integration — smoke test)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGenerateReport:
|
||||
def _minimal_session_data(self):
|
||||
return {
|
||||
"user_messages": 3,
|
||||
"timmy_messages": 3,
|
||||
"tool_calls": 2,
|
||||
"errors": 0,
|
||||
"tool_call_breakdown": {"memory_search": 2},
|
||||
}
|
||||
|
||||
def _minimal_sov_data(self):
|
||||
return {
|
||||
"metrics": {
|
||||
"cache_hit_rate": {"current": 0.45, "phase": "week1"},
|
||||
"api_cost": {"current": 0.12, "phase": "pre-start"},
|
||||
},
|
||||
"deltas": {
|
||||
"cache_hit_rate": {"start": 0.40, "end": 0.45},
|
||||
"api_cost": {"start": 0.10, "end": 0.12},
|
||||
},
|
||||
"previous_session": {
|
||||
"cache_hit_rate": 0.40,
|
||||
"api_cost": 0.10,
|
||||
},
|
||||
}
|
||||
|
||||
def test_smoke_produces_markdown(self):
|
||||
with (
|
||||
patch(
|
||||
"timmy.sovereignty.session_report._gather_session_data",
|
||||
return_value=self._minimal_session_data(),
|
||||
),
|
||||
patch(
|
||||
"timmy.sovereignty.session_report._gather_sovereignty_data",
|
||||
return_value=self._minimal_sov_data(),
|
||||
),
|
||||
):
|
||||
report = generate_report("test-session")
|
||||
|
||||
assert "# Sovereignty Session Report" in report
|
||||
assert "test-session" in report
|
||||
assert "## Session Activity" in report
|
||||
assert "## Sovereignty Scorecard" in report
|
||||
assert "## Cost Breakdown" in report
|
||||
assert "## Trend vs Previous Session" in report
|
||||
|
||||
def test_report_contains_session_stats(self):
|
||||
with (
|
||||
patch(
|
||||
"timmy.sovereignty.session_report._gather_session_data",
|
||||
return_value=self._minimal_session_data(),
|
||||
),
|
||||
patch(
|
||||
"timmy.sovereignty.session_report._gather_sovereignty_data",
|
||||
return_value=self._minimal_sov_data(),
|
||||
),
|
||||
):
|
||||
report = generate_report()
|
||||
|
||||
assert "| User messages | 3 |" in report
|
||||
assert "memory_search" in report
|
||||
|
||||
def test_report_no_previous_session(self):
|
||||
sov = self._minimal_sov_data()
|
||||
sov["previous_session"] = {"cache_hit_rate": None, "api_cost": None}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"timmy.sovereignty.session_report._gather_session_data",
|
||||
return_value=self._minimal_session_data(),
|
||||
),
|
||||
patch(
|
||||
"timmy.sovereignty.session_report._gather_sovereignty_data",
|
||||
return_value=sov,
|
||||
),
|
||||
):
|
||||
report = generate_report()
|
||||
|
||||
assert "No previous session data" in report
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# commit_report
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCommitReport:
|
||||
def test_returns_false_when_gitea_disabled(self):
|
||||
with patch("timmy.sovereignty.session_report.settings") as mock_settings:
|
||||
mock_settings.gitea_enabled = False
|
||||
result = commit_report("# test", "dashboard")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_returns_false_when_no_token(self):
|
||||
with patch("timmy.sovereignty.session_report.settings") as mock_settings:
|
||||
mock_settings.gitea_enabled = True
|
||||
mock_settings.gitea_token = ""
|
||||
result = commit_report("# test", "dashboard")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_creates_file_via_put(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 201
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
mock_check = MagicMock()
|
||||
mock_check.status_code = 404 # file does not exist yet
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client.__exit__ = MagicMock(return_value=False)
|
||||
mock_client.get.return_value = mock_check
|
||||
mock_client.put.return_value = mock_response
|
||||
|
||||
with (
|
||||
patch("timmy.sovereignty.session_report.settings") as mock_settings,
|
||||
patch("timmy.sovereignty.session_report.httpx.Client", return_value=mock_client),
|
||||
):
|
||||
mock_settings.gitea_enabled = True
|
||||
mock_settings.gitea_token = "fake-token"
|
||||
mock_settings.gitea_url = "http://localhost:3000"
|
||||
mock_settings.gitea_repo = "owner/repo"
|
||||
|
||||
result = commit_report("# report content", "dashboard")
|
||||
|
||||
assert result is True
|
||||
mock_client.put.assert_called_once()
|
||||
call_kwargs = mock_client.put.call_args
|
||||
payload = call_kwargs.kwargs.get("json", call_kwargs.args[1] if len(call_kwargs.args) > 1 else {})
|
||||
decoded = base64.b64decode(payload["content"]).decode()
|
||||
assert "# report content" in decoded
|
||||
|
||||
def test_updates_existing_file_with_sha(self):
|
||||
mock_check = MagicMock()
|
||||
mock_check.status_code = 200
|
||||
mock_check.json.return_value = {"sha": "abc123"}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client.__exit__ = MagicMock(return_value=False)
|
||||
mock_client.get.return_value = mock_check
|
||||
mock_client.put.return_value = mock_response
|
||||
|
||||
with (
|
||||
patch("timmy.sovereignty.session_report.settings") as mock_settings,
|
||||
patch("timmy.sovereignty.session_report.httpx.Client", return_value=mock_client),
|
||||
):
|
||||
mock_settings.gitea_enabled = True
|
||||
mock_settings.gitea_token = "fake-token"
|
||||
mock_settings.gitea_url = "http://localhost:3000"
|
||||
mock_settings.gitea_repo = "owner/repo"
|
||||
|
||||
result = commit_report("# updated", "dashboard")
|
||||
|
||||
assert result is True
|
||||
payload = mock_client.put.call_args.kwargs.get("json", {})
|
||||
assert payload.get("sha") == "abc123"
|
||||
|
||||
def test_returns_false_on_http_error(self):
|
||||
import httpx
|
||||
|
||||
mock_check = MagicMock()
|
||||
mock_check.status_code = 404
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client.__exit__ = MagicMock(return_value=False)
|
||||
mock_client.get.return_value = mock_check
|
||||
mock_client.put.side_effect = httpx.HTTPStatusError(
|
||||
"403", request=MagicMock(), response=MagicMock(status_code=403)
|
||||
)
|
||||
|
||||
with (
|
||||
patch("timmy.sovereignty.session_report.settings") as mock_settings,
|
||||
patch("timmy.sovereignty.session_report.httpx.Client", return_value=mock_client),
|
||||
):
|
||||
mock_settings.gitea_enabled = True
|
||||
mock_settings.gitea_token = "fake-token"
|
||||
mock_settings.gitea_url = "http://localhost:3000"
|
||||
mock_settings.gitea_repo = "owner/repo"
|
||||
|
||||
result = commit_report("# test", "dashboard")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# generate_and_commit_report (async)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGenerateAndCommitReport:
|
||||
async def test_returns_true_on_success(self):
|
||||
with (
|
||||
patch(
|
||||
"timmy.sovereignty.session_report.generate_report",
|
||||
return_value="# mock report",
|
||||
),
|
||||
patch(
|
||||
"timmy.sovereignty.session_report.commit_report",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
result = await generate_and_commit_report("test")
|
||||
|
||||
assert result is True
|
||||
|
||||
async def test_returns_false_when_commit_fails(self):
|
||||
with (
|
||||
patch(
|
||||
"timmy.sovereignty.session_report.generate_report",
|
||||
return_value="# mock report",
|
||||
),
|
||||
patch(
|
||||
"timmy.sovereignty.session_report.commit_report",
|
||||
return_value=False,
|
||||
),
|
||||
):
|
||||
result = await generate_and_commit_report()
|
||||
|
||||
assert result is False
|
||||
|
||||
async def test_graceful_on_exception(self):
|
||||
with patch(
|
||||
"timmy.sovereignty.session_report.generate_report",
|
||||
side_effect=RuntimeError("explode"),
|
||||
):
|
||||
result = await generate_and_commit_report()
|
||||
|
||||
assert result is False
|
||||
269
tests/unit/test_self_correction.py
Normal file
269
tests/unit/test_self_correction.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""Unit tests for infrastructure.self_correction."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolated_db(tmp_path, monkeypatch):
|
||||
"""Point the self-correction module at a fresh temp database per test."""
|
||||
import infrastructure.self_correction as sc_mod
|
||||
|
||||
# Reset the cached path so each test gets a clean DB
|
||||
sc_mod._DB_PATH = tmp_path / "self_correction.db"
|
||||
yield
|
||||
sc_mod._DB_PATH = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# log_self_correction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLogSelfCorrection:
|
||||
def test_returns_event_id(self):
|
||||
from infrastructure.self_correction import log_self_correction
|
||||
|
||||
eid = log_self_correction(
|
||||
source="test",
|
||||
original_intent="Do X",
|
||||
detected_error="ValueError: bad input",
|
||||
correction_strategy="Try Y instead",
|
||||
final_outcome="Y succeeded",
|
||||
)
|
||||
assert isinstance(eid, str)
|
||||
assert len(eid) == 36 # UUID format
|
||||
|
||||
def test_derives_error_type_from_error_string(self):
|
||||
from infrastructure.self_correction import get_corrections, log_self_correction
|
||||
|
||||
log_self_correction(
|
||||
source="test",
|
||||
original_intent="Connect",
|
||||
detected_error="ConnectionRefusedError: port 80",
|
||||
correction_strategy="Use port 8080",
|
||||
final_outcome="ok",
|
||||
)
|
||||
rows = get_corrections(limit=1)
|
||||
assert rows[0]["error_type"] == "ConnectionRefusedError"
|
||||
|
||||
def test_explicit_error_type_preserved(self):
|
||||
from infrastructure.self_correction import get_corrections, log_self_correction
|
||||
|
||||
log_self_correction(
|
||||
source="test",
|
||||
original_intent="Run task",
|
||||
detected_error="Some weird error",
|
||||
correction_strategy="Fix it",
|
||||
final_outcome="done",
|
||||
error_type="CustomError",
|
||||
)
|
||||
rows = get_corrections(limit=1)
|
||||
assert rows[0]["error_type"] == "CustomError"
|
||||
|
||||
def test_task_id_stored(self):
|
||||
from infrastructure.self_correction import get_corrections, log_self_correction
|
||||
|
||||
log_self_correction(
|
||||
source="test",
|
||||
original_intent="intent",
|
||||
detected_error="err",
|
||||
correction_strategy="strat",
|
||||
final_outcome="outcome",
|
||||
task_id="task-abc-123",
|
||||
)
|
||||
rows = get_corrections(limit=1)
|
||||
assert rows[0]["task_id"] == "task-abc-123"
|
||||
|
||||
def test_outcome_status_stored(self):
|
||||
from infrastructure.self_correction import get_corrections, log_self_correction
|
||||
|
||||
log_self_correction(
|
||||
source="test",
|
||||
original_intent="i",
|
||||
detected_error="e",
|
||||
correction_strategy="s",
|
||||
final_outcome="o",
|
||||
outcome_status="failed",
|
||||
)
|
||||
rows = get_corrections(limit=1)
|
||||
assert rows[0]["outcome_status"] == "failed"
|
||||
|
||||
def test_long_strings_truncated(self):
|
||||
from infrastructure.self_correction import get_corrections, log_self_correction
|
||||
|
||||
long = "x" * 3000
|
||||
log_self_correction(
|
||||
source="test",
|
||||
original_intent=long,
|
||||
detected_error=long,
|
||||
correction_strategy=long,
|
||||
final_outcome=long,
|
||||
)
|
||||
rows = get_corrections(limit=1)
|
||||
assert len(rows[0]["original_intent"]) <= 2000
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_corrections
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetCorrections:
|
||||
def test_empty_db_returns_empty_list(self):
|
||||
from infrastructure.self_correction import get_corrections
|
||||
|
||||
assert get_corrections() == []
|
||||
|
||||
def test_returns_newest_first(self):
|
||||
from infrastructure.self_correction import get_corrections, log_self_correction
|
||||
|
||||
for i in range(3):
|
||||
log_self_correction(
|
||||
source="test",
|
||||
original_intent=f"intent {i}",
|
||||
detected_error="err",
|
||||
correction_strategy="fix",
|
||||
final_outcome="done",
|
||||
error_type=f"Type{i}",
|
||||
)
|
||||
rows = get_corrections(limit=10)
|
||||
assert len(rows) == 3
|
||||
# Newest first — Type2 should appear before Type0
|
||||
types = [r["error_type"] for r in rows]
|
||||
assert types.index("Type2") < types.index("Type0")
|
||||
|
||||
def test_limit_respected(self):
|
||||
from infrastructure.self_correction import get_corrections, log_self_correction
|
||||
|
||||
for _ in range(5):
|
||||
log_self_correction(
|
||||
source="test",
|
||||
original_intent="i",
|
||||
detected_error="e",
|
||||
correction_strategy="s",
|
||||
final_outcome="o",
|
||||
)
|
||||
rows = get_corrections(limit=3)
|
||||
assert len(rows) == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_patterns
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetPatterns:
|
||||
def test_empty_db_returns_empty_list(self):
|
||||
from infrastructure.self_correction import get_patterns
|
||||
|
||||
assert get_patterns() == []
|
||||
|
||||
def test_counts_by_error_type(self):
|
||||
from infrastructure.self_correction import get_patterns, log_self_correction
|
||||
|
||||
for _ in range(3):
|
||||
log_self_correction(
|
||||
source="test",
|
||||
original_intent="i",
|
||||
detected_error="e",
|
||||
correction_strategy="s",
|
||||
final_outcome="o",
|
||||
error_type="TimeoutError",
|
||||
)
|
||||
log_self_correction(
|
||||
source="test",
|
||||
original_intent="i",
|
||||
detected_error="e",
|
||||
correction_strategy="s",
|
||||
final_outcome="o",
|
||||
error_type="ValueError",
|
||||
)
|
||||
patterns = get_patterns(top_n=10)
|
||||
by_type = {p["error_type"]: p for p in patterns}
|
||||
assert by_type["TimeoutError"]["count"] == 3
|
||||
assert by_type["ValueError"]["count"] == 1
|
||||
|
||||
def test_success_vs_failed_counts(self):
|
||||
from infrastructure.self_correction import get_patterns, log_self_correction
|
||||
|
||||
log_self_correction(
|
||||
source="test", original_intent="i", detected_error="e",
|
||||
correction_strategy="s", final_outcome="o",
|
||||
error_type="Foo", outcome_status="success",
|
||||
)
|
||||
log_self_correction(
|
||||
source="test", original_intent="i", detected_error="e",
|
||||
correction_strategy="s", final_outcome="o",
|
||||
error_type="Foo", outcome_status="failed",
|
||||
)
|
||||
patterns = get_patterns(top_n=5)
|
||||
foo = next(p for p in patterns if p["error_type"] == "Foo")
|
||||
assert foo["success_count"] == 1
|
||||
assert foo["failed_count"] == 1
|
||||
|
||||
def test_ordered_by_count_desc(self):
|
||||
from infrastructure.self_correction import get_patterns, log_self_correction
|
||||
|
||||
for _ in range(2):
|
||||
log_self_correction(
|
||||
source="t", original_intent="i", detected_error="e",
|
||||
correction_strategy="s", final_outcome="o", error_type="Rare",
|
||||
)
|
||||
for _ in range(5):
|
||||
log_self_correction(
|
||||
source="t", original_intent="i", detected_error="e",
|
||||
correction_strategy="s", final_outcome="o", error_type="Common",
|
||||
)
|
||||
patterns = get_patterns(top_n=5)
|
||||
assert patterns[0]["error_type"] == "Common"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_stats
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetStats:
|
||||
def test_empty_db_returns_zeroes(self):
|
||||
from infrastructure.self_correction import get_stats
|
||||
|
||||
stats = get_stats()
|
||||
assert stats["total"] == 0
|
||||
assert stats["success_rate"] == 0
|
||||
|
||||
def test_counts_outcomes(self):
|
||||
from infrastructure.self_correction import get_stats, log_self_correction
|
||||
|
||||
log_self_correction(
|
||||
source="t", original_intent="i", detected_error="e",
|
||||
correction_strategy="s", final_outcome="o", outcome_status="success",
|
||||
)
|
||||
log_self_correction(
|
||||
source="t", original_intent="i", detected_error="e",
|
||||
correction_strategy="s", final_outcome="o", outcome_status="failed",
|
||||
)
|
||||
stats = get_stats()
|
||||
assert stats["total"] == 2
|
||||
assert stats["success_count"] == 1
|
||||
assert stats["failed_count"] == 1
|
||||
assert stats["success_rate"] == 50
|
||||
|
||||
def test_success_rate_100_when_all_succeed(self):
|
||||
from infrastructure.self_correction import get_stats, log_self_correction
|
||||
|
||||
for _ in range(4):
|
||||
log_self_correction(
|
||||
source="t", original_intent="i", detected_error="e",
|
||||
correction_strategy="s", final_outcome="o", outcome_status="success",
|
||||
)
|
||||
stats = get_stats()
|
||||
assert stats["success_rate"] == 100
|
||||
Reference in New Issue
Block a user