Compare commits
7 Commits
claude/iss
...
gemini/iss
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
393cf3a2e1 | ||
|
|
0331e0e5bb | ||
| 1be1324a0d | |||
| 32a5b092d0 | |||
| 6f404c99f2 | |||
| 300d9575f1 | |||
| 510d890eb2 |
230
docs/research/bannerlord-vm-setup.md
Normal file
230
docs/research/bannerlord-vm-setup.md
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
# Bannerlord Windows VM Setup Guide
|
||||||
|
|
||||||
|
**Issue:** #1098
|
||||||
|
**Parent Epic:** #1091 (Project Bannerlord)
|
||||||
|
**Date:** 2026-03-23
|
||||||
|
**Status:** Reference
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This document covers provisioning the Windows VM that hosts Bannerlord + GABS mod,
|
||||||
|
verifying the GABS TCP JSON-RPC server, and confirming connectivity from Hermes.
|
||||||
|
|
||||||
|
Architecture reminder:
|
||||||
|
```
|
||||||
|
Timmy (Qwen3 on Ollama, Hermes M3 Max)
|
||||||
|
→ GABS TCP/JSON-RPC (port 4825)
|
||||||
|
→ Bannerlord.GABS C# mod
|
||||||
|
→ Game API + Harmony
|
||||||
|
→ Bannerlord (Windows VM)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Provision Windows VM
|
||||||
|
|
||||||
|
### Minimum Spec
|
||||||
|
| Resource | Minimum | Recommended |
|
||||||
|
|----------|---------|-------------|
|
||||||
|
| CPU | 4 cores | 8 cores |
|
||||||
|
| RAM | 16 GB | 32 GB |
|
||||||
|
| Disk | 100 GB SSD | 150 GB SSD |
|
||||||
|
| OS | Windows Server 2022 / Windows 11 | Windows 11 |
|
||||||
|
| Network | Private VLAN to Hermes | Private VLAN to Hermes |
|
||||||
|
|
||||||
|
### Hetzner (preferred)
|
||||||
|
```powershell
|
||||||
|
# Hetzner Cloud CLI — create CX41 (4 vCPU, 16 GB RAM, 160 GB SSD)
|
||||||
|
hcloud server create \
|
||||||
|
--name bannerlord-vm \
|
||||||
|
--type cx41 \
|
||||||
|
--image windows-server-2022 \
|
||||||
|
--location nbg1 \
|
||||||
|
--ssh-key your-key
|
||||||
|
```
|
||||||
|
|
||||||
|
### DigitalOcean alternative
|
||||||
|
```
|
||||||
|
Droplet: General Purpose 4 vCPU / 16 GB / 100 GB SSD
|
||||||
|
Image: Windows Server 2022
|
||||||
|
Region: Same region as Hermes
|
||||||
|
```
|
||||||
|
|
||||||
|
### Post-provision
|
||||||
|
1. Enable RDP (port 3389) for initial setup only — close after configuration
|
||||||
|
2. Open port 4825 TCP inbound from Hermes IP only
|
||||||
|
3. Disable Windows Firewall for 4825 or add specific allow rule:
|
||||||
|
```powershell
|
||||||
|
New-NetFirewallRule -DisplayName "GABS TCP" -Direction Inbound `
|
||||||
|
-Protocol TCP -LocalPort 4825 -Action Allow
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Install Steam + Bannerlord
|
||||||
|
|
||||||
|
### Steam installation
|
||||||
|
1. Download Steam installer from store.steampowered.com
|
||||||
|
2. Install silently:
|
||||||
|
```powershell
|
||||||
|
.\SteamSetup.exe /S
|
||||||
|
```
|
||||||
|
3. Log in with a dedicated Steam account (not personal)
|
||||||
|
|
||||||
|
### Bannerlord installation
|
||||||
|
```powershell
|
||||||
|
# Install Bannerlord (App ID: 261550) via SteamCMD
|
||||||
|
steamcmd +login <user> <pass> +app_update 261550 validate +quit
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pin game version
|
||||||
|
GABS requires a specific Bannerlord version. To pin and prevent auto-updates:
|
||||||
|
1. Right-click Bannerlord in Steam → Properties → Updates
|
||||||
|
2. Set "Automatic Updates" to "Only update this game when I launch it"
|
||||||
|
3. Record the current version in `docs/research/bannerlord-vm-setup.md` after installation
|
||||||
|
|
||||||
|
```powershell
|
||||||
|
# Check installed version
|
||||||
|
Get-Content "C:\Program Files (x86)\Steam\steamapps\appmanifest_261550.acf" |
|
||||||
|
Select-String "buildid"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Install GABS Mod
|
||||||
|
|
||||||
|
### Source
|
||||||
|
- NexusMods: https://www.nexusmods.com/mountandblade2bannerlord/mods/10419
|
||||||
|
- GitHub: https://github.com/BUTR/Bannerlord.GABS
|
||||||
|
- AGENTS.md: https://github.com/BUTR/Bannerlord.GABS/blob/master/AGENTS.md
|
||||||
|
|
||||||
|
### Installation via Vortex (NexusMods)
|
||||||
|
1. Install Vortex Mod Manager
|
||||||
|
2. Download GABS mod package from NexusMods
|
||||||
|
3. Install via Vortex — it handles the Modules/ directory layout automatically
|
||||||
|
4. Enable in the mod list and set load order after Harmony
|
||||||
|
|
||||||
|
### Manual installation
|
||||||
|
```powershell
|
||||||
|
# Copy mod to Bannerlord Modules directory
|
||||||
|
$BannerlordPath = "C:\Program Files (x86)\Steam\steamapps\common\Mount & Blade II Bannerlord"
|
||||||
|
Copy-Item -Recurse ".\Bannerlord.GABS" "$BannerlordPath\Modules\Bannerlord.GABS"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Required dependencies
|
||||||
|
- **Harmony** (BUTR.Harmony) — must load before GABS
|
||||||
|
- **ButterLib** — utility library
|
||||||
|
Install via the same method as GABS.
|
||||||
|
|
||||||
|
### GABS configuration
|
||||||
|
GABS TCP server listens on `0.0.0.0:4825` by default. To confirm or override:
|
||||||
|
```
|
||||||
|
%APPDATA%\Mount and Blade II Bannerlord\Configs\Bannerlord.GABS\settings.json
|
||||||
|
```
|
||||||
|
Expected defaults:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"ServerHost": "0.0.0.0",
|
||||||
|
"ServerPort": 4825,
|
||||||
|
"LogLevel": "Information"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. Verify GABS TCP Server
|
||||||
|
|
||||||
|
### Start Bannerlord with GABS
|
||||||
|
Launch Bannerlord with the mod enabled. GABS starts its TCP server during game
|
||||||
|
initialisation. Watch the game log for:
|
||||||
|
```
|
||||||
|
[GABS] TCP server listening on 0.0.0.0:4825
|
||||||
|
```
|
||||||
|
|
||||||
|
Log location:
|
||||||
|
```
|
||||||
|
%APPDATA%\Mount and Blade II Bannerlord\logs\rgl_log_*.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
### Local connectivity check (on VM)
|
||||||
|
```powershell
|
||||||
|
# Verify port is listening
|
||||||
|
netstat -an | findstr 4825
|
||||||
|
|
||||||
|
# Quick TCP probe
|
||||||
|
Test-NetConnection -ComputerName localhost -Port 4825
|
||||||
|
```
|
||||||
|
|
||||||
|
### Send a test JSON-RPC call
|
||||||
|
```powershell
|
||||||
|
$msg = '{"jsonrpc":"2.0","method":"ping","id":1}'
|
||||||
|
$client = New-Object System.Net.Sockets.TcpClient("localhost", 4825)
|
||||||
|
$stream = $client.GetStream()
|
||||||
|
$writer = New-Object System.IO.StreamWriter($stream)
|
||||||
|
$writer.AutoFlush = $true
|
||||||
|
$writer.WriteLine($msg)
|
||||||
|
$reader = New-Object System.IO.StreamReader($stream)
|
||||||
|
$response = $reader.ReadLine()
|
||||||
|
Write-Host "Response: $response"
|
||||||
|
$client.Close()
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected response shape:
|
||||||
|
```json
|
||||||
|
{"jsonrpc":"2.0","result":{"status":"ok"},"id":1}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Test Connectivity from Hermes
|
||||||
|
|
||||||
|
Use `scripts/test_gabs_connectivity.py` (checked in with this issue):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# From Hermes (M3 Max)
|
||||||
|
python scripts/test_gabs_connectivity.py --host <VM_IP> --port 4825
|
||||||
|
```
|
||||||
|
|
||||||
|
The script tests:
|
||||||
|
1. TCP socket connection
|
||||||
|
2. JSON-RPC ping round-trip
|
||||||
|
3. `get_game_state` call
|
||||||
|
4. Response latency (target < 100 ms on LAN)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Firewall / Network Summary
|
||||||
|
|
||||||
|
| Source | Destination | Port | Protocol | Purpose |
|
||||||
|
|--------|-------------|------|----------|---------|
|
||||||
|
| Hermes (local) | Bannerlord VM | 4825 | TCP | GABS JSON-RPC |
|
||||||
|
| Admin workstation | Bannerlord VM | 3389 | TCP | RDP setup (disable after) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. Reproducibility Checklist
|
||||||
|
|
||||||
|
After completing setup, record:
|
||||||
|
|
||||||
|
- [ ] VM provider + region + instance type
|
||||||
|
- [ ] Windows version + build number
|
||||||
|
- [ ] Steam account used (non-personal, credentials in secrets manager)
|
||||||
|
- [ ] Bannerlord App version (buildid from appmanifest)
|
||||||
|
- [ ] GABS version (from NexusMods or GitHub release tag)
|
||||||
|
- [ ] Harmony version
|
||||||
|
- [ ] ButterLib version
|
||||||
|
- [ ] GABS settings.json contents
|
||||||
|
- [ ] VM IP address (update Timmy config)
|
||||||
|
- [ ] Connectivity test output from `test_gabs_connectivity.py`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- GABS GitHub: https://github.com/BUTR/Bannerlord.GABS
|
||||||
|
- GABS AGENTS.md: https://github.com/BUTR/Bannerlord.GABS/blob/master/AGENTS.md
|
||||||
|
- NexusMods page: https://www.nexusmods.com/mountandblade2bannerlord/mods/10419
|
||||||
|
- Parent Epic: #1091
|
||||||
|
- Connectivity test script: `scripts/test_gabs_connectivity.py`
|
||||||
333
scripts/export_trajectories.py
Normal file
333
scripts/export_trajectories.py
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Export Timmy session logs as LoRA training data (ChatML JSONL).
|
||||||
|
|
||||||
|
Reads session JSONL files written by ``SessionLogger`` and converts them into
|
||||||
|
conversation pairs suitable for fine-tuning with ``mlx_lm.lora``.
|
||||||
|
|
||||||
|
Output format — one JSON object per line::
|
||||||
|
|
||||||
|
{"messages": [
|
||||||
|
{"role": "system", "content": "<Timmy system prompt>"},
|
||||||
|
{"role": "user", "content": "<user turn>"},
|
||||||
|
{"role": "assistant", "content": "<timmy response, with tool calls embedded>"}
|
||||||
|
]}
|
||||||
|
|
||||||
|
Tool calls that appear between a user turn and the next assistant message are
|
||||||
|
embedded in the assistant content using the Hermes 4 ``<tool_call>`` XML format
|
||||||
|
so the fine-tuned model learns both when to call tools and what JSON to emit.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
# Export all session logs (default paths)
|
||||||
|
python scripts/export_trajectories.py
|
||||||
|
|
||||||
|
# Custom source / destination
|
||||||
|
python scripts/export_trajectories.py \\
|
||||||
|
--logs-dir ~/custom-logs \\
|
||||||
|
--output ~/timmy-training-data.jsonl \\
|
||||||
|
--min-turns 2 \\
|
||||||
|
--verbose
|
||||||
|
|
||||||
|
Epic: #1091 Project Bannerlord — AutoLoRA Sovereignty Loop (Step 3 of 7)
|
||||||
|
Refs: #1103
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Constants ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
TIMMY_SYSTEM_PROMPT = (
|
||||||
|
"You are Timmy, Alexander's personal AI agent running on a local Mac. "
|
||||||
|
"You are concise, direct, and action-oriented. "
|
||||||
|
"You have access to a broad set of tools — use them proactively. "
|
||||||
|
"When you need to call a tool, output it in this format:\n"
|
||||||
|
"<tool_call>\n"
|
||||||
|
'{"name": "function_name", "arguments": {"param": "value"}}\n'
|
||||||
|
"</tool_call>\n\n"
|
||||||
|
"Always provide structured, accurate responses."
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Entry grouping ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _load_entries(logs_dir: Path) -> list[dict[str, Any]]:
|
||||||
|
"""Load all session log entries, sorted chronologically."""
|
||||||
|
entries: list[dict[str, Any]] = []
|
||||||
|
log_files = sorted(logs_dir.glob("session_*.jsonl"))
|
||||||
|
for log_file in log_files:
|
||||||
|
try:
|
||||||
|
with open(log_file) as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
entries.append(json.loads(line))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("Skipping malformed line in %s", log_file.name)
|
||||||
|
except OSError as exc:
|
||||||
|
logger.warning("Cannot read %s: %s", log_file, exc)
|
||||||
|
return entries
|
||||||
|
|
||||||
|
|
||||||
|
def _format_tool_call(entry: dict[str, Any]) -> str:
|
||||||
|
"""Render a tool_call entry as a Hermes 4 <tool_call> XML block."""
|
||||||
|
payload = {"name": entry.get("tool", "unknown"), "arguments": entry.get("args", {})}
|
||||||
|
return f"<tool_call>\n{json.dumps(payload)}\n</tool_call>"
|
||||||
|
|
||||||
|
|
||||||
|
def _format_tool_result(entry: dict[str, Any]) -> str:
|
||||||
|
"""Render a tool result observation."""
|
||||||
|
result = entry.get("result", "")
|
||||||
|
tool = entry.get("tool", "unknown")
|
||||||
|
return f"<tool_response>\n{{\"name\": \"{tool}\", \"result\": {json.dumps(result)}}}\n</tool_response>"
|
||||||
|
|
||||||
|
|
||||||
|
def _group_into_turns(entries: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Group raw session entries into (user_text, assistant_parts) turn pairs.
|
||||||
|
|
||||||
|
Returns a list of dicts with keys:
|
||||||
|
``user`` - user message content
|
||||||
|
``assistant`` - assembled assistant content (responses + tool calls)
|
||||||
|
"""
|
||||||
|
turns: list[dict[str, Any]] = []
|
||||||
|
pending_user: str | None = None
|
||||||
|
assistant_parts: list[str] = []
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
etype = entry.get("type", "")
|
||||||
|
role = entry.get("role", "")
|
||||||
|
|
||||||
|
if etype == "message" and role == "user":
|
||||||
|
# Flush any open turn
|
||||||
|
if pending_user is not None and assistant_parts:
|
||||||
|
turns.append(
|
||||||
|
{
|
||||||
|
"user": pending_user,
|
||||||
|
"assistant": "\n".join(assistant_parts).strip(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif pending_user is not None:
|
||||||
|
# User message with no assistant response — discard
|
||||||
|
pass
|
||||||
|
pending_user = entry.get("content", "").strip()
|
||||||
|
assistant_parts = []
|
||||||
|
|
||||||
|
elif etype == "message" and role == "timmy":
|
||||||
|
if pending_user is not None:
|
||||||
|
content = entry.get("content", "").strip()
|
||||||
|
if content:
|
||||||
|
assistant_parts.append(content)
|
||||||
|
|
||||||
|
elif etype == "tool_call":
|
||||||
|
if pending_user is not None:
|
||||||
|
assistant_parts.append(_format_tool_call(entry))
|
||||||
|
# Also append tool result as context so model learns the full loop
|
||||||
|
if entry.get("result"):
|
||||||
|
assistant_parts.append(_format_tool_result(entry))
|
||||||
|
|
||||||
|
# decision / error entries are skipped — they are meta-data, not conversation
|
||||||
|
|
||||||
|
# Flush final open turn
|
||||||
|
if pending_user is not None and assistant_parts:
|
||||||
|
turns.append(
|
||||||
|
{
|
||||||
|
"user": pending_user,
|
||||||
|
"assistant": "\n".join(assistant_parts).strip(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return turns
|
||||||
|
|
||||||
|
|
||||||
|
# ── Conversion ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def turns_to_training_examples(
|
||||||
|
turns: list[dict[str, Any]],
|
||||||
|
system_prompt: str = TIMMY_SYSTEM_PROMPT,
|
||||||
|
min_assistant_len: int = 10,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Convert grouped turns into mlx-lm training examples.
|
||||||
|
|
||||||
|
Each example has a ``messages`` list in ChatML order:
|
||||||
|
``[system, user, assistant]``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
turns: Output of ``_group_into_turns``.
|
||||||
|
system_prompt: System prompt prepended to every example.
|
||||||
|
min_assistant_len: Skip examples where the assistant turn is shorter
|
||||||
|
than this many characters (filters out empty/trivial turns).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of training example dicts.
|
||||||
|
"""
|
||||||
|
examples: list[dict[str, Any]] = []
|
||||||
|
for turn in turns:
|
||||||
|
assistant_text = turn.get("assistant", "").strip()
|
||||||
|
user_text = turn.get("user", "").strip()
|
||||||
|
if not user_text or len(assistant_text) < min_assistant_len:
|
||||||
|
continue
|
||||||
|
examples.append(
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_text},
|
||||||
|
{"role": "assistant", "content": assistant_text},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return examples
|
||||||
|
|
||||||
|
|
||||||
|
def export_training_data(
|
||||||
|
logs_dir: Path,
|
||||||
|
output_path: Path,
|
||||||
|
min_turns: int = 1,
|
||||||
|
min_assistant_len: int = 10,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> int:
|
||||||
|
"""Full export pipeline: load → group → convert → write.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logs_dir: Directory containing ``session_*.jsonl`` files.
|
||||||
|
output_path: Destination ``.jsonl`` file for training data.
|
||||||
|
min_turns: Minimum number of turns required (used for logging only).
|
||||||
|
min_assistant_len: Minimum assistant response length to include.
|
||||||
|
verbose: Print progress to stdout.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of training examples written.
|
||||||
|
"""
|
||||||
|
if verbose:
|
||||||
|
print(f"Loading session logs from: {logs_dir}")
|
||||||
|
|
||||||
|
entries = _load_entries(logs_dir)
|
||||||
|
if verbose:
|
||||||
|
print(f" Loaded {len(entries)} raw entries")
|
||||||
|
|
||||||
|
turns = _group_into_turns(entries)
|
||||||
|
if verbose:
|
||||||
|
print(f" Grouped into {len(turns)} conversation turns")
|
||||||
|
|
||||||
|
examples = turns_to_training_examples(
|
||||||
|
turns, min_assistant_len=min_assistant_len
|
||||||
|
)
|
||||||
|
if verbose:
|
||||||
|
print(f" Generated {len(examples)} training examples")
|
||||||
|
|
||||||
|
if not examples:
|
||||||
|
print("WARNING: No training examples generated. Check that session logs exist.")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(output_path, "w") as f:
|
||||||
|
for ex in examples:
|
||||||
|
f.write(json.dumps(ex) + "\n")
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f" Wrote {len(examples)} examples → {output_path}")
|
||||||
|
|
||||||
|
return len(examples)
|
||||||
|
|
||||||
|
|
||||||
|
# ── CLI ───────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _default_logs_dir() -> Path:
|
||||||
|
"""Return default logs directory (repo root / logs)."""
|
||||||
|
# Walk up from this script to find repo root (contains pyproject.toml)
|
||||||
|
candidate = Path(__file__).resolve().parent
|
||||||
|
for _ in range(5):
|
||||||
|
candidate = candidate.parent
|
||||||
|
if (candidate / "pyproject.toml").exists():
|
||||||
|
return candidate / "logs"
|
||||||
|
return Path.home() / "logs"
|
||||||
|
|
||||||
|
|
||||||
|
def _default_output_path() -> Path:
|
||||||
|
return Path.home() / "timmy-training-data.jsonl"
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv: list[str] | None = None) -> int:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Export Timmy session logs as LoRA training data (ChatML JSONL)",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog=__doc__,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--logs-dir",
|
||||||
|
type=Path,
|
||||||
|
default=_default_logs_dir(),
|
||||||
|
help="Directory containing session_*.jsonl files (default: <repo>/logs)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
type=Path,
|
||||||
|
default=_default_output_path(),
|
||||||
|
help="Output JSONL path (default: ~/timmy-training-data.jsonl)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--min-turns",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Minimum turns to process (informational, default: 1)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--min-assistant-len",
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="Minimum assistant response length in chars (default: 10)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
"-v",
|
||||||
|
action="store_true",
|
||||||
|
help="Print progress information",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args(argv)
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG if args.verbose else logging.WARNING,
|
||||||
|
format="%(levelname)s: %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not args.logs_dir.exists():
|
||||||
|
print(f"ERROR: Logs directory not found: {args.logs_dir}")
|
||||||
|
print("Run the Timmy dashboard first to generate session logs.")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
count = export_training_data(
|
||||||
|
logs_dir=args.logs_dir,
|
||||||
|
output_path=args.output,
|
||||||
|
min_turns=args.min_turns,
|
||||||
|
min_assistant_len=args.min_assistant_len,
|
||||||
|
verbose=args.verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
if count > 0:
|
||||||
|
print(f"Exported {count} training examples to: {args.output}")
|
||||||
|
print()
|
||||||
|
print("Next steps:")
|
||||||
|
print(f" mkdir -p ~/timmy-lora-training")
|
||||||
|
print(f" cp {args.output} ~/timmy-lora-training/train.jsonl")
|
||||||
|
print(f" python scripts/lora_finetune.py --data ~/timmy-lora-training")
|
||||||
|
else:
|
||||||
|
print("No training examples exported.")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
399
scripts/lora_finetune.py
Normal file
399
scripts/lora_finetune.py
Normal file
@@ -0,0 +1,399 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""LoRA fine-tuning launcher for Hermes 4 on Timmy trajectory data.
|
||||||
|
|
||||||
|
Wraps ``mlx_lm.lora`` with project-specific defaults and pre-flight checks.
|
||||||
|
Requires Apple Silicon (M-series) and the ``mlx-lm`` package.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
# Minimal — uses defaults (expects data in ~/timmy-lora-training/)
|
||||||
|
python scripts/lora_finetune.py
|
||||||
|
|
||||||
|
# Custom model path and data
|
||||||
|
python scripts/lora_finetune.py \\
|
||||||
|
--model /path/to/hermes4-mlx \\
|
||||||
|
--data ~/timmy-lora-training \\
|
||||||
|
--iters 500 \\
|
||||||
|
--adapter-path ~/timmy-lora-adapter
|
||||||
|
|
||||||
|
# Dry run (print command, don't execute)
|
||||||
|
python scripts/lora_finetune.py --dry-run
|
||||||
|
|
||||||
|
# After training, test with the adapter
|
||||||
|
python scripts/lora_finetune.py --test \\
|
||||||
|
--prompt "List the open PRs on the Timmy Time Dashboard repo"
|
||||||
|
|
||||||
|
# Fuse adapter into base model for Ollama import
|
||||||
|
python scripts/lora_finetune.py --fuse \\
|
||||||
|
--save-path ~/timmy-fused-model
|
||||||
|
|
||||||
|
Typical workflow::
|
||||||
|
|
||||||
|
# 1. Export trajectories
|
||||||
|
python scripts/export_trajectories.py --verbose
|
||||||
|
|
||||||
|
# 2. Prepare training dir
|
||||||
|
mkdir -p ~/timmy-lora-training
|
||||||
|
cp ~/timmy-training-data.jsonl ~/timmy-lora-training/train.jsonl
|
||||||
|
|
||||||
|
# 3. Fine-tune
|
||||||
|
python scripts/lora_finetune.py --verbose
|
||||||
|
|
||||||
|
# 4. Test
|
||||||
|
python scripts/lora_finetune.py --test
|
||||||
|
|
||||||
|
# 5. Fuse + import to Ollama
|
||||||
|
python scripts/lora_finetune.py --fuse
|
||||||
|
ollama create timmy-hermes4 -f Modelfile.timmy-hermes4
|
||||||
|
|
||||||
|
Epic: #1091 Project Bannerlord — AutoLoRA Sovereignty Loop (Step 4 of 7)
|
||||||
|
Refs: #1103
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import platform
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# ── Defaults ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
DEFAULT_DATA_DIR = Path.home() / "timmy-lora-training"
|
||||||
|
DEFAULT_ADAPTER_PATH = Path.home() / "timmy-lora-adapter"
|
||||||
|
DEFAULT_FUSED_PATH = Path.home() / "timmy-fused-model"
|
||||||
|
|
||||||
|
# mlx-lm model path — local HuggingFace checkout of Hermes 4 in MLX format.
|
||||||
|
# Set MLX_HERMES4_PATH env var or pass --model to override.
|
||||||
|
DEFAULT_MODEL_PATH_ENV = "MLX_HERMES4_PATH"
|
||||||
|
|
||||||
|
# Training hyperparameters (conservative for 36 GB M3 Max)
|
||||||
|
DEFAULT_BATCH_SIZE = 1
|
||||||
|
DEFAULT_LORA_LAYERS = 16
|
||||||
|
DEFAULT_ITERS = 1000
|
||||||
|
DEFAULT_LEARNING_RATE = 1e-5
|
||||||
|
|
||||||
|
# Test prompt used after training
|
||||||
|
DEFAULT_TEST_PROMPT = (
|
||||||
|
"List the open PRs on the Timmy Time Dashboard repo and triage them by priority."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Pre-flight checks ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _check_apple_silicon() -> bool:
|
||||||
|
"""Return True if running on Apple Silicon."""
|
||||||
|
return platform.system() == "Darwin" and platform.machine() == "arm64"
|
||||||
|
|
||||||
|
|
||||||
|
def _check_mlx_lm() -> bool:
|
||||||
|
"""Return True if mlx-lm is installed and mlx_lm.lora is runnable."""
|
||||||
|
return shutil.which("mlx_lm.lora") is not None or _can_import("mlx_lm")
|
||||||
|
|
||||||
|
|
||||||
|
def _can_import(module: str) -> bool:
|
||||||
|
try:
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
importlib.import_module(module)
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_model_path(model_arg: str | None) -> str | None:
|
||||||
|
"""Resolve model path from arg or environment variable."""
|
||||||
|
if model_arg:
|
||||||
|
return model_arg
|
||||||
|
import os
|
||||||
|
|
||||||
|
env_path = os.environ.get(DEFAULT_MODEL_PATH_ENV)
|
||||||
|
if env_path:
|
||||||
|
return env_path
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _preflight(model_path: str | None, data_dir: Path, verbose: bool) -> list[str]:
|
||||||
|
"""Run pre-flight checks and return a list of warnings (empty = all OK)."""
|
||||||
|
warnings: list[str] = []
|
||||||
|
|
||||||
|
if not _check_apple_silicon():
|
||||||
|
warnings.append(
|
||||||
|
"Not running on Apple Silicon. mlx-lm requires an M-series Mac.\n"
|
||||||
|
" Alternative: use Unsloth on Google Colab / RunPod / Modal."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not _check_mlx_lm():
|
||||||
|
warnings.append(
|
||||||
|
"mlx-lm not found. Install with:\n pip install mlx-lm"
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_path is None:
|
||||||
|
warnings.append(
|
||||||
|
f"No model path specified. Set {DEFAULT_MODEL_PATH_ENV} or pass --model.\n"
|
||||||
|
" Download Hermes 4 in MLX format from HuggingFace:\n"
|
||||||
|
" https://huggingface.co/collections/NousResearch/hermes-4-collection-68a7\n"
|
||||||
|
" or convert the GGUF:\n"
|
||||||
|
" mlx_lm.convert --hf-path NousResearch/Hermes-4-14B --mlx-path ~/hermes4-mlx"
|
||||||
|
)
|
||||||
|
elif not Path(model_path).exists():
|
||||||
|
warnings.append(f"Model path does not exist: {model_path}")
|
||||||
|
|
||||||
|
train_file = data_dir / "train.jsonl"
|
||||||
|
if not train_file.exists():
|
||||||
|
warnings.append(
|
||||||
|
f"Training data not found: {train_file}\n"
|
||||||
|
" Generate it with:\n"
|
||||||
|
" python scripts/export_trajectories.py --verbose\n"
|
||||||
|
f" mkdir -p {data_dir}\n"
|
||||||
|
f" cp ~/timmy-training-data.jsonl {train_file}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if verbose and not warnings:
|
||||||
|
print("Pre-flight checks: all OK")
|
||||||
|
|
||||||
|
return warnings
|
||||||
|
|
||||||
|
|
||||||
|
# ── Command builders ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _build_train_cmd(
|
||||||
|
model_path: str,
|
||||||
|
data_dir: Path,
|
||||||
|
adapter_path: Path,
|
||||||
|
batch_size: int,
|
||||||
|
lora_layers: int,
|
||||||
|
iters: int,
|
||||||
|
learning_rate: float,
|
||||||
|
) -> list[str]:
|
||||||
|
return [
|
||||||
|
sys.executable, "-m", "mlx_lm.lora",
|
||||||
|
"--model", model_path,
|
||||||
|
"--train",
|
||||||
|
"--data", str(data_dir),
|
||||||
|
"--batch-size", str(batch_size),
|
||||||
|
"--lora-layers", str(lora_layers),
|
||||||
|
"--iters", str(iters),
|
||||||
|
"--learning-rate", str(learning_rate),
|
||||||
|
"--adapter-path", str(adapter_path),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_test_cmd(
|
||||||
|
model_path: str,
|
||||||
|
adapter_path: Path,
|
||||||
|
prompt: str,
|
||||||
|
) -> list[str]:
|
||||||
|
return [
|
||||||
|
sys.executable, "-m", "mlx_lm.generate",
|
||||||
|
"--model", model_path,
|
||||||
|
"--adapter-path", str(adapter_path),
|
||||||
|
"--prompt", prompt,
|
||||||
|
"--max-tokens", "512",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_fuse_cmd(
|
||||||
|
model_path: str,
|
||||||
|
adapter_path: Path,
|
||||||
|
save_path: Path,
|
||||||
|
) -> list[str]:
|
||||||
|
return [
|
||||||
|
sys.executable, "-m", "mlx_lm.fuse",
|
||||||
|
"--model", model_path,
|
||||||
|
"--adapter-path", str(adapter_path),
|
||||||
|
"--save-path", str(save_path),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Runner ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _run(cmd: list[str], dry_run: bool, verbose: bool) -> int:
|
||||||
|
"""Print and optionally execute a command."""
|
||||||
|
print("\nCommand:")
|
||||||
|
print(" " + " \\\n ".join(cmd))
|
||||||
|
if dry_run:
|
||||||
|
print("\n(dry-run — not executing)")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
print()
|
||||||
|
result = subprocess.run(cmd)
|
||||||
|
return result.returncode
|
||||||
|
|
||||||
|
|
||||||
|
# ── Main ──────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv: list[str] | None = None) -> int:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="LoRA fine-tuning launcher for Hermes 4 (AutoLoRA Step 4)",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog=__doc__,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mode flags (mutually exclusive-ish)
|
||||||
|
mode = parser.add_mutually_exclusive_group()
|
||||||
|
mode.add_argument(
|
||||||
|
"--test",
|
||||||
|
action="store_true",
|
||||||
|
help="Run inference test with trained adapter instead of training",
|
||||||
|
)
|
||||||
|
mode.add_argument(
|
||||||
|
"--fuse",
|
||||||
|
action="store_true",
|
||||||
|
help="Fuse adapter into base model (for Ollama import)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Paths
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
default=None,
|
||||||
|
help=f"Path to local MLX model (or set {DEFAULT_MODEL_PATH_ENV} env var)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--data",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_DATA_DIR,
|
||||||
|
help=f"Training data directory (default: {DEFAULT_DATA_DIR})",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adapter-path",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_ADAPTER_PATH,
|
||||||
|
help=f"LoRA adapter output path (default: {DEFAULT_ADAPTER_PATH})",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-path",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_FUSED_PATH,
|
||||||
|
help=f"Fused model output path (default: {DEFAULT_FUSED_PATH})",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_BATCH_SIZE,
|
||||||
|
help=f"Training batch size (default: {DEFAULT_BATCH_SIZE}; reduce to 1 if OOM)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora-layers",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_LORA_LAYERS,
|
||||||
|
help=f"Number of LoRA layers (default: {DEFAULT_LORA_LAYERS}; reduce if OOM)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--iters",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_ITERS,
|
||||||
|
help=f"Training iterations (default: {DEFAULT_ITERS})",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning-rate",
|
||||||
|
type=float,
|
||||||
|
default=DEFAULT_LEARNING_RATE,
|
||||||
|
help=f"Learning rate (default: {DEFAULT_LEARNING_RATE})",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Misc
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
default=DEFAULT_TEST_PROMPT,
|
||||||
|
help="Prompt for --test mode",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dry-run",
|
||||||
|
action="store_true",
|
||||||
|
help="Print command without executing",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
"-v",
|
||||||
|
action="store_true",
|
||||||
|
help="Print extra progress information",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip-preflight",
|
||||||
|
action="store_true",
|
||||||
|
help="Skip pre-flight checks (useful in CI)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args(argv)
|
||||||
|
model_path = _resolve_model_path(args.model)
|
||||||
|
|
||||||
|
# ── Pre-flight ──────────────────────────────────────────────────────────
|
||||||
|
if not args.skip_preflight:
|
||||||
|
warnings = _preflight(model_path, args.data, args.verbose)
|
||||||
|
if warnings:
|
||||||
|
for w in warnings:
|
||||||
|
print(f"WARNING: {w}\n")
|
||||||
|
if not args.dry_run:
|
||||||
|
print("Aborting due to pre-flight warnings. Use --dry-run to see commands anyway.")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
if model_path is None:
|
||||||
|
# Allow dry-run without a model for documentation purposes
|
||||||
|
model_path = "<path-to-hermes4-mlx>"
|
||||||
|
|
||||||
|
# ── Mode dispatch ────────────────────────────────────────────────────────
|
||||||
|
if args.test:
|
||||||
|
print(f"Testing fine-tuned model with adapter: {args.adapter_path}")
|
||||||
|
cmd = _build_test_cmd(model_path, args.adapter_path, args.prompt)
|
||||||
|
return _run(cmd, args.dry_run, args.verbose)
|
||||||
|
|
||||||
|
if args.fuse:
|
||||||
|
print(f"Fusing adapter {args.adapter_path} into base model → {args.save_path}")
|
||||||
|
cmd = _build_fuse_cmd(model_path, args.adapter_path, args.save_path)
|
||||||
|
rc = _run(cmd, args.dry_run, args.verbose)
|
||||||
|
if rc == 0 and not args.dry_run:
|
||||||
|
print(
|
||||||
|
f"\nFused model saved to: {args.save_path}\n"
|
||||||
|
"To import into Ollama:\n"
|
||||||
|
f" ollama create timmy-hermes4 -f Modelfile.hermes4-14b\n"
|
||||||
|
" (edit Modelfile to point FROM to the fused GGUF path)"
|
||||||
|
)
|
||||||
|
return rc
|
||||||
|
|
||||||
|
# Default: train
|
||||||
|
print(f"Starting LoRA fine-tuning")
|
||||||
|
print(f" Model: {model_path}")
|
||||||
|
print(f" Data: {args.data}")
|
||||||
|
print(f" Adapter path: {args.adapter_path}")
|
||||||
|
print(f" Iterations: {args.iters}")
|
||||||
|
print(f" Batch size: {args.batch_size}")
|
||||||
|
print(f" LoRA layers: {args.lora_layers}")
|
||||||
|
print(f" Learning rate:{args.learning_rate}")
|
||||||
|
print()
|
||||||
|
print("Estimated time: 2-8 hours on M3 Max (depends on dataset size).")
|
||||||
|
print("If OOM: reduce --lora-layers to 8 or --batch-size stays at 1.")
|
||||||
|
|
||||||
|
cmd = _build_train_cmd(
|
||||||
|
model_path=model_path,
|
||||||
|
data_dir=args.data,
|
||||||
|
adapter_path=args.adapter_path,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
lora_layers=args.lora_layers,
|
||||||
|
iters=args.iters,
|
||||||
|
learning_rate=args.learning_rate,
|
||||||
|
)
|
||||||
|
rc = _run(cmd, args.dry_run, args.verbose)
|
||||||
|
|
||||||
|
if rc == 0 and not args.dry_run:
|
||||||
|
print(
|
||||||
|
f"\nTraining complete! Adapter saved to: {args.adapter_path}\n"
|
||||||
|
"Test with:\n"
|
||||||
|
f" python scripts/lora_finetune.py --test\n"
|
||||||
|
"Then fuse + import to Ollama:\n"
|
||||||
|
f" python scripts/lora_finetune.py --fuse"
|
||||||
|
)
|
||||||
|
|
||||||
|
return rc
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
244
scripts/test_gabs_connectivity.py
Normal file
244
scripts/test_gabs_connectivity.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""GABS TCP connectivity and JSON-RPC smoke test.
|
||||||
|
|
||||||
|
Tests connectivity from Hermes to the Bannerlord.GABS TCP server running on the
|
||||||
|
Windows VM. Covers:
|
||||||
|
1. TCP socket connection (port 4825 reachable)
|
||||||
|
2. JSON-RPC ping round-trip
|
||||||
|
3. get_game_state call (game must be running)
|
||||||
|
4. Latency — target < 100 ms on LAN
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python scripts/test_gabs_connectivity.py --host 10.0.0.50
|
||||||
|
python scripts/test_gabs_connectivity.py --host 10.0.0.50 --port 4825 --timeout 5
|
||||||
|
|
||||||
|
Refs: #1098 (Bannerlord Infra — Windows VM Setup + GABS Mod Installation)
|
||||||
|
Epic: #1091 (Project Bannerlord)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import socket
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
DEFAULT_HOST = "127.0.0.1"
|
||||||
|
DEFAULT_PORT = 4825
|
||||||
|
DEFAULT_TIMEOUT = 5 # seconds
|
||||||
|
LATENCY_TARGET_MS = 100.0
|
||||||
|
|
||||||
|
|
||||||
|
# ── Low-level TCP helpers ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _tcp_connect(host: str, port: int, timeout: float) -> socket.socket:
|
||||||
|
"""Open a TCP connection and return the socket. Raises on failure."""
|
||||||
|
sock = socket.create_connection((host, port), timeout=timeout)
|
||||||
|
sock.settimeout(timeout)
|
||||||
|
return sock
|
||||||
|
|
||||||
|
|
||||||
|
def _send_recv(sock: socket.socket, payload: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Send a newline-delimited JSON-RPC request and return the parsed response."""
|
||||||
|
raw = json.dumps(payload) + "\n"
|
||||||
|
sock.sendall(raw.encode())
|
||||||
|
|
||||||
|
buf = b""
|
||||||
|
while b"\n" not in buf:
|
||||||
|
chunk = sock.recv(4096)
|
||||||
|
if not chunk:
|
||||||
|
raise ConnectionError("Connection closed before response received")
|
||||||
|
buf += chunk
|
||||||
|
|
||||||
|
line = buf.split(b"\n", 1)[0]
|
||||||
|
return json.loads(line.decode())
|
||||||
|
|
||||||
|
|
||||||
|
def _rpc(sock: socket.socket, method: str, params: dict | None = None, req_id: int = 1) -> dict[str, Any]:
|
||||||
|
"""Build and send a JSON-RPC 2.0 request, return the response dict."""
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": method,
|
||||||
|
"id": req_id,
|
||||||
|
}
|
||||||
|
if params:
|
||||||
|
payload["params"] = params
|
||||||
|
return _send_recv(sock, payload)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test cases ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_tcp_connection(host: str, port: int, timeout: float) -> tuple[bool, socket.socket | None]:
|
||||||
|
"""PASS: TCP connection to host:port succeeds."""
|
||||||
|
print(f"\n[1/4] TCP connection → {host}:{port}")
|
||||||
|
try:
|
||||||
|
t0 = time.monotonic()
|
||||||
|
sock = _tcp_connect(host, port, timeout)
|
||||||
|
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||||
|
print(f" ✓ Connected ({elapsed_ms:.1f} ms)")
|
||||||
|
return True, sock
|
||||||
|
except OSError as exc:
|
||||||
|
print(f" ✗ Connection failed: {exc}")
|
||||||
|
print(f" Checklist:")
|
||||||
|
print(f" - Is Bannerlord running with GABS mod enabled?")
|
||||||
|
print(f" - Is port {port} open in Windows Firewall?")
|
||||||
|
print(f" - Is the VM IP correct? (got: {host})")
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
|
||||||
|
def test_ping(sock: socket.socket) -> bool:
|
||||||
|
"""PASS: JSON-RPC ping returns a 2.0 response."""
|
||||||
|
print(f"\n[2/4] JSON-RPC ping")
|
||||||
|
try:
|
||||||
|
t0 = time.monotonic()
|
||||||
|
resp = _rpc(sock, "ping", req_id=1)
|
||||||
|
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||||
|
if resp.get("jsonrpc") == "2.0" and "error" not in resp:
|
||||||
|
print(f" ✓ Ping OK ({elapsed_ms:.1f} ms): {json.dumps(resp)}")
|
||||||
|
return True
|
||||||
|
print(f" ✗ Unexpected response ({elapsed_ms:.1f} ms): {json.dumps(resp)}")
|
||||||
|
return False
|
||||||
|
except Exception as exc:
|
||||||
|
print(f" ✗ Ping failed: {exc}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def test_game_state(sock: socket.socket) -> bool:
|
||||||
|
"""PASS: get_game_state returns a result (game must be in a campaign)."""
|
||||||
|
print(f"\n[3/4] get_game_state call")
|
||||||
|
try:
|
||||||
|
t0 = time.monotonic()
|
||||||
|
resp = _rpc(sock, "get_game_state", req_id=2)
|
||||||
|
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||||
|
if "error" in resp:
|
||||||
|
code = resp["error"].get("code", "?")
|
||||||
|
msg = resp["error"].get("message", "")
|
||||||
|
if code == -32601:
|
||||||
|
# Method not found — GABS version may not expose this method
|
||||||
|
print(f" ~ Method not available ({elapsed_ms:.1f} ms): {msg}")
|
||||||
|
print(f" This is acceptable if game is not yet in a campaign.")
|
||||||
|
return True
|
||||||
|
print(f" ✗ RPC error ({elapsed_ms:.1f} ms) [{code}]: {msg}")
|
||||||
|
return False
|
||||||
|
result = resp.get("result", {})
|
||||||
|
print(f" ✓ Game state received ({elapsed_ms:.1f} ms):")
|
||||||
|
for k, v in result.items():
|
||||||
|
print(f" {k}: {v}")
|
||||||
|
return True
|
||||||
|
except Exception as exc:
|
||||||
|
print(f" ✗ get_game_state failed: {exc}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def test_latency(host: str, port: int, timeout: float, iterations: int = 5) -> bool:
|
||||||
|
"""PASS: Average round-trip latency is under LATENCY_TARGET_MS."""
|
||||||
|
print(f"\n[4/4] Latency test ({iterations} pings, target < {LATENCY_TARGET_MS:.0f} ms)")
|
||||||
|
try:
|
||||||
|
times: list[float] = []
|
||||||
|
for i in range(iterations):
|
||||||
|
sock = _tcp_connect(host, port, timeout)
|
||||||
|
try:
|
||||||
|
t0 = time.monotonic()
|
||||||
|
_rpc(sock, "ping", req_id=i + 10)
|
||||||
|
times.append((time.monotonic() - t0) * 1000)
|
||||||
|
finally:
|
||||||
|
sock.close()
|
||||||
|
|
||||||
|
avg_ms = sum(times) / len(times)
|
||||||
|
min_ms = min(times)
|
||||||
|
max_ms = max(times)
|
||||||
|
print(f" avg={avg_ms:.1f} ms min={min_ms:.1f} ms max={max_ms:.1f} ms")
|
||||||
|
|
||||||
|
if avg_ms <= LATENCY_TARGET_MS:
|
||||||
|
print(f" ✓ Latency within target ({avg_ms:.1f} ms ≤ {LATENCY_TARGET_MS:.0f} ms)")
|
||||||
|
return True
|
||||||
|
print(
|
||||||
|
f" ✗ Latency too high ({avg_ms:.1f} ms > {LATENCY_TARGET_MS:.0f} ms)\n"
|
||||||
|
f" Check network path between Hermes and the VM."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
except Exception as exc:
|
||||||
|
print(f" ✗ Latency test failed: {exc}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# ── Main ──────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser(description="GABS TCP connectivity smoke test")
|
||||||
|
parser.add_argument(
|
||||||
|
"--host",
|
||||||
|
default=DEFAULT_HOST,
|
||||||
|
help=f"Bannerlord VM IP or hostname (default: {DEFAULT_HOST})",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_PORT,
|
||||||
|
help=f"GABS TCP port (default: {DEFAULT_PORT})",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--timeout",
|
||||||
|
type=float,
|
||||||
|
default=DEFAULT_TIMEOUT,
|
||||||
|
help=f"Socket timeout in seconds (default: {DEFAULT_TIMEOUT})",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"GABS Connectivity Test Suite")
|
||||||
|
print(f"Target: {args.host}:{args.port}")
|
||||||
|
print(f"Timeout: {args.timeout}s")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
results: dict[str, bool] = {}
|
||||||
|
|
||||||
|
# Test 1: TCP connection (gate — skip remaining if unreachable)
|
||||||
|
ok, sock = test_tcp_connection(args.host, args.port, args.timeout)
|
||||||
|
results["tcp_connection"] = ok
|
||||||
|
if not ok:
|
||||||
|
_print_summary(results)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# Tests 2–3 reuse the same socket
|
||||||
|
try:
|
||||||
|
results["ping"] = test_ping(sock)
|
||||||
|
results["game_state"] = test_game_state(sock)
|
||||||
|
finally:
|
||||||
|
sock.close()
|
||||||
|
|
||||||
|
# Test 4: latency uses fresh connections
|
||||||
|
results["latency"] = test_latency(args.host, args.port, args.timeout)
|
||||||
|
|
||||||
|
return _print_summary(results)
|
||||||
|
|
||||||
|
|
||||||
|
def _print_summary(results: dict[str, bool]) -> int:
|
||||||
|
passed = sum(results.values())
|
||||||
|
total = len(results)
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print(f"Results: {passed}/{total} passed")
|
||||||
|
print("=" * 60)
|
||||||
|
for name, ok in results.items():
|
||||||
|
icon = "✓" if ok else "✗"
|
||||||
|
print(f" {icon} {name}")
|
||||||
|
|
||||||
|
if passed == total:
|
||||||
|
print("\n✓ GABS connectivity verified. Timmy can reach the game.")
|
||||||
|
print(" Next step: run benchmark level 0 (JSON compliance check).")
|
||||||
|
elif not results.get("tcp_connection"):
|
||||||
|
print("\n✗ TCP connection failed. VM/firewall setup incomplete.")
|
||||||
|
print(" See docs/research/bannerlord-vm-setup.md for checklist.")
|
||||||
|
else:
|
||||||
|
print("\n~ Partial pass — review failures above.")
|
||||||
|
|
||||||
|
return 0 if passed == total else 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
"""Bannerlord campaign agent — M2: Basic Campaign Actions.
|
|
||||||
|
|
||||||
Provides GABS integration (TCP JSON-RPC, port 4825) and the observe →
|
|
||||||
decide → act loop for autonomous campaign play: move, trade, recruit,
|
|
||||||
and engage bandits.
|
|
||||||
|
|
||||||
Key GABS tools: party/move_to_settlement, inventory/buy_item,
|
|
||||||
party/recruit_all, party/engage_party.
|
|
||||||
|
|
||||||
Done-condition: party grows from 20 → 100 troops, gold reaches 10 000 denars.
|
|
||||||
"""
|
|
||||||
@@ -1,200 +0,0 @@
|
|||||||
"""Bannerlord M2 campaign action primitives.
|
|
||||||
|
|
||||||
Wraps the four key GABS tools for the M2 milestone:
|
|
||||||
- party/move_to_settlement → move the party to a named settlement
|
|
||||||
- inventory/buy_item → purchase trade goods
|
|
||||||
- party/recruit_all → hire all available recruits
|
|
||||||
- party/engage_party → engage a nearby bandit party
|
|
||||||
|
|
||||||
All functions are async and return an ``ActionResult`` that is compatible
|
|
||||||
with the ``WorldInterface`` contract.
|
|
||||||
|
|
||||||
Error handling follows Pattern 3 (Feature Disable): if GABS rejects an
|
|
||||||
action, log a warning and return a FAILURE result — never raise.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from enum import StrEnum
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from infrastructure.world.types import ActionResult, ActionStatus
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from bannerlord.gabs_client import GabsClient
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# GABS method names — canonical reference
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class GabsTool(StrEnum):
|
|
||||||
"""GABS JSON-RPC method names for the M2 action set."""
|
|
||||||
|
|
||||||
MOVE_TO_SETTLEMENT = "party/move_to_settlement"
|
|
||||||
BUY_ITEM = "inventory/buy_item"
|
|
||||||
RECRUIT_ALL = "party/recruit_all"
|
|
||||||
ENGAGE_PARTY = "party/engage_party"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Action functions
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
async def move_to_settlement(
|
|
||||||
client: "GabsClient",
|
|
||||||
settlement_id: str,
|
|
||||||
*,
|
|
||||||
settlement_name: str = "",
|
|
||||||
) -> ActionResult:
|
|
||||||
"""Move the party to a target settlement.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
client:
|
|
||||||
Connected ``GabsClient`` instance.
|
|
||||||
settlement_id:
|
|
||||||
GABS settlement identifier (e.g. ``"town_A1"``).
|
|
||||||
settlement_name:
|
|
||||||
Human-readable name for logging only.
|
|
||||||
"""
|
|
||||||
label = settlement_name or settlement_id
|
|
||||||
try:
|
|
||||||
result = await client.call(
|
|
||||||
GabsTool.MOVE_TO_SETTLEMENT,
|
|
||||||
{"settlement_id": settlement_id},
|
|
||||||
)
|
|
||||||
logger.info("MOVE → %s: %s", label, result)
|
|
||||||
return ActionResult(
|
|
||||||
status=ActionStatus.SUCCESS,
|
|
||||||
message=f"Moving to {label}",
|
|
||||||
data=result or {},
|
|
||||||
)
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
logger.warning("MOVE → %s failed: %s", label, exc)
|
|
||||||
return ActionResult(
|
|
||||||
status=ActionStatus.FAILURE,
|
|
||||||
message=f"Move to {label} failed: {exc}",
|
|
||||||
data={},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def buy_item(
|
|
||||||
client: "GabsClient",
|
|
||||||
item_id: str,
|
|
||||||
quantity: int,
|
|
||||||
*,
|
|
||||||
settlement_id: str = "",
|
|
||||||
) -> ActionResult:
|
|
||||||
"""Purchase trade goods from the current or specified settlement.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
client:
|
|
||||||
Connected ``GabsClient`` instance.
|
|
||||||
item_id:
|
|
||||||
Item identifier (e.g. ``"grain"``, ``"iron"``, ``"wool"``).
|
|
||||||
quantity:
|
|
||||||
Number of units to purchase.
|
|
||||||
settlement_id:
|
|
||||||
Optional target settlement; empty means current location.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
params: dict = {"item_id": item_id, "quantity": quantity}
|
|
||||||
if settlement_id:
|
|
||||||
params["settlement_id"] = settlement_id
|
|
||||||
|
|
||||||
result = await client.call(GabsTool.BUY_ITEM, params)
|
|
||||||
logger.info("BUY %dx %s: %s", quantity, item_id, result)
|
|
||||||
return ActionResult(
|
|
||||||
status=ActionStatus.SUCCESS,
|
|
||||||
message=f"Purchased {quantity}x {item_id}",
|
|
||||||
data=result or {},
|
|
||||||
)
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
logger.warning("BUY %dx %s failed: %s", quantity, item_id, exc)
|
|
||||||
return ActionResult(
|
|
||||||
status=ActionStatus.FAILURE,
|
|
||||||
message=f"Buy {item_id} failed: {exc}",
|
|
||||||
data={},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def recruit_all(
|
|
||||||
client: "GabsClient",
|
|
||||||
*,
|
|
||||||
settlement_id: str = "",
|
|
||||||
) -> ActionResult:
|
|
||||||
"""Recruit all available troops at the current or specified settlement.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
client:
|
|
||||||
Connected ``GabsClient`` instance.
|
|
||||||
settlement_id:
|
|
||||||
Optional settlement to recruit from; empty means current.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
params: dict = {}
|
|
||||||
if settlement_id:
|
|
||||||
params["settlement_id"] = settlement_id
|
|
||||||
|
|
||||||
result = await client.call(GabsTool.RECRUIT_ALL, params)
|
|
||||||
recruited = (result or {}).get("recruited", "?")
|
|
||||||
logger.info("RECRUIT_ALL: recruited %s troops", recruited)
|
|
||||||
return ActionResult(
|
|
||||||
status=ActionStatus.SUCCESS,
|
|
||||||
message=f"Recruited {recruited} troops",
|
|
||||||
data=result or {},
|
|
||||||
)
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
logger.warning("RECRUIT_ALL failed: %s", exc)
|
|
||||||
return ActionResult(
|
|
||||||
status=ActionStatus.FAILURE,
|
|
||||||
message=f"Recruit all failed: {exc}",
|
|
||||||
data={},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def engage_party(
|
|
||||||
client: "GabsClient",
|
|
||||||
party_id: str,
|
|
||||||
*,
|
|
||||||
party_name: str = "",
|
|
||||||
) -> ActionResult:
|
|
||||||
"""Engage a nearby party (typically a bandit gang) in combat.
|
|
||||||
|
|
||||||
Auto-resolve is expected at high Tactics skill — the agent relies
|
|
||||||
on GABS to handle the battle outcome.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
client:
|
|
||||||
Connected ``GabsClient`` instance.
|
|
||||||
party_id:
|
|
||||||
GABS party identifier of the target.
|
|
||||||
party_name:
|
|
||||||
Human-readable name for logging only.
|
|
||||||
"""
|
|
||||||
label = party_name or party_id
|
|
||||||
try:
|
|
||||||
result = await client.call(GabsTool.ENGAGE_PARTY, {"party_id": party_id})
|
|
||||||
outcome = (result or {}).get("outcome", "unknown")
|
|
||||||
logger.info("ENGAGE %s: %s", label, outcome)
|
|
||||||
return ActionResult(
|
|
||||||
status=ActionStatus.SUCCESS,
|
|
||||||
message=f"Engaged {label}: {outcome}",
|
|
||||||
data=result or {},
|
|
||||||
)
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
logger.warning("ENGAGE %s failed: %s", label, exc)
|
|
||||||
return ActionResult(
|
|
||||||
status=ActionStatus.FAILURE,
|
|
||||||
message=f"Engage {label} failed: {exc}",
|
|
||||||
data={},
|
|
||||||
)
|
|
||||||
@@ -1,316 +0,0 @@
|
|||||||
"""Bannerlord M2 campaign action loop.
|
|
||||||
|
|
||||||
Implements the observe → decide → act → wait pipeline described in
|
|
||||||
issue #1094. The loop runs until the M2 victory conditions are met
|
|
||||||
(100 troops + 10 000 gold) or until stopped externally.
|
|
||||||
|
|
||||||
Architecture:
|
|
||||||
CampaignLoop.run()
|
|
||||||
while not m2_complete:
|
|
||||||
state = gabs.get_game_state() # observe
|
|
||||||
decision = decide(state) # decide (local Qwen3)
|
|
||||||
result = dispatch(decision, gabs) # act (GABS)
|
|
||||||
await asyncio.sleep(tick_seconds) # wait
|
|
||||||
|
|
||||||
Error handling:
|
|
||||||
- GABS connection failures → log + retry with backoff (max 3 attempts)
|
|
||||||
- LLM failures → WAIT action (graceful degradation)
|
|
||||||
- Action failures → log + continue to next tick
|
|
||||||
|
|
||||||
Progress tracking:
|
|
||||||
Loop publishes heartbeat events via the event bus so the dashboard
|
|
||||||
can display live party size and gold.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import UTC, datetime
|
|
||||||
|
|
||||||
from bannerlord.campaign_actions import buy_item, engage_party, move_to_settlement, recruit_all
|
|
||||||
from bannerlord.campaign_state import parse_campaign_state
|
|
||||||
from bannerlord.decision import M2Action, decide
|
|
||||||
from bannerlord.gabs_client import GabsClient
|
|
||||||
from config import settings
|
|
||||||
from infrastructure.world.types import ActionResult, ActionStatus
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_MAX_RECONNECT_ATTEMPTS = 3
|
|
||||||
_RECONNECT_DELAY = 5.0 # seconds
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Progress snapshot (emitted each tick)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TickResult:
|
|
||||||
"""Summary of one campaign tick."""
|
|
||||||
|
|
||||||
tick: int
|
|
||||||
timestamp: str
|
|
||||||
party_size: int
|
|
||||||
gold: int
|
|
||||||
action: str
|
|
||||||
action_status: str
|
|
||||||
reasoning: str
|
|
||||||
duration_ms: int
|
|
||||||
m2_complete: bool = False
|
|
||||||
error: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Campaign loop
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class CampaignLoop:
|
|
||||||
"""Runs the Bannerlord M2 autonomous campaign.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
gabs_host:
|
|
||||||
Override GABS server host.
|
|
||||||
gabs_port:
|
|
||||||
Override GABS server port.
|
|
||||||
tick_seconds:
|
|
||||||
Real-time pause between in-game ticks.
|
|
||||||
on_tick:
|
|
||||||
Optional async callback invoked after each tick with the
|
|
||||||
``TickResult``. Used by the dashboard for live updates.
|
|
||||||
max_ticks:
|
|
||||||
Hard cap for testing / benchmarking. 0 = unlimited.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
gabs_host: str | None = None,
|
|
||||||
gabs_port: int | None = None,
|
|
||||||
tick_seconds: float | None = None,
|
|
||||||
on_tick=None,
|
|
||||||
max_ticks: int = 0,
|
|
||||||
) -> None:
|
|
||||||
self._host = gabs_host or settings.gabs_host
|
|
||||||
self._port = gabs_port or settings.gabs_port
|
|
||||||
self._tick_seconds = tick_seconds if tick_seconds is not None else settings.bannerlord_tick_seconds
|
|
||||||
self._on_tick = on_tick
|
|
||||||
self._max_ticks = max_ticks
|
|
||||||
self._running = False
|
|
||||||
self.history: list[TickResult] = []
|
|
||||||
|
|
||||||
# -- public API --------------------------------------------------------
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_running(self) -> bool:
|
|
||||||
return self._running
|
|
||||||
|
|
||||||
def stop(self) -> None:
|
|
||||||
"""Signal the loop to stop after the current tick."""
|
|
||||||
self._running = False
|
|
||||||
logger.info("CampaignLoop stop requested")
|
|
||||||
|
|
||||||
async def run(self) -> list[TickResult]:
|
|
||||||
"""Start the campaign loop.
|
|
||||||
|
|
||||||
Returns the list of tick results (for testing / benchmarking).
|
|
||||||
Runs until M2 complete, externally stopped, or max_ticks reached.
|
|
||||||
"""
|
|
||||||
self._running = True
|
|
||||||
logger.info(
|
|
||||||
"CampaignLoop starting — gabs=%s:%d tick=%.1fs",
|
|
||||||
self._host,
|
|
||||||
self._port,
|
|
||||||
self._tick_seconds,
|
|
||||||
)
|
|
||||||
|
|
||||||
client = GabsClient(host=self._host, port=self._port)
|
|
||||||
try:
|
|
||||||
await self._connect_with_retry(client)
|
|
||||||
except RuntimeError as exc: # noqa: BLE001
|
|
||||||
logger.error("CampaignLoop: could not connect to GABS — aborting: %s", exc)
|
|
||||||
self._running = False
|
|
||||||
return self.history
|
|
||||||
|
|
||||||
tick_num = 0
|
|
||||||
try:
|
|
||||||
while self._running:
|
|
||||||
tick_num += 1
|
|
||||||
if self._max_ticks > 0 and tick_num > self._max_ticks:
|
|
||||||
logger.info("CampaignLoop: max_ticks=%d reached", self._max_ticks)
|
|
||||||
break
|
|
||||||
|
|
||||||
result = await self._run_tick(client, tick_num)
|
|
||||||
self.history.append(result)
|
|
||||||
|
|
||||||
await self._emit(result)
|
|
||||||
|
|
||||||
if result.m2_complete:
|
|
||||||
logger.info(
|
|
||||||
"M2 COMPLETE! Party=%d troops, Gold=%d denars",
|
|
||||||
result.party_size,
|
|
||||||
result.gold,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
if result.error and not self._running:
|
|
||||||
break
|
|
||||||
|
|
||||||
await asyncio.sleep(self._tick_seconds)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
await client.disconnect()
|
|
||||||
self._running = False
|
|
||||||
logger.info("CampaignLoop stopped after %d ticks", tick_num)
|
|
||||||
|
|
||||||
return self.history
|
|
||||||
|
|
||||||
# -- internal: single tick ---------------------------------------------
|
|
||||||
|
|
||||||
async def _run_tick(self, client: "Any", tick_num: int) -> TickResult:
|
|
||||||
"""Execute one observe → decide → act cycle."""
|
|
||||||
start = time.monotonic()
|
|
||||||
|
|
||||||
# 1. Observe
|
|
||||||
raw_state = await client.get_game_state()
|
|
||||||
state = parse_campaign_state(raw_state)
|
|
||||||
state = _override_tick(state, tick_num)
|
|
||||||
|
|
||||||
# 2. Decide
|
|
||||||
decision = await decide(state)
|
|
||||||
|
|
||||||
# 3. Act
|
|
||||||
action_result = await self._dispatch(decision, client)
|
|
||||||
|
|
||||||
duration_ms = int((time.monotonic() - start) * 1000)
|
|
||||||
|
|
||||||
return TickResult(
|
|
||||||
tick=tick_num,
|
|
||||||
timestamp=datetime.now(UTC).isoformat(),
|
|
||||||
party_size=state.party.party_size,
|
|
||||||
gold=state.economy.gold,
|
|
||||||
action=decision.action,
|
|
||||||
action_status=action_result.status.value,
|
|
||||||
reasoning=decision.reasoning,
|
|
||||||
duration_ms=duration_ms,
|
|
||||||
m2_complete=state.m2_complete,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _dispatch(self, decision: "Any", client: "Any") -> "Any":
|
|
||||||
"""Route the decision to the correct GABS action function."""
|
|
||||||
action = decision.action
|
|
||||||
|
|
||||||
if action == M2Action.MOVE:
|
|
||||||
if not decision.settlement_id:
|
|
||||||
logger.warning("MOVE decision has no settlement_id — skipping")
|
|
||||||
return ActionResult(
|
|
||||||
status=ActionStatus.FAILURE,
|
|
||||||
message="MOVE missing settlement_id",
|
|
||||||
)
|
|
||||||
return await move_to_settlement(
|
|
||||||
client,
|
|
||||||
decision.settlement_id,
|
|
||||||
settlement_name=decision.settlement_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif action == M2Action.TRADE:
|
|
||||||
if not decision.item_id:
|
|
||||||
logger.warning("TRADE decision has no item_id — skipping")
|
|
||||||
return ActionResult(
|
|
||||||
status=ActionStatus.FAILURE,
|
|
||||||
message="TRADE missing item_id",
|
|
||||||
)
|
|
||||||
return await buy_item(
|
|
||||||
client,
|
|
||||||
decision.item_id,
|
|
||||||
decision.quantity,
|
|
||||||
settlement_id=decision.settlement_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif action == M2Action.RECRUIT:
|
|
||||||
return await recruit_all(
|
|
||||||
client,
|
|
||||||
settlement_id=decision.settlement_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif action == M2Action.ENGAGE:
|
|
||||||
if not decision.party_id:
|
|
||||||
logger.warning("ENGAGE decision has no party_id — skipping")
|
|
||||||
return ActionResult(
|
|
||||||
status=ActionStatus.FAILURE,
|
|
||||||
message="ENGAGE missing party_id",
|
|
||||||
)
|
|
||||||
return await engage_party(
|
|
||||||
client,
|
|
||||||
decision.party_id,
|
|
||||||
party_name=decision.party_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
else: # WAIT or unknown
|
|
||||||
logger.debug("Tick %s: WAIT — %s", decision.action, decision.reasoning)
|
|
||||||
return ActionResult(
|
|
||||||
status=ActionStatus.NOOP,
|
|
||||||
message=f"WAIT: {decision.reasoning}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# -- internal: connectivity --------------------------------------------
|
|
||||||
|
|
||||||
async def _connect_with_retry(self, client: "Any") -> None:
|
|
||||||
"""Try to connect, retrying up to _MAX_RECONNECT_ATTEMPTS times."""
|
|
||||||
for attempt in range(1, _MAX_RECONNECT_ATTEMPTS + 1):
|
|
||||||
try:
|
|
||||||
await client.connect()
|
|
||||||
return
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
logger.warning(
|
|
||||||
"GABS connect attempt %d/%d failed: %s",
|
|
||||||
attempt,
|
|
||||||
_MAX_RECONNECT_ATTEMPTS,
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
if attempt < _MAX_RECONNECT_ATTEMPTS:
|
|
||||||
await asyncio.sleep(_RECONNECT_DELAY)
|
|
||||||
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Could not connect to GABS at {self._host}:{self._port} "
|
|
||||||
f"after {_MAX_RECONNECT_ATTEMPTS} attempts"
|
|
||||||
)
|
|
||||||
|
|
||||||
# -- internal: event emission ------------------------------------------
|
|
||||||
|
|
||||||
async def _emit(self, result: TickResult) -> None:
|
|
||||||
"""Emit tick data to the event bus (best-effort)."""
|
|
||||||
try:
|
|
||||||
from infrastructure.events.bus import event_bus # noqa: PLC0415
|
|
||||||
|
|
||||||
await event_bus.publish(
|
|
||||||
"bannerlord.tick",
|
|
||||||
{
|
|
||||||
"tick": result.tick,
|
|
||||||
"party_size": result.party_size,
|
|
||||||
"gold": result.gold,
|
|
||||||
"action": result.action,
|
|
||||||
"action_status": result.action_status,
|
|
||||||
"m2_complete": result.m2_complete,
|
|
||||||
"duration_ms": result.duration_ms,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
logger.debug("CampaignLoop emit skipped: %s", exc)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _override_tick(state: "Any", tick_num: int) -> "Any":
|
|
||||||
"""Set the tick counter from the loop (GABS may not provide it)."""
|
|
||||||
if state.tick == 0:
|
|
||||||
state.tick = tick_num
|
|
||||||
return state
|
|
||||||
@@ -1,213 +0,0 @@
|
|||||||
"""Bannerlord campaign state models.
|
|
||||||
|
|
||||||
Parses the raw GABS ``game/get_state`` payload into typed models and
|
|
||||||
tracks the M2 progress counters: party size and gold accumulation.
|
|
||||||
|
|
||||||
Done-condition (from issue #1094):
|
|
||||||
party_size >= 100 AND gold >= 10_000
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import UTC, datetime
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# M2 victory conditions
|
|
||||||
M2_TROOP_GOAL = 100
|
|
||||||
M2_GOLD_GOAL = 10_000
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PartyState:
|
|
||||||
"""Current party composition and position."""
|
|
||||||
|
|
||||||
party_size: int = 0
|
|
||||||
wounded: int = 0
|
|
||||||
prisoners: int = 0
|
|
||||||
food_days: float = 0.0
|
|
||||||
morale: float = 100.0
|
|
||||||
current_settlement: str = ""
|
|
||||||
speed: float = 0.0
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class EconomyState:
|
|
||||||
"""Current gold and trade state."""
|
|
||||||
|
|
||||||
gold: int = 0
|
|
||||||
daily_income: int = 0
|
|
||||||
daily_expenses: int = 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def net_income(self) -> int:
|
|
||||||
return self.daily_income - self.daily_expenses
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class NearbyParty:
|
|
||||||
"""A nearby lord/bandit party visible on the map."""
|
|
||||||
|
|
||||||
party_id: str
|
|
||||||
name: str
|
|
||||||
faction: str
|
|
||||||
is_hostile: bool
|
|
||||||
troop_count: int
|
|
||||||
distance: float
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Settlement:
|
|
||||||
"""A settlement visible or reachable from the current position."""
|
|
||||||
|
|
||||||
settlement_id: str
|
|
||||||
name: str
|
|
||||||
faction: str
|
|
||||||
is_friendly: bool
|
|
||||||
distance: float
|
|
||||||
has_recruits: bool = False
|
|
||||||
has_trade_goods: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CampaignState:
|
|
||||||
"""Full parsed snapshot of the GABS game state.
|
|
||||||
|
|
||||||
Built from the raw ``dict`` returned by ``GabsClient.get_game_state()``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
tick: int = 0
|
|
||||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
|
||||||
party: PartyState = field(default_factory=PartyState)
|
|
||||||
economy: EconomyState = field(default_factory=EconomyState)
|
|
||||||
nearby_parties: list[NearbyParty] = field(default_factory=list)
|
|
||||||
settlements: list[Settlement] = field(default_factory=list)
|
|
||||||
raw: dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
# -- M2 progress -------------------------------------------------------
|
|
||||||
|
|
||||||
@property
|
|
||||||
def troops_progress(self) -> str:
|
|
||||||
"""Human-readable M2 troop progress."""
|
|
||||||
return f"{self.party.party_size}/{M2_TROOP_GOAL}"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def gold_progress(self) -> str:
|
|
||||||
"""Human-readable M2 gold progress."""
|
|
||||||
return f"{self.economy.gold:,}/{M2_GOLD_GOAL:,}"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def m2_complete(self) -> bool:
|
|
||||||
"""True when both M2 victory conditions are met."""
|
|
||||||
return self.party.party_size >= M2_TROOP_GOAL and self.economy.gold >= M2_GOLD_GOAL
|
|
||||||
|
|
||||||
# -- hostile detection -------------------------------------------------
|
|
||||||
|
|
||||||
def hostile_bandits_nearby(self, max_distance: float = 5.0) -> list[NearbyParty]:
|
|
||||||
"""Return hostile bandit parties within *max_distance* map units."""
|
|
||||||
return [
|
|
||||||
p
|
|
||||||
for p in self.nearby_parties
|
|
||||||
if p.is_hostile and "bandit" in p.faction.lower() and p.distance <= max_distance
|
|
||||||
]
|
|
||||||
|
|
||||||
def nearest_settlement(self, *, friendly_only: bool = False) -> Settlement | None:
|
|
||||||
"""Return the closest (optionally friendly) settlement."""
|
|
||||||
candidates = [s for s in self.settlements if not friendly_only or s.is_friendly]
|
|
||||||
if not candidates:
|
|
||||||
return None
|
|
||||||
return min(candidates, key=lambda s: s.distance)
|
|
||||||
|
|
||||||
def nearest_recruit_settlement(self) -> Settlement | None:
|
|
||||||
"""Return the nearest settlement that has recruits available."""
|
|
||||||
candidates = [s for s in self.settlements if s.has_recruits]
|
|
||||||
if not candidates:
|
|
||||||
return None
|
|
||||||
return min(candidates, key=lambda s: s.distance)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Parser
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def parse_campaign_state(raw: dict[str, Any]) -> CampaignState:
|
|
||||||
"""Build a ``CampaignState`` from the raw GABS state dict.
|
|
||||||
|
|
||||||
Unknown / missing fields are silently defaulted so the parser never
|
|
||||||
crashes when GABS returns partial data.
|
|
||||||
"""
|
|
||||||
if not raw:
|
|
||||||
logger.debug("parse_campaign_state: empty payload — returning default state")
|
|
||||||
return CampaignState(raw=raw)
|
|
||||||
|
|
||||||
# -- party -------------------------------------------------------------
|
|
||||||
party_raw = raw.get("party", {})
|
|
||||||
party = PartyState(
|
|
||||||
party_size=int(party_raw.get("size", 0)),
|
|
||||||
wounded=int(party_raw.get("wounded", 0)),
|
|
||||||
prisoners=int(party_raw.get("prisoners", 0)),
|
|
||||||
food_days=float(party_raw.get("food_days", 0.0)),
|
|
||||||
morale=float(party_raw.get("morale", 100.0)),
|
|
||||||
current_settlement=str(party_raw.get("current_settlement", "")),
|
|
||||||
speed=float(party_raw.get("speed", 0.0)),
|
|
||||||
)
|
|
||||||
|
|
||||||
# -- economy -----------------------------------------------------------
|
|
||||||
economy_raw = raw.get("economy", {})
|
|
||||||
economy = EconomyState(
|
|
||||||
gold=int(economy_raw.get("gold", 0)),
|
|
||||||
daily_income=int(economy_raw.get("daily_income", 0)),
|
|
||||||
daily_expenses=int(economy_raw.get("daily_expenses", 0)),
|
|
||||||
)
|
|
||||||
|
|
||||||
# -- nearby parties ----------------------------------------------------
|
|
||||||
nearby_parties = []
|
|
||||||
for p in raw.get("nearby_parties", []):
|
|
||||||
try:
|
|
||||||
if not isinstance(p, dict) or not p.get("id"):
|
|
||||||
logger.debug("Skipping malformed nearby_party entry: missing id")
|
|
||||||
continue
|
|
||||||
nearby_parties.append(
|
|
||||||
NearbyParty(
|
|
||||||
party_id=str(p.get("id", "")),
|
|
||||||
name=str(p.get("name", "")),
|
|
||||||
faction=str(p.get("faction", "")),
|
|
||||||
is_hostile=bool(p.get("is_hostile", False)),
|
|
||||||
troop_count=int(p.get("troop_count", 0)),
|
|
||||||
distance=float(p.get("distance", 999.0)),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except (KeyError, ValueError, TypeError, AttributeError) as exc:
|
|
||||||
logger.debug("Skipping malformed nearby_party entry: %s", exc)
|
|
||||||
|
|
||||||
# -- settlements -------------------------------------------------------
|
|
||||||
settlements = []
|
|
||||||
for s in raw.get("settlements", []):
|
|
||||||
try:
|
|
||||||
settlements.append(
|
|
||||||
Settlement(
|
|
||||||
settlement_id=str(s.get("id", "")),
|
|
||||||
name=str(s.get("name", "")),
|
|
||||||
faction=str(s.get("faction", "")),
|
|
||||||
is_friendly=bool(s.get("is_friendly", False)),
|
|
||||||
distance=float(s.get("distance", 999.0)),
|
|
||||||
has_recruits=bool(s.get("has_recruits", False)),
|
|
||||||
has_trade_goods=bool(s.get("has_trade_goods", False)),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except (KeyError, ValueError, TypeError, AttributeError) as exc:
|
|
||||||
logger.debug("Skipping malformed settlement entry: %s", exc)
|
|
||||||
|
|
||||||
return CampaignState(
|
|
||||||
tick=int(raw.get("tick", 0)),
|
|
||||||
timestamp=datetime.now(UTC),
|
|
||||||
party=party,
|
|
||||||
economy=economy,
|
|
||||||
nearby_parties=nearby_parties,
|
|
||||||
settlements=settlements,
|
|
||||||
raw=raw,
|
|
||||||
)
|
|
||||||
@@ -1,284 +0,0 @@
|
|||||||
"""LLM-powered campaign decision engine for Bannerlord M2.
|
|
||||||
|
|
||||||
Builds a structured prompt from the current ``CampaignState`` and asks
|
|
||||||
the local Qwen3 model (via Ollama) to choose one action from the M2
|
|
||||||
action vocabulary. Returns a ``CampaignDecision`` pydantic model with
|
|
||||||
the chosen action and its parameters.
|
|
||||||
|
|
||||||
The decision model is intentionally simple for M2:
|
|
||||||
MOVE → move to a named settlement
|
|
||||||
TRADE → buy a trade item
|
|
||||||
RECRUIT → hire troops at current/nearby settlement
|
|
||||||
ENGAGE → fight a nearby bandit party
|
|
||||||
WAIT → idle (e.g. low food, waiting for morale to recover)
|
|
||||||
|
|
||||||
Qwen3 responds in JSON mode with temperature=0.1 for deterministic play.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from enum import StrEnum
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Decision schema
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class M2Action(StrEnum):
|
|
||||||
"""Vocabulary of actions available in the M2 milestone."""
|
|
||||||
|
|
||||||
MOVE = "MOVE"
|
|
||||||
TRADE = "TRADE"
|
|
||||||
RECRUIT = "RECRUIT"
|
|
||||||
ENGAGE = "ENGAGE"
|
|
||||||
WAIT = "WAIT"
|
|
||||||
|
|
||||||
|
|
||||||
class CampaignDecision:
|
|
||||||
"""Parsed LLM decision for one campaign tick.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
action:
|
|
||||||
One of the ``M2Action`` values.
|
|
||||||
settlement_id:
|
|
||||||
Target settlement ID (for MOVE / RECRUIT / TRADE).
|
|
||||||
settlement_name:
|
|
||||||
Human-readable settlement name (for logging).
|
|
||||||
item_id:
|
|
||||||
Trade item to buy (for TRADE).
|
|
||||||
quantity:
|
|
||||||
Trade quantity (for TRADE).
|
|
||||||
party_id:
|
|
||||||
Target party ID (for ENGAGE).
|
|
||||||
party_name:
|
|
||||||
Human-readable party name (for ENGAGE / logging).
|
|
||||||
reasoning:
|
|
||||||
LLM's brief explanation of the choice.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
action: M2Action = M2Action.WAIT,
|
|
||||||
*,
|
|
||||||
settlement_id: str = "",
|
|
||||||
settlement_name: str = "",
|
|
||||||
item_id: str = "",
|
|
||||||
quantity: int = 1,
|
|
||||||
party_id: str = "",
|
|
||||||
party_name: str = "",
|
|
||||||
reasoning: str = "",
|
|
||||||
) -> None:
|
|
||||||
self.action = action
|
|
||||||
self.settlement_id = settlement_id
|
|
||||||
self.settlement_name = settlement_name
|
|
||||||
self.item_id = item_id
|
|
||||||
self.quantity = quantity
|
|
||||||
self.party_id = party_id
|
|
||||||
self.party_name = party_name
|
|
||||||
self.reasoning = reasoning
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (
|
|
||||||
f"CampaignDecision(action={self.action!r}, "
|
|
||||||
f"reasoning={self.reasoning[:60]!r})"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Prompt builder
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def build_decision_prompt(state: "Any") -> list[dict[str, str]]:
|
|
||||||
"""Return an OpenAI-style message list for the decision LLM.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
state:
|
|
||||||
A ``CampaignState`` instance.
|
|
||||||
"""
|
|
||||||
# Build a compact context block
|
|
||||||
party = state.party
|
|
||||||
econ = state.economy
|
|
||||||
ctx_lines = [
|
|
||||||
f"Campaign tick: {state.tick}",
|
|
||||||
f"Party size: {party.party_size} troops ({party.wounded} wounded)",
|
|
||||||
f"Food: {party.food_days:.1f} days remaining",
|
|
||||||
f"Morale: {party.morale:.0f}/100",
|
|
||||||
f"Gold: {econ.gold:,} denars (net {econ.net_income:+d}/day)",
|
|
||||||
f"Current location: {party.current_settlement or 'travelling'}",
|
|
||||||
"",
|
|
||||||
"== M2 GOALS ==",
|
|
||||||
f"Troops: {state.troops_progress} (need 100)",
|
|
||||||
f"Gold: {state.gold_progress} (need 10,000)",
|
|
||||||
"",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Nearby bandits
|
|
||||||
bandits = state.hostile_bandits_nearby()
|
|
||||||
if bandits:
|
|
||||||
ctx_lines.append("== NEARBY HOSTILE BANDITS ==")
|
|
||||||
for b in bandits[:3]:
|
|
||||||
ctx_lines.append(
|
|
||||||
f" - {b.name} (id={b.party_id}, {b.troop_count} troops, "
|
|
||||||
f"{b.distance:.1f} away)"
|
|
||||||
)
|
|
||||||
ctx_lines.append("")
|
|
||||||
|
|
||||||
# Settlements
|
|
||||||
settlements = state.settlements[:5]
|
|
||||||
if settlements:
|
|
||||||
ctx_lines.append("== REACHABLE SETTLEMENTS ==")
|
|
||||||
for s in settlements:
|
|
||||||
flags = []
|
|
||||||
if s.has_recruits:
|
|
||||||
flags.append("recruits")
|
|
||||||
if s.has_trade_goods:
|
|
||||||
flags.append("trade")
|
|
||||||
if not s.is_friendly:
|
|
||||||
flags.append("hostile-faction")
|
|
||||||
flag_str = f" [{', '.join(flags)}]" if flags else ""
|
|
||||||
ctx_lines.append(
|
|
||||||
f" - {s.name} (id={s.settlement_id}, "
|
|
||||||
f"{s.distance:.1f} away{flag_str})"
|
|
||||||
)
|
|
||||||
ctx_lines.append("")
|
|
||||||
|
|
||||||
context = "\n".join(ctx_lines)
|
|
||||||
|
|
||||||
system_prompt = (
|
|
||||||
"You are the campaign manager for Timmy, an autonomous Bannerlord agent. "
|
|
||||||
"Your job is to choose the single best action for this campaign tick. "
|
|
||||||
"Respond ONLY with a JSON object — no prose, no markdown fences.\n\n"
|
|
||||||
"JSON schema:\n"
|
|
||||||
'{\n'
|
|
||||||
' "action": "MOVE|TRADE|RECRUIT|ENGAGE|WAIT",\n'
|
|
||||||
' "settlement_id": "<id or empty>",\n'
|
|
||||||
' "settlement_name": "<name or empty>",\n'
|
|
||||||
' "item_id": "<item or empty>",\n'
|
|
||||||
' "quantity": <int>,\n'
|
|
||||||
' "party_id": "<id or empty>",\n'
|
|
||||||
' "party_name": "<name or empty>",\n'
|
|
||||||
' "reasoning": "<one sentence>"\n'
|
|
||||||
"}\n\n"
|
|
||||||
"Priority rules:\n"
|
|
||||||
"1. ENGAGE bandits only if they are weak (< 15 troops) and we have > 25 troops.\n"
|
|
||||||
"2. RECRUIT when a nearby settlement has recruits and party < 80 troops.\n"
|
|
||||||
"3. TRADE when gold < 5000 and a settlement has trade goods.\n"
|
|
||||||
"4. MOVE toward the nearest settlement with recruits or trade goods.\n"
|
|
||||||
"5. WAIT only if food < 1 day or morale < 40."
|
|
||||||
)
|
|
||||||
|
|
||||||
user_prompt = f"Current game state:\n\n{context}\nChoose the best action."
|
|
||||||
|
|
||||||
return [
|
|
||||||
{"role": "system", "content": system_prompt},
|
|
||||||
{"role": "user", "content": user_prompt},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Response parser
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def parse_decision(raw_response: str) -> CampaignDecision:
|
|
||||||
"""Parse the LLM JSON response into a ``CampaignDecision``.
|
|
||||||
|
|
||||||
Falls back to ``WAIT`` on any parse error so the loop never crashes.
|
|
||||||
"""
|
|
||||||
# Strip accidental markdown code fences
|
|
||||||
text = raw_response.strip()
|
|
||||||
if text.startswith("```"):
|
|
||||||
lines = text.splitlines()
|
|
||||||
text = "\n".join(
|
|
||||||
line for line in lines if not line.startswith("```")
|
|
||||||
).strip()
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = json.loads(text)
|
|
||||||
except json.JSONDecodeError as exc:
|
|
||||||
logger.warning("Decision parse error (bad JSON): %s | raw=%r", exc, raw_response[:200])
|
|
||||||
return CampaignDecision(action=M2Action.WAIT, reasoning="parse error")
|
|
||||||
|
|
||||||
try:
|
|
||||||
action_str = str(data.get("action", "WAIT")).upper()
|
|
||||||
try:
|
|
||||||
action = M2Action(action_str)
|
|
||||||
except ValueError:
|
|
||||||
logger.warning("Unknown action %r — defaulting to WAIT", action_str)
|
|
||||||
action = M2Action.WAIT
|
|
||||||
|
|
||||||
return CampaignDecision(
|
|
||||||
action=action,
|
|
||||||
settlement_id=str(data.get("settlement_id", "")),
|
|
||||||
settlement_name=str(data.get("settlement_name", "")),
|
|
||||||
item_id=str(data.get("item_id", "")),
|
|
||||||
quantity=max(1, int(data.get("quantity", 1))),
|
|
||||||
party_id=str(data.get("party_id", "")),
|
|
||||||
party_name=str(data.get("party_name", "")),
|
|
||||||
reasoning=str(data.get("reasoning", "")),
|
|
||||||
)
|
|
||||||
except (KeyError, ValueError, TypeError) as exc:
|
|
||||||
logger.warning("Decision parse error (bad fields): %s", exc)
|
|
||||||
return CampaignDecision(action=M2Action.WAIT, reasoning=f"field error: {exc}")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Main entry point
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
async def decide(state: "Any") -> CampaignDecision:
|
|
||||||
"""Ask the local LLM to choose a campaign action.
|
|
||||||
|
|
||||||
Uses the cascade router (Ollama → Claude fallback) configured in
|
|
||||||
``config/providers.yaml``. Gracefully returns WAIT on any LLM failure.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
state:
|
|
||||||
A ``CampaignState`` instance.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
CampaignDecision
|
|
||||||
The chosen action and its parameters.
|
|
||||||
"""
|
|
||||||
from config import settings
|
|
||||||
|
|
||||||
messages = build_decision_prompt(state)
|
|
||||||
model = settings.bannerlord_model
|
|
||||||
|
|
||||||
try:
|
|
||||||
from infrastructure.router import get_router
|
|
||||||
|
|
||||||
router = get_router()
|
|
||||||
response = await router.complete(
|
|
||||||
messages=messages,
|
|
||||||
model=model,
|
|
||||||
temperature=0.1,
|
|
||||||
)
|
|
||||||
raw_text: str = response.get("content", "")
|
|
||||||
decision = parse_decision(raw_text)
|
|
||||||
logger.info(
|
|
||||||
"Decision [tick=%d]: %s — %s",
|
|
||||||
state.tick,
|
|
||||||
decision.action,
|
|
||||||
decision.reasoning,
|
|
||||||
)
|
|
||||||
return decision
|
|
||||||
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
logger.warning("Decision LLM call failed: %s — defaulting to WAIT", exc)
|
|
||||||
return CampaignDecision(
|
|
||||||
action=M2Action.WAIT,
|
|
||||||
reasoning=f"LLM unavailable: {exc}",
|
|
||||||
)
|
|
||||||
@@ -1,195 +0,0 @@
|
|||||||
"""GABS TCP/JSON-RPC client for Bannerlord.
|
|
||||||
|
|
||||||
Connects to the GABS C# mod (Bannerlord.GABS) over TCP on port 4825
|
|
||||||
and dispatches JSON-RPC 2.0 requests. All I/O is async; synchronous
|
|
||||||
callers must wrap in ``asyncio.to_thread()``.
|
|
||||||
|
|
||||||
Architecture:
|
|
||||||
Bannerlord (Windows VM) ← GABS C# mod ← TCP:4825 ← this client
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
async with GabsClient() as client:
|
|
||||||
state = await client.get_game_state()
|
|
||||||
result = await client.call("party/move_to_settlement",
|
|
||||||
{"settlement_id": "town_A1"})
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from config import settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# JSON-RPC framing: each message is newline-delimited UTF-8 JSON.
|
|
||||||
_ENCODING = "utf-8"
|
|
||||||
_NEWLINE = b"\n"
|
|
||||||
_DEFAULT_TIMEOUT = 30.0
|
|
||||||
|
|
||||||
|
|
||||||
class GabsError(Exception):
|
|
||||||
"""Raised when GABS returns a JSON-RPC error response."""
|
|
||||||
|
|
||||||
def __init__(self, code: int, message: str, data: Any = None) -> None:
|
|
||||||
super().__init__(f"GABS error {code}: {message}")
|
|
||||||
self.code = code
|
|
||||||
self.data = data
|
|
||||||
|
|
||||||
|
|
||||||
class GabsClient:
|
|
||||||
"""Async TCP JSON-RPC 2.0 client for the GABS Bannerlord mod.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
host:
|
|
||||||
GABS server host (Windows VM IP or ``localhost`` for port-forwarded).
|
|
||||||
port:
|
|
||||||
GABS server port (default 4825).
|
|
||||||
timeout:
|
|
||||||
Per-call timeout in seconds.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
host: str | None = None,
|
|
||||||
port: int | None = None,
|
|
||||||
timeout: float = _DEFAULT_TIMEOUT,
|
|
||||||
) -> None:
|
|
||||||
self._host = host or settings.gabs_host
|
|
||||||
self._port = port or settings.gabs_port
|
|
||||||
self._timeout = timeout
|
|
||||||
self._reader: asyncio.StreamReader | None = None
|
|
||||||
self._writer: asyncio.StreamWriter | None = None
|
|
||||||
self._req_id = 0
|
|
||||||
self._connected = False
|
|
||||||
|
|
||||||
# -- lifecycle ---------------------------------------------------------
|
|
||||||
|
|
||||||
async def connect(self) -> None:
|
|
||||||
"""Open the TCP connection to GABS."""
|
|
||||||
try:
|
|
||||||
self._reader, self._writer = await asyncio.wait_for(
|
|
||||||
asyncio.open_connection(self._host, self._port),
|
|
||||||
timeout=self._timeout,
|
|
||||||
)
|
|
||||||
self._connected = True
|
|
||||||
logger.info("GabsClient connected to %s:%d", self._host, self._port)
|
|
||||||
except (OSError, asyncio.TimeoutError) as exc:
|
|
||||||
logger.warning("GabsClient could not connect to GABS: %s", exc)
|
|
||||||
self._connected = False
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def disconnect(self) -> None:
|
|
||||||
"""Close the TCP connection."""
|
|
||||||
if self._writer is not None:
|
|
||||||
try:
|
|
||||||
self._writer.close()
|
|
||||||
await self._writer.wait_closed()
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
logger.debug("GabsClient disconnect error (ignored): %s", exc)
|
|
||||||
self._connected = False
|
|
||||||
self._reader = None
|
|
||||||
self._writer = None
|
|
||||||
logger.info("GabsClient disconnected")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_connected(self) -> bool:
|
|
||||||
return self._connected
|
|
||||||
|
|
||||||
# -- context manager ---------------------------------------------------
|
|
||||||
|
|
||||||
async def __aenter__(self) -> "GabsClient":
|
|
||||||
await self.connect()
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, *_: Any) -> None:
|
|
||||||
await self.disconnect()
|
|
||||||
|
|
||||||
# -- public API --------------------------------------------------------
|
|
||||||
|
|
||||||
async def call(self, method: str, params: dict[str, Any] | None = None) -> Any:
|
|
||||||
"""Call a GABS tool and return the result.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
method:
|
|
||||||
GABS tool name, e.g. ``"party/move_to_settlement"``.
|
|
||||||
params:
|
|
||||||
Tool parameters dict.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Any
|
|
||||||
The ``result`` field from the JSON-RPC response.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
GabsError
|
|
||||||
If GABS returns an error response.
|
|
||||||
RuntimeError
|
|
||||||
If not connected.
|
|
||||||
"""
|
|
||||||
if not self._connected or self._writer is None or self._reader is None:
|
|
||||||
raise RuntimeError("GabsClient is not connected — call connect() first")
|
|
||||||
|
|
||||||
self._req_id += 1
|
|
||||||
request = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": self._req_id,
|
|
||||||
"method": method,
|
|
||||||
"params": params or {},
|
|
||||||
}
|
|
||||||
|
|
||||||
raw = json.dumps(request).encode(_ENCODING) + _NEWLINE
|
|
||||||
|
|
||||||
try:
|
|
||||||
self._writer.write(raw)
|
|
||||||
await asyncio.wait_for(self._writer.drain(), timeout=self._timeout)
|
|
||||||
|
|
||||||
line = await asyncio.wait_for(
|
|
||||||
self._reader.readline(), timeout=self._timeout
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError as exc:
|
|
||||||
raise RuntimeError(f"GABS call '{method}' timed out after {self._timeout}s") from exc
|
|
||||||
except (OSError, ConnectionResetError) as exc:
|
|
||||||
self._connected = False
|
|
||||||
raise RuntimeError(f"GABS connection lost during '{method}': {exc}") from exc
|
|
||||||
|
|
||||||
response = json.loads(line.decode(_ENCODING))
|
|
||||||
|
|
||||||
if "error" in response:
|
|
||||||
err = response["error"]
|
|
||||||
raise GabsError(
|
|
||||||
code=err.get("code", -1),
|
|
||||||
message=err.get("message", "unknown error"),
|
|
||||||
data=err.get("data"),
|
|
||||||
)
|
|
||||||
|
|
||||||
return response.get("result")
|
|
||||||
|
|
||||||
async def get_game_state(self) -> dict[str, Any]:
|
|
||||||
"""Return the full game state snapshot from GABS.
|
|
||||||
|
|
||||||
Returns an empty dict and logs a warning if GABS is unreachable.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await self.call("game/get_state")
|
|
||||||
return result if isinstance(result, dict) else {}
|
|
||||||
except (GabsError, RuntimeError) as exc:
|
|
||||||
logger.warning("GABS get_game_state failed: %s", exc)
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def ping(self) -> bool:
|
|
||||||
"""Return True if GABS responds to a ping."""
|
|
||||||
try:
|
|
||||||
await self.call("game/ping")
|
|
||||||
return True
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
logger.debug("GABS ping failed: %s", exc)
|
|
||||||
return False
|
|
||||||
@@ -374,17 +374,6 @@ class Settings(BaseSettings):
|
|||||||
error_feedback_enabled: bool = True # Auto-create bug report tasks
|
error_feedback_enabled: bool = True # Auto-create bug report tasks
|
||||||
error_dedup_window_seconds: int = 300 # 5-min dedup window
|
error_dedup_window_seconds: int = 300 # 5-min dedup window
|
||||||
|
|
||||||
# ── Bannerlord / GABS ─────────────────────────────────────────────
|
|
||||||
# GABS (Bannerlord Agent Bridge System) TCP/JSON-RPC server.
|
|
||||||
# Runs inside the Windows VM hosting Bannerlord.
|
|
||||||
# Override with GABS_HOST / GABS_PORT env vars.
|
|
||||||
gabs_host: str = "localhost"
|
|
||||||
gabs_port: int = 4825
|
|
||||||
# Decision model for the Bannerlord campaign agent (Qwen3 preferred).
|
|
||||||
bannerlord_model: str = "qwen3:14b"
|
|
||||||
# Campaign-tick interval in seconds (real-time pause between in-game days).
|
|
||||||
bannerlord_tick_seconds: float = 5.0
|
|
||||||
|
|
||||||
# ── Scripture / Biblical Integration ──────────────────────────────
|
# ── Scripture / Biblical Integration ──────────────────────────────
|
||||||
# Enable the biblical text module.
|
# Enable the biblical text module.
|
||||||
scripture_enabled: bool = True
|
scripture_enabled: bool = True
|
||||||
|
|||||||
@@ -196,7 +196,7 @@ async def get_evening_ritual_form(request: Request, db: Session = Depends(get_db
|
|||||||
if not journal_entry:
|
if not journal_entry:
|
||||||
raise HTTPException(status_code=404, detail="No journal entry for today")
|
raise HTTPException(status_code=404, detail="No journal entry for today")
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
"calm/evening_ritual_form.html", {"request": request, "journal_entry": journal_entry}
|
request, "calm/evening_ritual_form.html", {"journal_entry": journal_entry}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -257,8 +257,9 @@ async def create_new_task(
|
|||||||
# After creating a new task, we might need to re-evaluate NOW/NEXT/LATER, but for simplicity
|
# After creating a new task, we might need to re-evaluate NOW/NEXT/LATER, but for simplicity
|
||||||
# and given the spec, new tasks go to LATER. Promotion happens on completion/deferral.
|
# and given the spec, new tasks go to LATER. Promotion happens on completion/deferral.
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
|
request,
|
||||||
"calm/partials/later_count.html",
|
"calm/partials/later_count.html",
|
||||||
{"request": request, "later_tasks_count": len(get_later_tasks(db))},
|
{"later_tasks_count": len(get_later_tasks(db))},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -287,9 +288,9 @@ async def start_task(
|
|||||||
promote_tasks(db)
|
promote_tasks(db)
|
||||||
|
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
|
request,
|
||||||
"calm/partials/now_next_later.html",
|
"calm/partials/now_next_later.html",
|
||||||
{
|
{
|
||||||
"request": request,
|
|
||||||
"now_task": get_now_task(db),
|
"now_task": get_now_task(db),
|
||||||
"next_task": get_next_task(db),
|
"next_task": get_next_task(db),
|
||||||
"later_tasks_count": len(get_later_tasks(db)),
|
"later_tasks_count": len(get_later_tasks(db)),
|
||||||
@@ -316,9 +317,9 @@ async def complete_task(
|
|||||||
promote_tasks(db)
|
promote_tasks(db)
|
||||||
|
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
|
request,
|
||||||
"calm/partials/now_next_later.html",
|
"calm/partials/now_next_later.html",
|
||||||
{
|
{
|
||||||
"request": request,
|
|
||||||
"now_task": get_now_task(db),
|
"now_task": get_now_task(db),
|
||||||
"next_task": get_next_task(db),
|
"next_task": get_next_task(db),
|
||||||
"later_tasks_count": len(get_later_tasks(db)),
|
"later_tasks_count": len(get_later_tasks(db)),
|
||||||
@@ -345,9 +346,9 @@ async def defer_task(
|
|||||||
promote_tasks(db)
|
promote_tasks(db)
|
||||||
|
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
|
request,
|
||||||
"calm/partials/now_next_later.html",
|
"calm/partials/now_next_later.html",
|
||||||
{
|
{
|
||||||
"request": request,
|
|
||||||
"now_task": get_now_task(db),
|
"now_task": get_now_task(db),
|
||||||
"next_task": get_next_task(db),
|
"next_task": get_next_task(db),
|
||||||
"later_tasks_count": len(get_later_tasks(db)),
|
"later_tasks_count": len(get_later_tasks(db)),
|
||||||
@@ -360,8 +361,7 @@ async def get_later_tasks_list(request: Request, db: Session = Depends(get_db)):
|
|||||||
"""Render the expandable list of LATER tasks."""
|
"""Render the expandable list of LATER tasks."""
|
||||||
later_tasks = get_later_tasks(db)
|
later_tasks = get_later_tasks(db)
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
"calm/partials/later_tasks_list.html",
|
request, "calm/partials/later_tasks_list.html", {"later_tasks": later_tasks}
|
||||||
{"request": request, "later_tasks": later_tasks},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -404,9 +404,9 @@ async def reorder_tasks(
|
|||||||
|
|
||||||
# Re-render the relevant parts of the UI
|
# Re-render the relevant parts of the UI
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
|
request,
|
||||||
"calm/partials/now_next_later.html",
|
"calm/partials/now_next_later.html",
|
||||||
{
|
{
|
||||||
"request": request,
|
|
||||||
"now_task": get_now_task(db),
|
"now_task": get_now_task(db),
|
||||||
"next_task": get_next_task(db),
|
"next_task": get_next_task(db),
|
||||||
"later_tasks_count": len(get_later_tasks(db)),
|
"later_tasks_count": len(get_later_tasks(db)),
|
||||||
|
|||||||
@@ -40,9 +40,9 @@ async def tools_page(request: Request):
|
|||||||
total_calls = 0
|
total_calls = 0
|
||||||
|
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
|
request,
|
||||||
"tools.html",
|
"tools.html",
|
||||||
{
|
{
|
||||||
"request": request,
|
|
||||||
"available_tools": available_tools,
|
"available_tools": available_tools,
|
||||||
"agent_tools": agent_tools,
|
"agent_tools": agent_tools,
|
||||||
"total_calls": total_calls,
|
"total_calls": total_calls,
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ from datetime import UTC, datetime
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from src.config import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -102,7 +104,7 @@ class EventBus:
|
|||||||
self._persistence_db_path.parent.mkdir(parents=True, exist_ok=True)
|
self._persistence_db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with closing(sqlite3.connect(str(self._persistence_db_path))) as conn:
|
with closing(sqlite3.connect(str(self._persistence_db_path))) as conn:
|
||||||
conn.execute("PRAGMA journal_mode=WAL")
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
conn.execute("PRAGMA busy_timeout=5000")
|
conn.execute(f"PRAGMA busy_timeout={settings.db_busy_timeout_ms}")
|
||||||
conn.executescript(_EVENTS_SCHEMA)
|
conn.executescript(_EVENTS_SCHEMA)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
@@ -114,7 +116,7 @@ class EventBus:
|
|||||||
return
|
return
|
||||||
with closing(sqlite3.connect(str(self._persistence_db_path))) as conn:
|
with closing(sqlite3.connect(str(self._persistence_db_path))) as conn:
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
conn.execute("PRAGMA busy_timeout=5000")
|
conn.execute(f"PRAGMA busy_timeout={settings.db_busy_timeout_ms}")
|
||||||
yield conn
|
yield conn
|
||||||
|
|
||||||
def _persist_event(self, event: Event) -> None:
|
def _persist_event(self, event: Event) -> None:
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ from datetime import UTC, datetime
|
|||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from src.config import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DB_PATH = Path("data/swarm.db")
|
DB_PATH = Path("data/swarm.db")
|
||||||
@@ -68,7 +70,7 @@ def _get_conn() -> Generator[sqlite3.Connection, None, None]:
|
|||||||
with closing(sqlite3.connect(str(DB_PATH))) as conn:
|
with closing(sqlite3.connect(str(DB_PATH))) as conn:
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
conn.execute("PRAGMA journal_mode=WAL")
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
conn.execute("PRAGMA busy_timeout=5000")
|
conn.execute(f"PRAGMA busy_timeout={settings.db_busy_timeout_ms}")
|
||||||
conn.execute("""
|
conn.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS custom_models (
|
CREATE TABLE IF NOT EXISTS custom_models (
|
||||||
name TEXT PRIMARY KEY,
|
name TEXT PRIMARY KEY,
|
||||||
|
|||||||
@@ -485,18 +485,26 @@ class CascadeRouter:
|
|||||||
def _quota_allows_cloud(self, provider: Provider) -> bool:
|
def _quota_allows_cloud(self, provider: Provider) -> bool:
|
||||||
"""Check quota before routing to a cloud provider.
|
"""Check quota before routing to a cloud provider.
|
||||||
|
|
||||||
Uses the metabolic protocol: cloud calls are gated by 5-hour quota.
|
Uses the metabolic protocol via select_model(): cloud calls are only
|
||||||
|
allowed when the quota monitor recommends a cloud model (BURST tier).
|
||||||
Returns True (allow cloud) if quota monitor is unavailable or returns None.
|
Returns True (allow cloud) if quota monitor is unavailable or returns None.
|
||||||
"""
|
"""
|
||||||
if _quota_monitor is None:
|
if _quota_monitor is None:
|
||||||
return True
|
return True
|
||||||
try:
|
try:
|
||||||
# Map provider type to task_value heuristic
|
suggested = _quota_monitor.select_model("high")
|
||||||
task_value = "high" # conservative default
|
# Cloud is allowed only when select_model recommends the cloud model
|
||||||
status = _quota_monitor.check()
|
allows = suggested == "claude-sonnet-4-6"
|
||||||
if status is None:
|
if not allows:
|
||||||
return True # No credentials — caller decides based on config
|
status = _quota_monitor.check()
|
||||||
return _quota_monitor.should_use_cloud(task_value)
|
tier = status.recommended_tier.value if status else "unknown"
|
||||||
|
logger.info(
|
||||||
|
"Metabolic protocol: %s tier — downshifting %s to local (%s)",
|
||||||
|
tier,
|
||||||
|
provider.name,
|
||||||
|
suggested,
|
||||||
|
)
|
||||||
|
return allows
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Quota check failed, allowing cloud: %s", exc)
|
logger.warning("Quota check failed, allowing cloud: %s", exc)
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -1,234 +0,0 @@
|
|||||||
"""Bannerlord world adapter — bridges GABS to the WorldInterface contract.
|
|
||||||
|
|
||||||
Allows the existing ``Heartbeat`` loop to drive the Bannerlord campaign
|
|
||||||
by treating it as just another game world. Wraps the async ``GabsClient``
|
|
||||||
for synchronous use (the ``Heartbeat`` calls ``observe()`` and ``act()``
|
|
||||||
synchronously).
|
|
||||||
|
|
||||||
Async callers should use ``CampaignLoop`` directly — it is more efficient
|
|
||||||
and handles the full M2 logic natively.
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
adapter = BannerlordWorldAdapter()
|
|
||||||
adapter.connect()
|
|
||||||
heartbeat = Heartbeat(world=adapter, interval=5.0)
|
|
||||||
await heartbeat.run_once()
|
|
||||||
adapter.disconnect()
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from infrastructure.world.interface import WorldInterface
|
|
||||||
from infrastructure.world.types import (
|
|
||||||
ActionResult,
|
|
||||||
ActionStatus,
|
|
||||||
CommandInput,
|
|
||||||
PerceptionOutput,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class BannerlordWorldAdapter(WorldInterface):
|
|
||||||
"""WorldInterface adapter for Bannerlord via GABS.
|
|
||||||
|
|
||||||
Wraps ``GabsClient`` and ``CampaignState`` to present the Bannerlord
|
|
||||||
campaign map as a ``WorldInterface``-compatible world.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
host:
|
|
||||||
Override GABS server host (defaults to ``settings.gabs_host``).
|
|
||||||
port:
|
|
||||||
Override GABS server port (defaults to ``settings.gabs_port``).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
host: str | None = None,
|
|
||||||
port: int | None = None,
|
|
||||||
) -> None:
|
|
||||||
from config import settings
|
|
||||||
|
|
||||||
self._host = host or settings.gabs_host
|
|
||||||
self._port = port or settings.gabs_port
|
|
||||||
self._connected = False
|
|
||||||
self._client = None
|
|
||||||
self._loop: asyncio.AbstractEventLoop | None = None
|
|
||||||
|
|
||||||
# -- lifecycle ---------------------------------------------------------
|
|
||||||
|
|
||||||
def connect(self) -> None:
|
|
||||||
"""Open the GABS TCP connection (synchronous wrapper)."""
|
|
||||||
from bannerlord.gabs_client import GabsClient
|
|
||||||
|
|
||||||
self._client = GabsClient(host=self._host, port=self._port)
|
|
||||||
try:
|
|
||||||
self._loop = asyncio.get_event_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
self._loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(self._loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
self._loop.run_until_complete(self._client.connect())
|
|
||||||
self._connected = True
|
|
||||||
logger.info("BannerlordWorldAdapter connected to GABS")
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
logger.warning("BannerlordWorldAdapter: GABS connect failed: %s", exc)
|
|
||||||
self._connected = False
|
|
||||||
|
|
||||||
def disconnect(self) -> None:
|
|
||||||
"""Close the GABS TCP connection (synchronous wrapper)."""
|
|
||||||
if self._client is not None and self._loop is not None:
|
|
||||||
try:
|
|
||||||
self._loop.run_until_complete(self._client.disconnect())
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
logger.debug("BannerlordWorldAdapter disconnect error: %s", exc)
|
|
||||||
self._connected = False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_connected(self) -> bool:
|
|
||||||
return self._connected
|
|
||||||
|
|
||||||
# -- core contract -----------------------------------------------------
|
|
||||||
|
|
||||||
def observe(self) -> PerceptionOutput:
|
|
||||||
"""Poll GABS for current game state and return structured perception."""
|
|
||||||
from bannerlord.campaign_state import parse_campaign_state
|
|
||||||
|
|
||||||
if not self._connected or self._client is None or self._loop is None:
|
|
||||||
return PerceptionOutput(
|
|
||||||
location="disconnected",
|
|
||||||
entities=[],
|
|
||||||
events=["gabs_disconnected"],
|
|
||||||
raw={"error": "GABS not connected"},
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
raw = self._loop.run_until_complete(self._client.get_game_state())
|
|
||||||
state = parse_campaign_state(raw)
|
|
||||||
|
|
||||||
# Build entities list from settlements and nearby parties
|
|
||||||
entities: list[str] = []
|
|
||||||
for s in state.settlements[:5]:
|
|
||||||
entities.append(f"settlement:{s.name}")
|
|
||||||
for p in state.nearby_parties[:3]:
|
|
||||||
prefix = "hostile" if p.is_hostile else "friendly"
|
|
||||||
entities.append(f"{prefix}_party:{p.name}")
|
|
||||||
|
|
||||||
# Build events list
|
|
||||||
events: list[str] = []
|
|
||||||
if state.party.food_days < 2.0:
|
|
||||||
events.append("low_food")
|
|
||||||
if state.party.morale < 40:
|
|
||||||
events.append("low_morale")
|
|
||||||
if state.hostile_bandits_nearby():
|
|
||||||
events.append("bandits_nearby")
|
|
||||||
if state.m2_complete:
|
|
||||||
events.append("m2_complete")
|
|
||||||
|
|
||||||
location = state.party.current_settlement or "campaign_map"
|
|
||||||
|
|
||||||
return PerceptionOutput(
|
|
||||||
location=location,
|
|
||||||
entities=entities,
|
|
||||||
events=events,
|
|
||||||
raw=raw,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
logger.warning("BannerlordWorldAdapter.observe() failed: %s", exc)
|
|
||||||
return PerceptionOutput(
|
|
||||||
location="unknown",
|
|
||||||
entities=[],
|
|
||||||
events=[f"observe_error:{exc}"],
|
|
||||||
raw={"error": str(exc)},
|
|
||||||
)
|
|
||||||
|
|
||||||
def act(self, command: CommandInput) -> ActionResult:
|
|
||||||
"""Dispatch a campaign command to GABS.
|
|
||||||
|
|
||||||
Recognized ``command.action`` values:
|
|
||||||
- ``"move"`` → party/move_to_settlement (target = settlement_id)
|
|
||||||
- ``"trade"`` → inventory/buy_item (target = item_id)
|
|
||||||
- ``"recruit"`` → party/recruit_all
|
|
||||||
- ``"engage"`` → party/engage_party (target = party_id)
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
command:
|
|
||||||
WorldInterface ``CommandInput`` with action, target, parameters.
|
|
||||||
"""
|
|
||||||
if not self._connected or self._client is None or self._loop is None:
|
|
||||||
return ActionResult(
|
|
||||||
status=ActionStatus.FAILURE,
|
|
||||||
message="GABS not connected",
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return self._loop.run_until_complete(self._async_act(command))
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
logger.warning("BannerlordWorldAdapter.act() failed: %s", exc)
|
|
||||||
return ActionResult(
|
|
||||||
status=ActionStatus.FAILURE,
|
|
||||||
message=f"act failed: {exc}",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _async_act(self, command: CommandInput) -> ActionResult:
|
|
||||||
"""Async implementation of act()."""
|
|
||||||
from bannerlord.campaign_actions import (
|
|
||||||
buy_item,
|
|
||||||
engage_party,
|
|
||||||
move_to_settlement,
|
|
||||||
recruit_all,
|
|
||||||
)
|
|
||||||
|
|
||||||
action = command.action.lower()
|
|
||||||
params = command.parameters
|
|
||||||
|
|
||||||
if action == "move":
|
|
||||||
settlement_id = command.target or params.get("settlement_id", "")
|
|
||||||
return await move_to_settlement(
|
|
||||||
self._client,
|
|
||||||
settlement_id,
|
|
||||||
settlement_name=params.get("settlement_name", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
elif action == "trade":
|
|
||||||
item_id = command.target or params.get("item_id", "")
|
|
||||||
quantity = int(params.get("quantity", 1))
|
|
||||||
return await buy_item(
|
|
||||||
self._client,
|
|
||||||
item_id,
|
|
||||||
quantity,
|
|
||||||
settlement_id=params.get("settlement_id", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
elif action == "recruit":
|
|
||||||
return await recruit_all(
|
|
||||||
self._client,
|
|
||||||
settlement_id=params.get("settlement_id", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
elif action == "engage":
|
|
||||||
party_id = command.target or params.get("party_id", "")
|
|
||||||
return await engage_party(
|
|
||||||
self._client,
|
|
||||||
party_id,
|
|
||||||
party_name=params.get("party_name", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
return ActionResult(
|
|
||||||
status=ActionStatus.NOOP,
|
|
||||||
message=f"Unknown action: {command.action}",
|
|
||||||
)
|
|
||||||
|
|
||||||
def speak(self, message: str, target: str | None = None) -> None:
|
|
||||||
"""Log the message — GABS has no chat mechanism in M2."""
|
|
||||||
logger.info("BannerlordWorldAdapter.speak: %r (target=%r)", message, target)
|
|
||||||
@@ -22,6 +22,8 @@ from dataclasses import dataclass
|
|||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from src.config import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DB_PATH = Path("data/spark.db")
|
DB_PATH = Path("data/spark.db")
|
||||||
@@ -47,7 +49,7 @@ def _get_conn() -> Generator[sqlite3.Connection, None, None]:
|
|||||||
with closing(sqlite3.connect(str(DB_PATH))) as conn:
|
with closing(sqlite3.connect(str(DB_PATH))) as conn:
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
conn.execute("PRAGMA journal_mode=WAL")
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
conn.execute("PRAGMA busy_timeout=5000")
|
conn.execute(f"PRAGMA busy_timeout={settings.db_busy_timeout_ms}")
|
||||||
conn.execute("""
|
conn.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS spark_predictions (
|
CREATE TABLE IF NOT EXISTS spark_predictions (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ from dataclasses import dataclass
|
|||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from src.config import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DB_PATH = Path("data/spark.db")
|
DB_PATH = Path("data/spark.db")
|
||||||
@@ -63,7 +65,7 @@ def _get_conn() -> Generator[sqlite3.Connection, None, None]:
|
|||||||
with closing(sqlite3.connect(str(DB_PATH))) as conn:
|
with closing(sqlite3.connect(str(DB_PATH))) as conn:
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
conn.execute("PRAGMA journal_mode=WAL")
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
conn.execute("PRAGMA busy_timeout=5000")
|
conn.execute(f"PRAGMA busy_timeout={settings.db_busy_timeout_ms}")
|
||||||
conn.execute("""
|
conn.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS spark_events (
|
CREATE TABLE IF NOT EXISTS spark_events (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ from dataclasses import dataclass
|
|||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from config import settings
|
from config import settings
|
||||||
|
from timmy.research_tools import get_llm_client, google_web_search
|
||||||
from timmy.research_triage import triage_research_report
|
from timmy.research_triage import triage_research_report
|
||||||
from timmy.research_tools import google_web_search, get_llm_client
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from config import settings
|
|
||||||
from serpapi import GoogleSearch
|
from serpapi import GoogleSearch
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -1,102 +0,0 @@
|
|||||||
"""Unit tests for bannerlord.campaign_actions."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from bannerlord.campaign_actions import (
|
|
||||||
GabsTool,
|
|
||||||
buy_item,
|
|
||||||
engage_party,
|
|
||||||
move_to_settlement,
|
|
||||||
recruit_all,
|
|
||||||
)
|
|
||||||
from infrastructure.world.types import ActionStatus
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_client(return_value=None, raise_exc=None):
|
|
||||||
"""Build a mock GabsClient."""
|
|
||||||
client = MagicMock()
|
|
||||||
if raise_exc is not None:
|
|
||||||
client.call = AsyncMock(side_effect=raise_exc)
|
|
||||||
else:
|
|
||||||
client.call = AsyncMock(return_value=return_value)
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
class TestMoveToSettlement:
|
|
||||||
async def test_success(self):
|
|
||||||
client = _mock_client({"eta_days": 2})
|
|
||||||
result = await move_to_settlement(client, "town_A1", settlement_name="Marunath")
|
|
||||||
assert result.status == ActionStatus.SUCCESS
|
|
||||||
client.call.assert_called_once_with(
|
|
||||||
GabsTool.MOVE_TO_SETTLEMENT, {"settlement_id": "town_A1"}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def test_failure_on_gabs_error(self):
|
|
||||||
client = _mock_client(raise_exc=RuntimeError("GABS timeout"))
|
|
||||||
result = await move_to_settlement(client, "town_A1")
|
|
||||||
assert result.status == ActionStatus.FAILURE
|
|
||||||
assert "GABS timeout" in result.message
|
|
||||||
|
|
||||||
async def test_uses_settlement_id_as_label_when_no_name(self):
|
|
||||||
client = _mock_client({})
|
|
||||||
result = await move_to_settlement(client, "town_B2")
|
|
||||||
assert result.status == ActionStatus.SUCCESS
|
|
||||||
assert "town_B2" in result.message
|
|
||||||
|
|
||||||
|
|
||||||
class TestBuyItem:
|
|
||||||
async def test_success(self):
|
|
||||||
client = _mock_client({"cost": 100})
|
|
||||||
result = await buy_item(client, "grain", 5)
|
|
||||||
assert result.status == ActionStatus.SUCCESS
|
|
||||||
assert "grain" in result.message
|
|
||||||
client.call.assert_called_once_with(
|
|
||||||
GabsTool.BUY_ITEM, {"item_id": "grain", "quantity": 5}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def test_includes_settlement_id_when_given(self):
|
|
||||||
client = _mock_client({})
|
|
||||||
await buy_item(client, "iron", 2, settlement_id="town_A1")
|
|
||||||
call_params = client.call.call_args[0][1]
|
|
||||||
assert call_params["settlement_id"] == "town_A1"
|
|
||||||
|
|
||||||
async def test_failure_logged_gracefully(self):
|
|
||||||
client = _mock_client(raise_exc=Exception("inventory full"))
|
|
||||||
result = await buy_item(client, "wool", 10)
|
|
||||||
assert result.status == ActionStatus.FAILURE
|
|
||||||
|
|
||||||
|
|
||||||
class TestRecruitAll:
|
|
||||||
async def test_success(self):
|
|
||||||
client = _mock_client({"recruited": 15})
|
|
||||||
result = await recruit_all(client)
|
|
||||||
assert result.status == ActionStatus.SUCCESS
|
|
||||||
assert "15" in result.message
|
|
||||||
|
|
||||||
async def test_success_with_settlement(self):
|
|
||||||
client = _mock_client({"recruited": 8})
|
|
||||||
result = await recruit_all(client, settlement_id="town_A1")
|
|
||||||
call_params = client.call.call_args[0][1]
|
|
||||||
assert call_params["settlement_id"] == "town_A1"
|
|
||||||
|
|
||||||
async def test_failure_graceful(self):
|
|
||||||
client = _mock_client(raise_exc=RuntimeError("no recruits"))
|
|
||||||
result = await recruit_all(client)
|
|
||||||
assert result.status == ActionStatus.FAILURE
|
|
||||||
|
|
||||||
|
|
||||||
class TestEngageParty:
|
|
||||||
async def test_success(self):
|
|
||||||
client = _mock_client({"outcome": "victory", "loot": 200})
|
|
||||||
result = await engage_party(client, "bandit_1", party_name="Forest Bandits")
|
|
||||||
assert result.status == ActionStatus.SUCCESS
|
|
||||||
assert "victory" in result.message
|
|
||||||
|
|
||||||
async def test_failure_graceful(self):
|
|
||||||
client = _mock_client(raise_exc=RuntimeError("party not found"))
|
|
||||||
result = await engage_party(client, "bandit_1")
|
|
||||||
assert result.status == ActionStatus.FAILURE
|
|
||||||
@@ -1,200 +0,0 @@
|
|||||||
"""Unit tests for bannerlord.campaign_loop."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from bannerlord.campaign_loop import CampaignLoop, TickResult
|
|
||||||
from bannerlord.decision import CampaignDecision, M2Action
|
|
||||||
from infrastructure.world.types import ActionResult, ActionStatus
|
|
||||||
|
|
||||||
|
|
||||||
def _make_game_state(*, troops: int = 30, gold: int = 2000) -> dict:
|
|
||||||
return {
|
|
||||||
"tick": 0,
|
|
||||||
"party": {
|
|
||||||
"size": troops,
|
|
||||||
"wounded": 0,
|
|
||||||
"food_days": 5.0,
|
|
||||||
"morale": 80.0,
|
|
||||||
"current_settlement": "town_A1",
|
|
||||||
},
|
|
||||||
"economy": {"gold": gold, "daily_income": 200, "daily_expenses": 150},
|
|
||||||
"nearby_parties": [],
|
|
||||||
"settlements": [
|
|
||||||
{
|
|
||||||
"id": "town_A1",
|
|
||||||
"name": "Marunath",
|
|
||||||
"faction": "aserai",
|
|
||||||
"is_friendly": True,
|
|
||||||
"distance": 0.0,
|
|
||||||
"has_recruits": True,
|
|
||||||
"has_trade_goods": False,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TestCampaignLoopDispatch:
|
|
||||||
"""Tests for the internal _dispatch() routing."""
|
|
||||||
|
|
||||||
def _loop(self) -> CampaignLoop:
|
|
||||||
return CampaignLoop(tick_seconds=0.0, max_ticks=1)
|
|
||||||
|
|
||||||
async def test_dispatch_move(self):
|
|
||||||
loop = self._loop()
|
|
||||||
client = MagicMock()
|
|
||||||
decision = CampaignDecision(
|
|
||||||
action=M2Action.MOVE,
|
|
||||||
settlement_id="town_A1",
|
|
||||||
settlement_name="Marunath",
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("bannerlord.campaign_loop.move_to_settlement", new_callable=AsyncMock) as mock_move:
|
|
||||||
mock_move.return_value = ActionResult(status=ActionStatus.SUCCESS, message="ok")
|
|
||||||
await loop._dispatch(decision, client)
|
|
||||||
mock_move.assert_called_once_with(client, "town_A1", settlement_name="Marunath")
|
|
||||||
|
|
||||||
async def test_dispatch_recruit(self):
|
|
||||||
loop = self._loop()
|
|
||||||
client = MagicMock()
|
|
||||||
decision = CampaignDecision(
|
|
||||||
action=M2Action.RECRUIT,
|
|
||||||
settlement_id="town_A1",
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("bannerlord.campaign_loop.recruit_all", new_callable=AsyncMock) as mock_recruit:
|
|
||||||
mock_recruit.return_value = ActionResult(status=ActionStatus.SUCCESS, message="15 recruited")
|
|
||||||
await loop._dispatch(decision, client)
|
|
||||||
mock_recruit.assert_called_once()
|
|
||||||
|
|
||||||
async def test_dispatch_engage(self):
|
|
||||||
loop = self._loop()
|
|
||||||
client = MagicMock()
|
|
||||||
decision = CampaignDecision(
|
|
||||||
action=M2Action.ENGAGE,
|
|
||||||
party_id="bandit_1",
|
|
||||||
party_name="Forest Bandits",
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("bannerlord.campaign_loop.engage_party", new_callable=AsyncMock) as mock_engage:
|
|
||||||
mock_engage.return_value = ActionResult(status=ActionStatus.SUCCESS, message="victory")
|
|
||||||
await loop._dispatch(decision, client)
|
|
||||||
mock_engage.assert_called_once_with(client, "bandit_1", party_name="Forest Bandits")
|
|
||||||
|
|
||||||
async def test_dispatch_trade(self):
|
|
||||||
loop = self._loop()
|
|
||||||
client = MagicMock()
|
|
||||||
decision = CampaignDecision(
|
|
||||||
action=M2Action.TRADE,
|
|
||||||
item_id="grain",
|
|
||||||
quantity=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("bannerlord.campaign_loop.buy_item", new_callable=AsyncMock) as mock_buy:
|
|
||||||
mock_buy.return_value = ActionResult(status=ActionStatus.SUCCESS, message="bought")
|
|
||||||
await loop._dispatch(decision, client)
|
|
||||||
mock_buy.assert_called_once_with(client, "grain", 5, settlement_id="")
|
|
||||||
|
|
||||||
async def test_dispatch_wait_returns_noop(self):
|
|
||||||
loop = self._loop()
|
|
||||||
client = MagicMock()
|
|
||||||
decision = CampaignDecision(action=M2Action.WAIT, reasoning="low food")
|
|
||||||
result = await loop._dispatch(decision, client)
|
|
||||||
assert result.status == ActionStatus.NOOP
|
|
||||||
|
|
||||||
async def test_dispatch_move_missing_settlement_id(self):
|
|
||||||
loop = self._loop()
|
|
||||||
client = MagicMock()
|
|
||||||
decision = CampaignDecision(action=M2Action.MOVE, settlement_id="")
|
|
||||||
result = await loop._dispatch(decision, client)
|
|
||||||
assert result.status == ActionStatus.FAILURE
|
|
||||||
|
|
||||||
async def test_dispatch_engage_missing_party_id(self):
|
|
||||||
loop = self._loop()
|
|
||||||
client = MagicMock()
|
|
||||||
decision = CampaignDecision(action=M2Action.ENGAGE, party_id="")
|
|
||||||
result = await loop._dispatch(decision, client)
|
|
||||||
assert result.status == ActionStatus.FAILURE
|
|
||||||
|
|
||||||
|
|
||||||
class TestCampaignLoopRun:
|
|
||||||
"""Integration-level tests for the full run() loop (mocked GABS)."""
|
|
||||||
|
|
||||||
async def test_run_stops_at_max_ticks(self):
|
|
||||||
"""Loop respects max_ticks and returns correct number of results."""
|
|
||||||
game_state = _make_game_state()
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("bannerlord.campaign_loop.GabsClient") as MockClient,
|
|
||||||
patch("bannerlord.campaign_loop.decide", new_callable=AsyncMock) as mock_decide,
|
|
||||||
patch("bannerlord.campaign_loop.move_to_settlement", new_callable=AsyncMock) as mock_move,
|
|
||||||
):
|
|
||||||
# Setup fake client
|
|
||||||
fake_client = AsyncMock()
|
|
||||||
fake_client.get_game_state = AsyncMock(return_value=game_state)
|
|
||||||
fake_client.connect = AsyncMock()
|
|
||||||
fake_client.disconnect = AsyncMock()
|
|
||||||
MockClient.return_value = fake_client
|
|
||||||
|
|
||||||
mock_decide.return_value = CampaignDecision(
|
|
||||||
action=M2Action.MOVE,
|
|
||||||
settlement_id="town_B1",
|
|
||||||
settlement_name="Epicrotea",
|
|
||||||
reasoning="moving",
|
|
||||||
)
|
|
||||||
mock_move.return_value = ActionResult(status=ActionStatus.SUCCESS, message="ok")
|
|
||||||
|
|
||||||
loop = CampaignLoop(tick_seconds=0.0, max_ticks=3)
|
|
||||||
results = await loop.run()
|
|
||||||
|
|
||||||
assert len(results) == 3
|
|
||||||
assert all(isinstance(r, TickResult) for r in results)
|
|
||||||
|
|
||||||
async def test_run_stops_when_m2_complete(self):
|
|
||||||
"""Loop exits early when M2 conditions are met."""
|
|
||||||
# State with M2 already complete
|
|
||||||
game_state = _make_game_state(troops=100, gold=10000)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("bannerlord.campaign_loop.GabsClient") as MockClient,
|
|
||||||
patch("bannerlord.campaign_loop.decide", new_callable=AsyncMock) as mock_decide,
|
|
||||||
):
|
|
||||||
fake_client = AsyncMock()
|
|
||||||
fake_client.get_game_state = AsyncMock(return_value=game_state)
|
|
||||||
fake_client.connect = AsyncMock()
|
|
||||||
fake_client.disconnect = AsyncMock()
|
|
||||||
MockClient.return_value = fake_client
|
|
||||||
|
|
||||||
mock_decide.return_value = CampaignDecision(
|
|
||||||
action=M2Action.WAIT,
|
|
||||||
reasoning="done",
|
|
||||||
)
|
|
||||||
|
|
||||||
loop = CampaignLoop(tick_seconds=0.0, max_ticks=10)
|
|
||||||
results = await loop.run()
|
|
||||||
|
|
||||||
# Should exit after first tick (m2_complete = True)
|
|
||||||
assert len(results) == 1
|
|
||||||
assert results[0].m2_complete is True
|
|
||||||
|
|
||||||
async def test_run_aborts_on_connect_failure(self):
|
|
||||||
"""Loop returns empty history if GABS cannot be reached."""
|
|
||||||
with patch("bannerlord.campaign_loop.GabsClient") as MockClient:
|
|
||||||
fake_client = AsyncMock()
|
|
||||||
fake_client.connect = AsyncMock(side_effect=OSError("refused"))
|
|
||||||
fake_client.disconnect = AsyncMock()
|
|
||||||
MockClient.return_value = fake_client
|
|
||||||
|
|
||||||
loop = CampaignLoop(tick_seconds=0.0, max_ticks=5)
|
|
||||||
results = await loop.run()
|
|
||||||
|
|
||||||
assert results == []
|
|
||||||
|
|
||||||
def test_stop_sets_running_false(self):
|
|
||||||
loop = CampaignLoop()
|
|
||||||
loop._running = True
|
|
||||||
loop.stop()
|
|
||||||
assert not loop.is_running
|
|
||||||
@@ -1,150 +0,0 @@
|
|||||||
"""Unit tests for bannerlord.campaign_state."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from bannerlord.campaign_state import (
|
|
||||||
M2_GOLD_GOAL,
|
|
||||||
M2_TROOP_GOAL,
|
|
||||||
CampaignState,
|
|
||||||
NearbyParty,
|
|
||||||
Settlement,
|
|
||||||
parse_campaign_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestParseCampaignState:
|
|
||||||
def test_empty_dict_returns_defaults(self):
|
|
||||||
state = parse_campaign_state({})
|
|
||||||
assert state.party.party_size == 0
|
|
||||||
assert state.economy.gold == 0
|
|
||||||
assert state.nearby_parties == []
|
|
||||||
assert state.settlements == []
|
|
||||||
|
|
||||||
def test_full_payload_parsed(self):
|
|
||||||
raw = {
|
|
||||||
"tick": 5,
|
|
||||||
"party": {
|
|
||||||
"size": 30,
|
|
||||||
"wounded": 2,
|
|
||||||
"prisoners": 1,
|
|
||||||
"food_days": 3.5,
|
|
||||||
"morale": 75.0,
|
|
||||||
"current_settlement": "town_A1",
|
|
||||||
"speed": 5.2,
|
|
||||||
},
|
|
||||||
"economy": {
|
|
||||||
"gold": 4500,
|
|
||||||
"daily_income": 200,
|
|
||||||
"daily_expenses": 150,
|
|
||||||
},
|
|
||||||
"nearby_parties": [
|
|
||||||
{
|
|
||||||
"id": "bandit_1",
|
|
||||||
"name": "Forest Bandits",
|
|
||||||
"faction": "bandit",
|
|
||||||
"is_hostile": True,
|
|
||||||
"troop_count": 10,
|
|
||||||
"distance": 3.0,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"settlements": [
|
|
||||||
{
|
|
||||||
"id": "town_A1",
|
|
||||||
"name": "Marunath",
|
|
||||||
"faction": "aserai",
|
|
||||||
"is_friendly": True,
|
|
||||||
"distance": 0.0,
|
|
||||||
"has_recruits": True,
|
|
||||||
"has_trade_goods": False,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
state = parse_campaign_state(raw)
|
|
||||||
|
|
||||||
assert state.tick == 5
|
|
||||||
assert state.party.party_size == 30
|
|
||||||
assert state.party.wounded == 2
|
|
||||||
assert state.economy.gold == 4500
|
|
||||||
assert state.economy.net_income == 50
|
|
||||||
assert len(state.nearby_parties) == 1
|
|
||||||
assert state.nearby_parties[0].name == "Forest Bandits"
|
|
||||||
assert len(state.settlements) == 1
|
|
||||||
assert state.settlements[0].name == "Marunath"
|
|
||||||
|
|
||||||
def test_malformed_entries_skipped(self):
|
|
||||||
raw = {
|
|
||||||
"nearby_parties": [{"id": "ok", "name": "Good", "faction": "bandit",
|
|
||||||
"is_hostile": True, "troop_count": 5, "distance": 2.0},
|
|
||||||
{"bad": "data"}],
|
|
||||||
"settlements": [None, "not_a_dict"],
|
|
||||||
}
|
|
||||||
state = parse_campaign_state(raw)
|
|
||||||
assert len(state.nearby_parties) == 1
|
|
||||||
assert state.settlements == []
|
|
||||||
|
|
||||||
|
|
||||||
class TestCampaignStateProperties:
|
|
||||||
def _make_state(self, *, troops: int, gold: int) -> CampaignState:
|
|
||||||
state = CampaignState()
|
|
||||||
state.party.party_size = troops
|
|
||||||
state.economy.gold = gold
|
|
||||||
return state
|
|
||||||
|
|
||||||
def test_m2_not_complete_by_default(self):
|
|
||||||
state = self._make_state(troops=20, gold=0)
|
|
||||||
assert not state.m2_complete
|
|
||||||
|
|
||||||
def test_m2_complete_when_both_goals_met(self):
|
|
||||||
state = self._make_state(troops=M2_TROOP_GOAL, gold=M2_GOLD_GOAL)
|
|
||||||
assert state.m2_complete
|
|
||||||
|
|
||||||
def test_m2_not_complete_if_only_troops_met(self):
|
|
||||||
state = self._make_state(troops=M2_TROOP_GOAL, gold=M2_GOLD_GOAL - 1)
|
|
||||||
assert not state.m2_complete
|
|
||||||
|
|
||||||
def test_m2_not_complete_if_only_gold_met(self):
|
|
||||||
state = self._make_state(troops=M2_TROOP_GOAL - 1, gold=M2_GOLD_GOAL)
|
|
||||||
assert not state.m2_complete
|
|
||||||
|
|
||||||
def test_troops_progress_string(self):
|
|
||||||
state = self._make_state(troops=45, gold=0)
|
|
||||||
assert state.troops_progress == f"45/{M2_TROOP_GOAL}"
|
|
||||||
|
|
||||||
def test_gold_progress_string(self):
|
|
||||||
state = self._make_state(troops=0, gold=3000)
|
|
||||||
assert "3,000" in state.gold_progress
|
|
||||||
|
|
||||||
def test_hostile_bandits_nearby_filter(self):
|
|
||||||
state = CampaignState()
|
|
||||||
state.nearby_parties = [
|
|
||||||
NearbyParty("b1", "Bandits", "bandit", True, 10, 2.0),
|
|
||||||
NearbyParty("l1", "Lord", "empire", False, 50, 1.0),
|
|
||||||
NearbyParty("b2", "Far Bandits", "bandit", True, 5, 10.0),
|
|
||||||
]
|
|
||||||
nearby = state.hostile_bandits_nearby(max_distance=5.0)
|
|
||||||
assert len(nearby) == 1
|
|
||||||
assert nearby[0].party_id == "b1"
|
|
||||||
|
|
||||||
def test_nearest_settlement_returns_closest(self):
|
|
||||||
state = CampaignState()
|
|
||||||
state.settlements = [
|
|
||||||
Settlement("s1", "Far Town", "empire", True, 10.0),
|
|
||||||
Settlement("s2", "Near Town", "empire", True, 2.0),
|
|
||||||
]
|
|
||||||
nearest = state.nearest_settlement()
|
|
||||||
assert nearest.settlement_id == "s2"
|
|
||||||
|
|
||||||
def test_nearest_recruit_settlement(self):
|
|
||||||
state = CampaignState()
|
|
||||||
state.settlements = [
|
|
||||||
Settlement("s1", "Town A", "empire", True, 5.0, has_recruits=False),
|
|
||||||
Settlement("s2", "Town B", "empire", True, 8.0, has_recruits=True),
|
|
||||||
]
|
|
||||||
recruit = state.nearest_recruit_settlement()
|
|
||||||
assert recruit.settlement_id == "s2"
|
|
||||||
|
|
||||||
def test_nearest_settlement_none_when_empty(self):
|
|
||||||
state = CampaignState()
|
|
||||||
assert state.nearest_settlement() is None
|
|
||||||
@@ -1,154 +0,0 @@
|
|||||||
"""Unit tests for bannerlord.decision."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from bannerlord.campaign_state import (
|
|
||||||
CampaignState,
|
|
||||||
EconomyState,
|
|
||||||
NearbyParty,
|
|
||||||
PartyState,
|
|
||||||
Settlement,
|
|
||||||
)
|
|
||||||
from bannerlord.decision import (
|
|
||||||
M2Action,
|
|
||||||
CampaignDecision,
|
|
||||||
build_decision_prompt,
|
|
||||||
parse_decision,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_state(
|
|
||||||
*,
|
|
||||||
troops: int = 30,
|
|
||||||
gold: int = 2000,
|
|
||||||
food_days: float = 5.0,
|
|
||||||
morale: float = 80.0,
|
|
||||||
settlements: list | None = None,
|
|
||||||
nearby_parties: list | None = None,
|
|
||||||
) -> CampaignState:
|
|
||||||
state = CampaignState()
|
|
||||||
state.party = PartyState(
|
|
||||||
party_size=troops,
|
|
||||||
food_days=food_days,
|
|
||||||
morale=morale,
|
|
||||||
)
|
|
||||||
state.economy = EconomyState(gold=gold, daily_income=200, daily_expenses=150)
|
|
||||||
state.settlements = settlements or []
|
|
||||||
state.nearby_parties = nearby_parties or []
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
class TestBuildDecisionPrompt:
|
|
||||||
def test_returns_two_messages(self):
|
|
||||||
state = _make_state()
|
|
||||||
messages = build_decision_prompt(state)
|
|
||||||
assert len(messages) == 2
|
|
||||||
assert messages[0]["role"] == "system"
|
|
||||||
assert messages[1]["role"] == "user"
|
|
||||||
|
|
||||||
def test_user_message_includes_party_info(self):
|
|
||||||
state = _make_state(troops=45, gold=3000)
|
|
||||||
messages = build_decision_prompt(state)
|
|
||||||
user_content = messages[1]["content"]
|
|
||||||
assert "45" in user_content
|
|
||||||
assert "3,000" in user_content
|
|
||||||
|
|
||||||
def test_bandits_appear_in_prompt_when_nearby(self):
|
|
||||||
state = _make_state(
|
|
||||||
nearby_parties=[NearbyParty("b1", "Forest Bandits", "bandit", True, 10, 2.0)]
|
|
||||||
)
|
|
||||||
messages = build_decision_prompt(state)
|
|
||||||
user_content = messages[1]["content"]
|
|
||||||
assert "Forest Bandits" in user_content
|
|
||||||
|
|
||||||
def test_settlements_appear_in_prompt(self):
|
|
||||||
state = _make_state(
|
|
||||||
settlements=[Settlement("s1", "Marunath", "aserai", True, 3.0, has_recruits=True)]
|
|
||||||
)
|
|
||||||
messages = build_decision_prompt(state)
|
|
||||||
user_content = messages[1]["content"]
|
|
||||||
assert "Marunath" in user_content
|
|
||||||
|
|
||||||
def test_system_prompt_contains_action_vocabulary(self):
|
|
||||||
state = _make_state()
|
|
||||||
messages = build_decision_prompt(state)
|
|
||||||
system = messages[0]["content"]
|
|
||||||
for action in ("MOVE", "TRADE", "RECRUIT", "ENGAGE", "WAIT"):
|
|
||||||
assert action in system
|
|
||||||
|
|
||||||
|
|
||||||
class TestParseDecision:
|
|
||||||
def test_valid_move_decision(self):
|
|
||||||
raw = json.dumps({
|
|
||||||
"action": "MOVE",
|
|
||||||
"settlement_id": "town_A1",
|
|
||||||
"settlement_name": "Marunath",
|
|
||||||
"item_id": "",
|
|
||||||
"quantity": 1,
|
|
||||||
"party_id": "",
|
|
||||||
"party_name": "",
|
|
||||||
"reasoning": "Moving to recruit troops",
|
|
||||||
})
|
|
||||||
decision = parse_decision(raw)
|
|
||||||
assert decision.action == M2Action.MOVE
|
|
||||||
assert decision.settlement_id == "town_A1"
|
|
||||||
assert decision.settlement_name == "Marunath"
|
|
||||||
|
|
||||||
def test_valid_recruit_decision(self):
|
|
||||||
raw = json.dumps({
|
|
||||||
"action": "RECRUIT",
|
|
||||||
"settlement_id": "town_A1",
|
|
||||||
"settlement_name": "Marunath",
|
|
||||||
"item_id": "",
|
|
||||||
"quantity": 1,
|
|
||||||
"party_id": "",
|
|
||||||
"party_name": "",
|
|
||||||
"reasoning": "Has recruits available",
|
|
||||||
})
|
|
||||||
decision = parse_decision(raw)
|
|
||||||
assert decision.action == M2Action.RECRUIT
|
|
||||||
|
|
||||||
def test_valid_engage_decision(self):
|
|
||||||
raw = json.dumps({
|
|
||||||
"action": "ENGAGE",
|
|
||||||
"settlement_id": "",
|
|
||||||
"settlement_name": "",
|
|
||||||
"item_id": "",
|
|
||||||
"quantity": 1,
|
|
||||||
"party_id": "bandit_1",
|
|
||||||
"party_name": "Forest Bandits",
|
|
||||||
"reasoning": "Weak bandits — easy XP",
|
|
||||||
})
|
|
||||||
decision = parse_decision(raw)
|
|
||||||
assert decision.action == M2Action.ENGAGE
|
|
||||||
assert decision.party_id == "bandit_1"
|
|
||||||
|
|
||||||
def test_wait_on_invalid_json(self):
|
|
||||||
decision = parse_decision("not json at all")
|
|
||||||
assert decision.action == M2Action.WAIT
|
|
||||||
|
|
||||||
def test_wait_on_unknown_action(self):
|
|
||||||
raw = json.dumps({"action": "TELEPORT", "reasoning": "hack"})
|
|
||||||
decision = parse_decision(raw)
|
|
||||||
assert decision.action == M2Action.WAIT
|
|
||||||
|
|
||||||
def test_strips_markdown_fences(self):
|
|
||||||
raw = '```json\n{"action": "WAIT", "reasoning": "low food"}\n```'
|
|
||||||
decision = parse_decision(raw)
|
|
||||||
assert decision.action == M2Action.WAIT
|
|
||||||
|
|
||||||
def test_quantity_minimum_one(self):
|
|
||||||
raw = json.dumps({"action": "TRADE", "item_id": "grain", "quantity": -5, "reasoning": "x"})
|
|
||||||
decision = parse_decision(raw)
|
|
||||||
assert decision.quantity == 1
|
|
||||||
|
|
||||||
def test_missing_optional_fields_default_to_empty(self):
|
|
||||||
raw = json.dumps({"action": "WAIT", "reasoning": "resting"})
|
|
||||||
decision = parse_decision(raw)
|
|
||||||
assert decision.settlement_id == ""
|
|
||||||
assert decision.party_id == ""
|
|
||||||
assert decision.item_id == ""
|
|
||||||
@@ -1,120 +0,0 @@
|
|||||||
"""Unit tests for bannerlord.gabs_client."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from bannerlord.gabs_client import GabsClient, GabsError
|
|
||||||
|
|
||||||
|
|
||||||
class TestGabsClientCall:
|
|
||||||
"""Tests for GabsClient.call() using mock StreamReader/Writer."""
|
|
||||||
|
|
||||||
def _make_client(self, response: dict) -> GabsClient:
|
|
||||||
"""Return a pre-connected GabsClient with mocked I/O."""
|
|
||||||
client = GabsClient(host="localhost", port=4825, timeout=5.0)
|
|
||||||
client._connected = True
|
|
||||||
|
|
||||||
writer = MagicMock()
|
|
||||||
writer.write = MagicMock()
|
|
||||||
writer.drain = AsyncMock()
|
|
||||||
|
|
||||||
raw_response = json.dumps(response).encode() + b"\n"
|
|
||||||
reader = MagicMock()
|
|
||||||
reader.readline = AsyncMock(return_value=raw_response)
|
|
||||||
|
|
||||||
client._reader = reader
|
|
||||||
client._writer = writer
|
|
||||||
return client
|
|
||||||
|
|
||||||
async def test_successful_call_returns_result(self):
|
|
||||||
client = self._make_client({"jsonrpc": "2.0", "id": 1, "result": {"status": "ok"}})
|
|
||||||
result = await client.call("game/ping")
|
|
||||||
assert result == {"status": "ok"}
|
|
||||||
|
|
||||||
async def test_error_response_raises_gabs_error(self):
|
|
||||||
client = self._make_client({
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": 1,
|
|
||||||
"error": {"code": -32601, "message": "Method not found"},
|
|
||||||
})
|
|
||||||
with pytest.raises(GabsError) as exc_info:
|
|
||||||
await client.call("unknown/method")
|
|
||||||
assert exc_info.value.code == -32601
|
|
||||||
|
|
||||||
async def test_not_connected_raises_runtime_error(self):
|
|
||||||
client = GabsClient()
|
|
||||||
with pytest.raises(RuntimeError, match="not connected"):
|
|
||||||
await client.call("game/ping")
|
|
||||||
|
|
||||||
async def test_request_id_increments(self):
|
|
||||||
client = self._make_client({"jsonrpc": "2.0", "id": 1, "result": {}})
|
|
||||||
await client.call("game/ping")
|
|
||||||
# Reset reader for second call
|
|
||||||
client._reader.readline = AsyncMock(
|
|
||||||
return_value=json.dumps({"jsonrpc": "2.0", "id": 2, "result": {}}).encode() + b"\n"
|
|
||||||
)
|
|
||||||
await client.call("game/ping")
|
|
||||||
assert client._req_id == 2
|
|
||||||
|
|
||||||
async def test_get_game_state_returns_empty_on_error(self):
|
|
||||||
client = GabsClient()
|
|
||||||
client._connected = True
|
|
||||||
|
|
||||||
writer = MagicMock()
|
|
||||||
writer.write = MagicMock()
|
|
||||||
writer.drain = AsyncMock()
|
|
||||||
reader = MagicMock()
|
|
||||||
reader.readline = AsyncMock(side_effect=OSError("connection reset"))
|
|
||||||
|
|
||||||
client._reader = reader
|
|
||||||
client._writer = writer
|
|
||||||
|
|
||||||
result = await client.get_game_state()
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
async def test_ping_returns_true_on_success(self):
|
|
||||||
client = self._make_client({"jsonrpc": "2.0", "id": 1, "result": "pong"})
|
|
||||||
result = await client.ping()
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
async def test_ping_returns_false_on_failure(self):
|
|
||||||
client = GabsClient()
|
|
||||||
result = await client.ping()
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
|
|
||||||
class TestGabsClientLifecycle:
|
|
||||||
async def test_connect_failure_sets_not_connected(self):
|
|
||||||
client = GabsClient(host="localhost", port=9999, timeout=0.1)
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
await client.connect()
|
|
||||||
assert not client.is_connected
|
|
||||||
|
|
||||||
async def test_context_manager_calls_connect_and_disconnect(self):
|
|
||||||
client = GabsClient()
|
|
||||||
connect_called = False
|
|
||||||
disconnect_called = False
|
|
||||||
|
|
||||||
async def _fake_connect():
|
|
||||||
nonlocal connect_called
|
|
||||||
connect_called = True
|
|
||||||
client._connected = True
|
|
||||||
|
|
||||||
async def _fake_disconnect():
|
|
||||||
nonlocal disconnect_called
|
|
||||||
disconnect_called = True
|
|
||||||
client._connected = False
|
|
||||||
|
|
||||||
client.connect = _fake_connect
|
|
||||||
client.disconnect = _fake_disconnect
|
|
||||||
|
|
||||||
async with client as c:
|
|
||||||
assert c is client
|
|
||||||
assert connect_called
|
|
||||||
|
|
||||||
assert disconnect_called
|
|
||||||
@@ -6,8 +6,8 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from src.config import settings
|
||||||
from infrastructure.db_pool import ConnectionPool
|
from src.infrastructure.db_pool import ConnectionPool
|
||||||
|
|
||||||
|
|
||||||
class TestConnectionPoolInit:
|
class TestConnectionPoolInit:
|
||||||
@@ -330,9 +330,9 @@ class TestPragmaApplication:
|
|||||||
"""busy_timeout pragma set on a pooled connection persists."""
|
"""busy_timeout pragma set on a pooled connection persists."""
|
||||||
pool = ConnectionPool(tmp_path / "test.db")
|
pool = ConnectionPool(tmp_path / "test.db")
|
||||||
conn = pool.get_connection()
|
conn = pool.get_connection()
|
||||||
conn.execute("PRAGMA busy_timeout=5000")
|
conn.execute(f"PRAGMA busy_timeout={settings.db_busy_timeout_ms}")
|
||||||
timeout = conn.execute("PRAGMA busy_timeout").fetchone()[0]
|
timeout = conn.execute("PRAGMA busy_timeout").fetchone()[0]
|
||||||
assert timeout == 5000
|
assert timeout == settings.db_busy_timeout_ms
|
||||||
pool.close_connection()
|
pool.close_connection()
|
||||||
|
|
||||||
def test_pragmas_apply_per_connection(self, tmp_path):
|
def test_pragmas_apply_per_connection(self, tmp_path):
|
||||||
|
|||||||
@@ -664,10 +664,10 @@ class TestVllmMlxProvider:
|
|||||||
)
|
)
|
||||||
router.providers = [provider]
|
router.providers = [provider]
|
||||||
|
|
||||||
# Quota monitor returns False (block cloud) — vllm_mlx should still be tried
|
# Quota monitor downshifts to local (ACTIVE tier) — vllm_mlx should still be tried
|
||||||
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
||||||
mock_qm.check.return_value = object()
|
mock_qm.select_model.return_value = "qwen3:14b"
|
||||||
mock_qm.should_use_cloud.return_value = False
|
mock_qm.check.return_value = None
|
||||||
|
|
||||||
with patch.object(router, "_call_vllm_mlx") as mock_call:
|
with patch.object(router, "_call_vllm_mlx") as mock_call:
|
||||||
mock_call.return_value = {
|
mock_call.return_value = {
|
||||||
@@ -681,6 +681,115 @@ class TestVllmMlxProvider:
|
|||||||
assert result["content"] == "Local MLX response"
|
assert result["content"] == "Local MLX response"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetabolicProtocol:
|
||||||
|
"""Test metabolic protocol: cloud providers skip when quota is ACTIVE/RESTING."""
|
||||||
|
|
||||||
|
def _make_anthropic_provider(self) -> "Provider":
|
||||||
|
return Provider(
|
||||||
|
name="anthropic-primary",
|
||||||
|
type="anthropic",
|
||||||
|
enabled=True,
|
||||||
|
priority=1,
|
||||||
|
api_key="test-key",
|
||||||
|
models=[{"name": "claude-sonnet-4-6", "default": True}],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_cloud_provider_allowed_in_burst_tier(self):
|
||||||
|
"""BURST tier (quota healthy): cloud provider is tried."""
|
||||||
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||||
|
router.providers = [self._make_anthropic_provider()]
|
||||||
|
|
||||||
|
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
||||||
|
# select_model returns cloud model → BURST tier
|
||||||
|
mock_qm.select_model.return_value = "claude-sonnet-4-6"
|
||||||
|
mock_qm.check.return_value = None
|
||||||
|
|
||||||
|
with patch.object(router, "_call_anthropic") as mock_call:
|
||||||
|
mock_call.return_value = {"content": "Cloud response", "model": "claude-sonnet-4-6"}
|
||||||
|
result = await router.complete(
|
||||||
|
messages=[{"role": "user", "content": "hard question"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_call.assert_called_once()
|
||||||
|
assert result["content"] == "Cloud response"
|
||||||
|
|
||||||
|
async def test_cloud_provider_skipped_in_active_tier(self):
|
||||||
|
"""ACTIVE tier (5-hour >= 50%): cloud provider is skipped."""
|
||||||
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||||
|
router.providers = [self._make_anthropic_provider()]
|
||||||
|
|
||||||
|
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
||||||
|
# select_model returns local 14B → ACTIVE tier
|
||||||
|
mock_qm.select_model.return_value = "qwen3:14b"
|
||||||
|
mock_qm.check.return_value = None
|
||||||
|
|
||||||
|
with patch.object(router, "_call_anthropic") as mock_call:
|
||||||
|
with pytest.raises(RuntimeError, match="All providers failed"):
|
||||||
|
await router.complete(
|
||||||
|
messages=[{"role": "user", "content": "question"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_call.assert_not_called()
|
||||||
|
|
||||||
|
async def test_cloud_provider_skipped_in_resting_tier(self):
|
||||||
|
"""RESTING tier (7-day >= 80%): cloud provider is skipped."""
|
||||||
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||||
|
router.providers = [self._make_anthropic_provider()]
|
||||||
|
|
||||||
|
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
||||||
|
# select_model returns local 8B → RESTING tier
|
||||||
|
mock_qm.select_model.return_value = "qwen3:8b"
|
||||||
|
mock_qm.check.return_value = None
|
||||||
|
|
||||||
|
with patch.object(router, "_call_anthropic") as mock_call:
|
||||||
|
with pytest.raises(RuntimeError, match="All providers failed"):
|
||||||
|
await router.complete(
|
||||||
|
messages=[{"role": "user", "content": "simple question"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_call.assert_not_called()
|
||||||
|
|
||||||
|
async def test_local_provider_always_tried_regardless_of_quota(self):
|
||||||
|
"""Local (ollama/vllm_mlx) providers bypass the metabolic protocol."""
|
||||||
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||||
|
provider = Provider(
|
||||||
|
name="ollama-local",
|
||||||
|
type="ollama",
|
||||||
|
enabled=True,
|
||||||
|
priority=1,
|
||||||
|
url="http://localhost:11434",
|
||||||
|
models=[{"name": "qwen3:14b", "default": True}],
|
||||||
|
)
|
||||||
|
router.providers = [provider]
|
||||||
|
|
||||||
|
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
||||||
|
mock_qm.select_model.return_value = "qwen3:8b" # RESTING tier
|
||||||
|
|
||||||
|
with patch.object(router, "_call_ollama") as mock_call:
|
||||||
|
mock_call.return_value = {"content": "Local response", "model": "qwen3:14b"}
|
||||||
|
result = await router.complete(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_call.assert_called_once()
|
||||||
|
assert result["content"] == "Local response"
|
||||||
|
|
||||||
|
async def test_no_quota_monitor_allows_cloud(self):
|
||||||
|
"""When quota monitor is None (unavailable), cloud providers are allowed."""
|
||||||
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||||
|
router.providers = [self._make_anthropic_provider()]
|
||||||
|
|
||||||
|
with patch("infrastructure.router.cascade._quota_monitor", None):
|
||||||
|
with patch.object(router, "_call_anthropic") as mock_call:
|
||||||
|
mock_call.return_value = {"content": "Cloud response", "model": "claude-sonnet-4-6"}
|
||||||
|
result = await router.complete(
|
||||||
|
messages=[{"role": "user", "content": "question"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_call.assert_called_once()
|
||||||
|
assert result["content"] == "Cloud response"
|
||||||
|
|
||||||
|
|
||||||
class TestCascadeRouterReload:
|
class TestCascadeRouterReload:
|
||||||
"""Test hot-reload of providers.yaml."""
|
"""Test hot-reload of providers.yaml."""
|
||||||
|
|
||||||
|
|||||||
285
tests/scripts/test_export_trajectories.py
Normal file
285
tests/scripts/test_export_trajectories.py
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
"""Unit tests for scripts/export_trajectories.py.
|
||||||
|
|
||||||
|
Tests trajectory conversion logic — no I/O, no Ollama, no mlx.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import scripts.export_trajectories as et
|
||||||
|
|
||||||
|
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def simple_session(tmp_path: Path) -> Path:
|
||||||
|
"""Write a minimal session JSONL file and return the logs dir."""
|
||||||
|
logs_dir = tmp_path / "logs"
|
||||||
|
logs_dir.mkdir()
|
||||||
|
entries = [
|
||||||
|
{"type": "message", "role": "user", "content": "What time is it?", "timestamp": "2026-03-01T10:00:00"},
|
||||||
|
{"type": "message", "role": "timmy", "content": "It is 10:00 AM.", "timestamp": "2026-03-01T10:00:01"},
|
||||||
|
{"type": "message", "role": "user", "content": "Thanks!", "timestamp": "2026-03-01T10:00:05"},
|
||||||
|
{"type": "message", "role": "timmy", "content": "You're welcome!", "timestamp": "2026-03-01T10:00:06"},
|
||||||
|
]
|
||||||
|
session_file = logs_dir / "session_2026-03-01.jsonl"
|
||||||
|
session_file.write_text("\n".join(json.dumps(e) for e in entries) + "\n")
|
||||||
|
return logs_dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def tool_call_session(tmp_path: Path) -> Path:
|
||||||
|
"""Write a session JSONL with tool calls."""
|
||||||
|
logs_dir = tmp_path / "logs"
|
||||||
|
logs_dir.mkdir()
|
||||||
|
entries = [
|
||||||
|
{"type": "message", "role": "user", "content": "Read CLAUDE.md", "timestamp": "2026-03-01T10:00:00"},
|
||||||
|
{
|
||||||
|
"type": "tool_call",
|
||||||
|
"tool": "read_file",
|
||||||
|
"args": {"path": "CLAUDE.md"},
|
||||||
|
"result": "# CLAUDE.md content here",
|
||||||
|
"timestamp": "2026-03-01T10:00:01",
|
||||||
|
},
|
||||||
|
{"type": "message", "role": "timmy", "content": "Here is the content.", "timestamp": "2026-03-01T10:00:02"},
|
||||||
|
]
|
||||||
|
session_file = logs_dir / "session_2026-03-01.jsonl"
|
||||||
|
session_file.write_text("\n".join(json.dumps(e) for e in entries) + "\n")
|
||||||
|
return logs_dir
|
||||||
|
|
||||||
|
|
||||||
|
# ── _load_entries ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_load_entries_returns_all(simple_session: Path) -> None:
|
||||||
|
entries = et._load_entries(simple_session)
|
||||||
|
assert len(entries) == 4
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_load_entries_skips_malformed(tmp_path: Path) -> None:
|
||||||
|
logs_dir = tmp_path / "logs"
|
||||||
|
logs_dir.mkdir()
|
||||||
|
session = logs_dir / "session_2026-03-01.jsonl"
|
||||||
|
session.write_text(
|
||||||
|
'{"type": "message", "role": "user", "content": "hi"}\n'
|
||||||
|
"NOT_JSON\n"
|
||||||
|
'{"type": "message", "role": "timmy", "content": "hello"}\n'
|
||||||
|
)
|
||||||
|
entries = et._load_entries(logs_dir)
|
||||||
|
assert len(entries) == 2 # malformed line skipped
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_load_entries_empty_dir(tmp_path: Path) -> None:
|
||||||
|
logs_dir = tmp_path / "logs"
|
||||||
|
logs_dir.mkdir()
|
||||||
|
entries = et._load_entries(logs_dir)
|
||||||
|
assert entries == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_load_entries_multiple_files(tmp_path: Path) -> None:
|
||||||
|
logs_dir = tmp_path / "logs"
|
||||||
|
logs_dir.mkdir()
|
||||||
|
for day in ("2026-03-01", "2026-03-02"):
|
||||||
|
entry = {"type": "message", "role": "user", "content": f"day {day}"}
|
||||||
|
(logs_dir / f"session_{day}.jsonl").write_text(json.dumps(entry) + "\n")
|
||||||
|
entries = et._load_entries(logs_dir)
|
||||||
|
assert len(entries) == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ── _format_tool_call ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_format_tool_call_structure() -> None:
|
||||||
|
entry = {
|
||||||
|
"type": "tool_call",
|
||||||
|
"tool": "read_file",
|
||||||
|
"args": {"path": "/tmp/foo.txt"},
|
||||||
|
"result": "file contents",
|
||||||
|
}
|
||||||
|
result = et._format_tool_call(entry)
|
||||||
|
assert result.startswith("<tool_call>")
|
||||||
|
assert result.endswith("</tool_call>")
|
||||||
|
payload = json.loads(result.split("\n")[1])
|
||||||
|
assert payload["name"] == "read_file"
|
||||||
|
assert payload["arguments"]["path"] == "/tmp/foo.txt"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_format_tool_call_missing_tool() -> None:
|
||||||
|
entry = {"type": "tool_call", "args": {}}
|
||||||
|
result = et._format_tool_call(entry)
|
||||||
|
assert "unknown" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ── _group_into_turns ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_group_basic_conversation() -> None:
|
||||||
|
entries = [
|
||||||
|
{"type": "message", "role": "user", "content": "hello"},
|
||||||
|
{"type": "message", "role": "timmy", "content": "hi there"},
|
||||||
|
{"type": "message", "role": "user", "content": "bye"},
|
||||||
|
{"type": "message", "role": "timmy", "content": "goodbye"},
|
||||||
|
]
|
||||||
|
turns = et._group_into_turns(entries)
|
||||||
|
assert len(turns) == 2
|
||||||
|
assert turns[0]["user"] == "hello"
|
||||||
|
assert turns[0]["assistant"] == "hi there"
|
||||||
|
assert turns[1]["user"] == "bye"
|
||||||
|
assert turns[1]["assistant"] == "goodbye"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_group_with_tool_call() -> None:
|
||||||
|
entries = [
|
||||||
|
{"type": "message", "role": "user", "content": "check the file"},
|
||||||
|
{"type": "tool_call", "tool": "read_file", "args": {"path": "x"}, "result": "content"},
|
||||||
|
{"type": "message", "role": "timmy", "content": "Done."},
|
||||||
|
]
|
||||||
|
turns = et._group_into_turns(entries)
|
||||||
|
assert len(turns) == 1
|
||||||
|
assert "<tool_call>" in turns[0]["assistant"]
|
||||||
|
assert "Done." in turns[0]["assistant"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_group_skips_user_without_response() -> None:
|
||||||
|
"""User message with no timmy response should not create a turn."""
|
||||||
|
entries = [
|
||||||
|
{"type": "message", "role": "user", "content": "hello"},
|
||||||
|
# No timmy response
|
||||||
|
{"type": "message", "role": "user", "content": "are you there?"},
|
||||||
|
{"type": "message", "role": "timmy", "content": "Yes!"},
|
||||||
|
]
|
||||||
|
turns = et._group_into_turns(entries)
|
||||||
|
assert len(turns) == 1
|
||||||
|
assert turns[0]["user"] == "are you there?"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_group_ignores_errors_and_decisions() -> None:
|
||||||
|
entries = [
|
||||||
|
{"type": "message", "role": "user", "content": "hello"},
|
||||||
|
{"type": "error", "error": "something failed"},
|
||||||
|
{"type": "decision", "decision": "retry"},
|
||||||
|
{"type": "message", "role": "timmy", "content": "Got it."},
|
||||||
|
]
|
||||||
|
turns = et._group_into_turns(entries)
|
||||||
|
assert len(turns) == 1
|
||||||
|
assert "error" not in turns[0]["assistant"]
|
||||||
|
assert "retry" not in turns[0]["assistant"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_group_empty_entries() -> None:
|
||||||
|
assert et._group_into_turns([]) == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── turns_to_training_examples ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_training_examples_structure() -> None:
|
||||||
|
turns = [{"user": "hello", "assistant": "hi there, how can I help?"}]
|
||||||
|
examples = et.turns_to_training_examples(turns)
|
||||||
|
assert len(examples) == 1
|
||||||
|
msgs = examples[0]["messages"]
|
||||||
|
assert msgs[0]["role"] == "system"
|
||||||
|
assert msgs[1]["role"] == "user"
|
||||||
|
assert msgs[1]["content"] == "hello"
|
||||||
|
assert msgs[2]["role"] == "assistant"
|
||||||
|
assert msgs[2]["content"] == "hi there, how can I help?"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_training_examples_filters_short_responses() -> None:
|
||||||
|
turns = [
|
||||||
|
{"user": "hello", "assistant": "ok"}, # too short
|
||||||
|
{"user": "hello", "assistant": "This is a longer response that passes."},
|
||||||
|
]
|
||||||
|
examples = et.turns_to_training_examples(turns, min_assistant_len=10)
|
||||||
|
assert len(examples) == 1
|
||||||
|
assert examples[0]["messages"][2]["content"] == "This is a longer response that passes."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_training_examples_filters_empty_user() -> None:
|
||||||
|
turns = [{"user": "", "assistant": "some response here"}]
|
||||||
|
examples = et.turns_to_training_examples(turns)
|
||||||
|
assert len(examples) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_training_examples_uses_custom_system_prompt() -> None:
|
||||||
|
turns = [{"user": "hi", "assistant": "hello there!"}]
|
||||||
|
examples = et.turns_to_training_examples(turns, system_prompt="Custom prompt.")
|
||||||
|
assert examples[0]["messages"][0]["content"] == "Custom prompt."
|
||||||
|
|
||||||
|
|
||||||
|
# ── export_training_data (integration-style, uses tmp_path) ──────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_export_training_data_writes_jsonl(simple_session: Path, tmp_path: Path) -> None:
|
||||||
|
output = tmp_path / "train.jsonl"
|
||||||
|
count = et.export_training_data(logs_dir=simple_session, output_path=output)
|
||||||
|
assert count == 2
|
||||||
|
assert output.exists()
|
||||||
|
lines = [
|
||||||
|
json.loads(line) for line in output.read_text().splitlines() if line.strip()
|
||||||
|
]
|
||||||
|
assert len(lines) == 2
|
||||||
|
for line in lines:
|
||||||
|
assert "messages" in line
|
||||||
|
roles = [m["role"] for m in line["messages"]]
|
||||||
|
assert roles == ["system", "user", "assistant"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_export_training_data_with_tool_calls(tool_call_session: Path, tmp_path: Path) -> None:
|
||||||
|
output = tmp_path / "train.jsonl"
|
||||||
|
count = et.export_training_data(logs_dir=tool_call_session, output_path=output)
|
||||||
|
assert count == 1
|
||||||
|
line = json.loads(output.read_text().strip())
|
||||||
|
assistant_content = line["messages"][2]["content"]
|
||||||
|
assert "<tool_call>" in assistant_content
|
||||||
|
assert "read_file" in assistant_content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_export_training_data_returns_zero_for_empty_logs(tmp_path: Path) -> None:
|
||||||
|
logs_dir = tmp_path / "logs"
|
||||||
|
logs_dir.mkdir()
|
||||||
|
output = tmp_path / "train.jsonl"
|
||||||
|
count = et.export_training_data(logs_dir=logs_dir, output_path=output)
|
||||||
|
assert count == 0
|
||||||
|
assert not output.exists()
|
||||||
|
|
||||||
|
|
||||||
|
# ── CLI ───────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_cli_missing_logs_dir(tmp_path: Path) -> None:
|
||||||
|
rc = et.main(["--logs-dir", str(tmp_path / "nonexistent"), "--output", str(tmp_path / "out.jsonl")])
|
||||||
|
assert rc == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_cli_exports_and_returns_zero(simple_session: Path, tmp_path: Path) -> None:
|
||||||
|
output = tmp_path / "out.jsonl"
|
||||||
|
rc = et.main([
|
||||||
|
"--logs-dir", str(simple_session),
|
||||||
|
"--output", str(output),
|
||||||
|
])
|
||||||
|
assert rc == 0
|
||||||
|
assert output.exists()
|
||||||
546
tests/unit/test_retrain_loop.py
Normal file
546
tests/unit/test_retrain_loop.py
Normal file
@@ -0,0 +1,546 @@
|
|||||||
|
"""Unit tests for the AutoLoRA continuous improvement loop.
|
||||||
|
|
||||||
|
Covers trajectory extraction, quality filtering, dataset management,
|
||||||
|
and the retrain orchestrator.
|
||||||
|
|
||||||
|
Refs: #1105
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from timmy_automations.retrain.quality_filter import QualityFilter, TrajectoryQuality
|
||||||
|
from timmy_automations.retrain.retrain import RetrainOrchestrator
|
||||||
|
from timmy_automations.retrain.training_dataset import TrainingDataset
|
||||||
|
from timmy_automations.retrain.training_log import CycleMetrics, TrainingLog
|
||||||
|
from timmy_automations.retrain.trajectory_exporter import Trajectory, TrajectoryExporter
|
||||||
|
|
||||||
|
# ── Fixtures ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _ts(offset_minutes: int = 0) -> str:
|
||||||
|
"""Return an ISO timestamp offset from now."""
|
||||||
|
return (datetime.now(tz=UTC) + timedelta(minutes=offset_minutes)).isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_session_log(entries: list[dict], date_str: str, tmp_path: Path) -> Path:
|
||||||
|
"""Write session JSONL entries to a temp log file."""
|
||||||
|
log_dir = tmp_path / "logs"
|
||||||
|
log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
log_file = log_dir / f"session_{date_str}.jsonl"
|
||||||
|
with open(log_file, "w") as f:
|
||||||
|
for entry in entries:
|
||||||
|
f.write(json.dumps(entry) + "\n")
|
||||||
|
return log_file
|
||||||
|
|
||||||
|
|
||||||
|
def _user_msg(content: str, offset: int = 0) -> dict:
|
||||||
|
return {"type": "message", "role": "user", "content": content, "timestamp": _ts(offset)}
|
||||||
|
|
||||||
|
|
||||||
|
def _timmy_msg(content: str, confidence: float | None = None, offset: int = 0) -> dict:
|
||||||
|
entry = {"type": "message", "role": "timmy", "content": content, "timestamp": _ts(offset)}
|
||||||
|
if confidence is not None:
|
||||||
|
entry["confidence"] = confidence
|
||||||
|
return entry
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_call(tool: str = "bash", result: str = "ok", offset: int = 0) -> dict:
|
||||||
|
return {
|
||||||
|
"type": "tool_call",
|
||||||
|
"tool": tool,
|
||||||
|
"args": {},
|
||||||
|
"result": result,
|
||||||
|
"timestamp": _ts(offset),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _error_entry(msg: str = "Something failed", offset: int = 0) -> dict:
|
||||||
|
return {"type": "error", "error": msg, "timestamp": _ts(offset)}
|
||||||
|
|
||||||
|
|
||||||
|
def _decision_entry(decision: str = "Use approach A", offset: int = 0) -> dict:
|
||||||
|
return {"type": "decision", "decision": decision, "timestamp": _ts(offset)}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Trajectory dataclass tests ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrajectory:
|
||||||
|
def test_message_count(self):
|
||||||
|
t = Trajectory(
|
||||||
|
session_date="2026-03-17",
|
||||||
|
started_at=_ts(),
|
||||||
|
ended_at=_ts(),
|
||||||
|
messages=[_user_msg("hi"), _timmy_msg("hello")],
|
||||||
|
)
|
||||||
|
assert t.message_count == 2
|
||||||
|
|
||||||
|
def test_tool_call_count(self):
|
||||||
|
t = Trajectory(
|
||||||
|
session_date="2026-03-17",
|
||||||
|
started_at=_ts(),
|
||||||
|
ended_at=_ts(),
|
||||||
|
tool_calls=[_tool_call(), _tool_call()],
|
||||||
|
)
|
||||||
|
assert t.tool_call_count == 2
|
||||||
|
|
||||||
|
def test_has_successful_tool_call_when_no_errors(self):
|
||||||
|
t = Trajectory(
|
||||||
|
session_date="2026-03-17",
|
||||||
|
started_at=_ts(),
|
||||||
|
ended_at=_ts(),
|
||||||
|
tool_calls=[_tool_call()],
|
||||||
|
errors=[],
|
||||||
|
)
|
||||||
|
assert t.has_successful_tool_call is True
|
||||||
|
|
||||||
|
def test_has_successful_tool_call_false_when_errors(self):
|
||||||
|
t = Trajectory(
|
||||||
|
session_date="2026-03-17",
|
||||||
|
started_at=_ts(),
|
||||||
|
ended_at=_ts(),
|
||||||
|
tool_calls=[_tool_call()],
|
||||||
|
errors=[_error_entry()],
|
||||||
|
)
|
||||||
|
assert t.has_successful_tool_call is False
|
||||||
|
|
||||||
|
def test_is_multi_step(self):
|
||||||
|
t = Trajectory(
|
||||||
|
session_date="2026-03-17",
|
||||||
|
started_at=_ts(),
|
||||||
|
ended_at=_ts(),
|
||||||
|
messages=[_user_msg("do it"), _timmy_msg("done")],
|
||||||
|
tool_calls=[_tool_call()],
|
||||||
|
)
|
||||||
|
assert t.is_multi_step is True
|
||||||
|
|
||||||
|
def test_is_not_multi_step_single_message(self):
|
||||||
|
t = Trajectory(
|
||||||
|
session_date="2026-03-17",
|
||||||
|
started_at=_ts(),
|
||||||
|
ended_at=_ts(),
|
||||||
|
messages=[_timmy_msg("hello")],
|
||||||
|
tool_calls=[],
|
||||||
|
)
|
||||||
|
assert t.is_multi_step is False
|
||||||
|
|
||||||
|
def test_to_chat_format_ordering(self):
|
||||||
|
t = Trajectory(
|
||||||
|
session_date="2026-03-17",
|
||||||
|
started_at=_ts(),
|
||||||
|
ended_at=_ts(),
|
||||||
|
messages=[_user_msg("question", offset=0), _timmy_msg("answer", offset=2)],
|
||||||
|
tool_calls=[_tool_call(offset=1)],
|
||||||
|
)
|
||||||
|
chat = t.to_chat_format()
|
||||||
|
roles = [m["role"] for m in chat]
|
||||||
|
assert "user" in roles
|
||||||
|
assert "assistant" in roles
|
||||||
|
|
||||||
|
def test_to_chat_format_empty_content_skipped(self):
|
||||||
|
t = Trajectory(
|
||||||
|
session_date="2026-03-17",
|
||||||
|
started_at=_ts(),
|
||||||
|
ended_at=_ts(),
|
||||||
|
messages=[_user_msg(""), _timmy_msg("response")],
|
||||||
|
)
|
||||||
|
chat = t.to_chat_format()
|
||||||
|
# Empty user message should be skipped
|
||||||
|
assert all(m["content"] for m in chat)
|
||||||
|
|
||||||
|
|
||||||
|
# ── TrajectoryExporter tests ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrajectoryExporter:
|
||||||
|
def test_export_empty_logs_dir(self, tmp_path):
|
||||||
|
(tmp_path / "logs").mkdir()
|
||||||
|
exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path)
|
||||||
|
result = exporter.export_week(weeks_ago=0)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_export_reads_session_files(self, tmp_path):
|
||||||
|
# Write a session file for this week
|
||||||
|
today = datetime.now(tz=UTC)
|
||||||
|
date_str = today.strftime("%Y-%m-%d")
|
||||||
|
entries = [
|
||||||
|
_user_msg("tell me about Python"),
|
||||||
|
_timmy_msg("Python is great"),
|
||||||
|
]
|
||||||
|
_make_session_log(entries, date_str, tmp_path)
|
||||||
|
|
||||||
|
exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path)
|
||||||
|
result = exporter.export_week(weeks_ago=0)
|
||||||
|
assert len(result) >= 1
|
||||||
|
|
||||||
|
def test_export_skips_old_sessions(self, tmp_path):
|
||||||
|
# Write a session file for 3 weeks ago
|
||||||
|
three_weeks_ago = datetime.now(tz=UTC) - timedelta(weeks=3)
|
||||||
|
date_str = three_weeks_ago.strftime("%Y-%m-%d")
|
||||||
|
entries = [_user_msg("old message"), _timmy_msg("old response")]
|
||||||
|
_make_session_log(entries, date_str, tmp_path)
|
||||||
|
|
||||||
|
exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path)
|
||||||
|
# Request current week — should not include 3-week-old data
|
||||||
|
result = exporter.export_week(weeks_ago=0)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_export_segments_by_gap(self, tmp_path):
|
||||||
|
today = datetime.now(tz=UTC)
|
||||||
|
date_str = today.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
# Two conversations separated by 10 minutes
|
||||||
|
t1 = (today - timedelta(minutes=15)).isoformat()
|
||||||
|
t2 = (today - timedelta(minutes=14)).isoformat()
|
||||||
|
t3 = (today - timedelta(minutes=2)).isoformat()
|
||||||
|
t4 = (today - timedelta(minutes=1)).isoformat()
|
||||||
|
|
||||||
|
entries = [
|
||||||
|
{"type": "message", "role": "user", "content": "first q", "timestamp": t1},
|
||||||
|
{"type": "message", "role": "timmy", "content": "first a", "timestamp": t2},
|
||||||
|
{"type": "message", "role": "user", "content": "second q", "timestamp": t3},
|
||||||
|
{"type": "message", "role": "timmy", "content": "second a", "timestamp": t4},
|
||||||
|
]
|
||||||
|
_make_session_log(entries, date_str, tmp_path)
|
||||||
|
|
||||||
|
exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path)
|
||||||
|
result = exporter.export_week(weeks_ago=0)
|
||||||
|
# Should have at least 1 trajectory (may be 1 or 2 depending on segmentation)
|
||||||
|
assert len(result) >= 1
|
||||||
|
|
||||||
|
def test_handles_malformed_log_file(self, tmp_path):
|
||||||
|
log_dir = tmp_path / "logs"
|
||||||
|
log_dir.mkdir()
|
||||||
|
today = datetime.now(tz=UTC).strftime("%Y-%m-%d")
|
||||||
|
(log_dir / f"session_{today}.jsonl").write_text("not json\n{}\n")
|
||||||
|
|
||||||
|
exporter = TrajectoryExporter(logs_dir=log_dir, repo_root=tmp_path)
|
||||||
|
# Should not raise, just return empty or partial results
|
||||||
|
result = exporter.export_week(weeks_ago=0)
|
||||||
|
assert isinstance(result, list)
|
||||||
|
|
||||||
|
|
||||||
|
# ── QualityFilter tests ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestQualityFilter:
|
||||||
|
def _make_high_quality(self) -> Trajectory:
|
||||||
|
return Trajectory(
|
||||||
|
session_date="2026-03-17",
|
||||||
|
started_at=_ts(),
|
||||||
|
ended_at=_ts(),
|
||||||
|
messages=[_user_msg("do task"), _timmy_msg("done", confidence=0.9)],
|
||||||
|
tool_calls=[_tool_call(), _tool_call()],
|
||||||
|
errors=[],
|
||||||
|
decisions=[_decision_entry()],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_medium_quality(self) -> Trajectory:
|
||||||
|
return Trajectory(
|
||||||
|
session_date="2026-03-17",
|
||||||
|
started_at=_ts(),
|
||||||
|
ended_at=_ts(),
|
||||||
|
messages=[_user_msg("hello"), _timmy_msg("hi")],
|
||||||
|
tool_calls=[],
|
||||||
|
errors=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_low_quality(self) -> Trajectory:
|
||||||
|
return Trajectory(
|
||||||
|
session_date="2026-03-17",
|
||||||
|
started_at=_ts(),
|
||||||
|
ended_at=_ts(),
|
||||||
|
messages=[_timmy_msg("oops")], # No user message
|
||||||
|
errors=[_error_entry()],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_high_quality_classification(self):
|
||||||
|
qf = QualityFilter()
|
||||||
|
result = qf.assess(self._make_high_quality())
|
||||||
|
assert result.quality == TrajectoryQuality.HIGH
|
||||||
|
assert result.score >= 4.0
|
||||||
|
assert result.is_trainable
|
||||||
|
|
||||||
|
def test_medium_quality_classification(self):
|
||||||
|
qf = QualityFilter()
|
||||||
|
result = qf.assess(self._make_medium_quality())
|
||||||
|
assert result.quality == TrajectoryQuality.MEDIUM
|
||||||
|
assert result.is_trainable
|
||||||
|
|
||||||
|
def test_low_quality_no_user_message(self):
|
||||||
|
qf = QualityFilter()
|
||||||
|
t = Trajectory(
|
||||||
|
session_date="2026-03-17",
|
||||||
|
started_at=_ts(),
|
||||||
|
ended_at=_ts(),
|
||||||
|
messages=[_timmy_msg("random")],
|
||||||
|
)
|
||||||
|
result = qf.assess(t)
|
||||||
|
assert result.quality == TrajectoryQuality.LOW
|
||||||
|
assert not result.is_trainable
|
||||||
|
|
||||||
|
def test_error_penalizes_score(self):
|
||||||
|
qf = QualityFilter()
|
||||||
|
t = Trajectory(
|
||||||
|
session_date="2026-03-17",
|
||||||
|
started_at=_ts(),
|
||||||
|
ended_at=_ts(),
|
||||||
|
messages=[_user_msg("go"), _timmy_msg("fail")],
|
||||||
|
tool_calls=[_tool_call()],
|
||||||
|
errors=[_error_entry(), _error_entry()],
|
||||||
|
)
|
||||||
|
result = qf.assess(t)
|
||||||
|
assert result.score < qf.assess(self._make_high_quality()).score
|
||||||
|
|
||||||
|
def test_low_confidence_penalizes_score(self):
|
||||||
|
qf = QualityFilter()
|
||||||
|
t = Trajectory(
|
||||||
|
session_date="2026-03-17",
|
||||||
|
started_at=_ts(),
|
||||||
|
ended_at=_ts(),
|
||||||
|
messages=[_user_msg("q"), _timmy_msg("a", confidence=0.2)],
|
||||||
|
)
|
||||||
|
result = qf.assess(t)
|
||||||
|
assert result.score < 1.0
|
||||||
|
|
||||||
|
def test_filter_returns_stats(self):
|
||||||
|
qf = QualityFilter()
|
||||||
|
trajectories = [
|
||||||
|
self._make_high_quality(),
|
||||||
|
self._make_medium_quality(),
|
||||||
|
self._make_low_quality(),
|
||||||
|
]
|
||||||
|
trainable, stats = qf.filter(trajectories)
|
||||||
|
assert stats["total"] == 3
|
||||||
|
assert stats["accepted"] == len(trainable)
|
||||||
|
assert stats["high"] + stats["medium"] + stats["low"] == 3
|
||||||
|
|
||||||
|
def test_filter_empty_list(self):
|
||||||
|
qf = QualityFilter()
|
||||||
|
trainable, stats = qf.filter([])
|
||||||
|
assert trainable == []
|
||||||
|
assert stats["total"] == 0
|
||||||
|
assert stats["accepted"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── TrainingDataset tests ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrainingDataset:
|
||||||
|
def _make_result(self, quality=TrajectoryQuality.HIGH, score=5.0) -> object:
|
||||||
|
from timmy_automations.retrain.quality_filter import QualityResult
|
||||||
|
|
||||||
|
t = Trajectory(
|
||||||
|
session_date="2026-03-17",
|
||||||
|
started_at=_ts(-5),
|
||||||
|
ended_at=_ts(),
|
||||||
|
messages=[_user_msg("do it"), _timmy_msg("done")],
|
||||||
|
tool_calls=[_tool_call()],
|
||||||
|
)
|
||||||
|
return QualityResult(trajectory=t, quality=quality, score=score, reasons=[])
|
||||||
|
|
||||||
|
def test_count_empty_dataset(self, tmp_path):
|
||||||
|
ds = TrainingDataset(
|
||||||
|
dataset_path=".loop/retrain/training_data.jsonl",
|
||||||
|
repo_root=tmp_path,
|
||||||
|
)
|
||||||
|
assert ds.count() == 0
|
||||||
|
|
||||||
|
def test_append_adds_examples(self, tmp_path):
|
||||||
|
ds = TrainingDataset(repo_root=tmp_path)
|
||||||
|
result = ds.append([self._make_result()], "2026-W12")
|
||||||
|
assert result.new_examples == 1
|
||||||
|
assert result.total_examples == 1
|
||||||
|
assert ds.count() == 1
|
||||||
|
|
||||||
|
def test_append_idempotent(self, tmp_path):
|
||||||
|
ds = TrainingDataset(repo_root=tmp_path)
|
||||||
|
r = self._make_result()
|
||||||
|
ds.append([r], "2026-W12")
|
||||||
|
result2 = ds.append([r], "2026-W12")
|
||||||
|
# Same trajectory shouldn't be added twice
|
||||||
|
assert result2.new_examples == 0
|
||||||
|
assert ds.count() == 1
|
||||||
|
|
||||||
|
def test_append_different_weeks(self, tmp_path):
|
||||||
|
ds = TrainingDataset(repo_root=tmp_path)
|
||||||
|
r1 = self._make_result()
|
||||||
|
ds.append([r1], "2026-W11")
|
||||||
|
ds.append([r1], "2026-W12")
|
||||||
|
# Different week tags = different records
|
||||||
|
assert ds.count() == 2
|
||||||
|
|
||||||
|
def test_dataset_file_is_valid_jsonl(self, tmp_path):
|
||||||
|
ds = TrainingDataset(repo_root=tmp_path)
|
||||||
|
ds.append([self._make_result()], "2026-W12")
|
||||||
|
with open(ds.dataset_path) as f:
|
||||||
|
lines = [line.strip() for line in f if line.strip()]
|
||||||
|
assert len(lines) == 1
|
||||||
|
record = json.loads(lines[0])
|
||||||
|
assert "messages" in record
|
||||||
|
assert "week" in record
|
||||||
|
assert "quality" in record
|
||||||
|
|
||||||
|
def test_index_updated_after_append(self, tmp_path):
|
||||||
|
ds = TrainingDataset(repo_root=tmp_path)
|
||||||
|
ds.append([self._make_result()], "2026-W12")
|
||||||
|
index_path = tmp_path / ".loop" / "retrain" / "dataset_index.json"
|
||||||
|
assert index_path.exists()
|
||||||
|
index = json.loads(index_path.read_text())
|
||||||
|
assert index["total_examples"] == 1
|
||||||
|
assert "2026-W12" in index["weeks"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── TrainingLog tests ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrainingLog:
|
||||||
|
def _make_metrics(self, iteration: int = 1) -> CycleMetrics:
|
||||||
|
return CycleMetrics(
|
||||||
|
iteration=iteration,
|
||||||
|
week="2026-W12",
|
||||||
|
ran_at=datetime.now(tz=UTC).isoformat(),
|
||||||
|
trajectories_total=10,
|
||||||
|
trajectories_high=5,
|
||||||
|
trajectories_medium=3,
|
||||||
|
trajectories_low=2,
|
||||||
|
trajectories_accepted=8,
|
||||||
|
examples_added=5,
|
||||||
|
dataset_total=5,
|
||||||
|
train_status="completed",
|
||||||
|
train_loss=1.2345,
|
||||||
|
train_duration_seconds=120.5,
|
||||||
|
adapter_path=".loop/retrain/adapters/iter_0001/adapters.npz",
|
||||||
|
model_name="hermes4-14b-ft-0001",
|
||||||
|
notes="First fine-tune cycle complete",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_next_iteration_starts_at_1(self, tmp_path):
|
||||||
|
log = TrainingLog(repo_root=tmp_path)
|
||||||
|
assert log.next_iteration() == 1
|
||||||
|
|
||||||
|
def test_next_iteration_increments(self, tmp_path):
|
||||||
|
log = TrainingLog(repo_root=tmp_path)
|
||||||
|
log.record(self._make_metrics(iteration=1))
|
||||||
|
assert log.next_iteration() == 2
|
||||||
|
|
||||||
|
def test_record_creates_log_file(self, tmp_path):
|
||||||
|
log = TrainingLog(repo_root=tmp_path)
|
||||||
|
log.record(self._make_metrics())
|
||||||
|
assert log.log_path.exists()
|
||||||
|
|
||||||
|
def test_load_all_returns_records(self, tmp_path):
|
||||||
|
log = TrainingLog(repo_root=tmp_path)
|
||||||
|
log.record(self._make_metrics(iteration=1))
|
||||||
|
log.record(self._make_metrics(iteration=2))
|
||||||
|
entries = log.load_all()
|
||||||
|
assert len(entries) == 2
|
||||||
|
assert entries[0]["iteration"] == 1
|
||||||
|
|
||||||
|
def test_latest_returns_last_entry(self, tmp_path):
|
||||||
|
log = TrainingLog(repo_root=tmp_path)
|
||||||
|
log.record(self._make_metrics(iteration=1))
|
||||||
|
log.record(self._make_metrics(iteration=2))
|
||||||
|
latest = log.latest()
|
||||||
|
assert latest is not None
|
||||||
|
assert latest["iteration"] == 2
|
||||||
|
|
||||||
|
def test_latest_returns_none_when_empty(self, tmp_path):
|
||||||
|
log = TrainingLog(repo_root=tmp_path)
|
||||||
|
assert log.latest() is None
|
||||||
|
|
||||||
|
def test_summary_markdown_written(self, tmp_path):
|
||||||
|
log = TrainingLog(repo_root=tmp_path)
|
||||||
|
log.record(self._make_metrics())
|
||||||
|
summary_path = tmp_path / ".loop" / "retrain" / "training_log.md"
|
||||||
|
assert summary_path.exists()
|
||||||
|
content = summary_path.read_text()
|
||||||
|
assert "AutoLoRA Training Log" in content
|
||||||
|
assert "2026-W12" in content
|
||||||
|
assert "completed" in content
|
||||||
|
|
||||||
|
def test_skill_accuracy_in_summary(self, tmp_path):
|
||||||
|
log = TrainingLog(repo_root=tmp_path)
|
||||||
|
m = self._make_metrics()
|
||||||
|
m.skill_accuracy = {"tool_calling": 0.85, "reasoning": 0.72}
|
||||||
|
log.record(m)
|
||||||
|
content = (tmp_path / ".loop" / "retrain" / "training_log.md").read_text()
|
||||||
|
assert "tool_calling" in content
|
||||||
|
assert "reasoning" in content
|
||||||
|
|
||||||
|
|
||||||
|
# ── RetrainOrchestrator integration tests ─────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetrainOrchestrator:
|
||||||
|
def test_run_dry_run_no_data(self, tmp_path):
|
||||||
|
"""Dry run with no session logs should complete without errors."""
|
||||||
|
(tmp_path / "logs").mkdir(parents=True)
|
||||||
|
orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True)
|
||||||
|
result = orc.run(weeks_ago=0)
|
||||||
|
assert result.train_status in ("skipped",)
|
||||||
|
assert result.examples_added == 0
|
||||||
|
assert result.iteration == 1
|
||||||
|
|
||||||
|
def test_run_creates_log_entry(self, tmp_path):
|
||||||
|
(tmp_path / "logs").mkdir(parents=True)
|
||||||
|
orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True)
|
||||||
|
orc.run(weeks_ago=0)
|
||||||
|
log = TrainingLog(repo_root=tmp_path)
|
||||||
|
entries = log.load_all()
|
||||||
|
assert len(entries) == 1
|
||||||
|
|
||||||
|
def test_run_with_session_data(self, tmp_path):
|
||||||
|
"""Run with actual session data — should export, filter, and log."""
|
||||||
|
today = datetime.now(tz=UTC)
|
||||||
|
date_str = today.strftime("%Y-%m-%d")
|
||||||
|
entries = [
|
||||||
|
_user_msg("deploy the service", offset=-10),
|
||||||
|
_tool_call("bash", "deployed successfully", offset=-9),
|
||||||
|
_tool_call("bash", "health check ok", offset=-8),
|
||||||
|
_timmy_msg("Service deployed and healthy", confidence=0.92, offset=-7),
|
||||||
|
_user_msg("run the tests", offset=-6),
|
||||||
|
_tool_call("bash", "All tests passed", offset=-5),
|
||||||
|
_timmy_msg("All 42 tests passed", confidence=0.95, offset=-4),
|
||||||
|
]
|
||||||
|
_make_session_log(entries, date_str, tmp_path)
|
||||||
|
|
||||||
|
orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True)
|
||||||
|
result = orc.run(weeks_ago=0)
|
||||||
|
|
||||||
|
assert result.trajectories_exported >= 1
|
||||||
|
assert result.iteration == 1
|
||||||
|
# In dry_run mode, fine-tune is skipped but trajectories should be processed
|
||||||
|
assert result.train_status == "skipped"
|
||||||
|
|
||||||
|
def test_iteration_increments_on_second_run(self, tmp_path):
|
||||||
|
(tmp_path / "logs").mkdir(parents=True)
|
||||||
|
orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True)
|
||||||
|
r1 = orc.run(weeks_ago=0)
|
||||||
|
r2 = orc.run(weeks_ago=0)
|
||||||
|
assert r2.iteration == r1.iteration + 1
|
||||||
|
|
||||||
|
def test_automations_json_has_retrain_entry(self):
|
||||||
|
"""Verify the retrain automation is registered in automations.json."""
|
||||||
|
config_path = _REPO_ROOT / "timmy_automations" / "config" / "automations.json"
|
||||||
|
assert config_path.exists()
|
||||||
|
manifest = json.loads(config_path.read_text())
|
||||||
|
ids = [a["id"] for a in manifest.get("automations", [])]
|
||||||
|
assert "retrain" in ids
|
||||||
|
|
||||||
|
def test_retrain_automation_config(self):
|
||||||
|
"""Verify retrain automation has correct schedule and config."""
|
||||||
|
config_path = _REPO_ROOT / "timmy_automations" / "config" / "automations.json"
|
||||||
|
manifest = json.loads(config_path.read_text())
|
||||||
|
retrain = next(a for a in manifest["automations"] if a["id"] == "retrain")
|
||||||
|
assert retrain["schedule"] == "weekly_sunday"
|
||||||
|
assert retrain["trigger"] == "scheduled"
|
||||||
|
assert retrain["config"]["base_model"] == "hermes4-14b"
|
||||||
|
assert retrain["config"]["weeks_ago"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
_REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||||
@@ -4,7 +4,7 @@
|
|||||||
"_health_snapshot": {
|
"_health_snapshot": {
|
||||||
"note": "Quick health check before coding — CI, P0/P1 issues, flakiness"
|
"note": "Quick health check before coding — CI, P0/P1 issues, flakiness"
|
||||||
},
|
},
|
||||||
"last_updated": "2026-03-21",
|
"last_updated": "2026-03-23",
|
||||||
"automations": [
|
"automations": [
|
||||||
{
|
{
|
||||||
"id": "cycle_retro",
|
"id": "cycle_retro",
|
||||||
@@ -268,6 +268,36 @@
|
|||||||
"ci_timeout_seconds": 5
|
"ci_timeout_seconds": 5
|
||||||
},
|
},
|
||||||
"outputs": []
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "retrain",
|
||||||
|
"name": "AutoLoRA Continuous Improvement Loop",
|
||||||
|
"description": "Weekly sovereignty loop — exports trajectories, filters quality, appends to training dataset, triggers LoRA fine-tune, loads new adapter, and logs iteration metrics",
|
||||||
|
"script": "timmy_automations/retrain/retrain.py",
|
||||||
|
"category": "autolora",
|
||||||
|
"enabled": true,
|
||||||
|
"trigger": "scheduled",
|
||||||
|
"schedule": "weekly_sunday",
|
||||||
|
"executable": "python3",
|
||||||
|
"epic": "#1091",
|
||||||
|
"pipeline": "AutoLoRA Sovereignty Loop (Step 6 of 7)",
|
||||||
|
"config": {
|
||||||
|
"weeks_ago": 1,
|
||||||
|
"base_model": "hermes4-14b",
|
||||||
|
"dry_run": false,
|
||||||
|
"logs_dir": "logs",
|
||||||
|
"dataset_path": ".loop/retrain/training_data.jsonl",
|
||||||
|
"adapter_dir": ".loop/retrain/adapters",
|
||||||
|
"training_log_path": ".loop/retrain/training_log.jsonl",
|
||||||
|
"training_summary_path": ".loop/retrain/training_log.md"
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
".loop/retrain/training_data.jsonl",
|
||||||
|
".loop/retrain/dataset_index.json",
|
||||||
|
".loop/retrain/training_log.jsonl",
|
||||||
|
".loop/retrain/training_log.md",
|
||||||
|
".loop/retrain/adapters/"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
26
timmy_automations/retrain/__init__.py
Normal file
26
timmy_automations/retrain/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""AutoLoRA continuous improvement loop — sovereignty engine for Timmy.
|
||||||
|
|
||||||
|
Implements the weekly retrain cycle:
|
||||||
|
Work → Record trajectories → Export weekly → Filter quality
|
||||||
|
→ LoRA fine-tune → Load adapter → Model improves → Repeat
|
||||||
|
|
||||||
|
Epic: #1091 — Project Bannerlord
|
||||||
|
Pipeline: AutoLoRA Sovereignty Loop (Step 6 of 7)
|
||||||
|
Refs: #1105
|
||||||
|
"""
|
||||||
|
|
||||||
|
from timmy_automations.retrain.quality_filter import QualityFilter, TrajectoryQuality
|
||||||
|
from timmy_automations.retrain.retrain import RetrainOrchestrator, RetrainResult
|
||||||
|
from timmy_automations.retrain.training_dataset import TrainingDataset
|
||||||
|
from timmy_automations.retrain.training_log import TrainingLog
|
||||||
|
from timmy_automations.retrain.trajectory_exporter import TrajectoryExporter
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"QualityFilter",
|
||||||
|
"RetrainOrchestrator",
|
||||||
|
"RetrainResult",
|
||||||
|
"TrainingDataset",
|
||||||
|
"TrainingLog",
|
||||||
|
"TrajectoryExporter",
|
||||||
|
"TrajectoryQuality",
|
||||||
|
]
|
||||||
262
timmy_automations/retrain/lora_trainer.py
Normal file
262
timmy_automations/retrain/lora_trainer.py
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
"""LoRA trainer — triggers fine-tune job and loads the resulting adapter.
|
||||||
|
|
||||||
|
Supports two backends:
|
||||||
|
1. mlx-lm (default, Apple Silicon) — `mlx_lm.lora` CLI
|
||||||
|
2. Ollama create (adapter packaging into a new Ollama model)
|
||||||
|
|
||||||
|
Graceful degradation: if neither backend is available, logs a warning
|
||||||
|
and returns a skipped result — the rest of the loop continues.
|
||||||
|
|
||||||
|
Refs: #1105
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_DEFAULT_BASE_MODEL = "hermes4-14b"
|
||||||
|
_DEFAULT_ADAPTER_DIR = ".loop/retrain/adapters"
|
||||||
|
_MLX_LM_BIN = "mlx_lm.lora"
|
||||||
|
_OLLAMA_BIN = "ollama"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainResult:
|
||||||
|
"""Result of a LoRA fine-tune run."""
|
||||||
|
|
||||||
|
status: str # "completed" | "skipped" | "failed"
|
||||||
|
adapter_path: str | None
|
||||||
|
model_name: str | None
|
||||||
|
iteration: int
|
||||||
|
duration_seconds: float
|
||||||
|
message: str
|
||||||
|
train_loss: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class LoRATrainer:
|
||||||
|
"""Orchestrates LoRA fine-tuning and adapter loading.
|
||||||
|
|
||||||
|
Workflow:
|
||||||
|
1. Run mlx_lm.lora fine-tune on the training dataset
|
||||||
|
2. Save the resulting adapter to .loop/retrain/adapters/<iteration>/
|
||||||
|
3. Create (or update) an Ollama model that uses the new adapter
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_model: str = _DEFAULT_BASE_MODEL,
|
||||||
|
adapter_dir: str | Path | None = None,
|
||||||
|
repo_root: str | Path | None = None,
|
||||||
|
dry_run: bool = False,
|
||||||
|
):
|
||||||
|
if repo_root is None:
|
||||||
|
repo_root = Path(__file__).resolve().parent.parent.parent
|
||||||
|
self._repo_root = Path(repo_root)
|
||||||
|
|
||||||
|
self._base_model = base_model
|
||||||
|
self._adapter_dir = self._repo_root / (adapter_dir or _DEFAULT_ADAPTER_DIR)
|
||||||
|
self._adapter_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._dry_run = dry_run
|
||||||
|
|
||||||
|
def train(self, dataset_path: Path, iteration: int) -> TrainResult:
|
||||||
|
"""Run LoRA fine-tuning on the dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_path: Path to the JSONL training dataset.
|
||||||
|
iteration: Current fine-tune iteration number (used for naming).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TrainResult with status, adapter path, and metrics.
|
||||||
|
"""
|
||||||
|
started = datetime.now(tz=UTC)
|
||||||
|
|
||||||
|
if not dataset_path.exists() or dataset_path.stat().st_size == 0:
|
||||||
|
return TrainResult(
|
||||||
|
status="skipped",
|
||||||
|
adapter_path=None,
|
||||||
|
model_name=None,
|
||||||
|
iteration=iteration,
|
||||||
|
duration_seconds=0.0,
|
||||||
|
message="Training dataset is empty — skipping fine-tune",
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._dry_run:
|
||||||
|
logger.info("[dry-run] Would fine-tune %s on %s", self._base_model, dataset_path)
|
||||||
|
adapter_path = self._adapter_dir / f"iter_{iteration:04d}" / "adapters.npz"
|
||||||
|
return TrainResult(
|
||||||
|
status="skipped",
|
||||||
|
adapter_path=str(adapter_path),
|
||||||
|
model_name=f"{self._base_model}-ft-{iteration:04d}",
|
||||||
|
iteration=iteration,
|
||||||
|
duration_seconds=0.0,
|
||||||
|
message="dry-run mode — no training performed",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine which backend is available
|
||||||
|
if shutil.which(_MLX_LM_BIN):
|
||||||
|
return self._train_mlx(dataset_path, iteration, started)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"%s not found — skipping LoRA fine-tune (install mlx-lm to enable)",
|
||||||
|
_MLX_LM_BIN,
|
||||||
|
)
|
||||||
|
return TrainResult(
|
||||||
|
status="skipped",
|
||||||
|
adapter_path=None,
|
||||||
|
model_name=None,
|
||||||
|
iteration=iteration,
|
||||||
|
duration_seconds=0.0,
|
||||||
|
message=(
|
||||||
|
f"{_MLX_LM_BIN} not available. "
|
||||||
|
"Install mlx-lm on Apple Silicon to enable LoRA fine-tuning."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _train_mlx(
|
||||||
|
self, dataset_path: Path, iteration: int, started: datetime
|
||||||
|
) -> TrainResult:
|
||||||
|
"""Run mlx_lm.lora fine-tune."""
|
||||||
|
adapter_out = self._adapter_dir / f"iter_{iteration:04d}"
|
||||||
|
adapter_out.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
_MLX_LM_BIN,
|
||||||
|
"--model", self._base_model,
|
||||||
|
"--data", str(dataset_path),
|
||||||
|
"--adapter-path", str(adapter_out),
|
||||||
|
"--train",
|
||||||
|
"--iters", "100",
|
||||||
|
"--batch-size", "1",
|
||||||
|
"--learning-rate", "1e-5",
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info("Starting mlx-lm LoRA fine-tune: iteration %d", iteration)
|
||||||
|
logger.info("Command: %s", " ".join(cmd))
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=3600, # 1 hour max
|
||||||
|
env={**os.environ, "PYTHONUNBUFFERED": "1"},
|
||||||
|
)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
duration = (datetime.now(tz=UTC) - started).total_seconds()
|
||||||
|
return TrainResult(
|
||||||
|
status="failed",
|
||||||
|
adapter_path=None,
|
||||||
|
model_name=None,
|
||||||
|
iteration=iteration,
|
||||||
|
duration_seconds=duration,
|
||||||
|
message="Fine-tune timed out after 1 hour",
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
duration = (datetime.now(tz=UTC) - started).total_seconds()
|
||||||
|
return TrainResult(
|
||||||
|
status="failed",
|
||||||
|
adapter_path=None,
|
||||||
|
model_name=None,
|
||||||
|
iteration=iteration,
|
||||||
|
duration_seconds=duration,
|
||||||
|
message=f"Fine-tune subprocess error: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
duration = (datetime.now(tz=UTC) - started).total_seconds()
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
logger.error("mlx-lm fine-tune failed: %s", result.stderr[:500])
|
||||||
|
return TrainResult(
|
||||||
|
status="failed",
|
||||||
|
adapter_path=None,
|
||||||
|
model_name=None,
|
||||||
|
iteration=iteration,
|
||||||
|
duration_seconds=duration,
|
||||||
|
message=f"mlx_lm.lora exited {result.returncode}: {result.stderr[:300]}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse final train loss from stdout if available
|
||||||
|
train_loss = _parse_train_loss(result.stdout)
|
||||||
|
|
||||||
|
adapter_file = adapter_out / "adapters.npz"
|
||||||
|
model_name = f"{self._base_model}-ft-{iteration:04d}"
|
||||||
|
|
||||||
|
# Attempt to register with Ollama
|
||||||
|
ollama_ok = self._register_ollama_adapter(adapter_out, model_name)
|
||||||
|
if not ollama_ok:
|
||||||
|
logger.warning("Ollama adapter registration failed — adapter saved locally")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Fine-tune complete: iteration=%d loss=%.4f duration=%.1fs adapter=%s",
|
||||||
|
iteration,
|
||||||
|
train_loss or 0.0,
|
||||||
|
duration,
|
||||||
|
adapter_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
return TrainResult(
|
||||||
|
status="completed",
|
||||||
|
adapter_path=str(adapter_file),
|
||||||
|
model_name=model_name,
|
||||||
|
iteration=iteration,
|
||||||
|
duration_seconds=duration,
|
||||||
|
message=f"LoRA fine-tune completed successfully in {duration:.0f}s",
|
||||||
|
train_loss=train_loss,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _register_ollama_adapter(self, adapter_dir: Path, model_name: str) -> bool:
|
||||||
|
"""Create an Ollama model entry for the new adapter.
|
||||||
|
|
||||||
|
Writes a minimal Modelfile and runs `ollama create`.
|
||||||
|
"""
|
||||||
|
if not shutil.which(_OLLAMA_BIN):
|
||||||
|
logger.debug("Ollama not found — skipping adapter registration")
|
||||||
|
return False
|
||||||
|
|
||||||
|
modelfile_content = (
|
||||||
|
f"FROM {self._base_model}\n"
|
||||||
|
f"ADAPTER {adapter_dir}\n"
|
||||||
|
)
|
||||||
|
modelfile_path = adapter_dir / "Modelfile"
|
||||||
|
try:
|
||||||
|
modelfile_path.write_text(modelfile_content)
|
||||||
|
result = subprocess.run(
|
||||||
|
[_OLLAMA_BIN, "create", model_name, "-f", str(modelfile_path)],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=300,
|
||||||
|
)
|
||||||
|
if result.returncode == 0:
|
||||||
|
logger.info("Ollama model registered: %s", model_name)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning("ollama create failed: %s", result.stderr[:200])
|
||||||
|
return False
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Ollama adapter registration error: %s", exc)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_train_loss(stdout: str) -> float | None:
|
||||||
|
"""Extract the final training loss from mlx-lm stdout."""
|
||||||
|
loss: float | None = None
|
||||||
|
for line in stdout.splitlines():
|
||||||
|
line_lower = line.lower()
|
||||||
|
if "train loss" in line_lower or "loss:" in line_lower:
|
||||||
|
parts = line.split()
|
||||||
|
for i, part in enumerate(parts):
|
||||||
|
if "loss" in part.lower() and i + 1 < len(parts):
|
||||||
|
try:
|
||||||
|
loss = float(parts[i + 1].strip(",:"))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return loss
|
||||||
172
timmy_automations/retrain/quality_filter.py
Normal file
172
timmy_automations/retrain/quality_filter.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
"""Quality filter — keeps only high-value trajectories for LoRA training.
|
||||||
|
|
||||||
|
Criteria for a high-quality training example:
|
||||||
|
1. Tool calls succeeded (tool calls present, no error entries)
|
||||||
|
2. Multi-step tasks completed (≥2 messages + ≥1 tool call)
|
||||||
|
3. No low-confidence signals (confidence < 0.5 on any Timmy message)
|
||||||
|
4. Minimum meaningful exchange (≥1 user message + ≥1 Timmy message)
|
||||||
|
|
||||||
|
Refs: #1105
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
from timmy_automations.retrain.trajectory_exporter import Trajectory
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_MIN_CONFIDENCE = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class TrajectoryQuality(StrEnum):
|
||||||
|
"""Quality classification for a trajectory."""
|
||||||
|
|
||||||
|
HIGH = "high" # Multi-step + tool success — ideal training data
|
||||||
|
MEDIUM = "medium" # Single exchange, no errors — acceptable
|
||||||
|
LOW = "low" # Error-prone or trivial — skip
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QualityResult:
|
||||||
|
"""Result of quality assessment for a single trajectory."""
|
||||||
|
|
||||||
|
trajectory: Trajectory
|
||||||
|
quality: TrajectoryQuality
|
||||||
|
score: float
|
||||||
|
reasons: list[str]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_trainable(self) -> bool:
|
||||||
|
return self.quality in (TrajectoryQuality.HIGH, TrajectoryQuality.MEDIUM)
|
||||||
|
|
||||||
|
|
||||||
|
class QualityFilter:
|
||||||
|
"""Filters trajectories to keep only those worth training on.
|
||||||
|
|
||||||
|
Scoring:
|
||||||
|
- +1 pt: base score for any valid clean exchange (no errors)
|
||||||
|
- +3 pts: multi-step task (≥2 messages + ≥1 tool call)
|
||||||
|
- +2 pts: tool calls present and no errors
|
||||||
|
- +1 pt: decision recorded (deliberate choice made)
|
||||||
|
- -2 pts: any error entry
|
||||||
|
- -1 pt: any low-confidence response (confidence < 0.5)
|
||||||
|
|
||||||
|
HIGH ≥ 4, MEDIUM 1–3, LOW ≤ 0
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, min_confidence: float = _MIN_CONFIDENCE):
|
||||||
|
self._min_confidence = min_confidence
|
||||||
|
|
||||||
|
def assess(self, trajectory: Trajectory) -> QualityResult:
|
||||||
|
"""Score and classify a single trajectory."""
|
||||||
|
score = 0.0
|
||||||
|
reasons: list[str] = []
|
||||||
|
|
||||||
|
# Minimum viable exchange check
|
||||||
|
user_msgs = [m for m in trajectory.messages if m.get("role") == "user"]
|
||||||
|
timmy_msgs = [m for m in trajectory.messages if m.get("role") == "timmy"]
|
||||||
|
|
||||||
|
if not user_msgs or not timmy_msgs:
|
||||||
|
return QualityResult(
|
||||||
|
trajectory=trajectory,
|
||||||
|
quality=TrajectoryQuality.LOW,
|
||||||
|
score=0.0,
|
||||||
|
reasons=["Missing user or assistant messages — not a valid exchange"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Multi-step bonus
|
||||||
|
if trajectory.is_multi_step:
|
||||||
|
score += 3.0
|
||||||
|
reasons.append(
|
||||||
|
f"Multi-step task: {trajectory.message_count} messages, "
|
||||||
|
f"{trajectory.tool_call_count} tool calls"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Base score for any clean exchange (user + timmy, no tool call required)
|
||||||
|
if trajectory.error_count == 0:
|
||||||
|
score += 1.0
|
||||||
|
reasons.append("Clean exchange (no errors)")
|
||||||
|
|
||||||
|
# Tool call quality
|
||||||
|
if trajectory.tool_call_count > 0:
|
||||||
|
if trajectory.error_count == 0:
|
||||||
|
score += 2.0
|
||||||
|
reasons.append(
|
||||||
|
f"All {trajectory.tool_call_count} tool call(s) succeeded"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
score -= 2.0
|
||||||
|
reasons.append(
|
||||||
|
f"{trajectory.error_count} error(s) during {trajectory.tool_call_count} tool call(s)"
|
||||||
|
)
|
||||||
|
elif trajectory.error_count > 0:
|
||||||
|
score -= 2.0
|
||||||
|
reasons.append(f"{trajectory.error_count} error(s) with no tool calls")
|
||||||
|
|
||||||
|
# Decision bonus
|
||||||
|
if trajectory.decisions:
|
||||||
|
score += 1.0
|
||||||
|
reasons.append(f"Decisions recorded: {len(trajectory.decisions)}")
|
||||||
|
|
||||||
|
# Confidence penalty
|
||||||
|
low_conf = [
|
||||||
|
m
|
||||||
|
for m in timmy_msgs
|
||||||
|
if m.get("confidence") is not None
|
||||||
|
and m["confidence"] < self._min_confidence
|
||||||
|
]
|
||||||
|
if low_conf:
|
||||||
|
score -= len(low_conf)
|
||||||
|
reasons.append(
|
||||||
|
f"{len(low_conf)} low-confidence response(s) (threshold={self._min_confidence})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Classify
|
||||||
|
if score >= 4.0:
|
||||||
|
quality = TrajectoryQuality.HIGH
|
||||||
|
elif score >= 1.0:
|
||||||
|
quality = TrajectoryQuality.MEDIUM
|
||||||
|
else:
|
||||||
|
quality = TrajectoryQuality.LOW
|
||||||
|
|
||||||
|
return QualityResult(
|
||||||
|
trajectory=trajectory,
|
||||||
|
quality=quality,
|
||||||
|
score=score,
|
||||||
|
reasons=reasons,
|
||||||
|
)
|
||||||
|
|
||||||
|
def filter(
|
||||||
|
self, trajectories: list[Trajectory]
|
||||||
|
) -> tuple[list[QualityResult], dict[str, int]]:
|
||||||
|
"""Assess all trajectories and return trainable ones with stats.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(trainable_results, stats_dict) where stats_dict has keys
|
||||||
|
'total', 'high', 'medium', 'low', 'accepted'.
|
||||||
|
"""
|
||||||
|
results = [self.assess(t) for t in trajectories]
|
||||||
|
trainable = [r for r in results if r.is_trainable]
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
"total": len(results),
|
||||||
|
"high": sum(1 for r in results if r.quality == TrajectoryQuality.HIGH),
|
||||||
|
"medium": sum(1 for r in results if r.quality == TrajectoryQuality.MEDIUM),
|
||||||
|
"low": sum(1 for r in results if r.quality == TrajectoryQuality.LOW),
|
||||||
|
"accepted": len(trainable),
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Quality filter: %d/%d accepted (high=%d medium=%d low=%d)",
|
||||||
|
stats["accepted"],
|
||||||
|
stats["total"],
|
||||||
|
stats["high"],
|
||||||
|
stats["medium"],
|
||||||
|
stats["low"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return trainable, stats
|
||||||
292
timmy_automations/retrain/retrain.py
Normal file
292
timmy_automations/retrain/retrain.py
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""AutoLoRA continuous improvement loop — the sovereignty retrain script.
|
||||||
|
|
||||||
|
Implements the weekly retrain cycle end-to-end:
|
||||||
|
Work → Record trajectories → Export weekly → Filter quality
|
||||||
|
→ LoRA fine-tune → Load adapter → Model improves → Repeat forever
|
||||||
|
|
||||||
|
Run:
|
||||||
|
python3 timmy_automations/retrain/retrain.py
|
||||||
|
python3 timmy_automations/retrain/retrain.py --dry-run
|
||||||
|
python3 timmy_automations/retrain/retrain.py --weeks-ago 1
|
||||||
|
|
||||||
|
Epic: #1091 — Project Bannerlord
|
||||||
|
Pipeline: AutoLoRA Sovereignty Loop (Step 6 of 7)
|
||||||
|
Refs: #1105
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Allow running directly from repo root
|
||||||
|
_REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||||
|
if str(_REPO_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(_REPO_ROOT))
|
||||||
|
|
||||||
|
from timmy_automations.retrain.lora_trainer import LoRATrainer
|
||||||
|
from timmy_automations.retrain.quality_filter import QualityFilter
|
||||||
|
from timmy_automations.retrain.training_dataset import TrainingDataset
|
||||||
|
from timmy_automations.retrain.training_log import CycleMetrics, TrainingLog
|
||||||
|
from timmy_automations.retrain.trajectory_exporter import TrajectoryExporter
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s %(levelname)-8s %(name)s: %(message)s",
|
||||||
|
datefmt="%Y-%m-%dT%H:%M:%S",
|
||||||
|
)
|
||||||
|
logger = logging.getLogger("retrain")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetrainResult:
|
||||||
|
"""Result of a complete retrain cycle."""
|
||||||
|
|
||||||
|
iteration: int
|
||||||
|
week: str
|
||||||
|
trajectories_exported: int
|
||||||
|
trajectories_accepted: int
|
||||||
|
examples_added: int
|
||||||
|
dataset_total: int
|
||||||
|
train_status: str
|
||||||
|
adapter_path: str | None
|
||||||
|
model_name: str | None
|
||||||
|
train_loss: float | None
|
||||||
|
duration_seconds: float
|
||||||
|
notes: str
|
||||||
|
|
||||||
|
|
||||||
|
class RetrainOrchestrator:
|
||||||
|
"""Orchestrates the complete AutoLoRA continuous improvement loop.
|
||||||
|
|
||||||
|
Step 1: Export this week's conversation trajectories from session logs
|
||||||
|
Step 2: Filter for high-quality exchanges
|
||||||
|
Step 3: Append to the training dataset
|
||||||
|
Step 4: Trigger LoRA fine-tune
|
||||||
|
Step 5: Load the new adapter (via Ollama)
|
||||||
|
Step 6: Log iteration, loss, skill accuracy
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_model: str = "hermes4-14b",
|
||||||
|
repo_root: str | Path | None = None,
|
||||||
|
dry_run: bool = False,
|
||||||
|
):
|
||||||
|
if repo_root is None:
|
||||||
|
repo_root = _REPO_ROOT
|
||||||
|
self._repo_root = Path(repo_root)
|
||||||
|
self._dry_run = dry_run
|
||||||
|
|
||||||
|
self.exporter = TrajectoryExporter(repo_root=self._repo_root)
|
||||||
|
self.quality_filter = QualityFilter()
|
||||||
|
self.dataset = TrainingDataset(repo_root=self._repo_root)
|
||||||
|
self.trainer = LoRATrainer(
|
||||||
|
base_model=base_model,
|
||||||
|
repo_root=self._repo_root,
|
||||||
|
dry_run=dry_run,
|
||||||
|
)
|
||||||
|
self.log = TrainingLog(repo_root=self._repo_root)
|
||||||
|
|
||||||
|
def run(self, weeks_ago: int = 1) -> RetrainResult:
|
||||||
|
"""Execute one complete retrain cycle.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weeks_ago: Which week to process. 0 = current week (partial),
|
||||||
|
1 = last week (default, Sunday night run), etc.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RetrainResult with full cycle summary.
|
||||||
|
"""
|
||||||
|
started = datetime.now(tz=UTC)
|
||||||
|
iteration = self.log.next_iteration()
|
||||||
|
|
||||||
|
# Determine ISO week tag
|
||||||
|
from datetime import timedelta
|
||||||
|
now = datetime.now(tz=UTC)
|
||||||
|
target_date = now - timedelta(weeks=weeks_ago)
|
||||||
|
week_tag = f"{target_date.year}-W{target_date.isocalendar().week:02d}"
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"=== AutoLoRA Retrain Cycle %d | Week: %s | dry_run=%s ===",
|
||||||
|
iteration,
|
||||||
|
week_tag,
|
||||||
|
self._dry_run,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 1: Export trajectories
|
||||||
|
logger.info("Step 1: Exporting trajectories for %s...", week_tag)
|
||||||
|
trajectories = self.exporter.export_week(weeks_ago=weeks_ago)
|
||||||
|
logger.info("Exported %d raw trajectories", len(trajectories))
|
||||||
|
|
||||||
|
# Step 2: Quality filter
|
||||||
|
logger.info("Step 2: Applying quality filter...")
|
||||||
|
trainable, filter_stats = self.quality_filter.filter(trajectories)
|
||||||
|
logger.info(
|
||||||
|
"Quality filter: %d/%d accepted (high=%d medium=%d low=%d)",
|
||||||
|
filter_stats["accepted"],
|
||||||
|
filter_stats["total"],
|
||||||
|
filter_stats["high"],
|
||||||
|
filter_stats["medium"],
|
||||||
|
filter_stats["low"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Append to dataset
|
||||||
|
logger.info("Step 3: Appending to training dataset...")
|
||||||
|
append_result = self.dataset.append(trainable, week_tag)
|
||||||
|
logger.info(
|
||||||
|
"Dataset: +%d new examples (%d total)",
|
||||||
|
append_result.new_examples,
|
||||||
|
append_result.total_examples,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 4: LoRA fine-tune
|
||||||
|
logger.info("Step 4: Triggering LoRA fine-tune (iteration=%d)...", iteration)
|
||||||
|
train_result = self.trainer.train(
|
||||||
|
dataset_path=self.dataset.dataset_path,
|
||||||
|
iteration=iteration,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Train result: status=%s loss=%s duration=%.1fs",
|
||||||
|
train_result.status,
|
||||||
|
train_result.train_loss,
|
||||||
|
train_result.duration_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 5 & 6: Log cycle
|
||||||
|
duration = (datetime.now(tz=UTC) - started).total_seconds()
|
||||||
|
metrics = CycleMetrics(
|
||||||
|
iteration=iteration,
|
||||||
|
week=week_tag,
|
||||||
|
ran_at=started.isoformat(),
|
||||||
|
trajectories_total=filter_stats["total"],
|
||||||
|
trajectories_high=filter_stats["high"],
|
||||||
|
trajectories_medium=filter_stats["medium"],
|
||||||
|
trajectories_low=filter_stats["low"],
|
||||||
|
trajectories_accepted=filter_stats["accepted"],
|
||||||
|
examples_added=append_result.new_examples,
|
||||||
|
dataset_total=append_result.total_examples,
|
||||||
|
train_status=train_result.status,
|
||||||
|
train_loss=train_result.train_loss,
|
||||||
|
train_duration_seconds=train_result.duration_seconds,
|
||||||
|
adapter_path=train_result.adapter_path,
|
||||||
|
model_name=train_result.model_name,
|
||||||
|
notes=train_result.message,
|
||||||
|
)
|
||||||
|
self.log.record(metrics)
|
||||||
|
|
||||||
|
result = RetrainResult(
|
||||||
|
iteration=iteration,
|
||||||
|
week=week_tag,
|
||||||
|
trajectories_exported=len(trajectories),
|
||||||
|
trajectories_accepted=filter_stats["accepted"],
|
||||||
|
examples_added=append_result.new_examples,
|
||||||
|
dataset_total=append_result.total_examples,
|
||||||
|
train_status=train_result.status,
|
||||||
|
adapter_path=train_result.adapter_path,
|
||||||
|
model_name=train_result.model_name,
|
||||||
|
train_loss=train_result.train_loss,
|
||||||
|
duration_seconds=duration,
|
||||||
|
notes=train_result.message,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"=== Cycle %d complete: status=%s examples_added=%d total=%.1fs ===",
|
||||||
|
iteration,
|
||||||
|
train_result.status,
|
||||||
|
append_result.new_examples,
|
||||||
|
duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _print_result(result: RetrainResult, as_json: bool = False) -> None:
|
||||||
|
"""Print cycle result to stdout."""
|
||||||
|
if as_json:
|
||||||
|
print(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"iteration": result.iteration,
|
||||||
|
"week": result.week,
|
||||||
|
"trajectories_exported": result.trajectories_exported,
|
||||||
|
"trajectories_accepted": result.trajectories_accepted,
|
||||||
|
"examples_added": result.examples_added,
|
||||||
|
"dataset_total": result.dataset_total,
|
||||||
|
"train_status": result.train_status,
|
||||||
|
"adapter_path": result.adapter_path,
|
||||||
|
"model_name": result.model_name,
|
||||||
|
"train_loss": result.train_loss,
|
||||||
|
"duration_seconds": result.duration_seconds,
|
||||||
|
"notes": result.notes,
|
||||||
|
},
|
||||||
|
indent=2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f" AutoLoRA Retrain — Cycle {result.iteration}")
|
||||||
|
print(f" Week: {result.week}")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
print(f" Trajectories: {result.trajectories_exported} exported, {result.trajectories_accepted} accepted")
|
||||||
|
print(f" Dataset: +{result.examples_added} examples ({result.dataset_total} total)")
|
||||||
|
print(f" Fine-tune: {result.train_status}")
|
||||||
|
if result.train_loss is not None:
|
||||||
|
print(f" Train loss: {result.train_loss:.4f}")
|
||||||
|
if result.model_name:
|
||||||
|
print(f" New model: {result.model_name}")
|
||||||
|
if result.adapter_path:
|
||||||
|
print(f" Adapter: {result.adapter_path}")
|
||||||
|
print(f" Duration: {result.duration_seconds:.1f}s")
|
||||||
|
print(f" Notes: {result.notes}")
|
||||||
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="AutoLoRA continuous improvement loop — sovereignty engine for Timmy"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--weeks-ago",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Which week to process: 0=current (partial), 1=last week (default)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--base-model",
|
||||||
|
default="hermes4-14b",
|
||||||
|
help="Ollama base model name (default: hermes4-14b)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dry-run",
|
||||||
|
action="store_true",
|
||||||
|
help="Export and filter trajectories but skip actual fine-tuning",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--json",
|
||||||
|
action="store_true",
|
||||||
|
dest="as_json",
|
||||||
|
help="Output result as JSON",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
orchestrator = RetrainOrchestrator(
|
||||||
|
base_model=args.base_model,
|
||||||
|
dry_run=args.dry_run,
|
||||||
|
)
|
||||||
|
result = orchestrator.run(weeks_ago=args.weeks_ago)
|
||||||
|
_print_result(result, as_json=args.as_json)
|
||||||
|
|
||||||
|
# Exit 0 even on skipped/failed training — the loop must continue
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
180
timmy_automations/retrain/training_dataset.py
Normal file
180
timmy_automations/retrain/training_dataset.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
"""Training dataset manager — appends filtered trajectories to a JSONL training file.
|
||||||
|
|
||||||
|
Maintains a growing dataset of high-quality conversation examples in the
|
||||||
|
chat-format expected by mlx-lm / HuggingFace fine-tuning pipelines.
|
||||||
|
|
||||||
|
Output format (one JSON object per line):
|
||||||
|
{"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
|
||||||
|
|
||||||
|
Refs: #1105
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from timmy_automations.retrain.quality_filter import QualityResult
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_DEFAULT_DATASET_PATH = ".loop/retrain/training_data.jsonl"
|
||||||
|
_DEFAULT_INDEX_PATH = ".loop/retrain/dataset_index.json"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AppendResult:
|
||||||
|
"""Result of appending trajectories to the training dataset."""
|
||||||
|
|
||||||
|
new_examples: int
|
||||||
|
total_examples: int
|
||||||
|
dataset_path: str
|
||||||
|
week_tag: str
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingDataset:
|
||||||
|
"""Manages the LoRA training dataset file.
|
||||||
|
|
||||||
|
Each entry is a chat-format example:
|
||||||
|
{"messages": [...], "week": "2026-W12", "quality": "high", "added_at": "..."}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset_path: str | Path | None = None,
|
||||||
|
index_path: str | Path | None = None,
|
||||||
|
repo_root: str | Path | None = None,
|
||||||
|
):
|
||||||
|
if repo_root is None:
|
||||||
|
repo_root = Path(__file__).resolve().parent.parent.parent
|
||||||
|
self._repo_root = Path(repo_root)
|
||||||
|
|
||||||
|
self._dataset_path = self._repo_root / (
|
||||||
|
dataset_path or _DEFAULT_DATASET_PATH
|
||||||
|
)
|
||||||
|
self._index_path = self._repo_root / (
|
||||||
|
index_path or _DEFAULT_INDEX_PATH
|
||||||
|
)
|
||||||
|
|
||||||
|
self._dataset_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dataset_path(self) -> Path:
|
||||||
|
return self._dataset_path
|
||||||
|
|
||||||
|
def count(self) -> int:
|
||||||
|
"""Return the number of examples currently in the dataset."""
|
||||||
|
if not self._dataset_path.exists():
|
||||||
|
return 0
|
||||||
|
count = 0
|
||||||
|
with open(self._dataset_path) as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
def append(
|
||||||
|
self, quality_results: list[QualityResult], week_tag: str
|
||||||
|
) -> AppendResult:
|
||||||
|
"""Append high-quality trajectories to the training dataset.
|
||||||
|
|
||||||
|
Deduplicates by (week_tag, session_date, started_at) so re-running
|
||||||
|
the export for the same week is idempotent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quality_results: Filtered, trainable quality results.
|
||||||
|
week_tag: ISO week string e.g. "2026-W12".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AppendResult with counts.
|
||||||
|
"""
|
||||||
|
existing_keys = self._load_existing_keys()
|
||||||
|
new_count = 0
|
||||||
|
added_at = datetime.now(tz=UTC).isoformat()
|
||||||
|
|
||||||
|
with open(self._dataset_path, "a") as f:
|
||||||
|
for result in quality_results:
|
||||||
|
traj = result.trajectory
|
||||||
|
dedup_key = (
|
||||||
|
f"{week_tag}|{traj.session_date}|{traj.started_at}"
|
||||||
|
)
|
||||||
|
if dedup_key in existing_keys:
|
||||||
|
logger.debug("Skipping duplicate trajectory: %s", dedup_key)
|
||||||
|
continue
|
||||||
|
|
||||||
|
chat_messages = traj.to_chat_format()
|
||||||
|
if len(chat_messages) < 2:
|
||||||
|
logger.debug(
|
||||||
|
"Skipping trajectory with %d chat messages (need ≥2)",
|
||||||
|
len(chat_messages),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
record = {
|
||||||
|
"messages": chat_messages,
|
||||||
|
"week": week_tag,
|
||||||
|
"quality": result.quality.value,
|
||||||
|
"score": result.score,
|
||||||
|
"session_date": traj.session_date,
|
||||||
|
"started_at": traj.started_at,
|
||||||
|
"tool_calls": traj.tool_call_count,
|
||||||
|
"added_at": added_at,
|
||||||
|
}
|
||||||
|
f.write(json.dumps(record) + "\n")
|
||||||
|
existing_keys.add(dedup_key)
|
||||||
|
new_count += 1
|
||||||
|
|
||||||
|
total = self.count()
|
||||||
|
self._update_index(week_tag, new_count, total)
|
||||||
|
logger.info(
|
||||||
|
"Dataset: appended %d new examples (total=%d)", new_count, total
|
||||||
|
)
|
||||||
|
|
||||||
|
return AppendResult(
|
||||||
|
new_examples=new_count,
|
||||||
|
total_examples=total,
|
||||||
|
dataset_path=str(self._dataset_path),
|
||||||
|
week_tag=week_tag,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_existing_keys(self) -> set[str]:
|
||||||
|
"""Load deduplication keys from the existing dataset."""
|
||||||
|
keys: set[str] = set()
|
||||||
|
if not self._dataset_path.exists():
|
||||||
|
return keys
|
||||||
|
with open(self._dataset_path) as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
record = json.loads(line)
|
||||||
|
week = record.get("week", "")
|
||||||
|
session_date = record.get("session_date", "")
|
||||||
|
started_at = record.get("started_at", "")
|
||||||
|
keys.add(f"{week}|{session_date}|{started_at}")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
return keys
|
||||||
|
|
||||||
|
def _update_index(self, week_tag: str, new_count: int, total: int) -> None:
|
||||||
|
"""Update the dataset index JSON with latest run metadata."""
|
||||||
|
index: dict = {}
|
||||||
|
if self._index_path.exists():
|
||||||
|
try:
|
||||||
|
index = json.loads(self._index_path.read_text())
|
||||||
|
except (json.JSONDecodeError, OSError):
|
||||||
|
index = {}
|
||||||
|
|
||||||
|
index.setdefault("weeks", {})
|
||||||
|
index["weeks"][week_tag] = {
|
||||||
|
"examples_added": new_count,
|
||||||
|
"updated_at": datetime.now(tz=UTC).isoformat(),
|
||||||
|
}
|
||||||
|
index["total_examples"] = total
|
||||||
|
index["last_updated"] = datetime.now(tz=UTC).isoformat()
|
||||||
|
|
||||||
|
self._index_path.write_text(json.dumps(index, indent=2))
|
||||||
183
timmy_automations/retrain/training_log.py
Normal file
183
timmy_automations/retrain/training_log.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
"""Training log — records each fine-tune cycle with metrics and skill deltas.
|
||||||
|
|
||||||
|
Writes to .loop/retrain/training_log.jsonl (one entry per cycle) and
|
||||||
|
maintains a human-readable .loop/retrain/training_log.md summary.
|
||||||
|
|
||||||
|
Each log entry captures:
|
||||||
|
- Iteration count
|
||||||
|
- Week processed
|
||||||
|
- Quality filter stats
|
||||||
|
- Examples added to dataset
|
||||||
|
- LoRA train result (loss, duration, adapter path)
|
||||||
|
- Skill accuracy deltas (from smoke tests)
|
||||||
|
|
||||||
|
Refs: #1105
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_DEFAULT_LOG_PATH = ".loop/retrain/training_log.jsonl"
|
||||||
|
_DEFAULT_SUMMARY_PATH = ".loop/retrain/training_log.md"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CycleMetrics:
|
||||||
|
"""Metrics for a single retrain cycle."""
|
||||||
|
|
||||||
|
iteration: int
|
||||||
|
week: str
|
||||||
|
ran_at: str
|
||||||
|
|
||||||
|
# Quality filter
|
||||||
|
trajectories_total: int = 0
|
||||||
|
trajectories_high: int = 0
|
||||||
|
trajectories_medium: int = 0
|
||||||
|
trajectories_low: int = 0
|
||||||
|
trajectories_accepted: int = 0
|
||||||
|
|
||||||
|
# Dataset
|
||||||
|
examples_added: int = 0
|
||||||
|
dataset_total: int = 0
|
||||||
|
|
||||||
|
# Training
|
||||||
|
train_status: str = "skipped"
|
||||||
|
train_loss: float | None = None
|
||||||
|
train_duration_seconds: float = 0.0
|
||||||
|
adapter_path: str | None = None
|
||||||
|
model_name: str | None = None
|
||||||
|
|
||||||
|
# Skill accuracy (optional, from smoke tests)
|
||||||
|
skill_accuracy: dict[str, float] = field(default_factory=dict)
|
||||||
|
skill_delta: dict[str, float] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Human-readable summary
|
||||||
|
notes: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingLog:
|
||||||
|
"""Persistent log of all retrain cycles."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
log_path: str | Path | None = None,
|
||||||
|
summary_path: str | Path | None = None,
|
||||||
|
repo_root: str | Path | None = None,
|
||||||
|
):
|
||||||
|
if repo_root is None:
|
||||||
|
repo_root = Path(__file__).resolve().parent.parent.parent
|
||||||
|
self._repo_root = Path(repo_root)
|
||||||
|
|
||||||
|
self._log_path = self._repo_root / (log_path or _DEFAULT_LOG_PATH)
|
||||||
|
self._summary_path = self._repo_root / (summary_path or _DEFAULT_SUMMARY_PATH)
|
||||||
|
self._log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log_path(self) -> Path:
|
||||||
|
return self._log_path
|
||||||
|
|
||||||
|
def next_iteration(self) -> int:
|
||||||
|
"""Return the next iteration number (1-indexed)."""
|
||||||
|
entries = self.load_all()
|
||||||
|
if not entries:
|
||||||
|
return 1
|
||||||
|
return max(e.get("iteration", 0) for e in entries) + 1
|
||||||
|
|
||||||
|
def record(self, metrics: CycleMetrics) -> None:
|
||||||
|
"""Append a cycle metrics record to the log."""
|
||||||
|
entry = asdict(metrics)
|
||||||
|
with open(self._log_path, "a") as f:
|
||||||
|
f.write(json.dumps(entry) + "\n")
|
||||||
|
|
||||||
|
self._update_summary(metrics)
|
||||||
|
logger.info(
|
||||||
|
"Training log: iteration=%d week=%s status=%s examples_added=%d",
|
||||||
|
metrics.iteration,
|
||||||
|
metrics.week,
|
||||||
|
metrics.train_status,
|
||||||
|
metrics.examples_added,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_all(self) -> list[dict[str, Any]]:
|
||||||
|
"""Load all cycle records from the log."""
|
||||||
|
if not self._log_path.exists():
|
||||||
|
return []
|
||||||
|
entries: list[dict[str, Any]] = []
|
||||||
|
with open(self._log_path) as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
entries.append(json.loads(line))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.debug("Skipping malformed log entry")
|
||||||
|
return entries
|
||||||
|
|
||||||
|
def latest(self) -> dict[str, Any] | None:
|
||||||
|
"""Return the most recent cycle record."""
|
||||||
|
entries = self.load_all()
|
||||||
|
return entries[-1] if entries else None
|
||||||
|
|
||||||
|
def _update_summary(self, metrics: CycleMetrics) -> None:
|
||||||
|
"""Rewrite the markdown summary with all cycles."""
|
||||||
|
all_entries = self.load_all()
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
"# AutoLoRA Training Log\n",
|
||||||
|
f"*Updated: {datetime.now(tz=UTC).isoformat()}*\n",
|
||||||
|
f"*Total iterations: {len(all_entries)}*\n",
|
||||||
|
"",
|
||||||
|
"## Cycles\n",
|
||||||
|
"| # | Week | Status | Loss | Examples | Duration |",
|
||||||
|
"|---|------|--------|------|----------|----------|",
|
||||||
|
]
|
||||||
|
|
||||||
|
for entry in reversed(all_entries[-20:]): # Last 20 cycles
|
||||||
|
loss = f"{entry.get('train_loss', 0.0) or 0.0:.4f}" if entry.get("train_loss") else "—"
|
||||||
|
lines.append(
|
||||||
|
f"| {entry.get('iteration', '?')} "
|
||||||
|
f"| {entry.get('week', '?')} "
|
||||||
|
f"| {entry.get('train_status', '?')} "
|
||||||
|
f"| {loss} "
|
||||||
|
f"| +{entry.get('examples_added', 0)} ({entry.get('dataset_total', 0)} total) "
|
||||||
|
f"| {entry.get('train_duration_seconds', 0.0):.0f}s |"
|
||||||
|
)
|
||||||
|
|
||||||
|
lines.append("")
|
||||||
|
lines.append("## Skill Accuracy Over Time\n")
|
||||||
|
|
||||||
|
# Collect all unique skills
|
||||||
|
all_skills: set[str] = set()
|
||||||
|
for entry in all_entries:
|
||||||
|
all_skills.update(entry.get("skill_accuracy", {}).keys())
|
||||||
|
|
||||||
|
if all_skills:
|
||||||
|
skill_header = "| # | Week | " + " | ".join(sorted(all_skills)) + " |"
|
||||||
|
skill_sep = "|---|------|" + "|".join("---" for _ in all_skills) + "|"
|
||||||
|
lines.extend([skill_header, skill_sep])
|
||||||
|
for entry in reversed(all_entries[-10:]):
|
||||||
|
acc = entry.get("skill_accuracy", {})
|
||||||
|
row = f"| {entry.get('iteration', '?')} | {entry.get('week', '?')} | "
|
||||||
|
row += " | ".join(
|
||||||
|
f"{acc.get(s, 0.0):.0%}" if s in acc else "—"
|
||||||
|
for s in sorted(all_skills)
|
||||||
|
)
|
||||||
|
row += " |"
|
||||||
|
lines.append(row)
|
||||||
|
else:
|
||||||
|
lines.append("*No skill accuracy data yet — run smoke tests after fine-tuning.*")
|
||||||
|
|
||||||
|
lines.append("")
|
||||||
|
if metrics.notes:
|
||||||
|
lines.append(f"## Latest Notes\n\n{metrics.notes}\n")
|
||||||
|
|
||||||
|
self._summary_path.write_text("\n".join(lines))
|
||||||
255
timmy_automations/retrain/trajectory_exporter.py
Normal file
255
timmy_automations/retrain/trajectory_exporter.py
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
"""Trajectory exporter — reads session JSONL logs and extracts conversation trajectories.
|
||||||
|
|
||||||
|
A trajectory is a coherent sequence of messages + tool calls that form
|
||||||
|
a single task attempt. Each trajectory becomes one training example.
|
||||||
|
|
||||||
|
Refs: #1105
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_LOGS_DIR_DEFAULT = "logs"
|
||||||
|
_SESSION_GLOB = "session_*.jsonl"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Trajectory:
|
||||||
|
"""A single conversation trajectory extracted from session logs."""
|
||||||
|
|
||||||
|
session_date: str
|
||||||
|
started_at: str
|
||||||
|
ended_at: str
|
||||||
|
messages: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
tool_calls: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
errors: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
decisions: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def message_count(self) -> int:
|
||||||
|
return len(self.messages)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tool_call_count(self) -> int:
|
||||||
|
return len(self.tool_calls)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def error_count(self) -> int:
|
||||||
|
return len(self.errors)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_successful_tool_call(self) -> bool:
|
||||||
|
"""True if any tool call succeeded (no error entry follows it)."""
|
||||||
|
return self.tool_call_count > 0 and self.error_count == 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_multi_step(self) -> bool:
|
||||||
|
"""True if this trajectory involved multiple turns with tool use."""
|
||||||
|
return self.message_count >= 2 and self.tool_call_count >= 1
|
||||||
|
|
||||||
|
def to_chat_format(self) -> list[dict[str, str]]:
|
||||||
|
"""Convert trajectory to chat-format messages for training.
|
||||||
|
|
||||||
|
Interleaves messages and tool-call results as assistant/tool turns.
|
||||||
|
"""
|
||||||
|
chat: list[dict[str, str]] = []
|
||||||
|
# Merge all entries by timestamp and emit in order
|
||||||
|
all_entries = sorted(
|
||||||
|
self.messages + self.tool_calls + self.decisions,
|
||||||
|
key=lambda e: e.get("timestamp", ""),
|
||||||
|
)
|
||||||
|
for entry in all_entries:
|
||||||
|
etype = entry.get("type")
|
||||||
|
if etype == "message":
|
||||||
|
role = "user" if entry.get("role") == "user" else "assistant"
|
||||||
|
content = entry.get("content", "")
|
||||||
|
if content:
|
||||||
|
chat.append({"role": role, "content": content})
|
||||||
|
elif etype == "tool_call":
|
||||||
|
tool = entry.get("tool", "unknown")
|
||||||
|
result = entry.get("result", "")
|
||||||
|
chat.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": f"[tool:{tool}] {result}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif etype == "decision":
|
||||||
|
decision = entry.get("decision", "")
|
||||||
|
if decision:
|
||||||
|
chat.append({"role": "assistant", "content": f"[decided] {decision}"})
|
||||||
|
return chat
|
||||||
|
|
||||||
|
|
||||||
|
class TrajectoryExporter:
|
||||||
|
"""Reads session JSONL logs and yields Trajectory objects for a date range."""
|
||||||
|
|
||||||
|
def __init__(self, logs_dir: str | Path | None = None, repo_root: str | Path | None = None):
|
||||||
|
if repo_root is None:
|
||||||
|
repo_root = Path(__file__).resolve().parent.parent.parent
|
||||||
|
self._repo_root = Path(repo_root)
|
||||||
|
|
||||||
|
if logs_dir is None:
|
||||||
|
self._logs_dir = self._repo_root / _LOGS_DIR_DEFAULT
|
||||||
|
else:
|
||||||
|
self._logs_dir = Path(logs_dir)
|
||||||
|
|
||||||
|
def export_week(self, weeks_ago: int = 0) -> list[Trajectory]:
|
||||||
|
"""Export all trajectories from the specified week.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weeks_ago: 0 = current week, 1 = last week, etc.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Trajectory objects extracted from session logs.
|
||||||
|
"""
|
||||||
|
now = datetime.now(tz=UTC)
|
||||||
|
# Week boundaries: Mon–Sun
|
||||||
|
days_since_monday = now.weekday()
|
||||||
|
week_start = (now - timedelta(days=days_since_monday + 7 * weeks_ago)).replace(
|
||||||
|
hour=0, minute=0, second=0, microsecond=0
|
||||||
|
)
|
||||||
|
week_end = week_start + timedelta(days=7)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Exporting trajectories for week %s–%s",
|
||||||
|
week_start.date().isoformat(),
|
||||||
|
week_end.date().isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
trajectories: list[Trajectory] = []
|
||||||
|
log_files = sorted(self._logs_dir.glob(_SESSION_GLOB))
|
||||||
|
|
||||||
|
for log_file in log_files:
|
||||||
|
# Parse date from filename: session_YYYY-MM-DD.jsonl
|
||||||
|
try:
|
||||||
|
date_str = log_file.stem.removeprefix("session_")
|
||||||
|
file_date = datetime.strptime(date_str, "%Y-%m-%d").replace(tzinfo=UTC)
|
||||||
|
except ValueError:
|
||||||
|
logger.debug("Skipping non-date session file: %s", log_file.name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not (week_start <= file_date < week_end):
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_trajectories = self._extract_from_file(log_file)
|
||||||
|
trajectories.extend(file_trajectories)
|
||||||
|
logger.info(
|
||||||
|
"Extracted %d trajectories from %s", len(file_trajectories), log_file.name
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Total trajectories exported: %d", len(trajectories))
|
||||||
|
return trajectories
|
||||||
|
|
||||||
|
def _extract_from_file(self, log_file: Path) -> list[Trajectory]:
|
||||||
|
"""Parse a single session JSONL file into trajectories.
|
||||||
|
|
||||||
|
Groups entries into trajectories by finding natural conversation
|
||||||
|
boundaries (gaps of inactivity or topic shifts in the message stream).
|
||||||
|
"""
|
||||||
|
entries: list[dict[str, Any]] = []
|
||||||
|
try:
|
||||||
|
with open(log_file) as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
entries.append(json.loads(line))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.debug("Skipping malformed JSON line in %s", log_file.name)
|
||||||
|
except OSError as exc:
|
||||||
|
logger.warning("Could not read %s: %s", log_file, exc)
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not entries:
|
||||||
|
return []
|
||||||
|
|
||||||
|
date_str = log_file.stem.removeprefix("session_")
|
||||||
|
return self._segment_trajectories(entries, date_str)
|
||||||
|
|
||||||
|
def _segment_trajectories(
|
||||||
|
self, entries: list[dict[str, Any]], session_date: str
|
||||||
|
) -> list[Trajectory]:
|
||||||
|
"""Split a flat list of session entries into discrete trajectories.
|
||||||
|
|
||||||
|
Segmentation rule: start a new trajectory when:
|
||||||
|
- A user message follows a Timmy message (new conversation turn)
|
||||||
|
- More than 5 minutes have elapsed between entries
|
||||||
|
|
||||||
|
This produces training examples that are coherent task attempts.
|
||||||
|
"""
|
||||||
|
if not entries:
|
||||||
|
return []
|
||||||
|
|
||||||
|
trajectories: list[Trajectory] = []
|
||||||
|
current_entries: list[dict[str, Any]] = []
|
||||||
|
prev_ts: datetime | None = None
|
||||||
|
_SEGMENT_GAP_MINUTES = 5
|
||||||
|
|
||||||
|
def _flush() -> None:
|
||||||
|
if current_entries:
|
||||||
|
traj = _build_trajectory(current_entries, session_date)
|
||||||
|
if traj.message_count > 0:
|
||||||
|
trajectories.append(traj)
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
ts_raw = entry.get("timestamp", "")
|
||||||
|
try:
|
||||||
|
ts = datetime.fromisoformat(ts_raw.replace("Z", "+00:00"))
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
ts = None
|
||||||
|
|
||||||
|
# Time-gap segmentation
|
||||||
|
if ts and prev_ts and (ts - prev_ts).total_seconds() > _SEGMENT_GAP_MINUTES * 60:
|
||||||
|
_flush()
|
||||||
|
current_entries = []
|
||||||
|
|
||||||
|
# New-turn segmentation: user message after assistant turn
|
||||||
|
etype = entry.get("type")
|
||||||
|
erole = entry.get("role")
|
||||||
|
if etype == "message" and erole == "user" and current_entries:
|
||||||
|
# Check if previous non-error entry was a Timmy message
|
||||||
|
for prev in reversed(current_entries):
|
||||||
|
if prev.get("type") == "message":
|
||||||
|
if prev.get("role") == "timmy":
|
||||||
|
_flush()
|
||||||
|
current_entries = []
|
||||||
|
break
|
||||||
|
|
||||||
|
current_entries.append(entry)
|
||||||
|
if ts:
|
||||||
|
prev_ts = ts
|
||||||
|
|
||||||
|
_flush()
|
||||||
|
return trajectories
|
||||||
|
|
||||||
|
|
||||||
|
def _build_trajectory(entries: list[dict[str, Any]], session_date: str) -> Trajectory:
|
||||||
|
"""Build a Trajectory from a flat list of entries."""
|
||||||
|
messages = [e for e in entries if e.get("type") == "message"]
|
||||||
|
tool_calls = [e for e in entries if e.get("type") == "tool_call"]
|
||||||
|
errors = [e for e in entries if e.get("type") == "error"]
|
||||||
|
decisions = [e for e in entries if e.get("type") == "decision"]
|
||||||
|
|
||||||
|
timestamps = [e.get("timestamp", "") for e in entries if e.get("timestamp")]
|
||||||
|
started_at = min(timestamps) if timestamps else ""
|
||||||
|
ended_at = max(timestamps) if timestamps else ""
|
||||||
|
|
||||||
|
return Trajectory(
|
||||||
|
session_date=session_date,
|
||||||
|
started_at=started_at,
|
||||||
|
ended_at=ended_at,
|
||||||
|
messages=messages,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
errors=errors,
|
||||||
|
decisions=decisions,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user