Compare commits
7 Commits
claude/iss
...
gemini/iss
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
393cf3a2e1 | ||
|
|
0331e0e5bb | ||
| 1be1324a0d | |||
| 32a5b092d0 | |||
| 6f404c99f2 | |||
| 300d9575f1 | |||
| 510d890eb2 |
230
docs/research/bannerlord-vm-setup.md
Normal file
230
docs/research/bannerlord-vm-setup.md
Normal file
@@ -0,0 +1,230 @@
|
||||
# Bannerlord Windows VM Setup Guide
|
||||
|
||||
**Issue:** #1098
|
||||
**Parent Epic:** #1091 (Project Bannerlord)
|
||||
**Date:** 2026-03-23
|
||||
**Status:** Reference
|
||||
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
This document covers provisioning the Windows VM that hosts Bannerlord + GABS mod,
|
||||
verifying the GABS TCP JSON-RPC server, and confirming connectivity from Hermes.
|
||||
|
||||
Architecture reminder:
|
||||
```
|
||||
Timmy (Qwen3 on Ollama, Hermes M3 Max)
|
||||
→ GABS TCP/JSON-RPC (port 4825)
|
||||
→ Bannerlord.GABS C# mod
|
||||
→ Game API + Harmony
|
||||
→ Bannerlord (Windows VM)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 1. Provision Windows VM
|
||||
|
||||
### Minimum Spec
|
||||
| Resource | Minimum | Recommended |
|
||||
|----------|---------|-------------|
|
||||
| CPU | 4 cores | 8 cores |
|
||||
| RAM | 16 GB | 32 GB |
|
||||
| Disk | 100 GB SSD | 150 GB SSD |
|
||||
| OS | Windows Server 2022 / Windows 11 | Windows 11 |
|
||||
| Network | Private VLAN to Hermes | Private VLAN to Hermes |
|
||||
|
||||
### Hetzner (preferred)
|
||||
```powershell
|
||||
# Hetzner Cloud CLI — create CX41 (4 vCPU, 16 GB RAM, 160 GB SSD)
|
||||
hcloud server create \
|
||||
--name bannerlord-vm \
|
||||
--type cx41 \
|
||||
--image windows-server-2022 \
|
||||
--location nbg1 \
|
||||
--ssh-key your-key
|
||||
```
|
||||
|
||||
### DigitalOcean alternative
|
||||
```
|
||||
Droplet: General Purpose 4 vCPU / 16 GB / 100 GB SSD
|
||||
Image: Windows Server 2022
|
||||
Region: Same region as Hermes
|
||||
```
|
||||
|
||||
### Post-provision
|
||||
1. Enable RDP (port 3389) for initial setup only — close after configuration
|
||||
2. Open port 4825 TCP inbound from Hermes IP only
|
||||
3. Disable Windows Firewall for 4825 or add specific allow rule:
|
||||
```powershell
|
||||
New-NetFirewallRule -DisplayName "GABS TCP" -Direction Inbound `
|
||||
-Protocol TCP -LocalPort 4825 -Action Allow
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 2. Install Steam + Bannerlord
|
||||
|
||||
### Steam installation
|
||||
1. Download Steam installer from store.steampowered.com
|
||||
2. Install silently:
|
||||
```powershell
|
||||
.\SteamSetup.exe /S
|
||||
```
|
||||
3. Log in with a dedicated Steam account (not personal)
|
||||
|
||||
### Bannerlord installation
|
||||
```powershell
|
||||
# Install Bannerlord (App ID: 261550) via SteamCMD
|
||||
steamcmd +login <user> <pass> +app_update 261550 validate +quit
|
||||
```
|
||||
|
||||
### Pin game version
|
||||
GABS requires a specific Bannerlord version. To pin and prevent auto-updates:
|
||||
1. Right-click Bannerlord in Steam → Properties → Updates
|
||||
2. Set "Automatic Updates" to "Only update this game when I launch it"
|
||||
3. Record the current version in `docs/research/bannerlord-vm-setup.md` after installation
|
||||
|
||||
```powershell
|
||||
# Check installed version
|
||||
Get-Content "C:\Program Files (x86)\Steam\steamapps\appmanifest_261550.acf" |
|
||||
Select-String "buildid"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. Install GABS Mod
|
||||
|
||||
### Source
|
||||
- NexusMods: https://www.nexusmods.com/mountandblade2bannerlord/mods/10419
|
||||
- GitHub: https://github.com/BUTR/Bannerlord.GABS
|
||||
- AGENTS.md: https://github.com/BUTR/Bannerlord.GABS/blob/master/AGENTS.md
|
||||
|
||||
### Installation via Vortex (NexusMods)
|
||||
1. Install Vortex Mod Manager
|
||||
2. Download GABS mod package from NexusMods
|
||||
3. Install via Vortex — it handles the Modules/ directory layout automatically
|
||||
4. Enable in the mod list and set load order after Harmony
|
||||
|
||||
### Manual installation
|
||||
```powershell
|
||||
# Copy mod to Bannerlord Modules directory
|
||||
$BannerlordPath = "C:\Program Files (x86)\Steam\steamapps\common\Mount & Blade II Bannerlord"
|
||||
Copy-Item -Recurse ".\Bannerlord.GABS" "$BannerlordPath\Modules\Bannerlord.GABS"
|
||||
```
|
||||
|
||||
### Required dependencies
|
||||
- **Harmony** (BUTR.Harmony) — must load before GABS
|
||||
- **ButterLib** — utility library
|
||||
Install via the same method as GABS.
|
||||
|
||||
### GABS configuration
|
||||
GABS TCP server listens on `0.0.0.0:4825` by default. To confirm or override:
|
||||
```
|
||||
%APPDATA%\Mount and Blade II Bannerlord\Configs\Bannerlord.GABS\settings.json
|
||||
```
|
||||
Expected defaults:
|
||||
```json
|
||||
{
|
||||
"ServerHost": "0.0.0.0",
|
||||
"ServerPort": 4825,
|
||||
"LogLevel": "Information"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. Verify GABS TCP Server
|
||||
|
||||
### Start Bannerlord with GABS
|
||||
Launch Bannerlord with the mod enabled. GABS starts its TCP server during game
|
||||
initialisation. Watch the game log for:
|
||||
```
|
||||
[GABS] TCP server listening on 0.0.0.0:4825
|
||||
```
|
||||
|
||||
Log location:
|
||||
```
|
||||
%APPDATA%\Mount and Blade II Bannerlord\logs\rgl_log_*.txt
|
||||
```
|
||||
|
||||
### Local connectivity check (on VM)
|
||||
```powershell
|
||||
# Verify port is listening
|
||||
netstat -an | findstr 4825
|
||||
|
||||
# Quick TCP probe
|
||||
Test-NetConnection -ComputerName localhost -Port 4825
|
||||
```
|
||||
|
||||
### Send a test JSON-RPC call
|
||||
```powershell
|
||||
$msg = '{"jsonrpc":"2.0","method":"ping","id":1}'
|
||||
$client = New-Object System.Net.Sockets.TcpClient("localhost", 4825)
|
||||
$stream = $client.GetStream()
|
||||
$writer = New-Object System.IO.StreamWriter($stream)
|
||||
$writer.AutoFlush = $true
|
||||
$writer.WriteLine($msg)
|
||||
$reader = New-Object System.IO.StreamReader($stream)
|
||||
$response = $reader.ReadLine()
|
||||
Write-Host "Response: $response"
|
||||
$client.Close()
|
||||
```
|
||||
|
||||
Expected response shape:
|
||||
```json
|
||||
{"jsonrpc":"2.0","result":{"status":"ok"},"id":1}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. Test Connectivity from Hermes
|
||||
|
||||
Use `scripts/test_gabs_connectivity.py` (checked in with this issue):
|
||||
|
||||
```bash
|
||||
# From Hermes (M3 Max)
|
||||
python scripts/test_gabs_connectivity.py --host <VM_IP> --port 4825
|
||||
```
|
||||
|
||||
The script tests:
|
||||
1. TCP socket connection
|
||||
2. JSON-RPC ping round-trip
|
||||
3. `get_game_state` call
|
||||
4. Response latency (target < 100 ms on LAN)
|
||||
|
||||
---
|
||||
|
||||
## 6. Firewall / Network Summary
|
||||
|
||||
| Source | Destination | Port | Protocol | Purpose |
|
||||
|--------|-------------|------|----------|---------|
|
||||
| Hermes (local) | Bannerlord VM | 4825 | TCP | GABS JSON-RPC |
|
||||
| Admin workstation | Bannerlord VM | 3389 | TCP | RDP setup (disable after) |
|
||||
|
||||
---
|
||||
|
||||
## 7. Reproducibility Checklist
|
||||
|
||||
After completing setup, record:
|
||||
|
||||
- [ ] VM provider + region + instance type
|
||||
- [ ] Windows version + build number
|
||||
- [ ] Steam account used (non-personal, credentials in secrets manager)
|
||||
- [ ] Bannerlord App version (buildid from appmanifest)
|
||||
- [ ] GABS version (from NexusMods or GitHub release tag)
|
||||
- [ ] Harmony version
|
||||
- [ ] ButterLib version
|
||||
- [ ] GABS settings.json contents
|
||||
- [ ] VM IP address (update Timmy config)
|
||||
- [ ] Connectivity test output from `test_gabs_connectivity.py`
|
||||
|
||||
---
|
||||
|
||||
## References
|
||||
|
||||
- GABS GitHub: https://github.com/BUTR/Bannerlord.GABS
|
||||
- GABS AGENTS.md: https://github.com/BUTR/Bannerlord.GABS/blob/master/AGENTS.md
|
||||
- NexusMods page: https://www.nexusmods.com/mountandblade2bannerlord/mods/10419
|
||||
- Parent Epic: #1091
|
||||
- Connectivity test script: `scripts/test_gabs_connectivity.py`
|
||||
@@ -14,7 +14,6 @@ repository = "http://localhost:3000/rockachopa/Timmy-time-dashboard"
|
||||
packages = [
|
||||
{ include = "config.py", from = "src" },
|
||||
|
||||
{ include = "bannerlord", from = "src" },
|
||||
{ include = "dashboard", from = "src" },
|
||||
{ include = "infrastructure", from = "src" },
|
||||
{ include = "integrations", from = "src" },
|
||||
|
||||
333
scripts/export_trajectories.py
Normal file
333
scripts/export_trajectories.py
Normal file
@@ -0,0 +1,333 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Export Timmy session logs as LoRA training data (ChatML JSONL).
|
||||
|
||||
Reads session JSONL files written by ``SessionLogger`` and converts them into
|
||||
conversation pairs suitable for fine-tuning with ``mlx_lm.lora``.
|
||||
|
||||
Output format — one JSON object per line::
|
||||
|
||||
{"messages": [
|
||||
{"role": "system", "content": "<Timmy system prompt>"},
|
||||
{"role": "user", "content": "<user turn>"},
|
||||
{"role": "assistant", "content": "<timmy response, with tool calls embedded>"}
|
||||
]}
|
||||
|
||||
Tool calls that appear between a user turn and the next assistant message are
|
||||
embedded in the assistant content using the Hermes 4 ``<tool_call>`` XML format
|
||||
so the fine-tuned model learns both when to call tools and what JSON to emit.
|
||||
|
||||
Usage::
|
||||
|
||||
# Export all session logs (default paths)
|
||||
python scripts/export_trajectories.py
|
||||
|
||||
# Custom source / destination
|
||||
python scripts/export_trajectories.py \\
|
||||
--logs-dir ~/custom-logs \\
|
||||
--output ~/timmy-training-data.jsonl \\
|
||||
--min-turns 2 \\
|
||||
--verbose
|
||||
|
||||
Epic: #1091 Project Bannerlord — AutoLoRA Sovereignty Loop (Step 3 of 7)
|
||||
Refs: #1103
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Constants ─────────────────────────────────────────────────────────────────
|
||||
|
||||
TIMMY_SYSTEM_PROMPT = (
|
||||
"You are Timmy, Alexander's personal AI agent running on a local Mac. "
|
||||
"You are concise, direct, and action-oriented. "
|
||||
"You have access to a broad set of tools — use them proactively. "
|
||||
"When you need to call a tool, output it in this format:\n"
|
||||
"<tool_call>\n"
|
||||
'{"name": "function_name", "arguments": {"param": "value"}}\n'
|
||||
"</tool_call>\n\n"
|
||||
"Always provide structured, accurate responses."
|
||||
)
|
||||
|
||||
# ── Entry grouping ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _load_entries(logs_dir: Path) -> list[dict[str, Any]]:
|
||||
"""Load all session log entries, sorted chronologically."""
|
||||
entries: list[dict[str, Any]] = []
|
||||
log_files = sorted(logs_dir.glob("session_*.jsonl"))
|
||||
for log_file in log_files:
|
||||
try:
|
||||
with open(log_file) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
entries.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Skipping malformed line in %s", log_file.name)
|
||||
except OSError as exc:
|
||||
logger.warning("Cannot read %s: %s", log_file, exc)
|
||||
return entries
|
||||
|
||||
|
||||
def _format_tool_call(entry: dict[str, Any]) -> str:
|
||||
"""Render a tool_call entry as a Hermes 4 <tool_call> XML block."""
|
||||
payload = {"name": entry.get("tool", "unknown"), "arguments": entry.get("args", {})}
|
||||
return f"<tool_call>\n{json.dumps(payload)}\n</tool_call>"
|
||||
|
||||
|
||||
def _format_tool_result(entry: dict[str, Any]) -> str:
|
||||
"""Render a tool result observation."""
|
||||
result = entry.get("result", "")
|
||||
tool = entry.get("tool", "unknown")
|
||||
return f"<tool_response>\n{{\"name\": \"{tool}\", \"result\": {json.dumps(result)}}}\n</tool_response>"
|
||||
|
||||
|
||||
def _group_into_turns(entries: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Group raw session entries into (user_text, assistant_parts) turn pairs.
|
||||
|
||||
Returns a list of dicts with keys:
|
||||
``user`` - user message content
|
||||
``assistant`` - assembled assistant content (responses + tool calls)
|
||||
"""
|
||||
turns: list[dict[str, Any]] = []
|
||||
pending_user: str | None = None
|
||||
assistant_parts: list[str] = []
|
||||
|
||||
for entry in entries:
|
||||
etype = entry.get("type", "")
|
||||
role = entry.get("role", "")
|
||||
|
||||
if etype == "message" and role == "user":
|
||||
# Flush any open turn
|
||||
if pending_user is not None and assistant_parts:
|
||||
turns.append(
|
||||
{
|
||||
"user": pending_user,
|
||||
"assistant": "\n".join(assistant_parts).strip(),
|
||||
}
|
||||
)
|
||||
elif pending_user is not None:
|
||||
# User message with no assistant response — discard
|
||||
pass
|
||||
pending_user = entry.get("content", "").strip()
|
||||
assistant_parts = []
|
||||
|
||||
elif etype == "message" and role == "timmy":
|
||||
if pending_user is not None:
|
||||
content = entry.get("content", "").strip()
|
||||
if content:
|
||||
assistant_parts.append(content)
|
||||
|
||||
elif etype == "tool_call":
|
||||
if pending_user is not None:
|
||||
assistant_parts.append(_format_tool_call(entry))
|
||||
# Also append tool result as context so model learns the full loop
|
||||
if entry.get("result"):
|
||||
assistant_parts.append(_format_tool_result(entry))
|
||||
|
||||
# decision / error entries are skipped — they are meta-data, not conversation
|
||||
|
||||
# Flush final open turn
|
||||
if pending_user is not None and assistant_parts:
|
||||
turns.append(
|
||||
{
|
||||
"user": pending_user,
|
||||
"assistant": "\n".join(assistant_parts).strip(),
|
||||
}
|
||||
)
|
||||
|
||||
return turns
|
||||
|
||||
|
||||
# ── Conversion ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def turns_to_training_examples(
|
||||
turns: list[dict[str, Any]],
|
||||
system_prompt: str = TIMMY_SYSTEM_PROMPT,
|
||||
min_assistant_len: int = 10,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert grouped turns into mlx-lm training examples.
|
||||
|
||||
Each example has a ``messages`` list in ChatML order:
|
||||
``[system, user, assistant]``.
|
||||
|
||||
Args:
|
||||
turns: Output of ``_group_into_turns``.
|
||||
system_prompt: System prompt prepended to every example.
|
||||
min_assistant_len: Skip examples where the assistant turn is shorter
|
||||
than this many characters (filters out empty/trivial turns).
|
||||
|
||||
Returns:
|
||||
List of training example dicts.
|
||||
"""
|
||||
examples: list[dict[str, Any]] = []
|
||||
for turn in turns:
|
||||
assistant_text = turn.get("assistant", "").strip()
|
||||
user_text = turn.get("user", "").strip()
|
||||
if not user_text or len(assistant_text) < min_assistant_len:
|
||||
continue
|
||||
examples.append(
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_text},
|
||||
{"role": "assistant", "content": assistant_text},
|
||||
]
|
||||
}
|
||||
)
|
||||
return examples
|
||||
|
||||
|
||||
def export_training_data(
|
||||
logs_dir: Path,
|
||||
output_path: Path,
|
||||
min_turns: int = 1,
|
||||
min_assistant_len: int = 10,
|
||||
verbose: bool = False,
|
||||
) -> int:
|
||||
"""Full export pipeline: load → group → convert → write.
|
||||
|
||||
Args:
|
||||
logs_dir: Directory containing ``session_*.jsonl`` files.
|
||||
output_path: Destination ``.jsonl`` file for training data.
|
||||
min_turns: Minimum number of turns required (used for logging only).
|
||||
min_assistant_len: Minimum assistant response length to include.
|
||||
verbose: Print progress to stdout.
|
||||
|
||||
Returns:
|
||||
Number of training examples written.
|
||||
"""
|
||||
if verbose:
|
||||
print(f"Loading session logs from: {logs_dir}")
|
||||
|
||||
entries = _load_entries(logs_dir)
|
||||
if verbose:
|
||||
print(f" Loaded {len(entries)} raw entries")
|
||||
|
||||
turns = _group_into_turns(entries)
|
||||
if verbose:
|
||||
print(f" Grouped into {len(turns)} conversation turns")
|
||||
|
||||
examples = turns_to_training_examples(
|
||||
turns, min_assistant_len=min_assistant_len
|
||||
)
|
||||
if verbose:
|
||||
print(f" Generated {len(examples)} training examples")
|
||||
|
||||
if not examples:
|
||||
print("WARNING: No training examples generated. Check that session logs exist.")
|
||||
return 0
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w") as f:
|
||||
for ex in examples:
|
||||
f.write(json.dumps(ex) + "\n")
|
||||
|
||||
if verbose:
|
||||
print(f" Wrote {len(examples)} examples → {output_path}")
|
||||
|
||||
return len(examples)
|
||||
|
||||
|
||||
# ── CLI ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _default_logs_dir() -> Path:
|
||||
"""Return default logs directory (repo root / logs)."""
|
||||
# Walk up from this script to find repo root (contains pyproject.toml)
|
||||
candidate = Path(__file__).resolve().parent
|
||||
for _ in range(5):
|
||||
candidate = candidate.parent
|
||||
if (candidate / "pyproject.toml").exists():
|
||||
return candidate / "logs"
|
||||
return Path.home() / "logs"
|
||||
|
||||
|
||||
def _default_output_path() -> Path:
|
||||
return Path.home() / "timmy-training-data.jsonl"
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Export Timmy session logs as LoRA training data (ChatML JSONL)",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
type=Path,
|
||||
default=_default_logs_dir(),
|
||||
help="Directory containing session_*.jsonl files (default: <repo>/logs)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
default=_default_output_path(),
|
||||
help="Output JSONL path (default: ~/timmy-training-data.jsonl)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-turns",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Minimum turns to process (informational, default: 1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-assistant-len",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Minimum assistant response length in chars (default: 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
"-v",
|
||||
action="store_true",
|
||||
help="Print progress information",
|
||||
)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.verbose else logging.WARNING,
|
||||
format="%(levelname)s: %(message)s",
|
||||
)
|
||||
|
||||
if not args.logs_dir.exists():
|
||||
print(f"ERROR: Logs directory not found: {args.logs_dir}")
|
||||
print("Run the Timmy dashboard first to generate session logs.")
|
||||
return 1
|
||||
|
||||
count = export_training_data(
|
||||
logs_dir=args.logs_dir,
|
||||
output_path=args.output,
|
||||
min_turns=args.min_turns,
|
||||
min_assistant_len=args.min_assistant_len,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
|
||||
if count > 0:
|
||||
print(f"Exported {count} training examples to: {args.output}")
|
||||
print()
|
||||
print("Next steps:")
|
||||
print(f" mkdir -p ~/timmy-lora-training")
|
||||
print(f" cp {args.output} ~/timmy-lora-training/train.jsonl")
|
||||
print(f" python scripts/lora_finetune.py --data ~/timmy-lora-training")
|
||||
else:
|
||||
print("No training examples exported.")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
399
scripts/lora_finetune.py
Normal file
399
scripts/lora_finetune.py
Normal file
@@ -0,0 +1,399 @@
|
||||
#!/usr/bin/env python3
|
||||
"""LoRA fine-tuning launcher for Hermes 4 on Timmy trajectory data.
|
||||
|
||||
Wraps ``mlx_lm.lora`` with project-specific defaults and pre-flight checks.
|
||||
Requires Apple Silicon (M-series) and the ``mlx-lm`` package.
|
||||
|
||||
Usage::
|
||||
|
||||
# Minimal — uses defaults (expects data in ~/timmy-lora-training/)
|
||||
python scripts/lora_finetune.py
|
||||
|
||||
# Custom model path and data
|
||||
python scripts/lora_finetune.py \\
|
||||
--model /path/to/hermes4-mlx \\
|
||||
--data ~/timmy-lora-training \\
|
||||
--iters 500 \\
|
||||
--adapter-path ~/timmy-lora-adapter
|
||||
|
||||
# Dry run (print command, don't execute)
|
||||
python scripts/lora_finetune.py --dry-run
|
||||
|
||||
# After training, test with the adapter
|
||||
python scripts/lora_finetune.py --test \\
|
||||
--prompt "List the open PRs on the Timmy Time Dashboard repo"
|
||||
|
||||
# Fuse adapter into base model for Ollama import
|
||||
python scripts/lora_finetune.py --fuse \\
|
||||
--save-path ~/timmy-fused-model
|
||||
|
||||
Typical workflow::
|
||||
|
||||
# 1. Export trajectories
|
||||
python scripts/export_trajectories.py --verbose
|
||||
|
||||
# 2. Prepare training dir
|
||||
mkdir -p ~/timmy-lora-training
|
||||
cp ~/timmy-training-data.jsonl ~/timmy-lora-training/train.jsonl
|
||||
|
||||
# 3. Fine-tune
|
||||
python scripts/lora_finetune.py --verbose
|
||||
|
||||
# 4. Test
|
||||
python scripts/lora_finetune.py --test
|
||||
|
||||
# 5. Fuse + import to Ollama
|
||||
python scripts/lora_finetune.py --fuse
|
||||
ollama create timmy-hermes4 -f Modelfile.timmy-hermes4
|
||||
|
||||
Epic: #1091 Project Bannerlord — AutoLoRA Sovereignty Loop (Step 4 of 7)
|
||||
Refs: #1103
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import platform
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# ── Defaults ──────────────────────────────────────────────────────────────────
|
||||
|
||||
DEFAULT_DATA_DIR = Path.home() / "timmy-lora-training"
|
||||
DEFAULT_ADAPTER_PATH = Path.home() / "timmy-lora-adapter"
|
||||
DEFAULT_FUSED_PATH = Path.home() / "timmy-fused-model"
|
||||
|
||||
# mlx-lm model path — local HuggingFace checkout of Hermes 4 in MLX format.
|
||||
# Set MLX_HERMES4_PATH env var or pass --model to override.
|
||||
DEFAULT_MODEL_PATH_ENV = "MLX_HERMES4_PATH"
|
||||
|
||||
# Training hyperparameters (conservative for 36 GB M3 Max)
|
||||
DEFAULT_BATCH_SIZE = 1
|
||||
DEFAULT_LORA_LAYERS = 16
|
||||
DEFAULT_ITERS = 1000
|
||||
DEFAULT_LEARNING_RATE = 1e-5
|
||||
|
||||
# Test prompt used after training
|
||||
DEFAULT_TEST_PROMPT = (
|
||||
"List the open PRs on the Timmy Time Dashboard repo and triage them by priority."
|
||||
)
|
||||
|
||||
|
||||
# ── Pre-flight checks ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _check_apple_silicon() -> bool:
|
||||
"""Return True if running on Apple Silicon."""
|
||||
return platform.system() == "Darwin" and platform.machine() == "arm64"
|
||||
|
||||
|
||||
def _check_mlx_lm() -> bool:
|
||||
"""Return True if mlx-lm is installed and mlx_lm.lora is runnable."""
|
||||
return shutil.which("mlx_lm.lora") is not None or _can_import("mlx_lm")
|
||||
|
||||
|
||||
def _can_import(module: str) -> bool:
|
||||
try:
|
||||
import importlib
|
||||
|
||||
importlib.import_module(module)
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def _resolve_model_path(model_arg: str | None) -> str | None:
|
||||
"""Resolve model path from arg or environment variable."""
|
||||
if model_arg:
|
||||
return model_arg
|
||||
import os
|
||||
|
||||
env_path = os.environ.get(DEFAULT_MODEL_PATH_ENV)
|
||||
if env_path:
|
||||
return env_path
|
||||
return None
|
||||
|
||||
|
||||
def _preflight(model_path: str | None, data_dir: Path, verbose: bool) -> list[str]:
|
||||
"""Run pre-flight checks and return a list of warnings (empty = all OK)."""
|
||||
warnings: list[str] = []
|
||||
|
||||
if not _check_apple_silicon():
|
||||
warnings.append(
|
||||
"Not running on Apple Silicon. mlx-lm requires an M-series Mac.\n"
|
||||
" Alternative: use Unsloth on Google Colab / RunPod / Modal."
|
||||
)
|
||||
|
||||
if not _check_mlx_lm():
|
||||
warnings.append(
|
||||
"mlx-lm not found. Install with:\n pip install mlx-lm"
|
||||
)
|
||||
|
||||
if model_path is None:
|
||||
warnings.append(
|
||||
f"No model path specified. Set {DEFAULT_MODEL_PATH_ENV} or pass --model.\n"
|
||||
" Download Hermes 4 in MLX format from HuggingFace:\n"
|
||||
" https://huggingface.co/collections/NousResearch/hermes-4-collection-68a7\n"
|
||||
" or convert the GGUF:\n"
|
||||
" mlx_lm.convert --hf-path NousResearch/Hermes-4-14B --mlx-path ~/hermes4-mlx"
|
||||
)
|
||||
elif not Path(model_path).exists():
|
||||
warnings.append(f"Model path does not exist: {model_path}")
|
||||
|
||||
train_file = data_dir / "train.jsonl"
|
||||
if not train_file.exists():
|
||||
warnings.append(
|
||||
f"Training data not found: {train_file}\n"
|
||||
" Generate it with:\n"
|
||||
" python scripts/export_trajectories.py --verbose\n"
|
||||
f" mkdir -p {data_dir}\n"
|
||||
f" cp ~/timmy-training-data.jsonl {train_file}"
|
||||
)
|
||||
|
||||
if verbose and not warnings:
|
||||
print("Pre-flight checks: all OK")
|
||||
|
||||
return warnings
|
||||
|
||||
|
||||
# ── Command builders ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _build_train_cmd(
|
||||
model_path: str,
|
||||
data_dir: Path,
|
||||
adapter_path: Path,
|
||||
batch_size: int,
|
||||
lora_layers: int,
|
||||
iters: int,
|
||||
learning_rate: float,
|
||||
) -> list[str]:
|
||||
return [
|
||||
sys.executable, "-m", "mlx_lm.lora",
|
||||
"--model", model_path,
|
||||
"--train",
|
||||
"--data", str(data_dir),
|
||||
"--batch-size", str(batch_size),
|
||||
"--lora-layers", str(lora_layers),
|
||||
"--iters", str(iters),
|
||||
"--learning-rate", str(learning_rate),
|
||||
"--adapter-path", str(adapter_path),
|
||||
]
|
||||
|
||||
|
||||
def _build_test_cmd(
|
||||
model_path: str,
|
||||
adapter_path: Path,
|
||||
prompt: str,
|
||||
) -> list[str]:
|
||||
return [
|
||||
sys.executable, "-m", "mlx_lm.generate",
|
||||
"--model", model_path,
|
||||
"--adapter-path", str(adapter_path),
|
||||
"--prompt", prompt,
|
||||
"--max-tokens", "512",
|
||||
]
|
||||
|
||||
|
||||
def _build_fuse_cmd(
|
||||
model_path: str,
|
||||
adapter_path: Path,
|
||||
save_path: Path,
|
||||
) -> list[str]:
|
||||
return [
|
||||
sys.executable, "-m", "mlx_lm.fuse",
|
||||
"--model", model_path,
|
||||
"--adapter-path", str(adapter_path),
|
||||
"--save-path", str(save_path),
|
||||
]
|
||||
|
||||
|
||||
# ── Runner ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _run(cmd: list[str], dry_run: bool, verbose: bool) -> int:
|
||||
"""Print and optionally execute a command."""
|
||||
print("\nCommand:")
|
||||
print(" " + " \\\n ".join(cmd))
|
||||
if dry_run:
|
||||
print("\n(dry-run — not executing)")
|
||||
return 0
|
||||
|
||||
print()
|
||||
result = subprocess.run(cmd)
|
||||
return result.returncode
|
||||
|
||||
|
||||
# ── Main ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LoRA fine-tuning launcher for Hermes 4 (AutoLoRA Step 4)",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__,
|
||||
)
|
||||
|
||||
# Mode flags (mutually exclusive-ish)
|
||||
mode = parser.add_mutually_exclusive_group()
|
||||
mode.add_argument(
|
||||
"--test",
|
||||
action="store_true",
|
||||
help="Run inference test with trained adapter instead of training",
|
||||
)
|
||||
mode.add_argument(
|
||||
"--fuse",
|
||||
action="store_true",
|
||||
help="Fuse adapter into base model (for Ollama import)",
|
||||
)
|
||||
|
||||
# Paths
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default=None,
|
||||
help=f"Path to local MLX model (or set {DEFAULT_MODEL_PATH_ENV} env var)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data",
|
||||
type=Path,
|
||||
default=DEFAULT_DATA_DIR,
|
||||
help=f"Training data directory (default: {DEFAULT_DATA_DIR})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-path",
|
||||
type=Path,
|
||||
default=DEFAULT_ADAPTER_PATH,
|
||||
help=f"LoRA adapter output path (default: {DEFAULT_ADAPTER_PATH})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=Path,
|
||||
default=DEFAULT_FUSED_PATH,
|
||||
help=f"Fused model output path (default: {DEFAULT_FUSED_PATH})",
|
||||
)
|
||||
|
||||
# Hyperparameters
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=DEFAULT_BATCH_SIZE,
|
||||
help=f"Training batch size (default: {DEFAULT_BATCH_SIZE}; reduce to 1 if OOM)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-layers",
|
||||
type=int,
|
||||
default=DEFAULT_LORA_LAYERS,
|
||||
help=f"Number of LoRA layers (default: {DEFAULT_LORA_LAYERS}; reduce if OOM)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iters",
|
||||
type=int,
|
||||
default=DEFAULT_ITERS,
|
||||
help=f"Training iterations (default: {DEFAULT_ITERS})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning-rate",
|
||||
type=float,
|
||||
default=DEFAULT_LEARNING_RATE,
|
||||
help=f"Learning rate (default: {DEFAULT_LEARNING_RATE})",
|
||||
)
|
||||
|
||||
# Misc
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
default=DEFAULT_TEST_PROMPT,
|
||||
help="Prompt for --test mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Print command without executing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
"-v",
|
||||
action="store_true",
|
||||
help="Print extra progress information",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-preflight",
|
||||
action="store_true",
|
||||
help="Skip pre-flight checks (useful in CI)",
|
||||
)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
model_path = _resolve_model_path(args.model)
|
||||
|
||||
# ── Pre-flight ──────────────────────────────────────────────────────────
|
||||
if not args.skip_preflight:
|
||||
warnings = _preflight(model_path, args.data, args.verbose)
|
||||
if warnings:
|
||||
for w in warnings:
|
||||
print(f"WARNING: {w}\n")
|
||||
if not args.dry_run:
|
||||
print("Aborting due to pre-flight warnings. Use --dry-run to see commands anyway.")
|
||||
return 1
|
||||
|
||||
if model_path is None:
|
||||
# Allow dry-run without a model for documentation purposes
|
||||
model_path = "<path-to-hermes4-mlx>"
|
||||
|
||||
# ── Mode dispatch ────────────────────────────────────────────────────────
|
||||
if args.test:
|
||||
print(f"Testing fine-tuned model with adapter: {args.adapter_path}")
|
||||
cmd = _build_test_cmd(model_path, args.adapter_path, args.prompt)
|
||||
return _run(cmd, args.dry_run, args.verbose)
|
||||
|
||||
if args.fuse:
|
||||
print(f"Fusing adapter {args.adapter_path} into base model → {args.save_path}")
|
||||
cmd = _build_fuse_cmd(model_path, args.adapter_path, args.save_path)
|
||||
rc = _run(cmd, args.dry_run, args.verbose)
|
||||
if rc == 0 and not args.dry_run:
|
||||
print(
|
||||
f"\nFused model saved to: {args.save_path}\n"
|
||||
"To import into Ollama:\n"
|
||||
f" ollama create timmy-hermes4 -f Modelfile.hermes4-14b\n"
|
||||
" (edit Modelfile to point FROM to the fused GGUF path)"
|
||||
)
|
||||
return rc
|
||||
|
||||
# Default: train
|
||||
print(f"Starting LoRA fine-tuning")
|
||||
print(f" Model: {model_path}")
|
||||
print(f" Data: {args.data}")
|
||||
print(f" Adapter path: {args.adapter_path}")
|
||||
print(f" Iterations: {args.iters}")
|
||||
print(f" Batch size: {args.batch_size}")
|
||||
print(f" LoRA layers: {args.lora_layers}")
|
||||
print(f" Learning rate:{args.learning_rate}")
|
||||
print()
|
||||
print("Estimated time: 2-8 hours on M3 Max (depends on dataset size).")
|
||||
print("If OOM: reduce --lora-layers to 8 or --batch-size stays at 1.")
|
||||
|
||||
cmd = _build_train_cmd(
|
||||
model_path=model_path,
|
||||
data_dir=args.data,
|
||||
adapter_path=args.adapter_path,
|
||||
batch_size=args.batch_size,
|
||||
lora_layers=args.lora_layers,
|
||||
iters=args.iters,
|
||||
learning_rate=args.learning_rate,
|
||||
)
|
||||
rc = _run(cmd, args.dry_run, args.verbose)
|
||||
|
||||
if rc == 0 and not args.dry_run:
|
||||
print(
|
||||
f"\nTraining complete! Adapter saved to: {args.adapter_path}\n"
|
||||
"Test with:\n"
|
||||
f" python scripts/lora_finetune.py --test\n"
|
||||
"Then fuse + import to Ollama:\n"
|
||||
f" python scripts/lora_finetune.py --fuse"
|
||||
)
|
||||
|
||||
return rc
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
244
scripts/test_gabs_connectivity.py
Normal file
244
scripts/test_gabs_connectivity.py
Normal file
@@ -0,0 +1,244 @@
|
||||
#!/usr/bin/env python3
|
||||
"""GABS TCP connectivity and JSON-RPC smoke test.
|
||||
|
||||
Tests connectivity from Hermes to the Bannerlord.GABS TCP server running on the
|
||||
Windows VM. Covers:
|
||||
1. TCP socket connection (port 4825 reachable)
|
||||
2. JSON-RPC ping round-trip
|
||||
3. get_game_state call (game must be running)
|
||||
4. Latency — target < 100 ms on LAN
|
||||
|
||||
Usage:
|
||||
python scripts/test_gabs_connectivity.py --host 10.0.0.50
|
||||
python scripts/test_gabs_connectivity.py --host 10.0.0.50 --port 4825 --timeout 5
|
||||
|
||||
Refs: #1098 (Bannerlord Infra — Windows VM Setup + GABS Mod Installation)
|
||||
Epic: #1091 (Project Bannerlord)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
DEFAULT_HOST = "127.0.0.1"
|
||||
DEFAULT_PORT = 4825
|
||||
DEFAULT_TIMEOUT = 5 # seconds
|
||||
LATENCY_TARGET_MS = 100.0
|
||||
|
||||
|
||||
# ── Low-level TCP helpers ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _tcp_connect(host: str, port: int, timeout: float) -> socket.socket:
|
||||
"""Open a TCP connection and return the socket. Raises on failure."""
|
||||
sock = socket.create_connection((host, port), timeout=timeout)
|
||||
sock.settimeout(timeout)
|
||||
return sock
|
||||
|
||||
|
||||
def _send_recv(sock: socket.socket, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Send a newline-delimited JSON-RPC request and return the parsed response."""
|
||||
raw = json.dumps(payload) + "\n"
|
||||
sock.sendall(raw.encode())
|
||||
|
||||
buf = b""
|
||||
while b"\n" not in buf:
|
||||
chunk = sock.recv(4096)
|
||||
if not chunk:
|
||||
raise ConnectionError("Connection closed before response received")
|
||||
buf += chunk
|
||||
|
||||
line = buf.split(b"\n", 1)[0]
|
||||
return json.loads(line.decode())
|
||||
|
||||
|
||||
def _rpc(sock: socket.socket, method: str, params: dict | None = None, req_id: int = 1) -> dict[str, Any]:
|
||||
"""Build and send a JSON-RPC 2.0 request, return the response dict."""
|
||||
payload: dict[str, Any] = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
"id": req_id,
|
||||
}
|
||||
if params:
|
||||
payload["params"] = params
|
||||
return _send_recv(sock, payload)
|
||||
|
||||
|
||||
# ── Test cases ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_tcp_connection(host: str, port: int, timeout: float) -> tuple[bool, socket.socket | None]:
|
||||
"""PASS: TCP connection to host:port succeeds."""
|
||||
print(f"\n[1/4] TCP connection → {host}:{port}")
|
||||
try:
|
||||
t0 = time.monotonic()
|
||||
sock = _tcp_connect(host, port, timeout)
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
print(f" ✓ Connected ({elapsed_ms:.1f} ms)")
|
||||
return True, sock
|
||||
except OSError as exc:
|
||||
print(f" ✗ Connection failed: {exc}")
|
||||
print(f" Checklist:")
|
||||
print(f" - Is Bannerlord running with GABS mod enabled?")
|
||||
print(f" - Is port {port} open in Windows Firewall?")
|
||||
print(f" - Is the VM IP correct? (got: {host})")
|
||||
return False, None
|
||||
|
||||
|
||||
def test_ping(sock: socket.socket) -> bool:
|
||||
"""PASS: JSON-RPC ping returns a 2.0 response."""
|
||||
print(f"\n[2/4] JSON-RPC ping")
|
||||
try:
|
||||
t0 = time.monotonic()
|
||||
resp = _rpc(sock, "ping", req_id=1)
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
if resp.get("jsonrpc") == "2.0" and "error" not in resp:
|
||||
print(f" ✓ Ping OK ({elapsed_ms:.1f} ms): {json.dumps(resp)}")
|
||||
return True
|
||||
print(f" ✗ Unexpected response ({elapsed_ms:.1f} ms): {json.dumps(resp)}")
|
||||
return False
|
||||
except Exception as exc:
|
||||
print(f" ✗ Ping failed: {exc}")
|
||||
return False
|
||||
|
||||
|
||||
def test_game_state(sock: socket.socket) -> bool:
|
||||
"""PASS: get_game_state returns a result (game must be in a campaign)."""
|
||||
print(f"\n[3/4] get_game_state call")
|
||||
try:
|
||||
t0 = time.monotonic()
|
||||
resp = _rpc(sock, "get_game_state", req_id=2)
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
if "error" in resp:
|
||||
code = resp["error"].get("code", "?")
|
||||
msg = resp["error"].get("message", "")
|
||||
if code == -32601:
|
||||
# Method not found — GABS version may not expose this method
|
||||
print(f" ~ Method not available ({elapsed_ms:.1f} ms): {msg}")
|
||||
print(f" This is acceptable if game is not yet in a campaign.")
|
||||
return True
|
||||
print(f" ✗ RPC error ({elapsed_ms:.1f} ms) [{code}]: {msg}")
|
||||
return False
|
||||
result = resp.get("result", {})
|
||||
print(f" ✓ Game state received ({elapsed_ms:.1f} ms):")
|
||||
for k, v in result.items():
|
||||
print(f" {k}: {v}")
|
||||
return True
|
||||
except Exception as exc:
|
||||
print(f" ✗ get_game_state failed: {exc}")
|
||||
return False
|
||||
|
||||
|
||||
def test_latency(host: str, port: int, timeout: float, iterations: int = 5) -> bool:
|
||||
"""PASS: Average round-trip latency is under LATENCY_TARGET_MS."""
|
||||
print(f"\n[4/4] Latency test ({iterations} pings, target < {LATENCY_TARGET_MS:.0f} ms)")
|
||||
try:
|
||||
times: list[float] = []
|
||||
for i in range(iterations):
|
||||
sock = _tcp_connect(host, port, timeout)
|
||||
try:
|
||||
t0 = time.monotonic()
|
||||
_rpc(sock, "ping", req_id=i + 10)
|
||||
times.append((time.monotonic() - t0) * 1000)
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
avg_ms = sum(times) / len(times)
|
||||
min_ms = min(times)
|
||||
max_ms = max(times)
|
||||
print(f" avg={avg_ms:.1f} ms min={min_ms:.1f} ms max={max_ms:.1f} ms")
|
||||
|
||||
if avg_ms <= LATENCY_TARGET_MS:
|
||||
print(f" ✓ Latency within target ({avg_ms:.1f} ms ≤ {LATENCY_TARGET_MS:.0f} ms)")
|
||||
return True
|
||||
print(
|
||||
f" ✗ Latency too high ({avg_ms:.1f} ms > {LATENCY_TARGET_MS:.0f} ms)\n"
|
||||
f" Check network path between Hermes and the VM."
|
||||
)
|
||||
return False
|
||||
except Exception as exc:
|
||||
print(f" ✗ Latency test failed: {exc}")
|
||||
return False
|
||||
|
||||
|
||||
# ── Main ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="GABS TCP connectivity smoke test")
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
default=DEFAULT_HOST,
|
||||
help=f"Bannerlord VM IP or hostname (default: {DEFAULT_HOST})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=DEFAULT_PORT,
|
||||
help=f"GABS TCP port (default: {DEFAULT_PORT})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=float,
|
||||
default=DEFAULT_TIMEOUT,
|
||||
help=f"Socket timeout in seconds (default: {DEFAULT_TIMEOUT})",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 60)
|
||||
print(f"GABS Connectivity Test Suite")
|
||||
print(f"Target: {args.host}:{args.port}")
|
||||
print(f"Timeout: {args.timeout}s")
|
||||
print("=" * 60)
|
||||
|
||||
results: dict[str, bool] = {}
|
||||
|
||||
# Test 1: TCP connection (gate — skip remaining if unreachable)
|
||||
ok, sock = test_tcp_connection(args.host, args.port, args.timeout)
|
||||
results["tcp_connection"] = ok
|
||||
if not ok:
|
||||
_print_summary(results)
|
||||
return 1
|
||||
|
||||
# Tests 2–3 reuse the same socket
|
||||
try:
|
||||
results["ping"] = test_ping(sock)
|
||||
results["game_state"] = test_game_state(sock)
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
# Test 4: latency uses fresh connections
|
||||
results["latency"] = test_latency(args.host, args.port, args.timeout)
|
||||
|
||||
return _print_summary(results)
|
||||
|
||||
|
||||
def _print_summary(results: dict[str, bool]) -> int:
|
||||
passed = sum(results.values())
|
||||
total = len(results)
|
||||
print("\n" + "=" * 60)
|
||||
print(f"Results: {passed}/{total} passed")
|
||||
print("=" * 60)
|
||||
for name, ok in results.items():
|
||||
icon = "✓" if ok else "✗"
|
||||
print(f" {icon} {name}")
|
||||
|
||||
if passed == total:
|
||||
print("\n✓ GABS connectivity verified. Timmy can reach the game.")
|
||||
print(" Next step: run benchmark level 0 (JSON compliance check).")
|
||||
elif not results.get("tcp_connection"):
|
||||
print("\n✗ TCP connection failed. VM/firewall setup incomplete.")
|
||||
print(" See docs/research/bannerlord-vm-setup.md for checklist.")
|
||||
else:
|
||||
print("\n~ Partial pass — review failures above.")
|
||||
|
||||
return 0 if passed == total else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -1,49 +0,0 @@
|
||||
"""Bannerlord M3 — Full Campaign Strategy.
|
||||
|
||||
Timmy runs a complete Bannerlord campaign: economy, diplomacy, kingdom
|
||||
building, and war decisions — all via sovereign local inference.
|
||||
|
||||
Key components:
|
||||
gabs_client — TCP JSON-RPC client for the GABS mod (port 4825)
|
||||
types — KingSubgoal, GameState, message schemas
|
||||
session_memory — SQLite-backed multi-day strategic plan persistence
|
||||
campaign — CampaignOrchestrator tying all agents together
|
||||
adapter — WorldInterface adapter for use with the benchmark runner
|
||||
agents/ — King, Vassal, and Companion agent hierarchy
|
||||
|
||||
Quick start::
|
||||
|
||||
from bannerlord.campaign import CampaignOrchestrator
|
||||
orch = CampaignOrchestrator()
|
||||
summary = await orch.run(max_ticks=100)
|
||||
|
||||
Register the world adapter::
|
||||
|
||||
from infrastructure.world import register_adapter
|
||||
from bannerlord.adapter import BannerlordWorldAdapter
|
||||
register_adapter("bannerlord", BannerlordWorldAdapter)
|
||||
|
||||
M3 done-when condition:
|
||||
Timmy establishes own kingdom with 3+ fiefs and
|
||||
survives 100 in-game days as ruler.
|
||||
"""
|
||||
|
||||
from bannerlord.adapter import BannerlordWorldAdapter
|
||||
from bannerlord.campaign import CampaignOrchestrator
|
||||
from bannerlord.gabs_client import GABSClient
|
||||
from bannerlord.session_memory import SessionMemory
|
||||
from bannerlord.types import (
|
||||
GameState,
|
||||
KingSubgoal,
|
||||
SubgoalToken,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BannerlordWorldAdapter",
|
||||
"CampaignOrchestrator",
|
||||
"GABSClient",
|
||||
"GameState",
|
||||
"KingSubgoal",
|
||||
"SessionMemory",
|
||||
"SubgoalToken",
|
||||
]
|
||||
@@ -1,228 +0,0 @@
|
||||
"""Bannerlord M3 — WorldInterface adapter wrapping the GABS TCP client.
|
||||
|
||||
Plugs Bannerlord into the engine-agnostic ``WorldInterface`` contract so
|
||||
the benchmark runner and heartbeat loop can drive the campaign the same way
|
||||
they would drive any other game world.
|
||||
|
||||
Register with::
|
||||
|
||||
from infrastructure.world import register_adapter
|
||||
from bannerlord.adapter import BannerlordWorldAdapter
|
||||
register_adapter("bannerlord", BannerlordWorldAdapter)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from bannerlord.gabs_client import GABSClient
|
||||
from bannerlord.types import GameState
|
||||
from infrastructure.world.interface import WorldInterface
|
||||
from infrastructure.world.types import (
|
||||
ActionResult,
|
||||
ActionStatus,
|
||||
CommandInput,
|
||||
PerceptionOutput,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BannerlordWorldAdapter(WorldInterface):
|
||||
"""WorldInterface adapter for Bannerlord via the GABS mod.
|
||||
|
||||
``observe()`` — fetches the full GameState from GABS and maps it to a
|
||||
``PerceptionOutput`` with structured fields.
|
||||
|
||||
``act()`` — dispatches ``CommandInput.action`` as a GABS JSON-RPC call,
|
||||
forwarding ``parameters`` as the call args.
|
||||
|
||||
``speak()`` — sends a chat message via GABS (e.g., for companion NPC
|
||||
conversations or on-screen overlays).
|
||||
|
||||
Degrades gracefully when GABS is unavailable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 4825,
|
||||
timeout: float = 10.0,
|
||||
) -> None:
|
||||
self._client = GABSClient(host=host, port=port, timeout=timeout)
|
||||
self._last_state: GameState | None = None
|
||||
|
||||
# -- lifecycle ---------------------------------------------------------
|
||||
|
||||
def connect(self) -> None:
|
||||
"""Synchronous connect wrapper (runs async in a new event loop)."""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# Inside async context — caller should use async connect
|
||||
logger.warning(
|
||||
"BannerlordWorldAdapter.connect() called from async context; "
|
||||
"use 'await adapter.async_connect()' instead"
|
||||
)
|
||||
return
|
||||
loop.run_until_complete(self._client.connect())
|
||||
except Exception as exc:
|
||||
logger.warning("BannerlordWorldAdapter.connect() failed: %s", exc)
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Synchronous disconnect wrapper."""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
return
|
||||
loop.run_until_complete(self._client.disconnect())
|
||||
except Exception as exc:
|
||||
logger.debug("BannerlordWorldAdapter.disconnect() error: %s", exc)
|
||||
|
||||
async def async_connect(self) -> bool:
|
||||
"""Async connect — preferred in async contexts."""
|
||||
return await self._client.connect()
|
||||
|
||||
async def async_disconnect(self) -> None:
|
||||
"""Async disconnect."""
|
||||
await self._client.disconnect()
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._client.is_connected
|
||||
|
||||
# -- WorldInterface contract ------------------------------------------
|
||||
|
||||
def observe(self) -> PerceptionOutput:
|
||||
"""Return a PerceptionOutput derived from the current GABS GameState.
|
||||
|
||||
Falls back to an empty perception if GABS is unreachable.
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
logger.warning("observe() called from async context — use async_observe()")
|
||||
return self._empty_perception()
|
||||
state = loop.run_until_complete(self._client.get_game_state())
|
||||
return self._state_to_perception(state)
|
||||
except Exception as exc:
|
||||
logger.warning("BannerlordWorldAdapter.observe() error: %s", exc)
|
||||
return self._empty_perception()
|
||||
|
||||
async def async_observe(self) -> PerceptionOutput:
|
||||
"""Async observe — preferred in async contexts."""
|
||||
try:
|
||||
state = await self._client.get_game_state()
|
||||
self._last_state = state
|
||||
return self._state_to_perception(state)
|
||||
except Exception as exc:
|
||||
logger.warning("async_observe() error: %s", exc)
|
||||
return self._empty_perception()
|
||||
|
||||
def act(self, command: CommandInput) -> ActionResult:
|
||||
"""Dispatch a command to GABS. Returns success/failure."""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
logger.warning("act() called from async context — use async_act()")
|
||||
return ActionResult(status=ActionStatus.NOOP)
|
||||
result = loop.run_until_complete(self.async_act(command))
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.warning("BannerlordWorldAdapter.act() error: %s", exc)
|
||||
return ActionResult(
|
||||
status=ActionStatus.FAILURE,
|
||||
message=str(exc),
|
||||
)
|
||||
|
||||
async def async_act(self, command: CommandInput) -> ActionResult:
|
||||
"""Async command dispatch."""
|
||||
try:
|
||||
result = await self._client._call(
|
||||
command.action, command.parameters or {}
|
||||
)
|
||||
if result is None:
|
||||
return ActionResult(
|
||||
status=ActionStatus.FAILURE,
|
||||
message=f"GABS returned no result for {command.action}",
|
||||
)
|
||||
return ActionResult(
|
||||
status=ActionStatus.SUCCESS,
|
||||
message=f"GABS executed: {command.action}",
|
||||
data=result if isinstance(result, dict) else {"result": result},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("async_act(%s) error: %s", command.action, exc)
|
||||
return ActionResult(status=ActionStatus.FAILURE, message=str(exc))
|
||||
|
||||
def speak(self, message: str, target: str | None = None) -> None:
|
||||
"""Send a message via GABS (e.g., companion dialogue or overlay)."""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
return
|
||||
loop.run_until_complete(
|
||||
self._client._call("chat/send", {"message": message, "target": target})
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("BannerlordWorldAdapter.speak() error: %s", exc)
|
||||
|
||||
# -- helpers -----------------------------------------------------------
|
||||
|
||||
def _state_to_perception(self, state: GameState) -> PerceptionOutput:
|
||||
"""Map a GameState snapshot to a PerceptionOutput."""
|
||||
entities: list[str] = []
|
||||
events: list[str] = []
|
||||
|
||||
# Party location
|
||||
if state.party.location:
|
||||
entities.append(f"location:{state.party.location}")
|
||||
|
||||
# Kingdom status
|
||||
if state.has_kingdom():
|
||||
entities.append(f"kingdom:{state.kingdom.name}")
|
||||
for fief in state.kingdom.fiefs:
|
||||
entities.append(f"fief:{fief}")
|
||||
|
||||
# Active wars as events
|
||||
for war in state.kingdom.active_wars:
|
||||
events.append(f"at_war_with:{war}")
|
||||
|
||||
# Faction snapshot
|
||||
for faction in state.factions:
|
||||
entities.append(f"faction:{faction.name}[{faction.army_strength}]")
|
||||
|
||||
# Alerts
|
||||
if state.is_two_front_war():
|
||||
events.append("alert:two_front_war")
|
||||
if state.party.wounded_pct > 0.30:
|
||||
events.append(f"alert:wounded_{state.party.wounded_pct:.0%}")
|
||||
if state.party.food_days < 3:
|
||||
events.append("alert:low_food")
|
||||
|
||||
return PerceptionOutput(
|
||||
timestamp=datetime.now(UTC),
|
||||
location=state.party.location,
|
||||
entities=entities,
|
||||
events=events,
|
||||
raw=state.raw,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _empty_perception() -> PerceptionOutput:
|
||||
return PerceptionOutput(
|
||||
timestamp=datetime.now(UTC),
|
||||
location="",
|
||||
entities=[],
|
||||
events=["gabs:unavailable"],
|
||||
raw={"adapter": "bannerlord", "connected": False},
|
||||
)
|
||||
|
||||
@property
|
||||
def last_game_state(self) -> GameState | None:
|
||||
"""Return the most recently observed GameState."""
|
||||
return self._last_state
|
||||
@@ -1,4 +0,0 @@
|
||||
"""Bannerlord M3 — feudal agent hierarchy.
|
||||
|
||||
King → Vassal → Companion, following Ahilan & Dayan (2019).
|
||||
"""
|
||||
@@ -1 +0,0 @@
|
||||
"""Bannerlord M3 — Companion worker agents (lowest tier)."""
|
||||
@@ -1,61 +0,0 @@
|
||||
"""Bannerlord M3 — Caravan Companion (trade operations).
|
||||
|
||||
Handles trade route assessment, buy/sell goods, caravan deployment.
|
||||
Triggered by TRADE subgoal or when treasury is below threshold.
|
||||
|
||||
Minimum margin threshold: 15% (never buy goods without ≥ 15% resale margin).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from bannerlord.types import GameState, KingSubgoal, SubgoalToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MIN_MARGIN_PCT = 0.15 # minimum profitable resale margin
|
||||
_CARAVAN_DENAR_THRESHOLD = 10_000 # must have 10k denars to deploy a caravan
|
||||
|
||||
|
||||
class CaravanCompanion:
|
||||
"""Companion worker for trade route management."""
|
||||
|
||||
AGENT_ID = "caravan_companion"
|
||||
|
||||
def evaluate(
|
||||
self, state: GameState, subgoal: KingSubgoal
|
||||
) -> list[dict]:
|
||||
"""Return trade primitives to execute.
|
||||
|
||||
Returns:
|
||||
List of dicts with 'primitive' and 'args' keys.
|
||||
"""
|
||||
if subgoal.token != SubgoalToken.TRADE:
|
||||
return []
|
||||
|
||||
actions: list[dict] = []
|
||||
party = state.party
|
||||
|
||||
# Always assess prices at current location first
|
||||
actions.append({
|
||||
"primitive": "assess_prices",
|
||||
"args": {"town": party.location},
|
||||
})
|
||||
|
||||
# Deploy a caravan if treasury is flush
|
||||
if party.denars >= _CARAVAN_DENAR_THRESHOLD and party.location:
|
||||
actions.append({
|
||||
"primitive": "establish_caravan",
|
||||
"args": {"town": party.location},
|
||||
})
|
||||
|
||||
return actions
|
||||
|
||||
@staticmethod
|
||||
def is_profitable_trade(buy_price: int, sell_price: int) -> bool:
|
||||
"""Return True if the trade margin meets the minimum threshold."""
|
||||
if buy_price <= 0:
|
||||
return False
|
||||
margin = (sell_price - buy_price) / buy_price
|
||||
return margin >= _MIN_MARGIN_PCT
|
||||
@@ -1,78 +0,0 @@
|
||||
"""Bannerlord M3 — Logistics Companion (party management).
|
||||
|
||||
Handles recruit, supply, rest, prisoner sale, and troop upgrade primitives.
|
||||
Runs on Qwen3:8b for sub-2-second response times.
|
||||
|
||||
Triggered by RECRUIT and HEAL subgoals, or by party condition thresholds.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from bannerlord.types import GameState, KingSubgoal, SubgoalToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_FOOD_WARN_DAYS = 5
|
||||
_WOUND_WARN_PCT = 0.20
|
||||
_PRISONER_CAP = 20
|
||||
|
||||
|
||||
class LogisticsCompanion:
|
||||
"""Companion worker for party logistics.
|
||||
|
||||
Evaluates the current party state and returns a list of primitive
|
||||
action names + args to dispatch to the GABSClient.
|
||||
"""
|
||||
|
||||
AGENT_ID = "logistics_companion"
|
||||
|
||||
def evaluate(
|
||||
self, state: GameState, subgoal: KingSubgoal
|
||||
) -> list[dict]:
|
||||
"""Return primitives to execute given current state and active subgoal.
|
||||
|
||||
Returns:
|
||||
List of dicts with 'primitive' and 'args' keys.
|
||||
"""
|
||||
actions: list[dict] = []
|
||||
party = state.party
|
||||
|
||||
# Subgoal-driven recruitment
|
||||
if subgoal.token == SubgoalToken.RECRUIT:
|
||||
qty = subgoal.quantity or 20
|
||||
actions.append({
|
||||
"primitive": "recruit_troop",
|
||||
"args": {"troop_type": "infantry", "qty": qty},
|
||||
})
|
||||
|
||||
# Emergency rest on heavy wounds
|
||||
if subgoal.token == SubgoalToken.HEAL or party.wounded_pct > _WOUND_WARN_PCT:
|
||||
actions.append({
|
||||
"primitive": "rest_party",
|
||||
"args": {"days": 3},
|
||||
})
|
||||
|
||||
# Replenish food if low
|
||||
if party.food_days < _FOOD_WARN_DAYS:
|
||||
actions.append({
|
||||
"primitive": "buy_supplies",
|
||||
"args": {"qty": max(0, 10 - party.food_days)},
|
||||
})
|
||||
|
||||
# Sell prisoners when near cap
|
||||
if party.prisoners >= _PRISONER_CAP:
|
||||
actions.append({
|
||||
"primitive": "sell_prisoners",
|
||||
"args": {"location": party.location},
|
||||
})
|
||||
|
||||
# Upgrade troops when stable
|
||||
if subgoal.token == SubgoalToken.TRAIN:
|
||||
actions.append({
|
||||
"primitive": "upgrade_troops",
|
||||
"args": {},
|
||||
})
|
||||
|
||||
return actions
|
||||
@@ -1,58 +0,0 @@
|
||||
"""Bannerlord M3 — Scout Companion (intelligence gathering).
|
||||
|
||||
Handles lord tracking, garrison assessment, and patrol route mapping.
|
||||
Triggered by SPY subgoal or proactively before expansion decisions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from bannerlord.types import GameState, KingSubgoal, SubgoalToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ScoutCompanion:
|
||||
"""Companion worker for tactical intelligence."""
|
||||
|
||||
AGENT_ID = "scout_companion"
|
||||
|
||||
def evaluate(
|
||||
self, state: GameState, subgoal: KingSubgoal
|
||||
) -> list[dict]:
|
||||
"""Return scouting primitives to execute.
|
||||
|
||||
Returns:
|
||||
List of dicts with 'primitive' and 'args' keys.
|
||||
"""
|
||||
actions: list[dict] = []
|
||||
|
||||
if subgoal.token == SubgoalToken.SPY:
|
||||
target = subgoal.target
|
||||
if target:
|
||||
actions.append({
|
||||
"primitive": "track_lord",
|
||||
"args": {"lord_name": target},
|
||||
})
|
||||
|
||||
elif subgoal.token == SubgoalToken.EXPAND_TERRITORY:
|
||||
target = subgoal.target
|
||||
if target:
|
||||
actions.append({
|
||||
"primitive": "assess_garrison",
|
||||
"args": {"settlement": target},
|
||||
})
|
||||
|
||||
# Proactively map patrols in active war regions
|
||||
for war_faction in state.kingdom.active_wars:
|
||||
# Find a fief belonging to the enemy as the region reference
|
||||
for faction in state.factions:
|
||||
if faction.name == war_faction and faction.fiefs:
|
||||
actions.append({
|
||||
"primitive": "map_patrol_routes",
|
||||
"args": {"region": faction.fiefs[0]},
|
||||
})
|
||||
break # one region per enemy faction per tick
|
||||
|
||||
return actions
|
||||
@@ -1,145 +0,0 @@
|
||||
"""Bannerlord M3 — Diplomacy Vassal agent.
|
||||
|
||||
Handles relations management: alliances, peace deals, tribute, marriage.
|
||||
Responds to the ALLY subgoal.
|
||||
|
||||
Reward function:
|
||||
R_diplo = w1 * AlliesCount
|
||||
+ w2 * TruceDurationValue
|
||||
+ w3 * RelationsScore_weighted
|
||||
- w4 * ActiveWarsFront
|
||||
+ w5 * SubgoalBonus
|
||||
|
||||
Key strategic rules:
|
||||
- Never start a new war if already in 2+ wars (2-front war rule)
|
||||
- Prefer peace with weakest current enemy when overextended
|
||||
- Time alliances before declaring war to reduce isolation risk
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from bannerlord.types import (
|
||||
GameState,
|
||||
KingSubgoal,
|
||||
SubgoalToken,
|
||||
TaskMessage,
|
||||
VassalReward,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_W1_ALLIES = 0.30
|
||||
_W2_TRUCE = 0.25
|
||||
_W3_RELATIONS = 0.25
|
||||
_W4_WAR_FRONTS = 0.15
|
||||
_W5_SUBGOAL = 0.05
|
||||
|
||||
_SUBGOAL_TRIGGERS = {SubgoalToken.ALLY}
|
||||
_MAX_WAR_FRONTS = 2 # flag when at 2+ simultaneous wars (two-front war)
|
||||
|
||||
|
||||
class DiplomacyVassal:
|
||||
"""Mid-tier agent responsible for diplomatic relations."""
|
||||
|
||||
AGENT_ID = "diplomacy_vassal"
|
||||
|
||||
def is_relevant(self, subgoal: KingSubgoal) -> bool:
|
||||
return subgoal.token in _SUBGOAL_TRIGGERS
|
||||
|
||||
def plan(self, state: GameState, subgoal: KingSubgoal) -> list[TaskMessage]:
|
||||
"""Return TaskMessages for the current diplomatic subgoal."""
|
||||
tasks: list[TaskMessage] = []
|
||||
|
||||
if subgoal.token == SubgoalToken.ALLY:
|
||||
tasks.extend(self._plan_alliance(state, subgoal))
|
||||
|
||||
return tasks
|
||||
|
||||
def _plan_alliance(
|
||||
self, state: GameState, subgoal: KingSubgoal
|
||||
) -> list[TaskMessage]:
|
||||
"""Plan diplomatic outreach to reduce war fronts or build alliances."""
|
||||
tasks: list[TaskMessage] = []
|
||||
target = subgoal.target
|
||||
|
||||
if not target:
|
||||
logger.warning("DiplomacyVassal: no target for ALLY subgoal")
|
||||
return tasks
|
||||
|
||||
# If target is already an enemy, propose peace
|
||||
if target in state.kingdom.active_wars:
|
||||
tasks.append(
|
||||
TaskMessage(
|
||||
from_agent=self.AGENT_ID,
|
||||
to_agent="gabs",
|
||||
primitive="propose_peace",
|
||||
args={"faction": target, "tribute": 0},
|
||||
priority=subgoal.priority * 1.5,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Otherwise pursue alliance
|
||||
tasks.append(
|
||||
TaskMessage(
|
||||
from_agent=self.AGENT_ID,
|
||||
to_agent="gabs",
|
||||
primitive="send_envoy",
|
||||
args={
|
||||
"faction": target,
|
||||
"message": "We seek a lasting alliance for mutual defence.",
|
||||
},
|
||||
priority=subgoal.priority,
|
||||
)
|
||||
)
|
||||
tasks.append(
|
||||
TaskMessage(
|
||||
from_agent=self.AGENT_ID,
|
||||
to_agent="gabs",
|
||||
primitive="request_alliance",
|
||||
args={"faction": target},
|
||||
priority=subgoal.priority,
|
||||
)
|
||||
)
|
||||
|
||||
return tasks
|
||||
|
||||
def should_avoid_war(self, state: GameState) -> bool:
|
||||
"""Return True if starting a new war would be strategically unsound."""
|
||||
return state.active_war_count() >= _MAX_WAR_FRONTS # 2-front war check
|
||||
|
||||
def compute_reward(
|
||||
self,
|
||||
prev_state: GameState,
|
||||
curr_state: GameState,
|
||||
active_subgoal: KingSubgoal,
|
||||
) -> VassalReward:
|
||||
"""Compute Diplomacy Vassal reward."""
|
||||
allies_count = len(curr_state.kingdom.active_alliances)
|
||||
truce_value = 50.0 # placeholder — days of truce remaining
|
||||
relations_avg = 30.0 # placeholder — weighted relations score
|
||||
war_fronts = curr_state.active_war_count()
|
||||
|
||||
subgoal_bonus = 1.0 if active_subgoal.token in _SUBGOAL_TRIGGERS else 0.0
|
||||
|
||||
total = (
|
||||
_W1_ALLIES * allies_count * 10
|
||||
+ _W2_TRUCE * truce_value / 100
|
||||
+ _W3_RELATIONS * relations_avg / 100
|
||||
- _W4_WAR_FRONTS * war_fronts * 10
|
||||
+ _W5_SUBGOAL * subgoal_bonus * 10
|
||||
)
|
||||
|
||||
return VassalReward(
|
||||
agent_id=self.AGENT_ID,
|
||||
component_scores={
|
||||
"allies_count": allies_count,
|
||||
"truce_value": truce_value,
|
||||
"relations_avg": relations_avg,
|
||||
"war_fronts": -war_fronts,
|
||||
"subgoal_bonus": subgoal_bonus,
|
||||
},
|
||||
subgoal_bonus=subgoal_bonus,
|
||||
total=total,
|
||||
)
|
||||
@@ -1,151 +0,0 @@
|
||||
"""Bannerlord M3 — Economy Vassal agent.
|
||||
|
||||
Handles settlement management, tax collection, construction, and food supply.
|
||||
Responds to FORTIFY and CONSOLIDATE subgoals.
|
||||
|
||||
Reward function:
|
||||
R_econ = w1 * DailyDenarsIncome
|
||||
+ w2 * FoodStockBuffer
|
||||
+ w3 * LoyaltyAverage
|
||||
- w4 * ConstructionQueueLength
|
||||
+ w5 * SubgoalBonus
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from bannerlord.types import (
|
||||
GameState,
|
||||
KingSubgoal,
|
||||
SubgoalToken,
|
||||
TaskMessage,
|
||||
VassalReward,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_W1_INCOME = 0.35
|
||||
_W2_FOOD = 0.25
|
||||
_W3_LOYALTY = 0.20
|
||||
_W4_CONSTRUCTION = 0.15
|
||||
_W5_SUBGOAL = 0.05
|
||||
|
||||
_SUBGOAL_TRIGGERS = {SubgoalToken.FORTIFY, SubgoalToken.CONSOLIDATE}
|
||||
|
||||
_LOW_FOOD_THRESHOLD = 3 # days of food remaining
|
||||
_INCOME_TARGET = 200 # daily net income target (denars)
|
||||
|
||||
|
||||
class EconomyVassal:
|
||||
"""Mid-tier agent responsible for settlement economy."""
|
||||
|
||||
AGENT_ID = "economy_vassal"
|
||||
|
||||
def is_relevant(self, subgoal: KingSubgoal) -> bool:
|
||||
return subgoal.token in _SUBGOAL_TRIGGERS
|
||||
|
||||
def plan(self, state: GameState, subgoal: KingSubgoal) -> list[TaskMessage]:
|
||||
"""Return TaskMessages for the current economic subgoal."""
|
||||
tasks: list[TaskMessage] = []
|
||||
|
||||
# Always maintain food supply
|
||||
if state.party.food_days < _LOW_FOOD_THRESHOLD:
|
||||
tasks.append(
|
||||
TaskMessage(
|
||||
from_agent=self.AGENT_ID,
|
||||
to_agent="logistics_companion",
|
||||
primitive="buy_supplies",
|
||||
args={"qty": 10},
|
||||
priority=2.0,
|
||||
)
|
||||
)
|
||||
|
||||
if subgoal.token == SubgoalToken.FORTIFY:
|
||||
tasks.extend(self._plan_fortify(state, subgoal))
|
||||
elif subgoal.token == SubgoalToken.CONSOLIDATE:
|
||||
tasks.extend(self._plan_consolidate(state))
|
||||
|
||||
return tasks
|
||||
|
||||
def _plan_fortify(
|
||||
self, state: GameState, subgoal: KingSubgoal
|
||||
) -> list[TaskMessage]:
|
||||
"""Queue construction projects in owned settlements."""
|
||||
tasks: list[TaskMessage] = []
|
||||
target = subgoal.target or (state.kingdom.fiefs[0] if state.kingdom.fiefs else None)
|
||||
if not target:
|
||||
return tasks
|
||||
tasks.append(
|
||||
TaskMessage(
|
||||
from_agent=self.AGENT_ID,
|
||||
to_agent="gabs",
|
||||
primitive="build_project",
|
||||
args={"settlement": target, "project": "granary"},
|
||||
priority=1.2,
|
||||
)
|
||||
)
|
||||
tasks.append(
|
||||
TaskMessage(
|
||||
from_agent=self.AGENT_ID,
|
||||
to_agent="gabs",
|
||||
primitive="set_tax_policy",
|
||||
args={"settlement": target, "policy": "normal"},
|
||||
priority=1.0,
|
||||
)
|
||||
)
|
||||
return tasks
|
||||
|
||||
def _plan_consolidate(self, state: GameState) -> list[TaskMessage]:
|
||||
"""Stabilise: optimise tax and food across all fiefs."""
|
||||
tasks: list[TaskMessage] = []
|
||||
net_income = state.kingdom.daily_income - state.kingdom.daily_expenses
|
||||
for fief in state.kingdom.fiefs:
|
||||
policy = "normal" if net_income >= _INCOME_TARGET else "low"
|
||||
tasks.append(
|
||||
TaskMessage(
|
||||
from_agent=self.AGENT_ID,
|
||||
to_agent="gabs",
|
||||
primitive="set_tax_policy",
|
||||
args={"settlement": fief, "policy": policy},
|
||||
priority=0.8,
|
||||
)
|
||||
)
|
||||
return tasks
|
||||
|
||||
def compute_reward(
|
||||
self,
|
||||
prev_state: GameState,
|
||||
curr_state: GameState,
|
||||
active_subgoal: KingSubgoal,
|
||||
) -> VassalReward:
|
||||
"""Compute Economy Vassal reward."""
|
||||
income_delta = (
|
||||
curr_state.kingdom.daily_income - prev_state.kingdom.daily_income
|
||||
)
|
||||
food_buffer = curr_state.party.food_days
|
||||
loyalty_avg = 70.0 # placeholder — real value from GABS raw data
|
||||
queue_len = 0 # placeholder
|
||||
|
||||
subgoal_bonus = 1.0 if active_subgoal.token in _SUBGOAL_TRIGGERS else 0.0
|
||||
|
||||
total = (
|
||||
_W1_INCOME * income_delta
|
||||
+ _W2_FOOD * food_buffer
|
||||
+ _W3_LOYALTY * loyalty_avg / 100
|
||||
- _W4_CONSTRUCTION * queue_len
|
||||
+ _W5_SUBGOAL * subgoal_bonus * 10
|
||||
)
|
||||
|
||||
return VassalReward(
|
||||
agent_id=self.AGENT_ID,
|
||||
component_scores={
|
||||
"income_delta": income_delta,
|
||||
"food_buffer": food_buffer,
|
||||
"loyalty_avg": loyalty_avg,
|
||||
"queue_len": -queue_len,
|
||||
"subgoal_bonus": subgoal_bonus,
|
||||
},
|
||||
subgoal_bonus=subgoal_bonus,
|
||||
total=total,
|
||||
)
|
||||
@@ -1,266 +0,0 @@
|
||||
"""Bannerlord M3 — King agent (Timmy, strategic tier).
|
||||
|
||||
The King operates on the campaign-map timescale (1 decision per in-game day).
|
||||
He reads the full GameState and emits a single KingSubgoal token that vassals
|
||||
interpret. He uses Qwen3:32b via the LLM router.
|
||||
|
||||
Decision rules baked in (no LLM required for simple cases):
|
||||
- Never initiate a second war while already fighting one (avoid 2-front wars)
|
||||
- Prioritise HEAL when party wounds > 30 %
|
||||
- Prioritise RECRUIT when troops < 80
|
||||
- Prioritise TRADE when denars < 5,000
|
||||
- If kingdom has < 3 fiefs and no active war, prioritise EXPAND_TERRITORY
|
||||
- Default to CONSOLIDATE when conditions are stable
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from bannerlord.types import (
|
||||
GameState,
|
||||
KingSubgoal,
|
||||
SubgoalToken,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Hard thresholds for rule-based fallback decisions
|
||||
_MIN_TROOPS = 80
|
||||
_MIN_DENARS = 5_000
|
||||
_MAX_WOUND_PCT = 0.30
|
||||
_TARGET_FIEFS = 3
|
||||
_SURVIVAL_DAYS = 100
|
||||
|
||||
|
||||
class KingAgent:
|
||||
"""Strategic decision-maker for the Bannerlord campaign.
|
||||
|
||||
The King agent is sovereign — it cannot be terminated by vassals.
|
||||
It decides the active subgoal at most once per campaign tick.
|
||||
|
||||
Usage::
|
||||
|
||||
king = KingAgent(model="qwen3:32b")
|
||||
subgoal = king.decide(game_state)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str = "qwen3:32b",
|
||||
temperature: float = 0.1,
|
||||
) -> None:
|
||||
self._model = model
|
||||
self._temperature = temperature
|
||||
self._last_subgoal: KingSubgoal | None = None
|
||||
self._tick = 0
|
||||
self._session_id: str | None = None
|
||||
|
||||
def set_session(self, session_id: str) -> None:
|
||||
self._session_id = session_id
|
||||
|
||||
# -- primary decision interface ----------------------------------------
|
||||
|
||||
def decide(self, state: GameState) -> KingSubgoal:
|
||||
"""Return the King's subgoal for the current campaign tick.
|
||||
|
||||
Uses rule-based heuristics as the primary decision engine.
|
||||
LLM override can be wired in via ``_llm_decide`` in a future PR.
|
||||
|
||||
Args:
|
||||
state: Full game state snapshot from GABS.
|
||||
|
||||
Returns:
|
||||
A KingSubgoal to be broadcast to vassals.
|
||||
"""
|
||||
self._tick += 1
|
||||
subgoal = self._rule_based_decide(state)
|
||||
self._last_subgoal = subgoal
|
||||
logger.info(
|
||||
"King[tick=%d, day=%d] → %s (target=%s)",
|
||||
self._tick,
|
||||
state.in_game_day,
|
||||
subgoal.token,
|
||||
subgoal.target,
|
||||
)
|
||||
return subgoal
|
||||
|
||||
# -- rule-based strategy engine ----------------------------------------
|
||||
|
||||
def _rule_based_decide(self, state: GameState) -> KingSubgoal:
|
||||
"""Encode campaign strategy as prioritised decision rules.
|
||||
|
||||
Priority order (highest to lowest):
|
||||
1. Emergency: heal if heavily wounded
|
||||
2. Survival: recruit if dangerously low on troops
|
||||
3. Economy: earn income if broke
|
||||
4. Diplomacy: seek peace if in a 2-front war
|
||||
5. Expansion: take fiefs if not at war and need more territory
|
||||
6. Alliance: seek allies when preparing for war
|
||||
7. Default: consolidate and stabilise
|
||||
"""
|
||||
party = state.party
|
||||
kingdom = state.kingdom
|
||||
|
||||
# 1. Emergency heal
|
||||
if party.wounded_pct > _MAX_WOUND_PCT:
|
||||
return KingSubgoal(
|
||||
token=SubgoalToken.HEAL,
|
||||
context=f"{party.wounded_pct:.0%} of party is wounded — rest required",
|
||||
)
|
||||
|
||||
# 2. Critical recruitment
|
||||
if party.troops < _MIN_TROOPS:
|
||||
return KingSubgoal(
|
||||
token=SubgoalToken.RECRUIT,
|
||||
quantity=_MIN_TROOPS - party.troops,
|
||||
context=f"Party at {party.troops} troops — must reach {_MIN_TROOPS}",
|
||||
)
|
||||
|
||||
# 3. Destitute treasury
|
||||
if party.denars < _MIN_DENARS:
|
||||
return KingSubgoal(
|
||||
token=SubgoalToken.TRADE,
|
||||
context=f"Treasury at {party.denars:,} denars — run trade routes",
|
||||
)
|
||||
|
||||
# 4. Avoid 2-front war: seek peace when fighting 2+ enemies
|
||||
if state.is_two_front_war():
|
||||
# Pick the weakest enemy to negotiate peace with first
|
||||
weakest = self._weakest_enemy(state)
|
||||
return KingSubgoal(
|
||||
token=SubgoalToken.ALLY,
|
||||
target=weakest,
|
||||
context="2-front war detected — de-escalate with weakest enemy",
|
||||
)
|
||||
|
||||
# 5. Kingdom not yet established: work toward first fief
|
||||
if not state.has_kingdom():
|
||||
if party.troops >= 120 and state.active_war_count() == 0:
|
||||
target_fief = self._select_expansion_target(state)
|
||||
return KingSubgoal(
|
||||
token=SubgoalToken.EXPAND_TERRITORY,
|
||||
target=target_fief,
|
||||
context="No kingdom yet — capture a fief to establish one",
|
||||
)
|
||||
elif state.active_war_count() == 0:
|
||||
# Not ready to fight; train up first
|
||||
return KingSubgoal(
|
||||
token=SubgoalToken.TRAIN,
|
||||
context="Building army before first expansion",
|
||||
)
|
||||
|
||||
# 6. Expand if below target fief count and no active war
|
||||
if state.fief_count() < _TARGET_FIEFS and state.active_war_count() == 0:
|
||||
target_fief = self._select_expansion_target(state)
|
||||
return KingSubgoal(
|
||||
token=SubgoalToken.EXPAND_TERRITORY,
|
||||
target=target_fief,
|
||||
priority=1.5,
|
||||
context=f"Only {state.fief_count()} fiefs — need {_TARGET_FIEFS}",
|
||||
)
|
||||
|
||||
# 7. Seek allies when stable and below fief target
|
||||
if not kingdom.active_alliances and state.active_war_count() == 0:
|
||||
ally_candidate = self._best_alliance_candidate(state)
|
||||
if ally_candidate:
|
||||
return KingSubgoal(
|
||||
token=SubgoalToken.ALLY,
|
||||
target=ally_candidate,
|
||||
context="Stable moment — pursue defensive alliance",
|
||||
)
|
||||
|
||||
# 8. Fortify if kingdom exists and there are fiefs to improve
|
||||
if state.has_kingdom() and state.fief_count() > 0:
|
||||
if kingdom.daily_income - kingdom.daily_expenses < 100:
|
||||
return KingSubgoal(
|
||||
token=SubgoalToken.FORTIFY,
|
||||
context="Low net income — invest in settlements",
|
||||
)
|
||||
|
||||
# 9. Default: consolidate
|
||||
return KingSubgoal(
|
||||
token=SubgoalToken.CONSOLIDATE,
|
||||
context="Stable — hold territory and recover strength",
|
||||
)
|
||||
|
||||
# -- helper methods ----------------------------------------------------
|
||||
|
||||
def _weakest_enemy(self, state: GameState) -> str | None:
|
||||
"""Return the name of the weakest faction currently at war with us."""
|
||||
enemy_names = set(state.kingdom.active_wars)
|
||||
enemies = [f for f in state.factions if f.name in enemy_names]
|
||||
if not enemies:
|
||||
return None
|
||||
return min(enemies, key=lambda f: f.army_strength).name
|
||||
|
||||
def _select_expansion_target(self, state: GameState) -> str | None:
|
||||
"""Select the most vulnerable enemy settlement to target."""
|
||||
# Prefer factions already at war with others (distracted)
|
||||
for faction in state.factions:
|
||||
if len(faction.is_at_war_with) >= 2 and faction.fiefs:
|
||||
return faction.fiefs[0]
|
||||
# Fallback: weakest faction with fiefs
|
||||
candidates = [f for f in state.factions if f.fiefs]
|
||||
if candidates:
|
||||
weakest = min(candidates, key=lambda f: f.army_strength)
|
||||
return weakest.fiefs[0]
|
||||
return None
|
||||
|
||||
def _best_alliance_candidate(self, state: GameState) -> str | None:
|
||||
"""Return the best faction to approach for an alliance."""
|
||||
# Prefer factions with good relations and no active war with us
|
||||
enemy_names = set(state.kingdom.active_wars)
|
||||
candidates = [
|
||||
f
|
||||
for f in state.factions
|
||||
if f.name not in enemy_names
|
||||
and f.name != state.kingdom.name
|
||||
]
|
||||
if not candidates:
|
||||
return None
|
||||
# Pick the strongest candidate (most useful ally)
|
||||
return max(candidates, key=lambda f: f.army_strength).name
|
||||
|
||||
# -- accessors ---------------------------------------------------------
|
||||
|
||||
@property
|
||||
def last_subgoal(self) -> KingSubgoal | None:
|
||||
return self._last_subgoal
|
||||
|
||||
@property
|
||||
def tick(self) -> int:
|
||||
return self._tick
|
||||
|
||||
def campaign_summary(self, state: GameState) -> dict[str, Any]:
|
||||
"""Return a brief summary of the campaign status."""
|
||||
return {
|
||||
"tick": self._tick,
|
||||
"in_game_day": state.in_game_day,
|
||||
"has_kingdom": state.has_kingdom(),
|
||||
"fief_count": state.fief_count(),
|
||||
"active_wars": state.active_war_count(),
|
||||
"two_front_war": state.is_two_front_war(),
|
||||
"troops": state.party.troops,
|
||||
"denars": state.party.denars,
|
||||
"survival_goal_met": (
|
||||
state.has_kingdom()
|
||||
and state.fief_count() >= _TARGET_FIEFS
|
||||
and state.in_game_day >= _SURVIVAL_DAYS
|
||||
),
|
||||
}
|
||||
|
||||
def is_done_condition_met(self, state: GameState) -> bool:
|
||||
"""Return True when the M3 done-when condition is satisfied.
|
||||
|
||||
Done when: Timmy establishes own kingdom with 3+ fiefs and
|
||||
survives 100 in-game days as ruler.
|
||||
"""
|
||||
return (
|
||||
state.has_kingdom()
|
||||
and state.fief_count() >= _TARGET_FIEFS
|
||||
and state.in_game_day >= _SURVIVAL_DAYS
|
||||
)
|
||||
@@ -1,236 +0,0 @@
|
||||
"""Bannerlord M3 — War Vassal agent.
|
||||
|
||||
Handles military operations: sieges, field battles, raids, defensive
|
||||
maneuvers. Responds to EXPAND_TERRITORY, RAID_ECONOMY, and TRAIN subgoals.
|
||||
|
||||
Reward function (from feudal hierarchy design):
|
||||
R_war = w1 * ΔTerritoryValue
|
||||
+ w2 * ΔArmyStrength_ratio
|
||||
- w3 * CasualtyCost
|
||||
- w4 * SupplyCost
|
||||
+ w5 * SubgoalBonus
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from bannerlord.types import (
|
||||
GameState,
|
||||
KingSubgoal,
|
||||
SubgoalToken,
|
||||
TaskMessage,
|
||||
VassalReward,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Reward weights
|
||||
_W1_TERRITORY = 0.40
|
||||
_W2_ARMY_RATIO = 0.25
|
||||
_W3_CASUALTY = 0.20
|
||||
_W4_SUPPLY = 0.10
|
||||
_W5_SUBGOAL = 0.05
|
||||
|
||||
_SUBGOAL_TRIGGERS = {
|
||||
SubgoalToken.EXPAND_TERRITORY,
|
||||
SubgoalToken.RAID_ECONOMY,
|
||||
SubgoalToken.TRAIN,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class WarContext:
|
||||
"""Mutable state tracked across War Vassal decisions."""
|
||||
|
||||
active_siege: str | None = None
|
||||
last_auto_resolve_result: dict = field(default_factory=dict)
|
||||
territory_gained: int = 0
|
||||
casualties_taken: int = 0
|
||||
|
||||
|
||||
class WarVassal:
|
||||
"""Mid-tier agent responsible for military operations.
|
||||
|
||||
Runs at 4× the King's decision frequency. Translates KingSubgoals
|
||||
into concrete TaskMessages for the Logistics Companion (troop management)
|
||||
and issues direct GABS calls for combat actions.
|
||||
"""
|
||||
|
||||
AGENT_ID = "war_vassal"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._ctx = WarContext()
|
||||
self._prev_army_ratio: float = 1.0
|
||||
|
||||
def is_relevant(self, subgoal: KingSubgoal) -> bool:
|
||||
"""Return True if this vassal should act on *subgoal*."""
|
||||
return subgoal.token in _SUBGOAL_TRIGGERS
|
||||
|
||||
def plan(self, state: GameState, subgoal: KingSubgoal) -> list[TaskMessage]:
|
||||
"""Return a list of TaskMessages for the current subgoal.
|
||||
|
||||
Args:
|
||||
state: Current game state.
|
||||
subgoal: Active King subgoal.
|
||||
|
||||
Returns:
|
||||
Ordered list of TaskMessages to dispatch.
|
||||
"""
|
||||
tasks: list[TaskMessage] = []
|
||||
|
||||
if subgoal.token == SubgoalToken.EXPAND_TERRITORY:
|
||||
tasks.extend(self._plan_expansion(state, subgoal))
|
||||
elif subgoal.token == SubgoalToken.RAID_ECONOMY:
|
||||
tasks.extend(self._plan_raid(state, subgoal))
|
||||
elif subgoal.token == SubgoalToken.TRAIN:
|
||||
tasks.extend(self._plan_training(state))
|
||||
|
||||
return tasks
|
||||
|
||||
def _plan_expansion(
|
||||
self, state: GameState, subgoal: KingSubgoal
|
||||
) -> list[TaskMessage]:
|
||||
"""Plan territory expansion toward subgoal.target."""
|
||||
tasks: list[TaskMessage] = []
|
||||
target = subgoal.target
|
||||
|
||||
if not target:
|
||||
logger.warning("WarVassal.EXPAND_TERRITORY: no target in subgoal")
|
||||
return tasks
|
||||
|
||||
# Scout garrison before sieging
|
||||
tasks.append(
|
||||
TaskMessage(
|
||||
from_agent=self.AGENT_ID,
|
||||
to_agent="scout_companion",
|
||||
primitive="assess_garrison",
|
||||
args={"settlement": target},
|
||||
priority=1.5,
|
||||
)
|
||||
)
|
||||
|
||||
# Ensure troops are sufficient (delegate to logistics if thin)
|
||||
if state.party.troops < 100:
|
||||
tasks.append(
|
||||
TaskMessage(
|
||||
from_agent=self.AGENT_ID,
|
||||
to_agent="logistics_companion",
|
||||
primitive="recruit_troop",
|
||||
args={"troop_type": "infantry", "qty": 100 - state.party.troops},
|
||||
priority=1.8,
|
||||
)
|
||||
)
|
||||
|
||||
# Issue the siege order
|
||||
tasks.append(
|
||||
TaskMessage(
|
||||
from_agent=self.AGENT_ID,
|
||||
to_agent="gabs",
|
||||
primitive="siege_settlement",
|
||||
args={"settlement": target},
|
||||
priority=subgoal.priority,
|
||||
)
|
||||
)
|
||||
# Follow up with auto-resolve
|
||||
tasks.append(
|
||||
TaskMessage(
|
||||
from_agent=self.AGENT_ID,
|
||||
to_agent="gabs",
|
||||
primitive="auto_resolve_battle",
|
||||
args={},
|
||||
priority=subgoal.priority,
|
||||
)
|
||||
)
|
||||
return tasks
|
||||
|
||||
def _plan_raid(
|
||||
self, state: GameState, subgoal: KingSubgoal
|
||||
) -> list[TaskMessage]:
|
||||
"""Plan economy raid for denars and food."""
|
||||
tasks: list[TaskMessage] = []
|
||||
target = subgoal.target or self._nearest_enemy_village(state)
|
||||
if not target:
|
||||
return tasks
|
||||
tasks.append(
|
||||
TaskMessage(
|
||||
from_agent=self.AGENT_ID,
|
||||
to_agent="gabs",
|
||||
primitive="raid_village",
|
||||
args={"village": target},
|
||||
priority=subgoal.priority,
|
||||
)
|
||||
)
|
||||
return tasks
|
||||
|
||||
def _plan_training(self, state: GameState) -> list[TaskMessage]:
|
||||
"""Plan troop training via auto-resolve bandit fights."""
|
||||
return [
|
||||
TaskMessage(
|
||||
from_agent=self.AGENT_ID,
|
||||
to_agent="logistics_companion",
|
||||
primitive="upgrade_troops",
|
||||
args={},
|
||||
priority=0.8,
|
||||
)
|
||||
]
|
||||
|
||||
# -- reward computation ------------------------------------------------
|
||||
|
||||
def compute_reward(
|
||||
self,
|
||||
prev_state: GameState,
|
||||
curr_state: GameState,
|
||||
active_subgoal: KingSubgoal,
|
||||
) -> VassalReward:
|
||||
"""Compute the War Vassal reward signal for the last decision cycle."""
|
||||
territory_delta = (
|
||||
curr_state.fief_count() - prev_state.fief_count()
|
||||
) * 100.0
|
||||
|
||||
prev_strength = max(prev_state.party.troops, 1)
|
||||
curr_strength = curr_state.party.troops
|
||||
army_delta = (curr_strength - prev_strength) / prev_strength
|
||||
|
||||
casualties = max(0, prev_state.party.troops - curr_strength)
|
||||
supply_burn = max(0, prev_state.party.food_days - curr_state.party.food_days)
|
||||
|
||||
subgoal_bonus = (
|
||||
1.0 if active_subgoal.token in _SUBGOAL_TRIGGERS else 0.0
|
||||
)
|
||||
|
||||
total = (
|
||||
_W1_TERRITORY * territory_delta
|
||||
+ _W2_ARMY_RATIO * army_delta * 10
|
||||
- _W3_CASUALTY * casualties
|
||||
- _W4_SUPPLY * supply_burn
|
||||
+ _W5_SUBGOAL * subgoal_bonus * 10
|
||||
)
|
||||
|
||||
return VassalReward(
|
||||
agent_id=self.AGENT_ID,
|
||||
component_scores={
|
||||
"territory": territory_delta,
|
||||
"army_ratio": army_delta,
|
||||
"casualties": -casualties,
|
||||
"supply_burn": -supply_burn,
|
||||
"subgoal_bonus": subgoal_bonus,
|
||||
},
|
||||
subgoal_bonus=subgoal_bonus,
|
||||
total=total,
|
||||
)
|
||||
|
||||
# -- helpers -----------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _nearest_enemy_village(state: GameState) -> str | None:
|
||||
enemy_names = set(state.kingdom.active_wars)
|
||||
for faction in state.factions:
|
||||
if faction.name in enemy_names and faction.fiefs:
|
||||
return faction.fiefs[0]
|
||||
return None
|
||||
@@ -1,270 +0,0 @@
|
||||
"""Bannerlord M3 — Campaign orchestrator.
|
||||
|
||||
Ties together the King agent, vassals, companions, GABS client, and session
|
||||
memory into a single async campaign loop.
|
||||
|
||||
Architecture::
|
||||
|
||||
CampaignOrchestrator.run()
|
||||
├── GABSClient.get_game_state() → GameState
|
||||
├── KingAgent.decide(state) → KingSubgoal
|
||||
├── SessionMemory.log_subgoal(...)
|
||||
├── WarVassal.plan(state, subgoal) → [TaskMessage]
|
||||
├── EconomyVassal.plan(state, subgoal) → [TaskMessage]
|
||||
├── DiplomacyVassal.plan(state, subgoal)→ [TaskMessage]
|
||||
├── [Companions].evaluate(state, subgoal) → [primitives]
|
||||
└── _dispatch_tasks([...]) → GABS calls
|
||||
|
||||
Usage::
|
||||
|
||||
from bannerlord.campaign import CampaignOrchestrator
|
||||
orch = CampaignOrchestrator()
|
||||
await orch.run(max_ticks=100)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from bannerlord.agents.companions.caravan import CaravanCompanion
|
||||
from bannerlord.agents.companions.logistics import LogisticsCompanion
|
||||
from bannerlord.agents.companions.scout import ScoutCompanion
|
||||
from bannerlord.agents.diplomacy_vassal import DiplomacyVassal
|
||||
from bannerlord.agents.economy_vassal import EconomyVassal
|
||||
from bannerlord.agents.king import KingAgent
|
||||
from bannerlord.agents.war_vassal import WarVassal
|
||||
from bannerlord.gabs_client import GABSClient
|
||||
from bannerlord.session_memory import SessionMemory
|
||||
from bannerlord.types import GameState, KingSubgoal, TaskMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_TICK_INTERVAL = 1.0 # seconds between campaign ticks (real time)
|
||||
_DEFAULT_DB_PATH = Path("data/bannerlord/campaign.db")
|
||||
_KINGDOM_NAME = "House Timmerson"
|
||||
|
||||
|
||||
class CampaignOrchestrator:
|
||||
"""Full-campaign strategy orchestrator for Bannerlord M3.
|
||||
|
||||
Runs the King → Vassal → Companion decision loop on each campaign tick.
|
||||
Persists progress to SQLite via SessionMemory.
|
||||
|
||||
Args:
|
||||
gabs_host: Hostname where GABS mod is listening.
|
||||
gabs_port: TCP port for GABS JSON-RPC (default 4825).
|
||||
tick_interval: Real-time seconds between campaign ticks.
|
||||
db_path: Path to the SQLite session memory database.
|
||||
session_id: Existing session to resume (None = new session).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
gabs_host: str = "127.0.0.1",
|
||||
gabs_port: int = 4825,
|
||||
tick_interval: float = _DEFAULT_TICK_INTERVAL,
|
||||
db_path: Path = _DEFAULT_DB_PATH,
|
||||
session_id: str | None = None,
|
||||
) -> None:
|
||||
self._gabs = GABSClient(host=gabs_host, port=gabs_port)
|
||||
self._king = KingAgent()
|
||||
self._war = WarVassal()
|
||||
self._economy = EconomyVassal()
|
||||
self._diplomacy = DiplomacyVassal()
|
||||
self._logistics = LogisticsCompanion()
|
||||
self._caravan = CaravanCompanion()
|
||||
self._scout = ScoutCompanion()
|
||||
self._memory = SessionMemory(db_path)
|
||||
self._tick_interval = tick_interval
|
||||
self._session_id = session_id
|
||||
self._running = False
|
||||
self._prev_state: GameState | None = None
|
||||
|
||||
# -- lifecycle ---------------------------------------------------------
|
||||
|
||||
async def start(self) -> str:
|
||||
"""Connect to GABS and initialise a campaign session.
|
||||
|
||||
Returns the active session_id.
|
||||
"""
|
||||
connected = await self._gabs.connect()
|
||||
if not connected:
|
||||
logger.warning(
|
||||
"CampaignOrchestrator: GABS unavailable — campaign will run "
|
||||
"in degraded mode (no game state updates)"
|
||||
)
|
||||
|
||||
if self._session_id is None:
|
||||
self._session_id = self._memory.start_session()
|
||||
else:
|
||||
# Resume existing session
|
||||
existing = self._memory.get_session(self._session_id)
|
||||
if not existing:
|
||||
self._session_id = self._memory.start_session(self._session_id)
|
||||
|
||||
self._king.set_session(self._session_id)
|
||||
logger.info("CampaignOrchestrator: session=%s", self._session_id)
|
||||
return self._session_id
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Gracefully stop the campaign and disconnect from GABS."""
|
||||
self._running = False
|
||||
await self._gabs.disconnect()
|
||||
logger.info("CampaignOrchestrator: stopped")
|
||||
|
||||
# -- main campaign loop ------------------------------------------------
|
||||
|
||||
async def run(self, max_ticks: int = 0) -> dict[str, Any]:
|
||||
"""Run the campaign loop.
|
||||
|
||||
Args:
|
||||
max_ticks: Stop after this many ticks. 0 = run indefinitely.
|
||||
|
||||
Returns:
|
||||
Campaign summary dict.
|
||||
"""
|
||||
if not self._session_id:
|
||||
await self.start()
|
||||
|
||||
self._running = True
|
||||
tick = 0
|
||||
|
||||
logger.info(
|
||||
"CampaignOrchestrator: starting campaign loop (max_ticks=%d)",
|
||||
max_ticks,
|
||||
)
|
||||
|
||||
try:
|
||||
while self._running:
|
||||
if max_ticks > 0 and tick >= max_ticks:
|
||||
break
|
||||
|
||||
await self._tick(tick)
|
||||
tick += 1
|
||||
|
||||
# Check done condition
|
||||
if self._prev_state and self._king.is_done_condition_met(
|
||||
self._prev_state
|
||||
):
|
||||
logger.info(
|
||||
"CampaignOrchestrator: M3 DONE condition met on tick %d!", tick
|
||||
)
|
||||
self._memory.add_note(
|
||||
self._session_id or "",
|
||||
self._prev_state.in_game_day,
|
||||
"milestone",
|
||||
"M3 done condition met — kingdom with 3+ fiefs, 100 days survived",
|
||||
)
|
||||
break
|
||||
|
||||
await asyncio.sleep(self._tick_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("CampaignOrchestrator: loop cancelled")
|
||||
finally:
|
||||
await self.stop()
|
||||
|
||||
return self._summary(tick)
|
||||
|
||||
async def _tick(self, tick: int) -> None:
|
||||
"""Execute one campaign tick."""
|
||||
# 1. Observe
|
||||
state = await self._gabs.get_game_state()
|
||||
state.tick = tick
|
||||
|
||||
# 2. King decides
|
||||
subgoal = self._king.decide(state)
|
||||
|
||||
# 3. Log subgoal to session memory
|
||||
if self._session_id:
|
||||
row_id = self._memory.log_subgoal(
|
||||
self._session_id, tick, state.in_game_day, subgoal
|
||||
)
|
||||
self._memory.update_session(
|
||||
self._session_id,
|
||||
in_game_day=state.in_game_day,
|
||||
fief_count=state.fief_count(),
|
||||
kingdom_name=state.kingdom.name or None,
|
||||
)
|
||||
|
||||
# 4. Vassal planning
|
||||
tasks: list[TaskMessage] = []
|
||||
if self._war.is_relevant(subgoal):
|
||||
tasks.extend(self._war.plan(state, subgoal))
|
||||
if self._economy.is_relevant(subgoal):
|
||||
tasks.extend(self._economy.plan(state, subgoal))
|
||||
if self._diplomacy.is_relevant(subgoal):
|
||||
tasks.extend(self._diplomacy.plan(state, subgoal))
|
||||
|
||||
# 5. Companion evaluation
|
||||
companion_actions = (
|
||||
self._logistics.evaluate(state, subgoal)
|
||||
+ self._caravan.evaluate(state, subgoal)
|
||||
+ self._scout.evaluate(state, subgoal)
|
||||
)
|
||||
|
||||
# 6. Dispatch tasks + companion primitives to GABS
|
||||
await self._dispatch_tasks(tasks, state)
|
||||
await self._dispatch_primitives(companion_actions)
|
||||
|
||||
# 7. Kingdom establishment check
|
||||
if not state.has_kingdom() and state.fief_count() > 0:
|
||||
# We have a fief but no kingdom yet — establish one
|
||||
ok = await self._gabs.establish_kingdom(_KINGDOM_NAME)
|
||||
if ok and self._session_id:
|
||||
self._memory.record_kingdom_established(
|
||||
self._session_id, state.in_game_day, _KINGDOM_NAME
|
||||
)
|
||||
|
||||
self._prev_state = state
|
||||
logger.debug(
|
||||
"Tick %d: day=%d, subgoal=%s, tasks=%d, companions=%d",
|
||||
tick,
|
||||
state.in_game_day,
|
||||
subgoal.token,
|
||||
len(tasks),
|
||||
len(companion_actions),
|
||||
)
|
||||
|
||||
# -- task dispatch -----------------------------------------------------
|
||||
|
||||
async def _dispatch_tasks(
|
||||
self, tasks: list[TaskMessage], state: GameState
|
||||
) -> None:
|
||||
"""Dispatch vassal TaskMessages to GABS."""
|
||||
for task in sorted(tasks, key=lambda t: t.priority, reverse=True):
|
||||
if task.to_agent != "gabs":
|
||||
# Companion-directed tasks are handled via companion.evaluate()
|
||||
continue
|
||||
await self._gabs._call(task.primitive, task.args)
|
||||
|
||||
async def _dispatch_primitives(self, actions: list[dict]) -> None:
|
||||
"""Dispatch companion primitive actions to GABS."""
|
||||
for action in actions:
|
||||
primitive = action.get("primitive", "")
|
||||
args = action.get("args", {})
|
||||
if primitive:
|
||||
await self._gabs._call(primitive, args)
|
||||
|
||||
# -- summary -----------------------------------------------------------
|
||||
|
||||
def _summary(self, ticks_run: int) -> dict[str, Any]:
|
||||
state = self._prev_state or GameState()
|
||||
summary = self._king.campaign_summary(state)
|
||||
summary["ticks_run"] = ticks_run
|
||||
summary["session_id"] = self._session_id
|
||||
return summary
|
||||
|
||||
# -- accessors ---------------------------------------------------------
|
||||
|
||||
@property
|
||||
def session_id(self) -> str | None:
|
||||
return self._session_id
|
||||
|
||||
@property
|
||||
def memory(self) -> SessionMemory:
|
||||
return self._memory
|
||||
@@ -1,434 +0,0 @@
|
||||
"""GABS TCP JSON-RPC client — connects to the Bannerlord.GABS mod.
|
||||
|
||||
The GABS (Game Agent Behavior System) mod exposes 90+ tools via a
|
||||
TCP JSON-RPC 2.0 server on port 4825. This client wraps the transport
|
||||
into a clean async interface used by all Bannerlord agents.
|
||||
|
||||
Degrades gracefully: if GABS is unreachable, methods return sensible
|
||||
fallbacks and log a warning (never crash).
|
||||
|
||||
Architecture reference:
|
||||
Timmy (Qwen3 on Ollama, M3 Max)
|
||||
→ GABSClient (this module, TCP JSON-RPC, port 4825)
|
||||
→ Bannerlord.GABS C# mod
|
||||
→ Game API + Harmony
|
||||
→ Bannerlord (Windows VM)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from bannerlord.types import (
|
||||
FactionState,
|
||||
GameState,
|
||||
KingdomState,
|
||||
PartyState,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_HOST = "127.0.0.1"
|
||||
_DEFAULT_PORT = 4825
|
||||
_DEFAULT_TIMEOUT = 10.0 # seconds
|
||||
_RECONNECT_DELAY = 5.0 # seconds between reconnect attempts
|
||||
|
||||
|
||||
@dataclass
|
||||
class GABSTool:
|
||||
"""Metadata for a single GABS tool."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class GABSConnectionError(Exception):
|
||||
"""Raised when GABS is unreachable and a fallback is not possible."""
|
||||
|
||||
|
||||
class GABSClient:
|
||||
"""Async TCP JSON-RPC 2.0 client for the Bannerlord.GABS mod.
|
||||
|
||||
Usage::
|
||||
|
||||
async with GABSClient() as client:
|
||||
state = await client.get_game_state()
|
||||
await client.move_party("Vlandia")
|
||||
|
||||
All public methods degrade gracefully — they return ``None`` or an
|
||||
empty structure when GABS is unavailable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = _DEFAULT_HOST,
|
||||
port: int = _DEFAULT_PORT,
|
||||
timeout: float = _DEFAULT_TIMEOUT,
|
||||
) -> None:
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._timeout = timeout
|
||||
self._reader: asyncio.StreamReader | None = None
|
||||
self._writer: asyncio.StreamWriter | None = None
|
||||
self._connected = False
|
||||
self._call_id = 0
|
||||
self._available_tools: list[GABSTool] = []
|
||||
|
||||
# -- lifecycle ---------------------------------------------------------
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Open a TCP connection to GABS.
|
||||
|
||||
Returns:
|
||||
True if connected successfully, False if GABS is unavailable.
|
||||
"""
|
||||
try:
|
||||
self._reader, self._writer = await asyncio.wait_for(
|
||||
asyncio.open_connection(self._host, self._port),
|
||||
timeout=self._timeout,
|
||||
)
|
||||
self._connected = True
|
||||
logger.info("GABSClient connected to %s:%d", self._host, self._port)
|
||||
await self._discover_tools()
|
||||
return True
|
||||
except (ConnectionRefusedError, OSError, TimeoutError, asyncio.TimeoutError) as exc:
|
||||
logger.warning(
|
||||
"GABSClient could not connect to %s:%d — %s",
|
||||
self._host,
|
||||
self._port,
|
||||
exc,
|
||||
)
|
||||
self._connected = False
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Close the TCP connection."""
|
||||
if self._writer is not None:
|
||||
try:
|
||||
self._writer.close()
|
||||
await self._writer.wait_closed()
|
||||
except Exception as exc:
|
||||
logger.debug("GABSClient disconnect error: %s", exc)
|
||||
self._connected = False
|
||||
logger.info("GABSClient disconnected")
|
||||
|
||||
async def __aenter__(self) -> GABSClient:
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *_: object) -> None:
|
||||
await self.disconnect()
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._connected
|
||||
|
||||
# -- raw JSON-RPC transport --------------------------------------------
|
||||
|
||||
def _next_id(self) -> int:
|
||||
self._call_id += 1
|
||||
return self._call_id
|
||||
|
||||
async def _call(self, method: str, params: dict[str, Any] | None = None) -> Any:
|
||||
"""Send a JSON-RPC 2.0 request and return the result.
|
||||
|
||||
Returns ``None`` and logs a warning on any error.
|
||||
"""
|
||||
if not self._connected:
|
||||
logger.warning("GABSClient._call(%s): not connected", method)
|
||||
return None
|
||||
|
||||
request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._next_id(),
|
||||
"method": method,
|
||||
"params": params or {},
|
||||
}
|
||||
|
||||
try:
|
||||
payload = json.dumps(request) + "\n"
|
||||
assert self._writer is not None
|
||||
self._writer.write(payload.encode())
|
||||
await asyncio.wait_for(self._writer.drain(), timeout=self._timeout)
|
||||
|
||||
assert self._reader is not None
|
||||
raw = await asyncio.wait_for(
|
||||
self._reader.readline(), timeout=self._timeout
|
||||
)
|
||||
response = json.loads(raw.decode().strip())
|
||||
|
||||
if "error" in response:
|
||||
logger.warning(
|
||||
"GABS error for %s: %s", method, response["error"].get("message")
|
||||
)
|
||||
return None
|
||||
|
||||
return response.get("result")
|
||||
|
||||
except (asyncio.TimeoutError, json.JSONDecodeError, AssertionError, OSError) as exc:
|
||||
logger.warning("GABSClient._call(%s) failed: %s", method, exc)
|
||||
self._connected = False
|
||||
return None
|
||||
|
||||
# -- tool discovery ----------------------------------------------------
|
||||
|
||||
async def _discover_tools(self) -> None:
|
||||
"""Populate self._available_tools via GABS tools/list."""
|
||||
result = await self._call("tools/list")
|
||||
if not result:
|
||||
return
|
||||
self._available_tools = [
|
||||
GABSTool(
|
||||
name=t.get("name", ""),
|
||||
description=t.get("description", ""),
|
||||
parameters=t.get("parameters", {}),
|
||||
)
|
||||
for t in (result if isinstance(result, list) else [])
|
||||
]
|
||||
logger.info("GABS: discovered %d tools", len(self._available_tools))
|
||||
|
||||
@property
|
||||
def available_tools(self) -> list[GABSTool]:
|
||||
"""Return the list of tools discovered from GABS."""
|
||||
return list(self._available_tools)
|
||||
|
||||
def tool_count(self) -> int:
|
||||
return len(self._available_tools)
|
||||
|
||||
# -- game state --------------------------------------------------------
|
||||
|
||||
async def get_game_state(self) -> GameState:
|
||||
"""Return the full campaign state snapshot.
|
||||
|
||||
Falls back to an empty GameState if GABS is unavailable.
|
||||
"""
|
||||
raw = await self._call("game/get_state")
|
||||
if raw is None:
|
||||
return GameState()
|
||||
|
||||
try:
|
||||
return self._parse_game_state(raw)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to parse GABS game state: %s", exc)
|
||||
return GameState()
|
||||
|
||||
def _parse_game_state(self, raw: dict[str, Any]) -> GameState:
|
||||
"""Convert raw GABS state dict into a typed GameState."""
|
||||
party_raw = raw.get("party", {})
|
||||
kingdom_raw = raw.get("kingdom", {})
|
||||
factions_raw = raw.get("factions", [])
|
||||
|
||||
party = PartyState(
|
||||
location=party_raw.get("location", ""),
|
||||
troops=party_raw.get("troops", 0),
|
||||
food_days=party_raw.get("food_days", 0),
|
||||
wounded_pct=party_raw.get("wounded_pct", 0.0),
|
||||
denars=party_raw.get("denars", 0),
|
||||
morale=party_raw.get("morale", 100.0),
|
||||
prisoners=party_raw.get("prisoners", 0),
|
||||
)
|
||||
|
||||
kingdom = KingdomState(
|
||||
name=kingdom_raw.get("name", ""),
|
||||
fiefs=kingdom_raw.get("fiefs", []),
|
||||
daily_income=kingdom_raw.get("daily_income", 0),
|
||||
daily_expenses=kingdom_raw.get("daily_expenses", 0),
|
||||
vassal_lords=kingdom_raw.get("vassal_lords", []),
|
||||
active_wars=kingdom_raw.get("active_wars", []),
|
||||
active_alliances=kingdom_raw.get("active_alliances", []),
|
||||
in_game_day=raw.get("in_game_day", 0),
|
||||
)
|
||||
|
||||
factions = [
|
||||
FactionState(
|
||||
name=f.get("name", ""),
|
||||
leader=f.get("leader", ""),
|
||||
fiefs=f.get("fiefs", []),
|
||||
army_strength=f.get("army_strength", 0),
|
||||
treasury=f.get("treasury", 0),
|
||||
is_at_war_with=f.get("is_at_war_with", []),
|
||||
relations=f.get("relations", {}),
|
||||
)
|
||||
for f in (factions_raw if isinstance(factions_raw, list) else [])
|
||||
]
|
||||
|
||||
return GameState(
|
||||
tick=raw.get("tick", 0),
|
||||
in_game_day=raw.get("in_game_day", 0),
|
||||
timestamp=datetime.now(UTC),
|
||||
party=party,
|
||||
kingdom=kingdom,
|
||||
factions=factions,
|
||||
raw=raw,
|
||||
)
|
||||
|
||||
# -- party actions -----------------------------------------------------
|
||||
|
||||
async def move_party(self, destination: str) -> bool:
|
||||
"""Command Timmy's party to move toward *destination*."""
|
||||
result = await self._call("party/move", {"destination": destination})
|
||||
return result is not None
|
||||
|
||||
async def recruit_troops(self, troop_type: str, quantity: int) -> bool:
|
||||
"""Recruit *quantity* troops of *troop_type* at current location."""
|
||||
result = await self._call(
|
||||
"party/recruit",
|
||||
{"troop_type": troop_type, "quantity": quantity},
|
||||
)
|
||||
return result is not None
|
||||
|
||||
async def buy_supplies(self, quantity: int) -> bool:
|
||||
"""Purchase food supplies for *quantity* days of march."""
|
||||
result = await self._call("party/buy_supplies", {"quantity": quantity})
|
||||
return result is not None
|
||||
|
||||
async def rest_party(self, days: int) -> bool:
|
||||
"""Rest the party in current location for *days* in-game days."""
|
||||
result = await self._call("party/rest", {"days": days})
|
||||
return result is not None
|
||||
|
||||
async def auto_resolve_battle(self) -> dict[str, Any]:
|
||||
"""Trigger auto-resolve for the current battle.
|
||||
|
||||
Returns the battle outcome dict, or empty dict on failure.
|
||||
"""
|
||||
result = await self._call("battle/auto_resolve")
|
||||
return result if isinstance(result, dict) else {}
|
||||
|
||||
async def upgrade_troops(self) -> bool:
|
||||
"""Spend accumulated XP on troop tier upgrades."""
|
||||
result = await self._call("party/upgrade_troops")
|
||||
return result is not None
|
||||
|
||||
async def sell_prisoners(self, location: str) -> int:
|
||||
"""Sell prisoners at *location*. Returns denars gained."""
|
||||
result = await self._call("party/sell_prisoners", {"location": location})
|
||||
if isinstance(result, dict):
|
||||
return result.get("denars_gained", 0)
|
||||
return 0
|
||||
|
||||
# -- trade actions -----------------------------------------------------
|
||||
|
||||
async def assess_prices(self, town: str) -> dict[str, Any]:
|
||||
"""Query buy/sell prices at *town*."""
|
||||
result = await self._call("trade/assess_prices", {"town": town})
|
||||
return result if isinstance(result, dict) else {}
|
||||
|
||||
async def buy_goods(self, item: str, quantity: int) -> bool:
|
||||
"""Purchase *quantity* of *item* at current location."""
|
||||
result = await self._call("trade/buy", {"item": item, "quantity": quantity})
|
||||
return result is not None
|
||||
|
||||
async def sell_goods(self, item: str, quantity: int, location: str) -> bool:
|
||||
"""Sell *quantity* of *item* at *location*."""
|
||||
result = await self._call(
|
||||
"trade/sell",
|
||||
{"item": item, "quantity": quantity, "location": location},
|
||||
)
|
||||
return result is not None
|
||||
|
||||
async def establish_caravan(self, town: str) -> bool:
|
||||
"""Deploy a caravan NPC at *town*."""
|
||||
result = await self._call("trade/establish_caravan", {"town": town})
|
||||
return result is not None
|
||||
|
||||
# -- diplomacy actions -------------------------------------------------
|
||||
|
||||
async def send_envoy(self, faction: str, message: str) -> bool:
|
||||
"""Send a diplomatic message to *faction*."""
|
||||
result = await self._call(
|
||||
"diplomacy/send_envoy",
|
||||
{"faction": faction, "message": message},
|
||||
)
|
||||
return result is not None
|
||||
|
||||
async def propose_peace(self, faction: str, tribute: int = 0) -> bool:
|
||||
"""Propose peace with *faction*, optionally offering *tribute* denars."""
|
||||
result = await self._call(
|
||||
"diplomacy/propose_peace",
|
||||
{"faction": faction, "tribute": tribute},
|
||||
)
|
||||
return result is not None
|
||||
|
||||
async def request_alliance(self, faction: str) -> bool:
|
||||
"""Request a military alliance with *faction*."""
|
||||
result = await self._call(
|
||||
"diplomacy/request_alliance",
|
||||
{"faction": faction},
|
||||
)
|
||||
return result is not None
|
||||
|
||||
async def request_military_access(self, faction: str) -> bool:
|
||||
"""Request military access through *faction*'s territory."""
|
||||
result = await self._call(
|
||||
"diplomacy/military_access",
|
||||
{"faction": faction},
|
||||
)
|
||||
return result is not None
|
||||
|
||||
# -- settlement / kingdom actions --------------------------------------
|
||||
|
||||
async def siege_settlement(self, settlement: str) -> bool:
|
||||
"""Begin siege of *settlement*."""
|
||||
result = await self._call("military/siege", {"settlement": settlement})
|
||||
return result is not None
|
||||
|
||||
async def raid_village(self, village: str) -> bool:
|
||||
"""Raid *village* for food and denars."""
|
||||
result = await self._call("military/raid_village", {"village": village})
|
||||
return result is not None
|
||||
|
||||
async def build_project(self, settlement: str, project: str) -> bool:
|
||||
"""Queue a construction *project* in *settlement*."""
|
||||
result = await self._call(
|
||||
"settlement/build",
|
||||
{"settlement": settlement, "project": project},
|
||||
)
|
||||
return result is not None
|
||||
|
||||
async def set_tax_policy(self, settlement: str, policy: str) -> bool:
|
||||
"""Set tax *policy* for *settlement* (e.g. 'low', 'normal', 'high')."""
|
||||
result = await self._call(
|
||||
"settlement/set_tax",
|
||||
{"settlement": settlement, "policy": policy},
|
||||
)
|
||||
return result is not None
|
||||
|
||||
async def appoint_governor(self, settlement: str, lord: str) -> bool:
|
||||
"""Appoint *lord* as governor of *settlement*."""
|
||||
result = await self._call(
|
||||
"settlement/appoint_governor",
|
||||
{"settlement": settlement, "lord": lord},
|
||||
)
|
||||
return result is not None
|
||||
|
||||
async def establish_kingdom(self, name: str) -> bool:
|
||||
"""Declare a new kingdom with *name* (requires a captured fief)."""
|
||||
result = await self._call("kingdom/establish", {"name": name})
|
||||
ok = result is not None
|
||||
if ok:
|
||||
logger.info("Kingdom '%s' established!", name)
|
||||
return ok
|
||||
|
||||
# -- scouting ----------------------------------------------------------
|
||||
|
||||
async def track_lord(self, lord_name: str) -> dict[str, Any]:
|
||||
"""Shadow *lord_name* and return their last-known position."""
|
||||
result = await self._call("scout/track_lord", {"lord": lord_name})
|
||||
return result if isinstance(result, dict) else {}
|
||||
|
||||
async def assess_garrison(self, settlement: str) -> dict[str, Any]:
|
||||
"""Estimate defender count for *settlement*."""
|
||||
result = await self._call("scout/assess_garrison", {"settlement": settlement})
|
||||
return result if isinstance(result, dict) else {}
|
||||
|
||||
async def map_patrol_routes(self, region: str) -> list[dict[str, Any]]:
|
||||
"""Log enemy patrol routes in *region*."""
|
||||
result = await self._call("scout/patrol_routes", {"region": region})
|
||||
return result if isinstance(result, list) else []
|
||||
@@ -1,347 +0,0 @@
|
||||
"""Bannerlord M3 — Session memory for multi-day strategic plans.
|
||||
|
||||
Persists the King's strategic plans, completed subgoals, and kingdom
|
||||
milestones to SQLite so the campaign can be interrupted and resumed
|
||||
across multiple sessions.
|
||||
|
||||
Pattern follows the existing EventBus persistence model.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
from contextlib import closing
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from bannerlord.types import KingSubgoal, SubgoalToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS campaign_sessions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT NOT NULL,
|
||||
started_at TEXT NOT NULL,
|
||||
last_updated TEXT NOT NULL,
|
||||
kingdom_name TEXT DEFAULT '',
|
||||
in_game_day INTEGER DEFAULT 0,
|
||||
fief_count INTEGER DEFAULT 0,
|
||||
meta TEXT DEFAULT '{}'
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS subgoal_log (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT NOT NULL,
|
||||
tick INTEGER NOT NULL,
|
||||
in_game_day INTEGER NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
target TEXT,
|
||||
quantity INTEGER,
|
||||
priority REAL DEFAULT 1.0,
|
||||
deadline_days INTEGER,
|
||||
context TEXT,
|
||||
issued_at TEXT NOT NULL,
|
||||
completed_at TEXT,
|
||||
outcome TEXT DEFAULT 'pending'
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS strategy_notes (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT NOT NULL,
|
||||
in_game_day INTEGER NOT NULL,
|
||||
note_type TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
recorded_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_subgoal_session ON subgoal_log(session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_subgoal_tick ON subgoal_log(tick);
|
||||
CREATE INDEX IF NOT EXISTS idx_notes_session ON strategy_notes(session_id);
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CampaignMilestone:
|
||||
"""A notable campaign achievement recorded in session memory."""
|
||||
|
||||
in_game_day: int
|
||||
event: str
|
||||
detail: str = ""
|
||||
recorded_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
|
||||
class SessionMemory:
|
||||
"""SQLite-backed session memory for the Bannerlord campaign.
|
||||
|
||||
Stores:
|
||||
- Active session metadata (kingdom name, in-game day, fief count)
|
||||
- Full subgoal history (every KingSubgoal issued and its outcome)
|
||||
- Strategy notes / milestones for campaign reflection
|
||||
|
||||
Usage::
|
||||
|
||||
mem = SessionMemory(Path("data/bannerlord/campaign.db"))
|
||||
session_id = mem.start_session()
|
||||
mem.log_subgoal(session_id, tick=1, day=42, subgoal)
|
||||
mem.complete_subgoal(subgoal_id, outcome="success")
|
||||
mem.add_note(session_id, day=42, note_type="milestone",
|
||||
content="Kingdom established: House Timmerson")
|
||||
history = mem.get_recent_subgoals(session_id, limit=10)
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Path) -> None:
|
||||
self._db_path = db_path
|
||||
self._init_db()
|
||||
|
||||
def _init_db(self) -> None:
|
||||
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with closing(sqlite3.connect(str(self._db_path))) as conn:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
conn.executescript(_SCHEMA)
|
||||
conn.commit()
|
||||
|
||||
def _conn(self) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(str(self._db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
return conn
|
||||
|
||||
# -- session lifecycle -------------------------------------------------
|
||||
|
||||
def start_session(self, session_id: str | None = None) -> str:
|
||||
"""Create a new campaign session. Returns the session_id."""
|
||||
if session_id is None:
|
||||
session_id = f"session_{datetime.now(UTC).strftime('%Y%m%d_%H%M%S')}"
|
||||
now = datetime.now(UTC).isoformat()
|
||||
with closing(self._conn()) as conn:
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO campaign_sessions "
|
||||
"(session_id, started_at, last_updated) VALUES (?, ?, ?)",
|
||||
(session_id, now, now),
|
||||
)
|
||||
conn.commit()
|
||||
logger.info("SessionMemory: started campaign session %s", session_id)
|
||||
return session_id
|
||||
|
||||
def update_session(
|
||||
self,
|
||||
session_id: str,
|
||||
*,
|
||||
kingdom_name: str | None = None,
|
||||
in_game_day: int | None = None,
|
||||
fief_count: int | None = None,
|
||||
meta: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Update the campaign session state."""
|
||||
now = datetime.now(UTC).isoformat()
|
||||
with closing(self._conn()) as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM campaign_sessions WHERE session_id = ?",
|
||||
(session_id,),
|
||||
).fetchone()
|
||||
if not row:
|
||||
return
|
||||
|
||||
current_meta = json.loads(row["meta"] or "{}")
|
||||
if meta:
|
||||
current_meta.update(meta)
|
||||
|
||||
conn.execute(
|
||||
"""UPDATE campaign_sessions SET
|
||||
last_updated = ?,
|
||||
kingdom_name = COALESCE(?, kingdom_name),
|
||||
in_game_day = COALESCE(?, in_game_day),
|
||||
fief_count = COALESCE(?, fief_count),
|
||||
meta = ?
|
||||
WHERE session_id = ?""",
|
||||
(
|
||||
now,
|
||||
kingdom_name,
|
||||
in_game_day,
|
||||
fief_count,
|
||||
json.dumps(current_meta),
|
||||
session_id,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_session(self, session_id: str) -> dict[str, Any] | None:
|
||||
"""Return session metadata dict or None if not found."""
|
||||
with closing(self._conn()) as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM campaign_sessions WHERE session_id = ?",
|
||||
(session_id,),
|
||||
).fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return dict(row)
|
||||
|
||||
def list_sessions(self) -> list[dict[str, Any]]:
|
||||
"""Return all campaign sessions, most recent first."""
|
||||
with closing(self._conn()) as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM campaign_sessions ORDER BY last_updated DESC"
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
# -- subgoal log -------------------------------------------------------
|
||||
|
||||
def log_subgoal(
|
||||
self,
|
||||
session_id: str,
|
||||
tick: int,
|
||||
in_game_day: int,
|
||||
subgoal: KingSubgoal,
|
||||
) -> int:
|
||||
"""Record a subgoal emission. Returns the row id."""
|
||||
with closing(self._conn()) as conn:
|
||||
cursor = conn.execute(
|
||||
"""INSERT INTO subgoal_log
|
||||
(session_id, tick, in_game_day, token, target, quantity,
|
||||
priority, deadline_days, context, issued_at, outcome)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 'pending')""",
|
||||
(
|
||||
session_id,
|
||||
tick,
|
||||
in_game_day,
|
||||
str(subgoal.token),
|
||||
subgoal.target,
|
||||
subgoal.quantity,
|
||||
subgoal.priority,
|
||||
subgoal.deadline_days,
|
||||
subgoal.context,
|
||||
subgoal.issued_at.isoformat(),
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.lastrowid or 0
|
||||
|
||||
def complete_subgoal(self, row_id: int, outcome: str = "success") -> None:
|
||||
"""Mark a subgoal log entry as completed."""
|
||||
with closing(self._conn()) as conn:
|
||||
conn.execute(
|
||||
"UPDATE subgoal_log SET completed_at = ?, outcome = ? WHERE id = ?",
|
||||
(datetime.now(UTC).isoformat(), outcome, row_id),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_recent_subgoals(
|
||||
self, session_id: str, limit: int = 20
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return the *limit* most recent subgoal log entries."""
|
||||
with closing(self._conn()) as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM subgoal_log WHERE session_id = ? "
|
||||
"ORDER BY tick DESC LIMIT ?",
|
||||
(session_id, limit),
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
def count_token(self, session_id: str, token: SubgoalToken) -> int:
|
||||
"""Count how many times a subgoal token has been issued in a session."""
|
||||
with closing(self._conn()) as conn:
|
||||
row = conn.execute(
|
||||
"SELECT COUNT(*) as n FROM subgoal_log "
|
||||
"WHERE session_id = ? AND token = ?",
|
||||
(session_id, str(token)),
|
||||
).fetchone()
|
||||
return row["n"] if row else 0
|
||||
|
||||
# -- strategy notes ----------------------------------------------------
|
||||
|
||||
def add_note(
|
||||
self,
|
||||
session_id: str,
|
||||
in_game_day: int,
|
||||
note_type: str,
|
||||
content: str,
|
||||
) -> None:
|
||||
"""Record a strategy note or milestone."""
|
||||
with closing(self._conn()) as conn:
|
||||
conn.execute(
|
||||
"INSERT INTO strategy_notes "
|
||||
"(session_id, in_game_day, note_type, content, recorded_at) "
|
||||
"VALUES (?, ?, ?, ?, ?)",
|
||||
(
|
||||
session_id,
|
||||
in_game_day,
|
||||
note_type,
|
||||
content,
|
||||
datetime.now(UTC).isoformat(),
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_notes(
|
||||
self,
|
||||
session_id: str,
|
||||
note_type: str | None = None,
|
||||
limit: int = 50,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return strategy notes, optionally filtered by type."""
|
||||
with closing(self._conn()) as conn:
|
||||
if note_type:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM strategy_notes "
|
||||
"WHERE session_id = ? AND note_type = ? "
|
||||
"ORDER BY in_game_day DESC LIMIT ?",
|
||||
(session_id, note_type, limit),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM strategy_notes WHERE session_id = ? "
|
||||
"ORDER BY in_game_day DESC LIMIT ?",
|
||||
(session_id, limit),
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
def get_milestones(self, session_id: str) -> list[dict[str, Any]]:
|
||||
"""Return all milestone notes for a session."""
|
||||
return self.get_notes(session_id, note_type="milestone", limit=200)
|
||||
|
||||
# -- diplomatic memory -------------------------------------------------
|
||||
|
||||
def record_war_declared(
|
||||
self, session_id: str, in_game_day: int, faction: str
|
||||
) -> None:
|
||||
"""Log that Timmy declared war on *faction*."""
|
||||
self.add_note(
|
||||
session_id,
|
||||
in_game_day,
|
||||
"war_declared",
|
||||
f"Declared war on {faction}",
|
||||
)
|
||||
|
||||
def record_peace_agreed(
|
||||
self, session_id: str, in_game_day: int, faction: str
|
||||
) -> None:
|
||||
"""Log that Timmy agreed to peace with *faction*."""
|
||||
self.add_note(
|
||||
session_id,
|
||||
in_game_day,
|
||||
"peace_agreed",
|
||||
f"Peace agreed with {faction}",
|
||||
)
|
||||
|
||||
def record_kingdom_established(
|
||||
self, session_id: str, in_game_day: int, kingdom_name: str
|
||||
) -> None:
|
||||
"""Record the kingdom establishment milestone."""
|
||||
self.add_note(
|
||||
session_id,
|
||||
in_game_day,
|
||||
"milestone",
|
||||
f"Kingdom established: {kingdom_name}",
|
||||
)
|
||||
self.update_session(session_id, kingdom_name=kingdom_name, in_game_day=in_game_day)
|
||||
logger.info(
|
||||
"SessionMemory: milestone — kingdom '%s' established on day %d",
|
||||
kingdom_name,
|
||||
in_game_day,
|
||||
)
|
||||
@@ -1,226 +0,0 @@
|
||||
"""Bannerlord M3 — core data types for the campaign strategy system.
|
||||
|
||||
KingSubgoal schema and all message types used by the feudal agent hierarchy.
|
||||
Design follows the Feudal Multi-Agent Hierarchies (Ahilan & Dayan, 2019) model
|
||||
specified in docs/research/bannerlord-feudal-hierarchy-design.md.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subgoal vocabulary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SubgoalToken(StrEnum):
|
||||
"""Fixed vocabulary of strategic intents the King can emit."""
|
||||
|
||||
EXPAND_TERRITORY = "EXPAND_TERRITORY"
|
||||
RAID_ECONOMY = "RAID_ECONOMY"
|
||||
FORTIFY = "FORTIFY"
|
||||
RECRUIT = "RECRUIT"
|
||||
TRADE = "TRADE"
|
||||
ALLY = "ALLY"
|
||||
SPY = "SPY"
|
||||
HEAL = "HEAL"
|
||||
CONSOLIDATE = "CONSOLIDATE"
|
||||
TRAIN = "TRAIN"
|
||||
|
||||
|
||||
@dataclass
|
||||
class KingSubgoal:
|
||||
"""A strategic directive issued by the King agent.
|
||||
|
||||
The King emits at most one subgoal per campaign tick. Vassals interpret
|
||||
the token and prioritise actions accordingly.
|
||||
|
||||
Attributes:
|
||||
token: Intent from the SubgoalToken vocabulary.
|
||||
target: Named target (settlement, lord, or faction).
|
||||
quantity: Scalar for RECRUIT / TRADE operations.
|
||||
priority: 0.0–2.0 weighting that scales vassal reward.
|
||||
deadline_days: Campaign-map days to complete (None = open-ended).
|
||||
context: Free-text hint passed verbatim to vassals (not parsed).
|
||||
issued_at: Timestamp when the King emitted this subgoal.
|
||||
"""
|
||||
|
||||
token: SubgoalToken
|
||||
target: str | None = None
|
||||
quantity: int | None = None
|
||||
priority: float = 1.0
|
||||
deadline_days: int | None = None
|
||||
context: str | None = None
|
||||
issued_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"token": str(self.token),
|
||||
"target": self.target,
|
||||
"quantity": self.quantity,
|
||||
"priority": self.priority,
|
||||
"deadline_days": self.deadline_days,
|
||||
"context": self.context,
|
||||
"issued_at": self.issued_at.isoformat(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> KingSubgoal:
|
||||
return cls(
|
||||
token=SubgoalToken(data["token"]),
|
||||
target=data.get("target"),
|
||||
quantity=data.get("quantity"),
|
||||
priority=data.get("priority", 1.0),
|
||||
deadline_days=data.get("deadline_days"),
|
||||
context=data.get("context"),
|
||||
issued_at=datetime.fromisoformat(data["issued_at"])
|
||||
if "issued_at" in data
|
||||
else datetime.now(UTC),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Game state snapshot
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class FactionState:
|
||||
"""Snapshot of a faction's status on the campaign map."""
|
||||
|
||||
name: str
|
||||
leader: str
|
||||
fiefs: list[str] = field(default_factory=list)
|
||||
army_strength: int = 0
|
||||
treasury: int = 0
|
||||
is_at_war_with: list[str] = field(default_factory=list)
|
||||
relations: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PartyState:
|
||||
"""Timmy's party snapshot."""
|
||||
|
||||
location: str = ""
|
||||
troops: int = 0
|
||||
food_days: int = 0
|
||||
wounded_pct: float = 0.0
|
||||
denars: int = 0
|
||||
morale: float = 100.0
|
||||
prisoners: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class KingdomState:
|
||||
"""Timmy's kingdom snapshot (only populated after kingdom is established)."""
|
||||
|
||||
name: str = ""
|
||||
fiefs: list[str] = field(default_factory=list)
|
||||
daily_income: int = 0
|
||||
daily_expenses: int = 0
|
||||
vassal_lords: list[str] = field(default_factory=list)
|
||||
active_wars: list[str] = field(default_factory=list)
|
||||
active_alliances: list[str] = field(default_factory=list)
|
||||
in_game_day: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class GameState:
|
||||
"""Full campaign state snapshot delivered by GABS on each tick.
|
||||
|
||||
This is the primary input to the King agent's decision loop.
|
||||
"""
|
||||
|
||||
tick: int = 0
|
||||
in_game_day: int = 0
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
party: PartyState = field(default_factory=PartyState)
|
||||
kingdom: KingdomState = field(default_factory=KingdomState)
|
||||
factions: list[FactionState] = field(default_factory=list)
|
||||
raw: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def has_kingdom(self) -> bool:
|
||||
return bool(self.kingdom.name)
|
||||
|
||||
def fief_count(self) -> int:
|
||||
return len(self.kingdom.fiefs)
|
||||
|
||||
def active_war_count(self) -> int:
|
||||
return len(self.kingdom.active_wars)
|
||||
|
||||
def is_two_front_war(self) -> bool:
|
||||
"""Return True if Timmy is engaged in 2+ simultaneous wars."""
|
||||
return self.active_war_count() >= 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inter-agent message schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubgoalMessage:
|
||||
"""King → Vassal directive."""
|
||||
|
||||
msg_type: Literal["subgoal"] = "subgoal"
|
||||
from_agent: Literal["king"] = "king"
|
||||
to_agent: str = ""
|
||||
subgoal: KingSubgoal | None = None
|
||||
issued_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskMessage:
|
||||
"""Vassal → Companion work order."""
|
||||
|
||||
msg_type: Literal["task"] = "task"
|
||||
from_agent: str = ""
|
||||
to_agent: str = ""
|
||||
primitive: str = ""
|
||||
args: dict[str, Any] = field(default_factory=dict)
|
||||
priority: float = 1.0
|
||||
issued_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResultMessage:
|
||||
"""Companion/Vassal → Parent outcome report."""
|
||||
|
||||
msg_type: Literal["result"] = "result"
|
||||
from_agent: str = ""
|
||||
to_agent: str = ""
|
||||
success: bool = True
|
||||
outcome: dict[str, Any] = field(default_factory=dict)
|
||||
reward_delta: float = 0.0
|
||||
completed_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateUpdateMessage:
|
||||
"""GABS → All agents broadcast."""
|
||||
|
||||
msg_type: Literal["state"] = "state"
|
||||
game_state: GameState = field(default_factory=GameState)
|
||||
tick: int = 0
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reward signals
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class VassalReward:
|
||||
"""Computed reward signal for a vassal agent after one decision cycle."""
|
||||
|
||||
agent_id: str
|
||||
component_scores: dict[str, float] = field(default_factory=dict)
|
||||
subgoal_bonus: float = 0.0
|
||||
total: float = 0.0
|
||||
computed_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
@@ -147,15 +147,6 @@ class Settings(BaseSettings):
|
||||
l402_macaroon_secret: str = ""
|
||||
lightning_backend: Literal["mock", "lnd"] = "mock"
|
||||
|
||||
# ── Bannerlord / GABS ────────────────────────────────────────────────
|
||||
# TCP JSON-RPC connection to the Bannerlord.GABS mod running on the
|
||||
# Windows VM. Override with GABS_HOST / GABS_PORT env vars.
|
||||
gabs_host: str = "127.0.0.1"
|
||||
gabs_port: int = 4825
|
||||
gabs_timeout: float = 10.0 # seconds per GABS call
|
||||
bannerlord_tick_interval: float = 1.0 # real-time seconds between campaign ticks
|
||||
bannerlord_db_path: str = "data/bannerlord/campaign.db"
|
||||
|
||||
# ── Privacy / Sovereignty ────────────────────────────────────────────
|
||||
# Disable Agno telemetry for air-gapped/sovereign deployments.
|
||||
# Default is False (telemetry disabled) to align with sovereign AI vision.
|
||||
|
||||
@@ -196,7 +196,7 @@ async def get_evening_ritual_form(request: Request, db: Session = Depends(get_db
|
||||
if not journal_entry:
|
||||
raise HTTPException(status_code=404, detail="No journal entry for today")
|
||||
return templates.TemplateResponse(
|
||||
"calm/evening_ritual_form.html", {"request": request, "journal_entry": journal_entry}
|
||||
request, "calm/evening_ritual_form.html", {"journal_entry": journal_entry}
|
||||
)
|
||||
|
||||
|
||||
@@ -257,8 +257,9 @@ async def create_new_task(
|
||||
# After creating a new task, we might need to re-evaluate NOW/NEXT/LATER, but for simplicity
|
||||
# and given the spec, new tasks go to LATER. Promotion happens on completion/deferral.
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"calm/partials/later_count.html",
|
||||
{"request": request, "later_tasks_count": len(get_later_tasks(db))},
|
||||
{"later_tasks_count": len(get_later_tasks(db))},
|
||||
)
|
||||
|
||||
|
||||
@@ -287,9 +288,9 @@ async def start_task(
|
||||
promote_tasks(db)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"calm/partials/now_next_later.html",
|
||||
{
|
||||
"request": request,
|
||||
"now_task": get_now_task(db),
|
||||
"next_task": get_next_task(db),
|
||||
"later_tasks_count": len(get_later_tasks(db)),
|
||||
@@ -316,9 +317,9 @@ async def complete_task(
|
||||
promote_tasks(db)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"calm/partials/now_next_later.html",
|
||||
{
|
||||
"request": request,
|
||||
"now_task": get_now_task(db),
|
||||
"next_task": get_next_task(db),
|
||||
"later_tasks_count": len(get_later_tasks(db)),
|
||||
@@ -345,9 +346,9 @@ async def defer_task(
|
||||
promote_tasks(db)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"calm/partials/now_next_later.html",
|
||||
{
|
||||
"request": request,
|
||||
"now_task": get_now_task(db),
|
||||
"next_task": get_next_task(db),
|
||||
"later_tasks_count": len(get_later_tasks(db)),
|
||||
@@ -360,8 +361,7 @@ async def get_later_tasks_list(request: Request, db: Session = Depends(get_db)):
|
||||
"""Render the expandable list of LATER tasks."""
|
||||
later_tasks = get_later_tasks(db)
|
||||
return templates.TemplateResponse(
|
||||
"calm/partials/later_tasks_list.html",
|
||||
{"request": request, "later_tasks": later_tasks},
|
||||
request, "calm/partials/later_tasks_list.html", {"later_tasks": later_tasks}
|
||||
)
|
||||
|
||||
|
||||
@@ -404,9 +404,9 @@ async def reorder_tasks(
|
||||
|
||||
# Re-render the relevant parts of the UI
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"calm/partials/now_next_later.html",
|
||||
{
|
||||
"request": request,
|
||||
"now_task": get_now_task(db),
|
||||
"next_task": get_next_task(db),
|
||||
"later_tasks_count": len(get_later_tasks(db)),
|
||||
|
||||
@@ -40,9 +40,9 @@ async def tools_page(request: Request):
|
||||
total_calls = 0
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"tools.html",
|
||||
{
|
||||
"request": request,
|
||||
"available_tools": available_tools,
|
||||
"agent_tools": agent_tools,
|
||||
"total_calls": total_calls,
|
||||
|
||||
@@ -16,6 +16,8 @@ from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from src.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -102,7 +104,7 @@ class EventBus:
|
||||
self._persistence_db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with closing(sqlite3.connect(str(self._persistence_db_path))) as conn:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
conn.execute(f"PRAGMA busy_timeout={settings.db_busy_timeout_ms}")
|
||||
conn.executescript(_EVENTS_SCHEMA)
|
||||
conn.commit()
|
||||
|
||||
@@ -114,7 +116,7 @@ class EventBus:
|
||||
return
|
||||
with closing(sqlite3.connect(str(self._persistence_db_path))) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
conn.execute(f"PRAGMA busy_timeout={settings.db_busy_timeout_ms}")
|
||||
yield conn
|
||||
|
||||
def _persist_event(self, event: Event) -> None:
|
||||
|
||||
@@ -18,6 +18,8 @@ from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
|
||||
from src.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DB_PATH = Path("data/swarm.db")
|
||||
@@ -68,7 +70,7 @@ def _get_conn() -> Generator[sqlite3.Connection, None, None]:
|
||||
with closing(sqlite3.connect(str(DB_PATH))) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
conn.execute(f"PRAGMA busy_timeout={settings.db_busy_timeout_ms}")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS custom_models (
|
||||
name TEXT PRIMARY KEY,
|
||||
|
||||
@@ -485,18 +485,26 @@ class CascadeRouter:
|
||||
def _quota_allows_cloud(self, provider: Provider) -> bool:
|
||||
"""Check quota before routing to a cloud provider.
|
||||
|
||||
Uses the metabolic protocol: cloud calls are gated by 5-hour quota.
|
||||
Uses the metabolic protocol via select_model(): cloud calls are only
|
||||
allowed when the quota monitor recommends a cloud model (BURST tier).
|
||||
Returns True (allow cloud) if quota monitor is unavailable or returns None.
|
||||
"""
|
||||
if _quota_monitor is None:
|
||||
return True
|
||||
try:
|
||||
# Map provider type to task_value heuristic
|
||||
task_value = "high" # conservative default
|
||||
status = _quota_monitor.check()
|
||||
if status is None:
|
||||
return True # No credentials — caller decides based on config
|
||||
return _quota_monitor.should_use_cloud(task_value)
|
||||
suggested = _quota_monitor.select_model("high")
|
||||
# Cloud is allowed only when select_model recommends the cloud model
|
||||
allows = suggested == "claude-sonnet-4-6"
|
||||
if not allows:
|
||||
status = _quota_monitor.check()
|
||||
tier = status.recommended_tier.value if status else "unknown"
|
||||
logger.info(
|
||||
"Metabolic protocol: %s tier — downshifting %s to local (%s)",
|
||||
tier,
|
||||
provider.name,
|
||||
suggested,
|
||||
)
|
||||
return allows
|
||||
except Exception as exc:
|
||||
logger.warning("Quota check failed, allowing cloud: %s", exc)
|
||||
return True
|
||||
|
||||
@@ -12,11 +12,6 @@ Quick start::
|
||||
register_adapter("mock", MockWorldAdapter)
|
||||
world = get_adapter("mock")
|
||||
perception = world.observe()
|
||||
|
||||
Registered adapters:
|
||||
"mock" — in-memory stub for testing
|
||||
"tes3mp" — Morrowind multiplayer (stub, pending PR #864)
|
||||
"bannerlord" — Bannerlord via GABS mod (M3 campaign strategy)
|
||||
"""
|
||||
|
||||
from infrastructure.world.registry import AdapterRegistry
|
||||
@@ -27,27 +22,6 @@ register_adapter = _registry.register
|
||||
get_adapter = _registry.get
|
||||
list_adapters = _registry.list_adapters
|
||||
|
||||
# -- Built-in adapter registration -----------------------------------------
|
||||
# Adapters are registered lazily to avoid import errors when their
|
||||
# optional dependencies (e.g., GABS TCP connection) are unavailable.
|
||||
|
||||
def _register_builtin_adapters() -> None:
|
||||
from infrastructure.world.adapters.mock import MockWorldAdapter
|
||||
from infrastructure.world.adapters.tes3mp import TES3MPWorldAdapter
|
||||
|
||||
_registry.register("mock", MockWorldAdapter)
|
||||
_registry.register("tes3mp", TES3MPWorldAdapter)
|
||||
|
||||
try:
|
||||
from bannerlord.adapter import BannerlordWorldAdapter
|
||||
_registry.register("bannerlord", BannerlordWorldAdapter)
|
||||
except Exception:
|
||||
# bannerlord package not installed or import error — skip silently
|
||||
pass
|
||||
|
||||
|
||||
_register_builtin_adapters()
|
||||
|
||||
__all__ = [
|
||||
"register_adapter",
|
||||
"get_adapter",
|
||||
|
||||
@@ -22,6 +22,8 @@ from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from src.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DB_PATH = Path("data/spark.db")
|
||||
@@ -47,7 +49,7 @@ def _get_conn() -> Generator[sqlite3.Connection, None, None]:
|
||||
with closing(sqlite3.connect(str(DB_PATH))) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
conn.execute(f"PRAGMA busy_timeout={settings.db_busy_timeout_ms}")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS spark_predictions (
|
||||
id TEXT PRIMARY KEY,
|
||||
|
||||
@@ -19,6 +19,8 @@ from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from src.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DB_PATH = Path("data/spark.db")
|
||||
@@ -63,7 +65,7 @@ def _get_conn() -> Generator[sqlite3.Connection, None, None]:
|
||||
with closing(sqlite3.connect(str(DB_PATH))) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
conn.execute(f"PRAGMA busy_timeout={settings.db_busy_timeout_ms}")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS spark_events (
|
||||
id TEXT PRIMARY KEY,
|
||||
|
||||
@@ -13,8 +13,8 @@ from dataclasses import dataclass
|
||||
import httpx
|
||||
|
||||
from config import settings
|
||||
from timmy.research_tools import get_llm_client, google_web_search
|
||||
from timmy.research_triage import triage_research_report
|
||||
from timmy.research_tools import google_web_search, get_llm_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from config import settings
|
||||
from serpapi import GoogleSearch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
"""Tests for the CampaignOrchestrator — mocked GABS client."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from bannerlord.campaign import CampaignOrchestrator
|
||||
from bannerlord.types import (
|
||||
FactionState,
|
||||
GameState,
|
||||
KingdomState,
|
||||
PartyState,
|
||||
)
|
||||
|
||||
|
||||
def _make_state(
|
||||
*,
|
||||
in_game_day: int = 50,
|
||||
troops: int = 150,
|
||||
denars: int = 10_000,
|
||||
food_days: int = 10,
|
||||
kingdom_name: str = "House Timmerson",
|
||||
fiefs: list | None = None,
|
||||
active_wars: list | None = None,
|
||||
) -> GameState:
|
||||
return GameState(
|
||||
in_game_day=in_game_day,
|
||||
party=PartyState(
|
||||
troops=troops,
|
||||
denars=denars,
|
||||
food_days=food_days,
|
||||
location="Epicrotea",
|
||||
),
|
||||
kingdom=KingdomState(
|
||||
name=kingdom_name,
|
||||
fiefs=fiefs or ["Epicrotea"],
|
||||
active_wars=active_wars or [],
|
||||
daily_income=500,
|
||||
daily_expenses=300,
|
||||
),
|
||||
factions=[
|
||||
FactionState(
|
||||
name="Vlandia",
|
||||
leader="Derthert",
|
||||
fiefs=["Pravend"],
|
||||
army_strength=200,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_db(tmp_path):
|
||||
return tmp_path / "test_campaign.db"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def orch(tmp_db):
|
||||
return CampaignOrchestrator(
|
||||
db_path=tmp_db,
|
||||
tick_interval=0.0,
|
||||
)
|
||||
|
||||
|
||||
class TestCampaignOrchestratorLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_creates_session(self, orch):
|
||||
with patch.object(orch._gabs, "connect", return_value=False):
|
||||
sid = await orch.start()
|
||||
assert sid is not None
|
||||
assert orch.session_id == sid
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_resumes_existing_session(self, orch):
|
||||
existing_sid = orch._memory.start_session("existing_run")
|
||||
orch._session_id = existing_sid
|
||||
with patch.object(orch._gabs, "connect", return_value=False):
|
||||
sid = await orch.start()
|
||||
assert sid == existing_sid
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_disconnects_gabs(self, orch):
|
||||
disconnect_mock = AsyncMock()
|
||||
orch._gabs.disconnect = disconnect_mock
|
||||
await orch.stop()
|
||||
disconnect_mock.assert_awaited_once()
|
||||
|
||||
|
||||
class TestCampaignTick:
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_tick_logs_subgoal(self, orch, tmp_db):
|
||||
state = _make_state()
|
||||
orch._gabs.get_game_state = AsyncMock(return_value=state)
|
||||
orch._gabs._call = AsyncMock(return_value=None)
|
||||
orch._session_id = orch._memory.start_session()
|
||||
|
||||
await orch._tick(1)
|
||||
|
||||
entries = orch.memory.get_recent_subgoals(orch.session_id, limit=5)
|
||||
assert len(entries) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_stops_at_max_ticks(self, orch):
|
||||
state = _make_state()
|
||||
orch._gabs.connect = AsyncMock(return_value=False)
|
||||
orch._gabs.get_game_state = AsyncMock(return_value=state)
|
||||
orch._gabs._call = AsyncMock(return_value=None)
|
||||
orch._gabs.disconnect = AsyncMock()
|
||||
|
||||
summary = await orch.run(max_ticks=3)
|
||||
assert summary["ticks_run"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_detects_done_condition(self, orch):
|
||||
"""Campaign stops early when M3 done condition is met."""
|
||||
state = _make_state(
|
||||
in_game_day=110,
|
||||
fiefs=["A", "B", "C"],
|
||||
)
|
||||
orch._gabs.connect = AsyncMock(return_value=False)
|
||||
orch._gabs.get_game_state = AsyncMock(return_value=state)
|
||||
orch._gabs._call = AsyncMock(return_value=None)
|
||||
orch._gabs.disconnect = AsyncMock()
|
||||
|
||||
summary = await orch.run(max_ticks=100)
|
||||
# Should stop at tick 1 because done condition is met immediately
|
||||
assert summary["ticks_run"] <= 2
|
||||
assert summary["survival_goal_met"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summary_shape(self, orch):
|
||||
state = _make_state(in_game_day=110, fiefs=["A", "B", "C"])
|
||||
orch._gabs.connect = AsyncMock(return_value=False)
|
||||
orch._gabs.get_game_state = AsyncMock(return_value=state)
|
||||
orch._gabs._call = AsyncMock(return_value=None)
|
||||
orch._gabs.disconnect = AsyncMock()
|
||||
|
||||
summary = await orch.run(max_ticks=1)
|
||||
assert "ticks_run" in summary
|
||||
assert "session_id" in summary
|
||||
assert "has_kingdom" in summary
|
||||
assert "fief_count" in summary
|
||||
@@ -1,161 +0,0 @@
|
||||
"""Tests for Bannerlord companion worker agents."""
|
||||
|
||||
from bannerlord.agents.companions.caravan import CaravanCompanion
|
||||
from bannerlord.agents.companions.logistics import LogisticsCompanion
|
||||
from bannerlord.agents.companions.scout import ScoutCompanion
|
||||
from bannerlord.types import (
|
||||
FactionState,
|
||||
GameState,
|
||||
KingSubgoal,
|
||||
KingdomState,
|
||||
PartyState,
|
||||
SubgoalToken,
|
||||
)
|
||||
|
||||
|
||||
def _state(
|
||||
*,
|
||||
troops: int = 150,
|
||||
denars: int = 10_000,
|
||||
food_days: int = 10,
|
||||
wounded_pct: float = 0.0,
|
||||
prisoners: int = 0,
|
||||
location: str = "Epicrotea",
|
||||
active_wars: list | None = None,
|
||||
factions: list | None = None,
|
||||
) -> GameState:
|
||||
return GameState(
|
||||
party=PartyState(
|
||||
troops=troops,
|
||||
denars=denars,
|
||||
food_days=food_days,
|
||||
wounded_pct=wounded_pct,
|
||||
prisoners=prisoners,
|
||||
location=location,
|
||||
),
|
||||
kingdom=KingdomState(
|
||||
name="House Timmerson",
|
||||
active_wars=active_wars or [],
|
||||
),
|
||||
factions=factions or [],
|
||||
)
|
||||
|
||||
|
||||
class TestLogisticsCompanion:
|
||||
def setup_method(self):
|
||||
self.companion = LogisticsCompanion()
|
||||
|
||||
def test_recruits_on_recruit_subgoal(self):
|
||||
state = _state()
|
||||
sg = KingSubgoal(token=SubgoalToken.RECRUIT, quantity=30)
|
||||
actions = self.companion.evaluate(state, sg)
|
||||
assert any(a["primitive"] == "recruit_troop" for a in actions)
|
||||
|
||||
def test_rests_on_heal_subgoal(self):
|
||||
state = _state()
|
||||
sg = KingSubgoal(token=SubgoalToken.HEAL)
|
||||
actions = self.companion.evaluate(state, sg)
|
||||
assert any(a["primitive"] == "rest_party" for a in actions)
|
||||
|
||||
def test_rests_when_heavily_wounded(self):
|
||||
state = _state(wounded_pct=0.25)
|
||||
sg = KingSubgoal(token=SubgoalToken.CONSOLIDATE)
|
||||
actions = self.companion.evaluate(state, sg)
|
||||
assert any(a["primitive"] == "rest_party" for a in actions)
|
||||
|
||||
def test_buys_food_when_low(self):
|
||||
state = _state(food_days=2)
|
||||
sg = KingSubgoal(token=SubgoalToken.CONSOLIDATE)
|
||||
actions = self.companion.evaluate(state, sg)
|
||||
assert any(a["primitive"] == "buy_supplies" for a in actions)
|
||||
|
||||
def test_no_food_purchase_when_stocked(self):
|
||||
state = _state(food_days=10)
|
||||
sg = KingSubgoal(token=SubgoalToken.CONSOLIDATE)
|
||||
actions = self.companion.evaluate(state, sg)
|
||||
assert not any(a["primitive"] == "buy_supplies" for a in actions)
|
||||
|
||||
def test_sells_prisoners_at_cap(self):
|
||||
state = _state(prisoners=20)
|
||||
sg = KingSubgoal(token=SubgoalToken.CONSOLIDATE)
|
||||
actions = self.companion.evaluate(state, sg)
|
||||
assert any(a["primitive"] == "sell_prisoners" for a in actions)
|
||||
|
||||
def test_upgrades_troops_on_train(self):
|
||||
state = _state()
|
||||
sg = KingSubgoal(token=SubgoalToken.TRAIN)
|
||||
actions = self.companion.evaluate(state, sg)
|
||||
assert any(a["primitive"] == "upgrade_troops" for a in actions)
|
||||
|
||||
|
||||
class TestCaravanCompanion:
|
||||
def setup_method(self):
|
||||
self.companion = CaravanCompanion()
|
||||
|
||||
def test_no_actions_when_not_trade_subgoal(self):
|
||||
state = _state()
|
||||
sg = KingSubgoal(token=SubgoalToken.RAID_ECONOMY)
|
||||
actions = self.companion.evaluate(state, sg)
|
||||
assert actions == []
|
||||
|
||||
def test_assesses_prices_on_trade(self):
|
||||
state = _state()
|
||||
sg = KingSubgoal(token=SubgoalToken.TRADE)
|
||||
actions = self.companion.evaluate(state, sg)
|
||||
assert any(a["primitive"] == "assess_prices" for a in actions)
|
||||
|
||||
def test_deploys_caravan_when_flush(self):
|
||||
state = _state(denars=15_000)
|
||||
sg = KingSubgoal(token=SubgoalToken.TRADE)
|
||||
actions = self.companion.evaluate(state, sg)
|
||||
assert any(a["primitive"] == "establish_caravan" for a in actions)
|
||||
|
||||
def test_no_caravan_when_broke(self):
|
||||
state = _state(denars=5_000)
|
||||
sg = KingSubgoal(token=SubgoalToken.TRADE)
|
||||
actions = self.companion.evaluate(state, sg)
|
||||
assert not any(a["primitive"] == "establish_caravan" for a in actions)
|
||||
|
||||
def test_profitable_trade_threshold(self):
|
||||
assert CaravanCompanion.is_profitable_trade(100, 116) # 16% margin = ok
|
||||
assert not CaravanCompanion.is_profitable_trade(100, 114) # 14% = below threshold
|
||||
assert not CaravanCompanion.is_profitable_trade(0, 100) # zero buy price = no
|
||||
|
||||
|
||||
class TestScoutCompanion:
|
||||
def setup_method(self):
|
||||
self.companion = ScoutCompanion()
|
||||
|
||||
def test_tracks_lord_on_spy_subgoal(self):
|
||||
state = _state()
|
||||
sg = KingSubgoal(token=SubgoalToken.SPY, target="Derthert")
|
||||
actions = self.companion.evaluate(state, sg)
|
||||
assert any(a["primitive"] == "track_lord" for a in actions)
|
||||
|
||||
def test_assesses_garrison_on_expand(self):
|
||||
state = _state()
|
||||
sg = KingSubgoal(token=SubgoalToken.EXPAND_TERRITORY, target="Pravend")
|
||||
actions = self.companion.evaluate(state, sg)
|
||||
assert any(a["primitive"] == "assess_garrison" for a in actions)
|
||||
|
||||
def test_maps_patrols_in_war_regions(self):
|
||||
state = _state(
|
||||
active_wars=["Vlandia"],
|
||||
factions=[
|
||||
FactionState(
|
||||
name="Vlandia",
|
||||
leader="Derthert",
|
||||
fiefs=["Pravend"],
|
||||
army_strength=300,
|
||||
)
|
||||
],
|
||||
)
|
||||
sg = KingSubgoal(token=SubgoalToken.CONSOLIDATE)
|
||||
actions = self.companion.evaluate(state, sg)
|
||||
assert any(a["primitive"] == "map_patrol_routes" for a in actions)
|
||||
|
||||
def test_no_patrol_map_when_no_wars(self):
|
||||
state = _state(active_wars=[])
|
||||
sg = KingSubgoal(token=SubgoalToken.CONSOLIDATE)
|
||||
actions = self.companion.evaluate(state, sg)
|
||||
assert not any(a["primitive"] == "map_patrol_routes" for a in actions)
|
||||
@@ -1,162 +0,0 @@
|
||||
"""Tests for the GABSClient — uses mocked asyncio streams, no real TCP."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from bannerlord.gabs_client import GABSClient
|
||||
from bannerlord.types import GameState
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
return GABSClient(host="127.0.0.1", port=4825, timeout=2.0)
|
||||
|
||||
|
||||
class TestGABSClientConnection:
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_returns_false_when_refused(self, client):
|
||||
with patch(
|
||||
"asyncio.open_connection",
|
||||
side_effect=ConnectionRefusedError("refused"),
|
||||
):
|
||||
result = await client.connect()
|
||||
assert result is False
|
||||
assert not client.is_connected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_success(self, client):
|
||||
mock_reader = AsyncMock()
|
||||
mock_writer = MagicMock()
|
||||
mock_writer.drain = AsyncMock()
|
||||
mock_writer.close = MagicMock()
|
||||
mock_writer.wait_closed = AsyncMock()
|
||||
|
||||
# Simulate tools/list response on connect
|
||||
tools_response = json.dumps({"jsonrpc": "2.0", "id": 1, "result": []}) + "\n"
|
||||
mock_reader.readline = AsyncMock(return_value=tools_response.encode())
|
||||
|
||||
with patch("asyncio.open_connection", return_value=(mock_reader, mock_writer)):
|
||||
with patch("asyncio.wait_for", side_effect=_passthrough_wait_for):
|
||||
result = await client.connect()
|
||||
|
||||
assert result is True
|
||||
assert client.is_connected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_when_not_connected(self, client):
|
||||
# Should not raise
|
||||
await client.disconnect()
|
||||
assert not client.is_connected
|
||||
|
||||
|
||||
class TestGABSClientCall:
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_returns_none_when_disconnected(self, client):
|
||||
result = await client._call("game/get_state")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_id_increments(self, client):
|
||||
assert client._next_id() == 1
|
||||
assert client._next_id() == 2
|
||||
assert client._next_id() == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_game_state_returns_empty_when_disconnected(self, client):
|
||||
state = await client.get_game_state()
|
||||
assert isinstance(state, GameState)
|
||||
assert state.tick == 0
|
||||
assert not state.has_kingdom()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_party_returns_false_when_disconnected(self, client):
|
||||
result = await client.move_party("Vlandia")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_propose_peace_returns_false_when_disconnected(self, client):
|
||||
result = await client.propose_peace("Vlandia")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assess_prices_returns_empty_dict_when_disconnected(self, client):
|
||||
result = await client.assess_prices("Pravend")
|
||||
assert result == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_patrol_routes_returns_empty_list_when_disconnected(self, client):
|
||||
result = await client.map_patrol_routes("Vlandia")
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestGameStateParsing:
|
||||
def test_parse_game_state_full(self):
|
||||
client = GABSClient()
|
||||
raw = {
|
||||
"tick": 5,
|
||||
"in_game_day": 42,
|
||||
"party": {
|
||||
"location": "Pravend",
|
||||
"troops": 200,
|
||||
"food_days": 8,
|
||||
"wounded_pct": 0.1,
|
||||
"denars": 15000,
|
||||
"morale": 85.0,
|
||||
"prisoners": 3,
|
||||
},
|
||||
"kingdom": {
|
||||
"name": "House Timmerson",
|
||||
"fiefs": ["Pravend", "Epicrotea"],
|
||||
"daily_income": 500,
|
||||
"daily_expenses": 300,
|
||||
"vassal_lords": ["Lord A"],
|
||||
"active_wars": ["Sturgia"],
|
||||
"active_alliances": ["Battania"],
|
||||
},
|
||||
"factions": [
|
||||
{
|
||||
"name": "Sturgia",
|
||||
"leader": "Raganvad",
|
||||
"fiefs": ["Varcheg"],
|
||||
"army_strength": 250,
|
||||
"treasury": 5000,
|
||||
"is_at_war_with": ["House Timmerson"],
|
||||
"relations": {"House Timmerson": -50},
|
||||
}
|
||||
],
|
||||
}
|
||||
state = client._parse_game_state(raw)
|
||||
assert state.tick == 5
|
||||
assert state.in_game_day == 42
|
||||
assert state.party.location == "Pravend"
|
||||
assert state.party.troops == 200
|
||||
assert state.kingdom.name == "House Timmerson"
|
||||
assert state.fief_count() == 2
|
||||
assert len(state.factions) == 1
|
||||
assert state.factions[0].name == "Sturgia"
|
||||
assert state.has_kingdom()
|
||||
assert not state.is_two_front_war()
|
||||
|
||||
def test_parse_game_state_minimal(self):
|
||||
client = GABSClient()
|
||||
state = client._parse_game_state({})
|
||||
assert isinstance(state, GameState)
|
||||
assert not state.has_kingdom()
|
||||
|
||||
def test_tool_count_zero_before_connect(self):
|
||||
client = GABSClient()
|
||||
assert client.tool_count() == 0
|
||||
assert client.available_tools == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _passthrough_wait_for(coro, timeout=None):
|
||||
"""Stand-in for asyncio.wait_for that just awaits the coroutine."""
|
||||
import asyncio
|
||||
return await coro
|
||||
@@ -1,176 +0,0 @@
|
||||
"""Tests for the King agent decision rules."""
|
||||
|
||||
import pytest
|
||||
|
||||
from bannerlord.agents.king import KingAgent, _MIN_DENARS, _MIN_TROOPS, _TARGET_FIEFS
|
||||
from bannerlord.types import (
|
||||
FactionState,
|
||||
GameState,
|
||||
KingdomState,
|
||||
PartyState,
|
||||
SubgoalToken,
|
||||
)
|
||||
|
||||
|
||||
def _make_state(
|
||||
*,
|
||||
troops: int = 150,
|
||||
denars: int = 10_000,
|
||||
wounded_pct: float = 0.0,
|
||||
food_days: int = 10,
|
||||
kingdom_name: str = "",
|
||||
fiefs: list | None = None,
|
||||
active_wars: list | None = None,
|
||||
active_alliances: list | None = None,
|
||||
factions: list | None = None,
|
||||
in_game_day: int = 50,
|
||||
) -> GameState:
|
||||
return GameState(
|
||||
in_game_day=in_game_day,
|
||||
party=PartyState(
|
||||
troops=troops,
|
||||
denars=denars,
|
||||
wounded_pct=wounded_pct,
|
||||
food_days=food_days,
|
||||
location="Epicrotea",
|
||||
),
|
||||
kingdom=KingdomState(
|
||||
name=kingdom_name,
|
||||
fiefs=fiefs or [],
|
||||
active_wars=active_wars or [],
|
||||
active_alliances=active_alliances or [],
|
||||
),
|
||||
factions=factions or [],
|
||||
)
|
||||
|
||||
|
||||
class TestKingAgentRules:
|
||||
def setup_method(self):
|
||||
self.king = KingAgent()
|
||||
|
||||
def test_heal_when_heavily_wounded(self):
|
||||
state = _make_state(wounded_pct=0.35)
|
||||
sg = self.king.decide(state)
|
||||
assert sg.token == SubgoalToken.HEAL
|
||||
|
||||
def test_recruit_when_low_troops(self):
|
||||
state = _make_state(troops=30, wounded_pct=0.0)
|
||||
sg = self.king.decide(state)
|
||||
assert sg.token == SubgoalToken.RECRUIT
|
||||
assert sg.quantity == _MIN_TROOPS - 30
|
||||
|
||||
def test_trade_when_broke(self):
|
||||
state = _make_state(troops=150, denars=2_000)
|
||||
sg = self.king.decide(state)
|
||||
assert sg.token == SubgoalToken.TRADE
|
||||
|
||||
def test_no_two_front_war_rule(self):
|
||||
"""King must avoid 2-front wars by seeking peace."""
|
||||
state = _make_state(
|
||||
active_wars=["Vlandia", "Sturgia"],
|
||||
kingdom_name="House Timmerson",
|
||||
factions=[
|
||||
FactionState(name="Vlandia", leader="Derthert", army_strength=500),
|
||||
FactionState(name="Sturgia", leader="Raganvad", army_strength=200),
|
||||
],
|
||||
)
|
||||
sg = self.king.decide(state)
|
||||
assert sg.token == SubgoalToken.ALLY
|
||||
# Should target weakest enemy (Sturgia at 200 strength)
|
||||
assert sg.target == "Sturgia"
|
||||
|
||||
def test_expand_territory_when_no_kingdom(self):
|
||||
state = _make_state(
|
||||
troops=150,
|
||||
denars=10_000,
|
||||
kingdom_name="",
|
||||
factions=[
|
||||
FactionState(
|
||||
name="Vlandia",
|
||||
leader="Derthert",
|
||||
fiefs=["Pravend"],
|
||||
army_strength=100,
|
||||
is_at_war_with=["Battania", "Aserai"],
|
||||
)
|
||||
],
|
||||
)
|
||||
sg = self.king.decide(state)
|
||||
# Distracted faction should be the expansion target
|
||||
assert sg.token == SubgoalToken.EXPAND_TERRITORY
|
||||
|
||||
def test_train_when_troops_insufficient_for_expansion(self):
|
||||
state = _make_state(
|
||||
troops=90,
|
||||
denars=10_000,
|
||||
kingdom_name="",
|
||||
factions=[
|
||||
FactionState(
|
||||
name="Vlandia",
|
||||
leader="Derthert",
|
||||
fiefs=["Pravend"],
|
||||
army_strength=100,
|
||||
)
|
||||
],
|
||||
)
|
||||
sg = self.king.decide(state)
|
||||
assert sg.token == SubgoalToken.TRAIN
|
||||
|
||||
def test_expand_when_below_target_fiefs(self):
|
||||
state = _make_state(
|
||||
kingdom_name="House Timmerson",
|
||||
fiefs=["Epicrotea"],
|
||||
factions=[
|
||||
FactionState(
|
||||
name="Vlandia",
|
||||
leader="Derthert",
|
||||
fiefs=["Pravend", "Sargot"],
|
||||
army_strength=100,
|
||||
)
|
||||
],
|
||||
)
|
||||
sg = self.king.decide(state)
|
||||
assert sg.token == SubgoalToken.EXPAND_TERRITORY
|
||||
|
||||
def test_consolidate_when_stable(self):
|
||||
state = _make_state(
|
||||
kingdom_name="House Timmerson",
|
||||
fiefs=["Epicrotea", "Pravend", "Sargot"],
|
||||
active_alliances=["Battania"],
|
||||
)
|
||||
sg = self.king.decide(state)
|
||||
assert sg.token in {SubgoalToken.CONSOLIDATE, SubgoalToken.FORTIFY}
|
||||
|
||||
def test_tick_increments(self):
|
||||
king = KingAgent()
|
||||
state = _make_state()
|
||||
king.decide(state)
|
||||
king.decide(state)
|
||||
assert king.tick == 2
|
||||
|
||||
def test_done_condition_not_met_without_kingdom(self):
|
||||
state = _make_state(in_game_day=200)
|
||||
assert not self.king.is_done_condition_met(state)
|
||||
|
||||
def test_done_condition_met(self):
|
||||
state = _make_state(
|
||||
kingdom_name="House Timmerson",
|
||||
fiefs=["A", "B", "C"],
|
||||
in_game_day=110,
|
||||
)
|
||||
assert self.king.is_done_condition_met(state)
|
||||
|
||||
def test_done_condition_not_met_insufficient_days(self):
|
||||
state = _make_state(
|
||||
kingdom_name="House Timmerson",
|
||||
fiefs=["A", "B", "C"],
|
||||
in_game_day=50,
|
||||
)
|
||||
assert not self.king.is_done_condition_met(state)
|
||||
|
||||
def test_campaign_summary_shape(self):
|
||||
state = _make_state(kingdom_name="House Timmerson", fiefs=["A"])
|
||||
summary = self.king.campaign_summary(state)
|
||||
assert "tick" in summary
|
||||
assert "has_kingdom" in summary
|
||||
assert "fief_count" in summary
|
||||
assert "survival_goal_met" in summary
|
||||
@@ -1,140 +0,0 @@
|
||||
"""Tests for SessionMemory — SQLite-backed campaign persistence."""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from bannerlord.session_memory import SessionMemory
|
||||
from bannerlord.types import KingSubgoal, SubgoalToken
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory(tmp_path):
|
||||
return SessionMemory(tmp_path / "test_campaign.db")
|
||||
|
||||
|
||||
class TestSessionLifecycle:
|
||||
def test_start_session_returns_id(self, memory):
|
||||
sid = memory.start_session()
|
||||
assert sid.startswith("session_")
|
||||
|
||||
def test_start_session_with_explicit_id(self, memory):
|
||||
sid = memory.start_session("my_run_001")
|
||||
assert sid == "my_run_001"
|
||||
|
||||
def test_start_idempotent(self, memory):
|
||||
sid1 = memory.start_session("run")
|
||||
sid2 = memory.start_session("run")
|
||||
assert sid1 == sid2
|
||||
|
||||
def test_get_session_returns_dict(self, memory):
|
||||
sid = memory.start_session()
|
||||
row = memory.get_session(sid)
|
||||
assert row is not None
|
||||
assert row["session_id"] == sid
|
||||
|
||||
def test_get_unknown_session_returns_none(self, memory):
|
||||
assert memory.get_session("does_not_exist") is None
|
||||
|
||||
def test_list_sessions(self, memory):
|
||||
memory.start_session("s1")
|
||||
memory.start_session("s2")
|
||||
sessions = memory.list_sessions()
|
||||
ids = [s["session_id"] for s in sessions]
|
||||
assert "s1" in ids
|
||||
assert "s2" in ids
|
||||
|
||||
def test_update_session(self, memory):
|
||||
sid = memory.start_session()
|
||||
memory.update_session(sid, kingdom_name="House Timmerson", fief_count=2, in_game_day=45)
|
||||
row = memory.get_session(sid)
|
||||
assert row["kingdom_name"] == "House Timmerson"
|
||||
assert row["fief_count"] == 2
|
||||
assert row["in_game_day"] == 45
|
||||
|
||||
|
||||
class TestSubgoalLog:
|
||||
def test_log_and_retrieve_subgoal(self, memory):
|
||||
sid = memory.start_session()
|
||||
sg = KingSubgoal(token=SubgoalToken.EXPAND_TERRITORY, target="Epicrotea")
|
||||
row_id = memory.log_subgoal(sid, tick=1, in_game_day=10, subgoal=sg)
|
||||
assert row_id > 0
|
||||
entries = memory.get_recent_subgoals(sid, limit=5)
|
||||
assert len(entries) == 1
|
||||
assert entries[0]["token"] == "EXPAND_TERRITORY"
|
||||
assert entries[0]["target"] == "Epicrotea"
|
||||
assert entries[0]["outcome"] == "pending"
|
||||
|
||||
def test_complete_subgoal(self, memory):
|
||||
sid = memory.start_session()
|
||||
sg = KingSubgoal(token=SubgoalToken.TRADE)
|
||||
row_id = memory.log_subgoal(sid, tick=2, in_game_day=11, subgoal=sg)
|
||||
memory.complete_subgoal(row_id, outcome="success")
|
||||
entries = memory.get_recent_subgoals(sid, limit=5)
|
||||
assert entries[0]["outcome"] == "success"
|
||||
assert entries[0]["completed_at"] is not None
|
||||
|
||||
def test_count_token(self, memory):
|
||||
sid = memory.start_session()
|
||||
for i in range(3):
|
||||
memory.log_subgoal(
|
||||
sid, tick=i, in_game_day=i,
|
||||
subgoal=KingSubgoal(token=SubgoalToken.RECRUIT)
|
||||
)
|
||||
memory.log_subgoal(
|
||||
sid, tick=10, in_game_day=10,
|
||||
subgoal=KingSubgoal(token=SubgoalToken.TRADE)
|
||||
)
|
||||
assert memory.count_token(sid, SubgoalToken.RECRUIT) == 3
|
||||
assert memory.count_token(sid, SubgoalToken.TRADE) == 1
|
||||
assert memory.count_token(sid, SubgoalToken.ALLY) == 0
|
||||
|
||||
def test_recent_subgoals_respects_limit(self, memory):
|
||||
sid = memory.start_session()
|
||||
for i in range(10):
|
||||
memory.log_subgoal(
|
||||
sid, tick=i, in_game_day=i,
|
||||
subgoal=KingSubgoal(token=SubgoalToken.CONSOLIDATE)
|
||||
)
|
||||
entries = memory.get_recent_subgoals(sid, limit=3)
|
||||
assert len(entries) == 3
|
||||
|
||||
|
||||
class TestStrategyNotes:
|
||||
def test_add_and_get_notes(self, memory):
|
||||
sid = memory.start_session()
|
||||
memory.add_note(sid, in_game_day=5, note_type="intel", content="Vlandia weakened")
|
||||
notes = memory.get_notes(sid)
|
||||
assert len(notes) == 1
|
||||
assert notes[0]["content"] == "Vlandia weakened"
|
||||
|
||||
def test_filter_notes_by_type(self, memory):
|
||||
sid = memory.start_session()
|
||||
memory.add_note(sid, 1, "milestone", "Kingdom established")
|
||||
memory.add_note(sid, 2, "intel", "Enemy sighted")
|
||||
milestones = memory.get_notes(sid, note_type="milestone")
|
||||
assert len(milestones) == 1
|
||||
assert milestones[0]["content"] == "Kingdom established"
|
||||
|
||||
def test_record_kingdom_established(self, memory):
|
||||
sid = memory.start_session()
|
||||
memory.record_kingdom_established(sid, in_game_day=42, kingdom_name="House Timmerson")
|
||||
milestones = memory.get_milestones(sid)
|
||||
assert any("House Timmerson" in m["content"] for m in milestones)
|
||||
row = memory.get_session(sid)
|
||||
assert row["kingdom_name"] == "House Timmerson"
|
||||
|
||||
def test_record_war_declared(self, memory):
|
||||
sid = memory.start_session()
|
||||
memory.record_war_declared(sid, in_game_day=20, faction="Vlandia")
|
||||
notes = memory.get_notes(sid, note_type="war_declared")
|
||||
assert len(notes) == 1
|
||||
assert "Vlandia" in notes[0]["content"]
|
||||
|
||||
def test_record_peace_agreed(self, memory):
|
||||
sid = memory.start_session()
|
||||
memory.record_peace_agreed(sid, in_game_day=30, faction="Sturgia")
|
||||
notes = memory.get_notes(sid, note_type="peace_agreed")
|
||||
assert len(notes) == 1
|
||||
assert "Sturgia" in notes[0]["content"]
|
||||
@@ -1,103 +0,0 @@
|
||||
"""Tests for Bannerlord M3 core data types."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from bannerlord.types import (
|
||||
FactionState,
|
||||
GameState,
|
||||
KingSubgoal,
|
||||
KingdomState,
|
||||
PartyState,
|
||||
SubgoalToken,
|
||||
VassalReward,
|
||||
)
|
||||
|
||||
|
||||
class TestSubgoalToken:
|
||||
def test_all_tokens_are_strings(self):
|
||||
for token in SubgoalToken:
|
||||
assert isinstance(str(token), str)
|
||||
|
||||
def test_round_trip_from_string(self):
|
||||
assert SubgoalToken("EXPAND_TERRITORY") == SubgoalToken.EXPAND_TERRITORY
|
||||
assert SubgoalToken("ALLY") == SubgoalToken.ALLY
|
||||
|
||||
|
||||
class TestKingSubgoal:
|
||||
def test_defaults(self):
|
||||
sg = KingSubgoal(token=SubgoalToken.CONSOLIDATE)
|
||||
assert sg.priority == 1.0
|
||||
assert sg.target is None
|
||||
assert sg.quantity is None
|
||||
assert isinstance(sg.issued_at, datetime)
|
||||
|
||||
def test_to_dict_round_trip(self):
|
||||
sg = KingSubgoal(
|
||||
token=SubgoalToken.EXPAND_TERRITORY,
|
||||
target="Epicrotea",
|
||||
quantity=None,
|
||||
priority=1.5,
|
||||
deadline_days=10,
|
||||
context="capture castle",
|
||||
)
|
||||
d = sg.to_dict()
|
||||
assert d["token"] == "EXPAND_TERRITORY"
|
||||
assert d["target"] == "Epicrotea"
|
||||
assert d["priority"] == 1.5
|
||||
|
||||
restored = KingSubgoal.from_dict(d)
|
||||
assert restored.token == SubgoalToken.EXPAND_TERRITORY
|
||||
assert restored.target == "Epicrotea"
|
||||
assert restored.priority == 1.5
|
||||
assert restored.deadline_days == 10
|
||||
|
||||
def test_from_dict_without_issued_at(self):
|
||||
d = {"token": "TRADE"}
|
||||
sg = KingSubgoal.from_dict(d)
|
||||
assert sg.token == SubgoalToken.TRADE
|
||||
assert isinstance(sg.issued_at, datetime)
|
||||
|
||||
|
||||
class TestGameState:
|
||||
def test_empty_state(self):
|
||||
state = GameState()
|
||||
assert not state.has_kingdom()
|
||||
assert state.fief_count() == 0
|
||||
assert state.active_war_count() == 0
|
||||
assert not state.is_two_front_war()
|
||||
|
||||
def test_has_kingdom(self):
|
||||
state = GameState(kingdom=KingdomState(name="House Timmerson"))
|
||||
assert state.has_kingdom()
|
||||
|
||||
def test_fief_count(self):
|
||||
state = GameState(
|
||||
kingdom=KingdomState(name="House Timmerson", fiefs=["Pravend", "Epicrotea"])
|
||||
)
|
||||
assert state.fief_count() == 2
|
||||
|
||||
def test_two_front_war(self):
|
||||
state = GameState(
|
||||
kingdom=KingdomState(
|
||||
name="House Timmerson",
|
||||
active_wars=["Vlandia", "Sturgia"],
|
||||
)
|
||||
)
|
||||
assert state.is_two_front_war()
|
||||
|
||||
def test_single_war_not_two_front(self):
|
||||
state = GameState(
|
||||
kingdom=KingdomState(
|
||||
name="House Timmerson",
|
||||
active_wars=["Vlandia"],
|
||||
)
|
||||
)
|
||||
assert not state.is_two_front_war()
|
||||
|
||||
|
||||
class TestVassalReward:
|
||||
def test_defaults(self):
|
||||
reward = VassalReward(agent_id="war_vassal")
|
||||
assert reward.total == 0.0
|
||||
assert reward.subgoal_bonus == 0.0
|
||||
assert isinstance(reward.computed_at, datetime)
|
||||
@@ -1,179 +0,0 @@
|
||||
"""Tests for Bannerlord vassal agents (War, Economy, Diplomacy)."""
|
||||
|
||||
from bannerlord.agents.diplomacy_vassal import DiplomacyVassal
|
||||
from bannerlord.agents.economy_vassal import EconomyVassal
|
||||
from bannerlord.agents.war_vassal import WarVassal
|
||||
from bannerlord.types import (
|
||||
FactionState,
|
||||
GameState,
|
||||
KingSubgoal,
|
||||
KingdomState,
|
||||
PartyState,
|
||||
SubgoalToken,
|
||||
)
|
||||
|
||||
|
||||
def _state(
|
||||
*,
|
||||
troops: int = 150,
|
||||
denars: int = 10_000,
|
||||
food_days: int = 10,
|
||||
kingdom_name: str = "House Timmerson",
|
||||
fiefs: list | None = None,
|
||||
active_wars: list | None = None,
|
||||
active_alliances: list | None = None,
|
||||
factions: list | None = None,
|
||||
) -> GameState:
|
||||
return GameState(
|
||||
party=PartyState(
|
||||
troops=troops,
|
||||
denars=denars,
|
||||
food_days=food_days,
|
||||
location="Epicrotea",
|
||||
),
|
||||
kingdom=KingdomState(
|
||||
name=kingdom_name,
|
||||
fiefs=fiefs or ["Epicrotea"],
|
||||
active_wars=active_wars or [],
|
||||
active_alliances=active_alliances or [],
|
||||
),
|
||||
factions=factions or [],
|
||||
)
|
||||
|
||||
|
||||
class TestWarVassal:
|
||||
def setup_method(self):
|
||||
self.vassal = WarVassal()
|
||||
|
||||
def test_is_relevant_expand(self):
|
||||
sg = KingSubgoal(token=SubgoalToken.EXPAND_TERRITORY)
|
||||
assert self.vassal.is_relevant(sg)
|
||||
|
||||
def test_is_relevant_raid(self):
|
||||
sg = KingSubgoal(token=SubgoalToken.RAID_ECONOMY)
|
||||
assert self.vassal.is_relevant(sg)
|
||||
|
||||
def test_not_relevant_for_trade(self):
|
||||
sg = KingSubgoal(token=SubgoalToken.TRADE)
|
||||
assert not self.vassal.is_relevant(sg)
|
||||
|
||||
def test_plan_expansion_with_target(self):
|
||||
state = _state()
|
||||
sg = KingSubgoal(token=SubgoalToken.EXPAND_TERRITORY, target="Pravend")
|
||||
tasks = self.vassal.plan(state, sg)
|
||||
primitives = [t.primitive for t in tasks]
|
||||
assert "siege_settlement" in primitives
|
||||
assert "auto_resolve_battle" in primitives
|
||||
|
||||
def test_plan_expansion_scouts_garrison_first(self):
|
||||
state = _state()
|
||||
sg = KingSubgoal(token=SubgoalToken.EXPAND_TERRITORY, target="Pravend")
|
||||
tasks = self.vassal.plan(state, sg)
|
||||
# Scout garrison should be the first task (highest priority)
|
||||
assert tasks[0].primitive == "assess_garrison"
|
||||
|
||||
def test_plan_expansion_recruits_when_low(self):
|
||||
state = _state(troops=80)
|
||||
sg = KingSubgoal(token=SubgoalToken.EXPAND_TERRITORY, target="Pravend")
|
||||
tasks = self.vassal.plan(state, sg)
|
||||
primitives = [t.primitive for t in tasks]
|
||||
assert "recruit_troop" in primitives
|
||||
|
||||
def test_plan_expansion_no_target(self):
|
||||
state = _state()
|
||||
sg = KingSubgoal(token=SubgoalToken.EXPAND_TERRITORY, target=None)
|
||||
tasks = self.vassal.plan(state, sg)
|
||||
assert tasks == []
|
||||
|
||||
def test_plan_raid_with_target(self):
|
||||
state = _state()
|
||||
sg = KingSubgoal(token=SubgoalToken.RAID_ECONOMY, target="Enemy Village")
|
||||
tasks = self.vassal.plan(state, sg)
|
||||
assert any(t.primitive == "raid_village" for t in tasks)
|
||||
|
||||
def test_compute_reward_territory_gain(self):
|
||||
prev = _state(fiefs=["A"])
|
||||
curr = _state(fiefs=["A", "B"])
|
||||
sg = KingSubgoal(token=SubgoalToken.EXPAND_TERRITORY)
|
||||
reward = self.vassal.compute_reward(prev, curr, sg)
|
||||
assert reward.total > 0
|
||||
assert reward.component_scores["territory"] > 0
|
||||
|
||||
|
||||
class TestEconomyVassal:
|
||||
def setup_method(self):
|
||||
self.vassal = EconomyVassal()
|
||||
|
||||
def test_is_relevant_fortify(self):
|
||||
sg = KingSubgoal(token=SubgoalToken.FORTIFY)
|
||||
assert self.vassal.is_relevant(sg)
|
||||
|
||||
def test_is_relevant_consolidate(self):
|
||||
sg = KingSubgoal(token=SubgoalToken.CONSOLIDATE)
|
||||
assert self.vassal.is_relevant(sg)
|
||||
|
||||
def test_not_relevant_for_raid(self):
|
||||
sg = KingSubgoal(token=SubgoalToken.RAID_ECONOMY)
|
||||
assert not self.vassal.is_relevant(sg)
|
||||
|
||||
def test_plan_fortify_queues_build(self):
|
||||
state = _state()
|
||||
sg = KingSubgoal(token=SubgoalToken.FORTIFY, target="Epicrotea")
|
||||
tasks = self.vassal.plan(state, sg)
|
||||
primitives = [t.primitive for t in tasks]
|
||||
assert "build_project" in primitives
|
||||
assert "set_tax_policy" in primitives
|
||||
|
||||
def test_plan_buys_food_when_low(self):
|
||||
state = _state(food_days=2)
|
||||
sg = KingSubgoal(token=SubgoalToken.CONSOLIDATE)
|
||||
tasks = self.vassal.plan(state, sg)
|
||||
primitives = [t.primitive for t in tasks]
|
||||
assert "buy_supplies" in primitives
|
||||
|
||||
def test_plan_consolidate_sets_tax(self):
|
||||
state = _state()
|
||||
sg = KingSubgoal(token=SubgoalToken.CONSOLIDATE)
|
||||
tasks = self.vassal.plan(state, sg)
|
||||
assert any(t.primitive == "set_tax_policy" for t in tasks)
|
||||
|
||||
|
||||
class TestDiplomacyVassal:
|
||||
def setup_method(self):
|
||||
self.vassal = DiplomacyVassal()
|
||||
|
||||
def test_is_relevant_ally(self):
|
||||
sg = KingSubgoal(token=SubgoalToken.ALLY)
|
||||
assert self.vassal.is_relevant(sg)
|
||||
|
||||
def test_not_relevant_for_train(self):
|
||||
sg = KingSubgoal(token=SubgoalToken.TRAIN)
|
||||
assert not self.vassal.is_relevant(sg)
|
||||
|
||||
def test_plan_proposes_peace_with_enemy(self):
|
||||
state = _state(active_wars=["Vlandia"])
|
||||
sg = KingSubgoal(token=SubgoalToken.ALLY, target="Vlandia")
|
||||
tasks = self.vassal.plan(state, sg)
|
||||
assert any(t.primitive == "propose_peace" for t in tasks)
|
||||
|
||||
def test_plan_requests_alliance_with_non_enemy(self):
|
||||
state = _state()
|
||||
sg = KingSubgoal(token=SubgoalToken.ALLY, target="Battania")
|
||||
tasks = self.vassal.plan(state, sg)
|
||||
primitives = [t.primitive for t in tasks]
|
||||
assert "send_envoy" in primitives
|
||||
assert "request_alliance" in primitives
|
||||
|
||||
def test_plan_no_target_returns_empty(self):
|
||||
state = _state()
|
||||
sg = KingSubgoal(token=SubgoalToken.ALLY, target=None)
|
||||
tasks = self.vassal.plan(state, sg)
|
||||
assert tasks == []
|
||||
|
||||
def test_should_avoid_war_when_two_fronts(self):
|
||||
state = _state(active_wars=["A", "B"])
|
||||
assert self.vassal.should_avoid_war(state)
|
||||
|
||||
def test_should_not_avoid_war_when_one_front(self):
|
||||
state = _state(active_wars=["A"])
|
||||
assert not self.vassal.should_avoid_war(state)
|
||||
@@ -6,8 +6,8 @@ import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.db_pool import ConnectionPool
|
||||
from src.config import settings
|
||||
from src.infrastructure.db_pool import ConnectionPool
|
||||
|
||||
|
||||
class TestConnectionPoolInit:
|
||||
@@ -330,9 +330,9 @@ class TestPragmaApplication:
|
||||
"""busy_timeout pragma set on a pooled connection persists."""
|
||||
pool = ConnectionPool(tmp_path / "test.db")
|
||||
conn = pool.get_connection()
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
conn.execute(f"PRAGMA busy_timeout={settings.db_busy_timeout_ms}")
|
||||
timeout = conn.execute("PRAGMA busy_timeout").fetchone()[0]
|
||||
assert timeout == 5000
|
||||
assert timeout == settings.db_busy_timeout_ms
|
||||
pool.close_connection()
|
||||
|
||||
def test_pragmas_apply_per_connection(self, tmp_path):
|
||||
|
||||
@@ -664,10 +664,10 @@ class TestVllmMlxProvider:
|
||||
)
|
||||
router.providers = [provider]
|
||||
|
||||
# Quota monitor returns False (block cloud) — vllm_mlx should still be tried
|
||||
# Quota monitor downshifts to local (ACTIVE tier) — vllm_mlx should still be tried
|
||||
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
||||
mock_qm.check.return_value = object()
|
||||
mock_qm.should_use_cloud.return_value = False
|
||||
mock_qm.select_model.return_value = "qwen3:14b"
|
||||
mock_qm.check.return_value = None
|
||||
|
||||
with patch.object(router, "_call_vllm_mlx") as mock_call:
|
||||
mock_call.return_value = {
|
||||
@@ -681,6 +681,115 @@ class TestVllmMlxProvider:
|
||||
assert result["content"] == "Local MLX response"
|
||||
|
||||
|
||||
class TestMetabolicProtocol:
|
||||
"""Test metabolic protocol: cloud providers skip when quota is ACTIVE/RESTING."""
|
||||
|
||||
def _make_anthropic_provider(self) -> "Provider":
|
||||
return Provider(
|
||||
name="anthropic-primary",
|
||||
type="anthropic",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
api_key="test-key",
|
||||
models=[{"name": "claude-sonnet-4-6", "default": True}],
|
||||
)
|
||||
|
||||
async def test_cloud_provider_allowed_in_burst_tier(self):
|
||||
"""BURST tier (quota healthy): cloud provider is tried."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
router.providers = [self._make_anthropic_provider()]
|
||||
|
||||
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
||||
# select_model returns cloud model → BURST tier
|
||||
mock_qm.select_model.return_value = "claude-sonnet-4-6"
|
||||
mock_qm.check.return_value = None
|
||||
|
||||
with patch.object(router, "_call_anthropic") as mock_call:
|
||||
mock_call.return_value = {"content": "Cloud response", "model": "claude-sonnet-4-6"}
|
||||
result = await router.complete(
|
||||
messages=[{"role": "user", "content": "hard question"}],
|
||||
)
|
||||
|
||||
mock_call.assert_called_once()
|
||||
assert result["content"] == "Cloud response"
|
||||
|
||||
async def test_cloud_provider_skipped_in_active_tier(self):
|
||||
"""ACTIVE tier (5-hour >= 50%): cloud provider is skipped."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
router.providers = [self._make_anthropic_provider()]
|
||||
|
||||
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
||||
# select_model returns local 14B → ACTIVE tier
|
||||
mock_qm.select_model.return_value = "qwen3:14b"
|
||||
mock_qm.check.return_value = None
|
||||
|
||||
with patch.object(router, "_call_anthropic") as mock_call:
|
||||
with pytest.raises(RuntimeError, match="All providers failed"):
|
||||
await router.complete(
|
||||
messages=[{"role": "user", "content": "question"}],
|
||||
)
|
||||
|
||||
mock_call.assert_not_called()
|
||||
|
||||
async def test_cloud_provider_skipped_in_resting_tier(self):
|
||||
"""RESTING tier (7-day >= 80%): cloud provider is skipped."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
router.providers = [self._make_anthropic_provider()]
|
||||
|
||||
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
||||
# select_model returns local 8B → RESTING tier
|
||||
mock_qm.select_model.return_value = "qwen3:8b"
|
||||
mock_qm.check.return_value = None
|
||||
|
||||
with patch.object(router, "_call_anthropic") as mock_call:
|
||||
with pytest.raises(RuntimeError, match="All providers failed"):
|
||||
await router.complete(
|
||||
messages=[{"role": "user", "content": "simple question"}],
|
||||
)
|
||||
|
||||
mock_call.assert_not_called()
|
||||
|
||||
async def test_local_provider_always_tried_regardless_of_quota(self):
|
||||
"""Local (ollama/vllm_mlx) providers bypass the metabolic protocol."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
provider = Provider(
|
||||
name="ollama-local",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
url="http://localhost:11434",
|
||||
models=[{"name": "qwen3:14b", "default": True}],
|
||||
)
|
||||
router.providers = [provider]
|
||||
|
||||
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
||||
mock_qm.select_model.return_value = "qwen3:8b" # RESTING tier
|
||||
|
||||
with patch.object(router, "_call_ollama") as mock_call:
|
||||
mock_call.return_value = {"content": "Local response", "model": "qwen3:14b"}
|
||||
result = await router.complete(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
|
||||
mock_call.assert_called_once()
|
||||
assert result["content"] == "Local response"
|
||||
|
||||
async def test_no_quota_monitor_allows_cloud(self):
|
||||
"""When quota monitor is None (unavailable), cloud providers are allowed."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
router.providers = [self._make_anthropic_provider()]
|
||||
|
||||
with patch("infrastructure.router.cascade._quota_monitor", None):
|
||||
with patch.object(router, "_call_anthropic") as mock_call:
|
||||
mock_call.return_value = {"content": "Cloud response", "model": "claude-sonnet-4-6"}
|
||||
result = await router.complete(
|
||||
messages=[{"role": "user", "content": "question"}],
|
||||
)
|
||||
|
||||
mock_call.assert_called_once()
|
||||
assert result["content"] == "Cloud response"
|
||||
|
||||
|
||||
class TestCascadeRouterReload:
|
||||
"""Test hot-reload of providers.yaml."""
|
||||
|
||||
|
||||
285
tests/scripts/test_export_trajectories.py
Normal file
285
tests/scripts/test_export_trajectories.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""Unit tests for scripts/export_trajectories.py.
|
||||
|
||||
Tests trajectory conversion logic — no I/O, no Ollama, no mlx.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import scripts.export_trajectories as et
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def simple_session(tmp_path: Path) -> Path:
|
||||
"""Write a minimal session JSONL file and return the logs dir."""
|
||||
logs_dir = tmp_path / "logs"
|
||||
logs_dir.mkdir()
|
||||
entries = [
|
||||
{"type": "message", "role": "user", "content": "What time is it?", "timestamp": "2026-03-01T10:00:00"},
|
||||
{"type": "message", "role": "timmy", "content": "It is 10:00 AM.", "timestamp": "2026-03-01T10:00:01"},
|
||||
{"type": "message", "role": "user", "content": "Thanks!", "timestamp": "2026-03-01T10:00:05"},
|
||||
{"type": "message", "role": "timmy", "content": "You're welcome!", "timestamp": "2026-03-01T10:00:06"},
|
||||
]
|
||||
session_file = logs_dir / "session_2026-03-01.jsonl"
|
||||
session_file.write_text("\n".join(json.dumps(e) for e in entries) + "\n")
|
||||
return logs_dir
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tool_call_session(tmp_path: Path) -> Path:
|
||||
"""Write a session JSONL with tool calls."""
|
||||
logs_dir = tmp_path / "logs"
|
||||
logs_dir.mkdir()
|
||||
entries = [
|
||||
{"type": "message", "role": "user", "content": "Read CLAUDE.md", "timestamp": "2026-03-01T10:00:00"},
|
||||
{
|
||||
"type": "tool_call",
|
||||
"tool": "read_file",
|
||||
"args": {"path": "CLAUDE.md"},
|
||||
"result": "# CLAUDE.md content here",
|
||||
"timestamp": "2026-03-01T10:00:01",
|
||||
},
|
||||
{"type": "message", "role": "timmy", "content": "Here is the content.", "timestamp": "2026-03-01T10:00:02"},
|
||||
]
|
||||
session_file = logs_dir / "session_2026-03-01.jsonl"
|
||||
session_file.write_text("\n".join(json.dumps(e) for e in entries) + "\n")
|
||||
return logs_dir
|
||||
|
||||
|
||||
# ── _load_entries ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_entries_returns_all(simple_session: Path) -> None:
|
||||
entries = et._load_entries(simple_session)
|
||||
assert len(entries) == 4
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_entries_skips_malformed(tmp_path: Path) -> None:
|
||||
logs_dir = tmp_path / "logs"
|
||||
logs_dir.mkdir()
|
||||
session = logs_dir / "session_2026-03-01.jsonl"
|
||||
session.write_text(
|
||||
'{"type": "message", "role": "user", "content": "hi"}\n'
|
||||
"NOT_JSON\n"
|
||||
'{"type": "message", "role": "timmy", "content": "hello"}\n'
|
||||
)
|
||||
entries = et._load_entries(logs_dir)
|
||||
assert len(entries) == 2 # malformed line skipped
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_entries_empty_dir(tmp_path: Path) -> None:
|
||||
logs_dir = tmp_path / "logs"
|
||||
logs_dir.mkdir()
|
||||
entries = et._load_entries(logs_dir)
|
||||
assert entries == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_entries_multiple_files(tmp_path: Path) -> None:
|
||||
logs_dir = tmp_path / "logs"
|
||||
logs_dir.mkdir()
|
||||
for day in ("2026-03-01", "2026-03-02"):
|
||||
entry = {"type": "message", "role": "user", "content": f"day {day}"}
|
||||
(logs_dir / f"session_{day}.jsonl").write_text(json.dumps(entry) + "\n")
|
||||
entries = et._load_entries(logs_dir)
|
||||
assert len(entries) == 2
|
||||
|
||||
|
||||
# ── _format_tool_call ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_format_tool_call_structure() -> None:
|
||||
entry = {
|
||||
"type": "tool_call",
|
||||
"tool": "read_file",
|
||||
"args": {"path": "/tmp/foo.txt"},
|
||||
"result": "file contents",
|
||||
}
|
||||
result = et._format_tool_call(entry)
|
||||
assert result.startswith("<tool_call>")
|
||||
assert result.endswith("</tool_call>")
|
||||
payload = json.loads(result.split("\n")[1])
|
||||
assert payload["name"] == "read_file"
|
||||
assert payload["arguments"]["path"] == "/tmp/foo.txt"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_format_tool_call_missing_tool() -> None:
|
||||
entry = {"type": "tool_call", "args": {}}
|
||||
result = et._format_tool_call(entry)
|
||||
assert "unknown" in result
|
||||
|
||||
|
||||
# ── _group_into_turns ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_group_basic_conversation() -> None:
|
||||
entries = [
|
||||
{"type": "message", "role": "user", "content": "hello"},
|
||||
{"type": "message", "role": "timmy", "content": "hi there"},
|
||||
{"type": "message", "role": "user", "content": "bye"},
|
||||
{"type": "message", "role": "timmy", "content": "goodbye"},
|
||||
]
|
||||
turns = et._group_into_turns(entries)
|
||||
assert len(turns) == 2
|
||||
assert turns[0]["user"] == "hello"
|
||||
assert turns[0]["assistant"] == "hi there"
|
||||
assert turns[1]["user"] == "bye"
|
||||
assert turns[1]["assistant"] == "goodbye"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_group_with_tool_call() -> None:
|
||||
entries = [
|
||||
{"type": "message", "role": "user", "content": "check the file"},
|
||||
{"type": "tool_call", "tool": "read_file", "args": {"path": "x"}, "result": "content"},
|
||||
{"type": "message", "role": "timmy", "content": "Done."},
|
||||
]
|
||||
turns = et._group_into_turns(entries)
|
||||
assert len(turns) == 1
|
||||
assert "<tool_call>" in turns[0]["assistant"]
|
||||
assert "Done." in turns[0]["assistant"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_group_skips_user_without_response() -> None:
|
||||
"""User message with no timmy response should not create a turn."""
|
||||
entries = [
|
||||
{"type": "message", "role": "user", "content": "hello"},
|
||||
# No timmy response
|
||||
{"type": "message", "role": "user", "content": "are you there?"},
|
||||
{"type": "message", "role": "timmy", "content": "Yes!"},
|
||||
]
|
||||
turns = et._group_into_turns(entries)
|
||||
assert len(turns) == 1
|
||||
assert turns[0]["user"] == "are you there?"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_group_ignores_errors_and_decisions() -> None:
|
||||
entries = [
|
||||
{"type": "message", "role": "user", "content": "hello"},
|
||||
{"type": "error", "error": "something failed"},
|
||||
{"type": "decision", "decision": "retry"},
|
||||
{"type": "message", "role": "timmy", "content": "Got it."},
|
||||
]
|
||||
turns = et._group_into_turns(entries)
|
||||
assert len(turns) == 1
|
||||
assert "error" not in turns[0]["assistant"]
|
||||
assert "retry" not in turns[0]["assistant"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_group_empty_entries() -> None:
|
||||
assert et._group_into_turns([]) == []
|
||||
|
||||
|
||||
# ── turns_to_training_examples ────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_training_examples_structure() -> None:
|
||||
turns = [{"user": "hello", "assistant": "hi there, how can I help?"}]
|
||||
examples = et.turns_to_training_examples(turns)
|
||||
assert len(examples) == 1
|
||||
msgs = examples[0]["messages"]
|
||||
assert msgs[0]["role"] == "system"
|
||||
assert msgs[1]["role"] == "user"
|
||||
assert msgs[1]["content"] == "hello"
|
||||
assert msgs[2]["role"] == "assistant"
|
||||
assert msgs[2]["content"] == "hi there, how can I help?"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_training_examples_filters_short_responses() -> None:
|
||||
turns = [
|
||||
{"user": "hello", "assistant": "ok"}, # too short
|
||||
{"user": "hello", "assistant": "This is a longer response that passes."},
|
||||
]
|
||||
examples = et.turns_to_training_examples(turns, min_assistant_len=10)
|
||||
assert len(examples) == 1
|
||||
assert examples[0]["messages"][2]["content"] == "This is a longer response that passes."
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_training_examples_filters_empty_user() -> None:
|
||||
turns = [{"user": "", "assistant": "some response here"}]
|
||||
examples = et.turns_to_training_examples(turns)
|
||||
assert len(examples) == 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_training_examples_uses_custom_system_prompt() -> None:
|
||||
turns = [{"user": "hi", "assistant": "hello there!"}]
|
||||
examples = et.turns_to_training_examples(turns, system_prompt="Custom prompt.")
|
||||
assert examples[0]["messages"][0]["content"] == "Custom prompt."
|
||||
|
||||
|
||||
# ── export_training_data (integration-style, uses tmp_path) ──────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_export_training_data_writes_jsonl(simple_session: Path, tmp_path: Path) -> None:
|
||||
output = tmp_path / "train.jsonl"
|
||||
count = et.export_training_data(logs_dir=simple_session, output_path=output)
|
||||
assert count == 2
|
||||
assert output.exists()
|
||||
lines = [
|
||||
json.loads(line) for line in output.read_text().splitlines() if line.strip()
|
||||
]
|
||||
assert len(lines) == 2
|
||||
for line in lines:
|
||||
assert "messages" in line
|
||||
roles = [m["role"] for m in line["messages"]]
|
||||
assert roles == ["system", "user", "assistant"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_export_training_data_with_tool_calls(tool_call_session: Path, tmp_path: Path) -> None:
|
||||
output = tmp_path / "train.jsonl"
|
||||
count = et.export_training_data(logs_dir=tool_call_session, output_path=output)
|
||||
assert count == 1
|
||||
line = json.loads(output.read_text().strip())
|
||||
assistant_content = line["messages"][2]["content"]
|
||||
assert "<tool_call>" in assistant_content
|
||||
assert "read_file" in assistant_content
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_export_training_data_returns_zero_for_empty_logs(tmp_path: Path) -> None:
|
||||
logs_dir = tmp_path / "logs"
|
||||
logs_dir.mkdir()
|
||||
output = tmp_path / "train.jsonl"
|
||||
count = et.export_training_data(logs_dir=logs_dir, output_path=output)
|
||||
assert count == 0
|
||||
assert not output.exists()
|
||||
|
||||
|
||||
# ── CLI ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_cli_missing_logs_dir(tmp_path: Path) -> None:
|
||||
rc = et.main(["--logs-dir", str(tmp_path / "nonexistent"), "--output", str(tmp_path / "out.jsonl")])
|
||||
assert rc == 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_cli_exports_and_returns_zero(simple_session: Path, tmp_path: Path) -> None:
|
||||
output = tmp_path / "out.jsonl"
|
||||
rc = et.main([
|
||||
"--logs-dir", str(simple_session),
|
||||
"--output", str(output),
|
||||
])
|
||||
assert rc == 0
|
||||
assert output.exists()
|
||||
546
tests/unit/test_retrain_loop.py
Normal file
546
tests/unit/test_retrain_loop.py
Normal file
@@ -0,0 +1,546 @@
|
||||
"""Unit tests for the AutoLoRA continuous improvement loop.
|
||||
|
||||
Covers trajectory extraction, quality filtering, dataset management,
|
||||
and the retrain orchestrator.
|
||||
|
||||
Refs: #1105
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
from timmy_automations.retrain.quality_filter import QualityFilter, TrajectoryQuality
|
||||
from timmy_automations.retrain.retrain import RetrainOrchestrator
|
||||
from timmy_automations.retrain.training_dataset import TrainingDataset
|
||||
from timmy_automations.retrain.training_log import CycleMetrics, TrainingLog
|
||||
from timmy_automations.retrain.trajectory_exporter import Trajectory, TrajectoryExporter
|
||||
|
||||
# ── Fixtures ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _ts(offset_minutes: int = 0) -> str:
|
||||
"""Return an ISO timestamp offset from now."""
|
||||
return (datetime.now(tz=UTC) + timedelta(minutes=offset_minutes)).isoformat()
|
||||
|
||||
|
||||
def _make_session_log(entries: list[dict], date_str: str, tmp_path: Path) -> Path:
|
||||
"""Write session JSONL entries to a temp log file."""
|
||||
log_dir = tmp_path / "logs"
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
log_file = log_dir / f"session_{date_str}.jsonl"
|
||||
with open(log_file, "w") as f:
|
||||
for entry in entries:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
return log_file
|
||||
|
||||
|
||||
def _user_msg(content: str, offset: int = 0) -> dict:
|
||||
return {"type": "message", "role": "user", "content": content, "timestamp": _ts(offset)}
|
||||
|
||||
|
||||
def _timmy_msg(content: str, confidence: float | None = None, offset: int = 0) -> dict:
|
||||
entry = {"type": "message", "role": "timmy", "content": content, "timestamp": _ts(offset)}
|
||||
if confidence is not None:
|
||||
entry["confidence"] = confidence
|
||||
return entry
|
||||
|
||||
|
||||
def _tool_call(tool: str = "bash", result: str = "ok", offset: int = 0) -> dict:
|
||||
return {
|
||||
"type": "tool_call",
|
||||
"tool": tool,
|
||||
"args": {},
|
||||
"result": result,
|
||||
"timestamp": _ts(offset),
|
||||
}
|
||||
|
||||
|
||||
def _error_entry(msg: str = "Something failed", offset: int = 0) -> dict:
|
||||
return {"type": "error", "error": msg, "timestamp": _ts(offset)}
|
||||
|
||||
|
||||
def _decision_entry(decision: str = "Use approach A", offset: int = 0) -> dict:
|
||||
return {"type": "decision", "decision": decision, "timestamp": _ts(offset)}
|
||||
|
||||
|
||||
# ── Trajectory dataclass tests ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestTrajectory:
|
||||
def test_message_count(self):
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
messages=[_user_msg("hi"), _timmy_msg("hello")],
|
||||
)
|
||||
assert t.message_count == 2
|
||||
|
||||
def test_tool_call_count(self):
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
tool_calls=[_tool_call(), _tool_call()],
|
||||
)
|
||||
assert t.tool_call_count == 2
|
||||
|
||||
def test_has_successful_tool_call_when_no_errors(self):
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
tool_calls=[_tool_call()],
|
||||
errors=[],
|
||||
)
|
||||
assert t.has_successful_tool_call is True
|
||||
|
||||
def test_has_successful_tool_call_false_when_errors(self):
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
tool_calls=[_tool_call()],
|
||||
errors=[_error_entry()],
|
||||
)
|
||||
assert t.has_successful_tool_call is False
|
||||
|
||||
def test_is_multi_step(self):
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
messages=[_user_msg("do it"), _timmy_msg("done")],
|
||||
tool_calls=[_tool_call()],
|
||||
)
|
||||
assert t.is_multi_step is True
|
||||
|
||||
def test_is_not_multi_step_single_message(self):
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
messages=[_timmy_msg("hello")],
|
||||
tool_calls=[],
|
||||
)
|
||||
assert t.is_multi_step is False
|
||||
|
||||
def test_to_chat_format_ordering(self):
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
messages=[_user_msg("question", offset=0), _timmy_msg("answer", offset=2)],
|
||||
tool_calls=[_tool_call(offset=1)],
|
||||
)
|
||||
chat = t.to_chat_format()
|
||||
roles = [m["role"] for m in chat]
|
||||
assert "user" in roles
|
||||
assert "assistant" in roles
|
||||
|
||||
def test_to_chat_format_empty_content_skipped(self):
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
messages=[_user_msg(""), _timmy_msg("response")],
|
||||
)
|
||||
chat = t.to_chat_format()
|
||||
# Empty user message should be skipped
|
||||
assert all(m["content"] for m in chat)
|
||||
|
||||
|
||||
# ── TrajectoryExporter tests ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestTrajectoryExporter:
|
||||
def test_export_empty_logs_dir(self, tmp_path):
|
||||
(tmp_path / "logs").mkdir()
|
||||
exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path)
|
||||
result = exporter.export_week(weeks_ago=0)
|
||||
assert result == []
|
||||
|
||||
def test_export_reads_session_files(self, tmp_path):
|
||||
# Write a session file for this week
|
||||
today = datetime.now(tz=UTC)
|
||||
date_str = today.strftime("%Y-%m-%d")
|
||||
entries = [
|
||||
_user_msg("tell me about Python"),
|
||||
_timmy_msg("Python is great"),
|
||||
]
|
||||
_make_session_log(entries, date_str, tmp_path)
|
||||
|
||||
exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path)
|
||||
result = exporter.export_week(weeks_ago=0)
|
||||
assert len(result) >= 1
|
||||
|
||||
def test_export_skips_old_sessions(self, tmp_path):
|
||||
# Write a session file for 3 weeks ago
|
||||
three_weeks_ago = datetime.now(tz=UTC) - timedelta(weeks=3)
|
||||
date_str = three_weeks_ago.strftime("%Y-%m-%d")
|
||||
entries = [_user_msg("old message"), _timmy_msg("old response")]
|
||||
_make_session_log(entries, date_str, tmp_path)
|
||||
|
||||
exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path)
|
||||
# Request current week — should not include 3-week-old data
|
||||
result = exporter.export_week(weeks_ago=0)
|
||||
assert result == []
|
||||
|
||||
def test_export_segments_by_gap(self, tmp_path):
|
||||
today = datetime.now(tz=UTC)
|
||||
date_str = today.strftime("%Y-%m-%d")
|
||||
|
||||
# Two conversations separated by 10 minutes
|
||||
t1 = (today - timedelta(minutes=15)).isoformat()
|
||||
t2 = (today - timedelta(minutes=14)).isoformat()
|
||||
t3 = (today - timedelta(minutes=2)).isoformat()
|
||||
t4 = (today - timedelta(minutes=1)).isoformat()
|
||||
|
||||
entries = [
|
||||
{"type": "message", "role": "user", "content": "first q", "timestamp": t1},
|
||||
{"type": "message", "role": "timmy", "content": "first a", "timestamp": t2},
|
||||
{"type": "message", "role": "user", "content": "second q", "timestamp": t3},
|
||||
{"type": "message", "role": "timmy", "content": "second a", "timestamp": t4},
|
||||
]
|
||||
_make_session_log(entries, date_str, tmp_path)
|
||||
|
||||
exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path)
|
||||
result = exporter.export_week(weeks_ago=0)
|
||||
# Should have at least 1 trajectory (may be 1 or 2 depending on segmentation)
|
||||
assert len(result) >= 1
|
||||
|
||||
def test_handles_malformed_log_file(self, tmp_path):
|
||||
log_dir = tmp_path / "logs"
|
||||
log_dir.mkdir()
|
||||
today = datetime.now(tz=UTC).strftime("%Y-%m-%d")
|
||||
(log_dir / f"session_{today}.jsonl").write_text("not json\n{}\n")
|
||||
|
||||
exporter = TrajectoryExporter(logs_dir=log_dir, repo_root=tmp_path)
|
||||
# Should not raise, just return empty or partial results
|
||||
result = exporter.export_week(weeks_ago=0)
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
# ── QualityFilter tests ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestQualityFilter:
|
||||
def _make_high_quality(self) -> Trajectory:
|
||||
return Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
messages=[_user_msg("do task"), _timmy_msg("done", confidence=0.9)],
|
||||
tool_calls=[_tool_call(), _tool_call()],
|
||||
errors=[],
|
||||
decisions=[_decision_entry()],
|
||||
)
|
||||
|
||||
def _make_medium_quality(self) -> Trajectory:
|
||||
return Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
messages=[_user_msg("hello"), _timmy_msg("hi")],
|
||||
tool_calls=[],
|
||||
errors=[],
|
||||
)
|
||||
|
||||
def _make_low_quality(self) -> Trajectory:
|
||||
return Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
messages=[_timmy_msg("oops")], # No user message
|
||||
errors=[_error_entry()],
|
||||
)
|
||||
|
||||
def test_high_quality_classification(self):
|
||||
qf = QualityFilter()
|
||||
result = qf.assess(self._make_high_quality())
|
||||
assert result.quality == TrajectoryQuality.HIGH
|
||||
assert result.score >= 4.0
|
||||
assert result.is_trainable
|
||||
|
||||
def test_medium_quality_classification(self):
|
||||
qf = QualityFilter()
|
||||
result = qf.assess(self._make_medium_quality())
|
||||
assert result.quality == TrajectoryQuality.MEDIUM
|
||||
assert result.is_trainable
|
||||
|
||||
def test_low_quality_no_user_message(self):
|
||||
qf = QualityFilter()
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
messages=[_timmy_msg("random")],
|
||||
)
|
||||
result = qf.assess(t)
|
||||
assert result.quality == TrajectoryQuality.LOW
|
||||
assert not result.is_trainable
|
||||
|
||||
def test_error_penalizes_score(self):
|
||||
qf = QualityFilter()
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
messages=[_user_msg("go"), _timmy_msg("fail")],
|
||||
tool_calls=[_tool_call()],
|
||||
errors=[_error_entry(), _error_entry()],
|
||||
)
|
||||
result = qf.assess(t)
|
||||
assert result.score < qf.assess(self._make_high_quality()).score
|
||||
|
||||
def test_low_confidence_penalizes_score(self):
|
||||
qf = QualityFilter()
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
messages=[_user_msg("q"), _timmy_msg("a", confidence=0.2)],
|
||||
)
|
||||
result = qf.assess(t)
|
||||
assert result.score < 1.0
|
||||
|
||||
def test_filter_returns_stats(self):
|
||||
qf = QualityFilter()
|
||||
trajectories = [
|
||||
self._make_high_quality(),
|
||||
self._make_medium_quality(),
|
||||
self._make_low_quality(),
|
||||
]
|
||||
trainable, stats = qf.filter(trajectories)
|
||||
assert stats["total"] == 3
|
||||
assert stats["accepted"] == len(trainable)
|
||||
assert stats["high"] + stats["medium"] + stats["low"] == 3
|
||||
|
||||
def test_filter_empty_list(self):
|
||||
qf = QualityFilter()
|
||||
trainable, stats = qf.filter([])
|
||||
assert trainable == []
|
||||
assert stats["total"] == 0
|
||||
assert stats["accepted"] == 0
|
||||
|
||||
|
||||
# ── TrainingDataset tests ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestTrainingDataset:
|
||||
def _make_result(self, quality=TrajectoryQuality.HIGH, score=5.0) -> object:
|
||||
from timmy_automations.retrain.quality_filter import QualityResult
|
||||
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(-5),
|
||||
ended_at=_ts(),
|
||||
messages=[_user_msg("do it"), _timmy_msg("done")],
|
||||
tool_calls=[_tool_call()],
|
||||
)
|
||||
return QualityResult(trajectory=t, quality=quality, score=score, reasons=[])
|
||||
|
||||
def test_count_empty_dataset(self, tmp_path):
|
||||
ds = TrainingDataset(
|
||||
dataset_path=".loop/retrain/training_data.jsonl",
|
||||
repo_root=tmp_path,
|
||||
)
|
||||
assert ds.count() == 0
|
||||
|
||||
def test_append_adds_examples(self, tmp_path):
|
||||
ds = TrainingDataset(repo_root=tmp_path)
|
||||
result = ds.append([self._make_result()], "2026-W12")
|
||||
assert result.new_examples == 1
|
||||
assert result.total_examples == 1
|
||||
assert ds.count() == 1
|
||||
|
||||
def test_append_idempotent(self, tmp_path):
|
||||
ds = TrainingDataset(repo_root=tmp_path)
|
||||
r = self._make_result()
|
||||
ds.append([r], "2026-W12")
|
||||
result2 = ds.append([r], "2026-W12")
|
||||
# Same trajectory shouldn't be added twice
|
||||
assert result2.new_examples == 0
|
||||
assert ds.count() == 1
|
||||
|
||||
def test_append_different_weeks(self, tmp_path):
|
||||
ds = TrainingDataset(repo_root=tmp_path)
|
||||
r1 = self._make_result()
|
||||
ds.append([r1], "2026-W11")
|
||||
ds.append([r1], "2026-W12")
|
||||
# Different week tags = different records
|
||||
assert ds.count() == 2
|
||||
|
||||
def test_dataset_file_is_valid_jsonl(self, tmp_path):
|
||||
ds = TrainingDataset(repo_root=tmp_path)
|
||||
ds.append([self._make_result()], "2026-W12")
|
||||
with open(ds.dataset_path) as f:
|
||||
lines = [line.strip() for line in f if line.strip()]
|
||||
assert len(lines) == 1
|
||||
record = json.loads(lines[0])
|
||||
assert "messages" in record
|
||||
assert "week" in record
|
||||
assert "quality" in record
|
||||
|
||||
def test_index_updated_after_append(self, tmp_path):
|
||||
ds = TrainingDataset(repo_root=tmp_path)
|
||||
ds.append([self._make_result()], "2026-W12")
|
||||
index_path = tmp_path / ".loop" / "retrain" / "dataset_index.json"
|
||||
assert index_path.exists()
|
||||
index = json.loads(index_path.read_text())
|
||||
assert index["total_examples"] == 1
|
||||
assert "2026-W12" in index["weeks"]
|
||||
|
||||
|
||||
# ── TrainingLog tests ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestTrainingLog:
|
||||
def _make_metrics(self, iteration: int = 1) -> CycleMetrics:
|
||||
return CycleMetrics(
|
||||
iteration=iteration,
|
||||
week="2026-W12",
|
||||
ran_at=datetime.now(tz=UTC).isoformat(),
|
||||
trajectories_total=10,
|
||||
trajectories_high=5,
|
||||
trajectories_medium=3,
|
||||
trajectories_low=2,
|
||||
trajectories_accepted=8,
|
||||
examples_added=5,
|
||||
dataset_total=5,
|
||||
train_status="completed",
|
||||
train_loss=1.2345,
|
||||
train_duration_seconds=120.5,
|
||||
adapter_path=".loop/retrain/adapters/iter_0001/adapters.npz",
|
||||
model_name="hermes4-14b-ft-0001",
|
||||
notes="First fine-tune cycle complete",
|
||||
)
|
||||
|
||||
def test_next_iteration_starts_at_1(self, tmp_path):
|
||||
log = TrainingLog(repo_root=tmp_path)
|
||||
assert log.next_iteration() == 1
|
||||
|
||||
def test_next_iteration_increments(self, tmp_path):
|
||||
log = TrainingLog(repo_root=tmp_path)
|
||||
log.record(self._make_metrics(iteration=1))
|
||||
assert log.next_iteration() == 2
|
||||
|
||||
def test_record_creates_log_file(self, tmp_path):
|
||||
log = TrainingLog(repo_root=tmp_path)
|
||||
log.record(self._make_metrics())
|
||||
assert log.log_path.exists()
|
||||
|
||||
def test_load_all_returns_records(self, tmp_path):
|
||||
log = TrainingLog(repo_root=tmp_path)
|
||||
log.record(self._make_metrics(iteration=1))
|
||||
log.record(self._make_metrics(iteration=2))
|
||||
entries = log.load_all()
|
||||
assert len(entries) == 2
|
||||
assert entries[0]["iteration"] == 1
|
||||
|
||||
def test_latest_returns_last_entry(self, tmp_path):
|
||||
log = TrainingLog(repo_root=tmp_path)
|
||||
log.record(self._make_metrics(iteration=1))
|
||||
log.record(self._make_metrics(iteration=2))
|
||||
latest = log.latest()
|
||||
assert latest is not None
|
||||
assert latest["iteration"] == 2
|
||||
|
||||
def test_latest_returns_none_when_empty(self, tmp_path):
|
||||
log = TrainingLog(repo_root=tmp_path)
|
||||
assert log.latest() is None
|
||||
|
||||
def test_summary_markdown_written(self, tmp_path):
|
||||
log = TrainingLog(repo_root=tmp_path)
|
||||
log.record(self._make_metrics())
|
||||
summary_path = tmp_path / ".loop" / "retrain" / "training_log.md"
|
||||
assert summary_path.exists()
|
||||
content = summary_path.read_text()
|
||||
assert "AutoLoRA Training Log" in content
|
||||
assert "2026-W12" in content
|
||||
assert "completed" in content
|
||||
|
||||
def test_skill_accuracy_in_summary(self, tmp_path):
|
||||
log = TrainingLog(repo_root=tmp_path)
|
||||
m = self._make_metrics()
|
||||
m.skill_accuracy = {"tool_calling": 0.85, "reasoning": 0.72}
|
||||
log.record(m)
|
||||
content = (tmp_path / ".loop" / "retrain" / "training_log.md").read_text()
|
||||
assert "tool_calling" in content
|
||||
assert "reasoning" in content
|
||||
|
||||
|
||||
# ── RetrainOrchestrator integration tests ─────────────────────────────────────
|
||||
|
||||
|
||||
class TestRetrainOrchestrator:
|
||||
def test_run_dry_run_no_data(self, tmp_path):
|
||||
"""Dry run with no session logs should complete without errors."""
|
||||
(tmp_path / "logs").mkdir(parents=True)
|
||||
orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True)
|
||||
result = orc.run(weeks_ago=0)
|
||||
assert result.train_status in ("skipped",)
|
||||
assert result.examples_added == 0
|
||||
assert result.iteration == 1
|
||||
|
||||
def test_run_creates_log_entry(self, tmp_path):
|
||||
(tmp_path / "logs").mkdir(parents=True)
|
||||
orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True)
|
||||
orc.run(weeks_ago=0)
|
||||
log = TrainingLog(repo_root=tmp_path)
|
||||
entries = log.load_all()
|
||||
assert len(entries) == 1
|
||||
|
||||
def test_run_with_session_data(self, tmp_path):
|
||||
"""Run with actual session data — should export, filter, and log."""
|
||||
today = datetime.now(tz=UTC)
|
||||
date_str = today.strftime("%Y-%m-%d")
|
||||
entries = [
|
||||
_user_msg("deploy the service", offset=-10),
|
||||
_tool_call("bash", "deployed successfully", offset=-9),
|
||||
_tool_call("bash", "health check ok", offset=-8),
|
||||
_timmy_msg("Service deployed and healthy", confidence=0.92, offset=-7),
|
||||
_user_msg("run the tests", offset=-6),
|
||||
_tool_call("bash", "All tests passed", offset=-5),
|
||||
_timmy_msg("All 42 tests passed", confidence=0.95, offset=-4),
|
||||
]
|
||||
_make_session_log(entries, date_str, tmp_path)
|
||||
|
||||
orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True)
|
||||
result = orc.run(weeks_ago=0)
|
||||
|
||||
assert result.trajectories_exported >= 1
|
||||
assert result.iteration == 1
|
||||
# In dry_run mode, fine-tune is skipped but trajectories should be processed
|
||||
assert result.train_status == "skipped"
|
||||
|
||||
def test_iteration_increments_on_second_run(self, tmp_path):
|
||||
(tmp_path / "logs").mkdir(parents=True)
|
||||
orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True)
|
||||
r1 = orc.run(weeks_ago=0)
|
||||
r2 = orc.run(weeks_ago=0)
|
||||
assert r2.iteration == r1.iteration + 1
|
||||
|
||||
def test_automations_json_has_retrain_entry(self):
|
||||
"""Verify the retrain automation is registered in automations.json."""
|
||||
config_path = _REPO_ROOT / "timmy_automations" / "config" / "automations.json"
|
||||
assert config_path.exists()
|
||||
manifest = json.loads(config_path.read_text())
|
||||
ids = [a["id"] for a in manifest.get("automations", [])]
|
||||
assert "retrain" in ids
|
||||
|
||||
def test_retrain_automation_config(self):
|
||||
"""Verify retrain automation has correct schedule and config."""
|
||||
config_path = _REPO_ROOT / "timmy_automations" / "config" / "automations.json"
|
||||
manifest = json.loads(config_path.read_text())
|
||||
retrain = next(a for a in manifest["automations"] if a["id"] == "retrain")
|
||||
assert retrain["schedule"] == "weekly_sunday"
|
||||
assert retrain["trigger"] == "scheduled"
|
||||
assert retrain["config"]["base_model"] == "hermes4-14b"
|
||||
assert retrain["config"]["weeks_ago"] == 1
|
||||
|
||||
|
||||
_REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
@@ -4,7 +4,7 @@
|
||||
"_health_snapshot": {
|
||||
"note": "Quick health check before coding — CI, P0/P1 issues, flakiness"
|
||||
},
|
||||
"last_updated": "2026-03-21",
|
||||
"last_updated": "2026-03-23",
|
||||
"automations": [
|
||||
{
|
||||
"id": "cycle_retro",
|
||||
@@ -268,6 +268,36 @@
|
||||
"ci_timeout_seconds": 5
|
||||
},
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"id": "retrain",
|
||||
"name": "AutoLoRA Continuous Improvement Loop",
|
||||
"description": "Weekly sovereignty loop — exports trajectories, filters quality, appends to training dataset, triggers LoRA fine-tune, loads new adapter, and logs iteration metrics",
|
||||
"script": "timmy_automations/retrain/retrain.py",
|
||||
"category": "autolora",
|
||||
"enabled": true,
|
||||
"trigger": "scheduled",
|
||||
"schedule": "weekly_sunday",
|
||||
"executable": "python3",
|
||||
"epic": "#1091",
|
||||
"pipeline": "AutoLoRA Sovereignty Loop (Step 6 of 7)",
|
||||
"config": {
|
||||
"weeks_ago": 1,
|
||||
"base_model": "hermes4-14b",
|
||||
"dry_run": false,
|
||||
"logs_dir": "logs",
|
||||
"dataset_path": ".loop/retrain/training_data.jsonl",
|
||||
"adapter_dir": ".loop/retrain/adapters",
|
||||
"training_log_path": ".loop/retrain/training_log.jsonl",
|
||||
"training_summary_path": ".loop/retrain/training_log.md"
|
||||
},
|
||||
"outputs": [
|
||||
".loop/retrain/training_data.jsonl",
|
||||
".loop/retrain/dataset_index.json",
|
||||
".loop/retrain/training_log.jsonl",
|
||||
".loop/retrain/training_log.md",
|
||||
".loop/retrain/adapters/"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
26
timmy_automations/retrain/__init__.py
Normal file
26
timmy_automations/retrain/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""AutoLoRA continuous improvement loop — sovereignty engine for Timmy.
|
||||
|
||||
Implements the weekly retrain cycle:
|
||||
Work → Record trajectories → Export weekly → Filter quality
|
||||
→ LoRA fine-tune → Load adapter → Model improves → Repeat
|
||||
|
||||
Epic: #1091 — Project Bannerlord
|
||||
Pipeline: AutoLoRA Sovereignty Loop (Step 6 of 7)
|
||||
Refs: #1105
|
||||
"""
|
||||
|
||||
from timmy_automations.retrain.quality_filter import QualityFilter, TrajectoryQuality
|
||||
from timmy_automations.retrain.retrain import RetrainOrchestrator, RetrainResult
|
||||
from timmy_automations.retrain.training_dataset import TrainingDataset
|
||||
from timmy_automations.retrain.training_log import TrainingLog
|
||||
from timmy_automations.retrain.trajectory_exporter import TrajectoryExporter
|
||||
|
||||
__all__ = [
|
||||
"QualityFilter",
|
||||
"RetrainOrchestrator",
|
||||
"RetrainResult",
|
||||
"TrainingDataset",
|
||||
"TrainingLog",
|
||||
"TrajectoryExporter",
|
||||
"TrajectoryQuality",
|
||||
]
|
||||
262
timmy_automations/retrain/lora_trainer.py
Normal file
262
timmy_automations/retrain/lora_trainer.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""LoRA trainer — triggers fine-tune job and loads the resulting adapter.
|
||||
|
||||
Supports two backends:
|
||||
1. mlx-lm (default, Apple Silicon) — `mlx_lm.lora` CLI
|
||||
2. Ollama create (adapter packaging into a new Ollama model)
|
||||
|
||||
Graceful degradation: if neither backend is available, logs a warning
|
||||
and returns a skipped result — the rest of the loop continues.
|
||||
|
||||
Refs: #1105
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_BASE_MODEL = "hermes4-14b"
|
||||
_DEFAULT_ADAPTER_DIR = ".loop/retrain/adapters"
|
||||
_MLX_LM_BIN = "mlx_lm.lora"
|
||||
_OLLAMA_BIN = "ollama"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainResult:
|
||||
"""Result of a LoRA fine-tune run."""
|
||||
|
||||
status: str # "completed" | "skipped" | "failed"
|
||||
adapter_path: str | None
|
||||
model_name: str | None
|
||||
iteration: int
|
||||
duration_seconds: float
|
||||
message: str
|
||||
train_loss: float | None = None
|
||||
|
||||
|
||||
class LoRATrainer:
|
||||
"""Orchestrates LoRA fine-tuning and adapter loading.
|
||||
|
||||
Workflow:
|
||||
1. Run mlx_lm.lora fine-tune on the training dataset
|
||||
2. Save the resulting adapter to .loop/retrain/adapters/<iteration>/
|
||||
3. Create (or update) an Ollama model that uses the new adapter
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_model: str = _DEFAULT_BASE_MODEL,
|
||||
adapter_dir: str | Path | None = None,
|
||||
repo_root: str | Path | None = None,
|
||||
dry_run: bool = False,
|
||||
):
|
||||
if repo_root is None:
|
||||
repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
self._repo_root = Path(repo_root)
|
||||
|
||||
self._base_model = base_model
|
||||
self._adapter_dir = self._repo_root / (adapter_dir or _DEFAULT_ADAPTER_DIR)
|
||||
self._adapter_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._dry_run = dry_run
|
||||
|
||||
def train(self, dataset_path: Path, iteration: int) -> TrainResult:
|
||||
"""Run LoRA fine-tuning on the dataset.
|
||||
|
||||
Args:
|
||||
dataset_path: Path to the JSONL training dataset.
|
||||
iteration: Current fine-tune iteration number (used for naming).
|
||||
|
||||
Returns:
|
||||
TrainResult with status, adapter path, and metrics.
|
||||
"""
|
||||
started = datetime.now(tz=UTC)
|
||||
|
||||
if not dataset_path.exists() or dataset_path.stat().st_size == 0:
|
||||
return TrainResult(
|
||||
status="skipped",
|
||||
adapter_path=None,
|
||||
model_name=None,
|
||||
iteration=iteration,
|
||||
duration_seconds=0.0,
|
||||
message="Training dataset is empty — skipping fine-tune",
|
||||
)
|
||||
|
||||
if self._dry_run:
|
||||
logger.info("[dry-run] Would fine-tune %s on %s", self._base_model, dataset_path)
|
||||
adapter_path = self._adapter_dir / f"iter_{iteration:04d}" / "adapters.npz"
|
||||
return TrainResult(
|
||||
status="skipped",
|
||||
adapter_path=str(adapter_path),
|
||||
model_name=f"{self._base_model}-ft-{iteration:04d}",
|
||||
iteration=iteration,
|
||||
duration_seconds=0.0,
|
||||
message="dry-run mode — no training performed",
|
||||
)
|
||||
|
||||
# Determine which backend is available
|
||||
if shutil.which(_MLX_LM_BIN):
|
||||
return self._train_mlx(dataset_path, iteration, started)
|
||||
else:
|
||||
logger.warning(
|
||||
"%s not found — skipping LoRA fine-tune (install mlx-lm to enable)",
|
||||
_MLX_LM_BIN,
|
||||
)
|
||||
return TrainResult(
|
||||
status="skipped",
|
||||
adapter_path=None,
|
||||
model_name=None,
|
||||
iteration=iteration,
|
||||
duration_seconds=0.0,
|
||||
message=(
|
||||
f"{_MLX_LM_BIN} not available. "
|
||||
"Install mlx-lm on Apple Silicon to enable LoRA fine-tuning."
|
||||
),
|
||||
)
|
||||
|
||||
def _train_mlx(
|
||||
self, dataset_path: Path, iteration: int, started: datetime
|
||||
) -> TrainResult:
|
||||
"""Run mlx_lm.lora fine-tune."""
|
||||
adapter_out = self._adapter_dir / f"iter_{iteration:04d}"
|
||||
adapter_out.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cmd = [
|
||||
_MLX_LM_BIN,
|
||||
"--model", self._base_model,
|
||||
"--data", str(dataset_path),
|
||||
"--adapter-path", str(adapter_out),
|
||||
"--train",
|
||||
"--iters", "100",
|
||||
"--batch-size", "1",
|
||||
"--learning-rate", "1e-5",
|
||||
]
|
||||
|
||||
logger.info("Starting mlx-lm LoRA fine-tune: iteration %d", iteration)
|
||||
logger.info("Command: %s", " ".join(cmd))
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3600, # 1 hour max
|
||||
env={**os.environ, "PYTHONUNBUFFERED": "1"},
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
duration = (datetime.now(tz=UTC) - started).total_seconds()
|
||||
return TrainResult(
|
||||
status="failed",
|
||||
adapter_path=None,
|
||||
model_name=None,
|
||||
iteration=iteration,
|
||||
duration_seconds=duration,
|
||||
message="Fine-tune timed out after 1 hour",
|
||||
)
|
||||
except Exception as exc:
|
||||
duration = (datetime.now(tz=UTC) - started).total_seconds()
|
||||
return TrainResult(
|
||||
status="failed",
|
||||
adapter_path=None,
|
||||
model_name=None,
|
||||
iteration=iteration,
|
||||
duration_seconds=duration,
|
||||
message=f"Fine-tune subprocess error: {exc}",
|
||||
)
|
||||
|
||||
duration = (datetime.now(tz=UTC) - started).total_seconds()
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.error("mlx-lm fine-tune failed: %s", result.stderr[:500])
|
||||
return TrainResult(
|
||||
status="failed",
|
||||
adapter_path=None,
|
||||
model_name=None,
|
||||
iteration=iteration,
|
||||
duration_seconds=duration,
|
||||
message=f"mlx_lm.lora exited {result.returncode}: {result.stderr[:300]}",
|
||||
)
|
||||
|
||||
# Parse final train loss from stdout if available
|
||||
train_loss = _parse_train_loss(result.stdout)
|
||||
|
||||
adapter_file = adapter_out / "adapters.npz"
|
||||
model_name = f"{self._base_model}-ft-{iteration:04d}"
|
||||
|
||||
# Attempt to register with Ollama
|
||||
ollama_ok = self._register_ollama_adapter(adapter_out, model_name)
|
||||
if not ollama_ok:
|
||||
logger.warning("Ollama adapter registration failed — adapter saved locally")
|
||||
|
||||
logger.info(
|
||||
"Fine-tune complete: iteration=%d loss=%.4f duration=%.1fs adapter=%s",
|
||||
iteration,
|
||||
train_loss or 0.0,
|
||||
duration,
|
||||
adapter_file,
|
||||
)
|
||||
|
||||
return TrainResult(
|
||||
status="completed",
|
||||
adapter_path=str(adapter_file),
|
||||
model_name=model_name,
|
||||
iteration=iteration,
|
||||
duration_seconds=duration,
|
||||
message=f"LoRA fine-tune completed successfully in {duration:.0f}s",
|
||||
train_loss=train_loss,
|
||||
)
|
||||
|
||||
def _register_ollama_adapter(self, adapter_dir: Path, model_name: str) -> bool:
|
||||
"""Create an Ollama model entry for the new adapter.
|
||||
|
||||
Writes a minimal Modelfile and runs `ollama create`.
|
||||
"""
|
||||
if not shutil.which(_OLLAMA_BIN):
|
||||
logger.debug("Ollama not found — skipping adapter registration")
|
||||
return False
|
||||
|
||||
modelfile_content = (
|
||||
f"FROM {self._base_model}\n"
|
||||
f"ADAPTER {adapter_dir}\n"
|
||||
)
|
||||
modelfile_path = adapter_dir / "Modelfile"
|
||||
try:
|
||||
modelfile_path.write_text(modelfile_content)
|
||||
result = subprocess.run(
|
||||
[_OLLAMA_BIN, "create", model_name, "-f", str(modelfile_path)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
logger.info("Ollama model registered: %s", model_name)
|
||||
return True
|
||||
else:
|
||||
logger.warning("ollama create failed: %s", result.stderr[:200])
|
||||
return False
|
||||
except Exception as exc:
|
||||
logger.warning("Ollama adapter registration error: %s", exc)
|
||||
return False
|
||||
|
||||
|
||||
def _parse_train_loss(stdout: str) -> float | None:
|
||||
"""Extract the final training loss from mlx-lm stdout."""
|
||||
loss: float | None = None
|
||||
for line in stdout.splitlines():
|
||||
line_lower = line.lower()
|
||||
if "train loss" in line_lower or "loss:" in line_lower:
|
||||
parts = line.split()
|
||||
for i, part in enumerate(parts):
|
||||
if "loss" in part.lower() and i + 1 < len(parts):
|
||||
try:
|
||||
loss = float(parts[i + 1].strip(",:"))
|
||||
except ValueError:
|
||||
pass
|
||||
return loss
|
||||
172
timmy_automations/retrain/quality_filter.py
Normal file
172
timmy_automations/retrain/quality_filter.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""Quality filter — keeps only high-value trajectories for LoRA training.
|
||||
|
||||
Criteria for a high-quality training example:
|
||||
1. Tool calls succeeded (tool calls present, no error entries)
|
||||
2. Multi-step tasks completed (≥2 messages + ≥1 tool call)
|
||||
3. No low-confidence signals (confidence < 0.5 on any Timmy message)
|
||||
4. Minimum meaningful exchange (≥1 user message + ≥1 Timmy message)
|
||||
|
||||
Refs: #1105
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
|
||||
from timmy_automations.retrain.trajectory_exporter import Trajectory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MIN_CONFIDENCE = 0.5
|
||||
|
||||
|
||||
class TrajectoryQuality(StrEnum):
|
||||
"""Quality classification for a trajectory."""
|
||||
|
||||
HIGH = "high" # Multi-step + tool success — ideal training data
|
||||
MEDIUM = "medium" # Single exchange, no errors — acceptable
|
||||
LOW = "low" # Error-prone or trivial — skip
|
||||
|
||||
|
||||
@dataclass
|
||||
class QualityResult:
|
||||
"""Result of quality assessment for a single trajectory."""
|
||||
|
||||
trajectory: Trajectory
|
||||
quality: TrajectoryQuality
|
||||
score: float
|
||||
reasons: list[str]
|
||||
|
||||
@property
|
||||
def is_trainable(self) -> bool:
|
||||
return self.quality in (TrajectoryQuality.HIGH, TrajectoryQuality.MEDIUM)
|
||||
|
||||
|
||||
class QualityFilter:
|
||||
"""Filters trajectories to keep only those worth training on.
|
||||
|
||||
Scoring:
|
||||
- +1 pt: base score for any valid clean exchange (no errors)
|
||||
- +3 pts: multi-step task (≥2 messages + ≥1 tool call)
|
||||
- +2 pts: tool calls present and no errors
|
||||
- +1 pt: decision recorded (deliberate choice made)
|
||||
- -2 pts: any error entry
|
||||
- -1 pt: any low-confidence response (confidence < 0.5)
|
||||
|
||||
HIGH ≥ 4, MEDIUM 1–3, LOW ≤ 0
|
||||
"""
|
||||
|
||||
def __init__(self, min_confidence: float = _MIN_CONFIDENCE):
|
||||
self._min_confidence = min_confidence
|
||||
|
||||
def assess(self, trajectory: Trajectory) -> QualityResult:
|
||||
"""Score and classify a single trajectory."""
|
||||
score = 0.0
|
||||
reasons: list[str] = []
|
||||
|
||||
# Minimum viable exchange check
|
||||
user_msgs = [m for m in trajectory.messages if m.get("role") == "user"]
|
||||
timmy_msgs = [m for m in trajectory.messages if m.get("role") == "timmy"]
|
||||
|
||||
if not user_msgs or not timmy_msgs:
|
||||
return QualityResult(
|
||||
trajectory=trajectory,
|
||||
quality=TrajectoryQuality.LOW,
|
||||
score=0.0,
|
||||
reasons=["Missing user or assistant messages — not a valid exchange"],
|
||||
)
|
||||
|
||||
# Multi-step bonus
|
||||
if trajectory.is_multi_step:
|
||||
score += 3.0
|
||||
reasons.append(
|
||||
f"Multi-step task: {trajectory.message_count} messages, "
|
||||
f"{trajectory.tool_call_count} tool calls"
|
||||
)
|
||||
|
||||
# Base score for any clean exchange (user + timmy, no tool call required)
|
||||
if trajectory.error_count == 0:
|
||||
score += 1.0
|
||||
reasons.append("Clean exchange (no errors)")
|
||||
|
||||
# Tool call quality
|
||||
if trajectory.tool_call_count > 0:
|
||||
if trajectory.error_count == 0:
|
||||
score += 2.0
|
||||
reasons.append(
|
||||
f"All {trajectory.tool_call_count} tool call(s) succeeded"
|
||||
)
|
||||
else:
|
||||
score -= 2.0
|
||||
reasons.append(
|
||||
f"{trajectory.error_count} error(s) during {trajectory.tool_call_count} tool call(s)"
|
||||
)
|
||||
elif trajectory.error_count > 0:
|
||||
score -= 2.0
|
||||
reasons.append(f"{trajectory.error_count} error(s) with no tool calls")
|
||||
|
||||
# Decision bonus
|
||||
if trajectory.decisions:
|
||||
score += 1.0
|
||||
reasons.append(f"Decisions recorded: {len(trajectory.decisions)}")
|
||||
|
||||
# Confidence penalty
|
||||
low_conf = [
|
||||
m
|
||||
for m in timmy_msgs
|
||||
if m.get("confidence") is not None
|
||||
and m["confidence"] < self._min_confidence
|
||||
]
|
||||
if low_conf:
|
||||
score -= len(low_conf)
|
||||
reasons.append(
|
||||
f"{len(low_conf)} low-confidence response(s) (threshold={self._min_confidence})"
|
||||
)
|
||||
|
||||
# Classify
|
||||
if score >= 4.0:
|
||||
quality = TrajectoryQuality.HIGH
|
||||
elif score >= 1.0:
|
||||
quality = TrajectoryQuality.MEDIUM
|
||||
else:
|
||||
quality = TrajectoryQuality.LOW
|
||||
|
||||
return QualityResult(
|
||||
trajectory=trajectory,
|
||||
quality=quality,
|
||||
score=score,
|
||||
reasons=reasons,
|
||||
)
|
||||
|
||||
def filter(
|
||||
self, trajectories: list[Trajectory]
|
||||
) -> tuple[list[QualityResult], dict[str, int]]:
|
||||
"""Assess all trajectories and return trainable ones with stats.
|
||||
|
||||
Returns:
|
||||
(trainable_results, stats_dict) where stats_dict has keys
|
||||
'total', 'high', 'medium', 'low', 'accepted'.
|
||||
"""
|
||||
results = [self.assess(t) for t in trajectories]
|
||||
trainable = [r for r in results if r.is_trainable]
|
||||
|
||||
stats = {
|
||||
"total": len(results),
|
||||
"high": sum(1 for r in results if r.quality == TrajectoryQuality.HIGH),
|
||||
"medium": sum(1 for r in results if r.quality == TrajectoryQuality.MEDIUM),
|
||||
"low": sum(1 for r in results if r.quality == TrajectoryQuality.LOW),
|
||||
"accepted": len(trainable),
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"Quality filter: %d/%d accepted (high=%d medium=%d low=%d)",
|
||||
stats["accepted"],
|
||||
stats["total"],
|
||||
stats["high"],
|
||||
stats["medium"],
|
||||
stats["low"],
|
||||
)
|
||||
|
||||
return trainable, stats
|
||||
292
timmy_automations/retrain/retrain.py
Normal file
292
timmy_automations/retrain/retrain.py
Normal file
@@ -0,0 +1,292 @@
|
||||
#!/usr/bin/env python3
|
||||
"""AutoLoRA continuous improvement loop — the sovereignty retrain script.
|
||||
|
||||
Implements the weekly retrain cycle end-to-end:
|
||||
Work → Record trajectories → Export weekly → Filter quality
|
||||
→ LoRA fine-tune → Load adapter → Model improves → Repeat forever
|
||||
|
||||
Run:
|
||||
python3 timmy_automations/retrain/retrain.py
|
||||
python3 timmy_automations/retrain/retrain.py --dry-run
|
||||
python3 timmy_automations/retrain/retrain.py --weeks-ago 1
|
||||
|
||||
Epic: #1091 — Project Bannerlord
|
||||
Pipeline: AutoLoRA Sovereignty Loop (Step 6 of 7)
|
||||
Refs: #1105
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
# Allow running directly from repo root
|
||||
_REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
if str(_REPO_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_REPO_ROOT))
|
||||
|
||||
from timmy_automations.retrain.lora_trainer import LoRATrainer
|
||||
from timmy_automations.retrain.quality_filter import QualityFilter
|
||||
from timmy_automations.retrain.training_dataset import TrainingDataset
|
||||
from timmy_automations.retrain.training_log import CycleMetrics, TrainingLog
|
||||
from timmy_automations.retrain.trajectory_exporter import TrajectoryExporter
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)-8s %(name)s: %(message)s",
|
||||
datefmt="%Y-%m-%dT%H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger("retrain")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrainResult:
|
||||
"""Result of a complete retrain cycle."""
|
||||
|
||||
iteration: int
|
||||
week: str
|
||||
trajectories_exported: int
|
||||
trajectories_accepted: int
|
||||
examples_added: int
|
||||
dataset_total: int
|
||||
train_status: str
|
||||
adapter_path: str | None
|
||||
model_name: str | None
|
||||
train_loss: float | None
|
||||
duration_seconds: float
|
||||
notes: str
|
||||
|
||||
|
||||
class RetrainOrchestrator:
|
||||
"""Orchestrates the complete AutoLoRA continuous improvement loop.
|
||||
|
||||
Step 1: Export this week's conversation trajectories from session logs
|
||||
Step 2: Filter for high-quality exchanges
|
||||
Step 3: Append to the training dataset
|
||||
Step 4: Trigger LoRA fine-tune
|
||||
Step 5: Load the new adapter (via Ollama)
|
||||
Step 6: Log iteration, loss, skill accuracy
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_model: str = "hermes4-14b",
|
||||
repo_root: str | Path | None = None,
|
||||
dry_run: bool = False,
|
||||
):
|
||||
if repo_root is None:
|
||||
repo_root = _REPO_ROOT
|
||||
self._repo_root = Path(repo_root)
|
||||
self._dry_run = dry_run
|
||||
|
||||
self.exporter = TrajectoryExporter(repo_root=self._repo_root)
|
||||
self.quality_filter = QualityFilter()
|
||||
self.dataset = TrainingDataset(repo_root=self._repo_root)
|
||||
self.trainer = LoRATrainer(
|
||||
base_model=base_model,
|
||||
repo_root=self._repo_root,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
self.log = TrainingLog(repo_root=self._repo_root)
|
||||
|
||||
def run(self, weeks_ago: int = 1) -> RetrainResult:
|
||||
"""Execute one complete retrain cycle.
|
||||
|
||||
Args:
|
||||
weeks_ago: Which week to process. 0 = current week (partial),
|
||||
1 = last week (default, Sunday night run), etc.
|
||||
|
||||
Returns:
|
||||
RetrainResult with full cycle summary.
|
||||
"""
|
||||
started = datetime.now(tz=UTC)
|
||||
iteration = self.log.next_iteration()
|
||||
|
||||
# Determine ISO week tag
|
||||
from datetime import timedelta
|
||||
now = datetime.now(tz=UTC)
|
||||
target_date = now - timedelta(weeks=weeks_ago)
|
||||
week_tag = f"{target_date.year}-W{target_date.isocalendar().week:02d}"
|
||||
|
||||
logger.info(
|
||||
"=== AutoLoRA Retrain Cycle %d | Week: %s | dry_run=%s ===",
|
||||
iteration,
|
||||
week_tag,
|
||||
self._dry_run,
|
||||
)
|
||||
|
||||
# Step 1: Export trajectories
|
||||
logger.info("Step 1: Exporting trajectories for %s...", week_tag)
|
||||
trajectories = self.exporter.export_week(weeks_ago=weeks_ago)
|
||||
logger.info("Exported %d raw trajectories", len(trajectories))
|
||||
|
||||
# Step 2: Quality filter
|
||||
logger.info("Step 2: Applying quality filter...")
|
||||
trainable, filter_stats = self.quality_filter.filter(trajectories)
|
||||
logger.info(
|
||||
"Quality filter: %d/%d accepted (high=%d medium=%d low=%d)",
|
||||
filter_stats["accepted"],
|
||||
filter_stats["total"],
|
||||
filter_stats["high"],
|
||||
filter_stats["medium"],
|
||||
filter_stats["low"],
|
||||
)
|
||||
|
||||
# Step 3: Append to dataset
|
||||
logger.info("Step 3: Appending to training dataset...")
|
||||
append_result = self.dataset.append(trainable, week_tag)
|
||||
logger.info(
|
||||
"Dataset: +%d new examples (%d total)",
|
||||
append_result.new_examples,
|
||||
append_result.total_examples,
|
||||
)
|
||||
|
||||
# Step 4: LoRA fine-tune
|
||||
logger.info("Step 4: Triggering LoRA fine-tune (iteration=%d)...", iteration)
|
||||
train_result = self.trainer.train(
|
||||
dataset_path=self.dataset.dataset_path,
|
||||
iteration=iteration,
|
||||
)
|
||||
logger.info(
|
||||
"Train result: status=%s loss=%s duration=%.1fs",
|
||||
train_result.status,
|
||||
train_result.train_loss,
|
||||
train_result.duration_seconds,
|
||||
)
|
||||
|
||||
# Step 5 & 6: Log cycle
|
||||
duration = (datetime.now(tz=UTC) - started).total_seconds()
|
||||
metrics = CycleMetrics(
|
||||
iteration=iteration,
|
||||
week=week_tag,
|
||||
ran_at=started.isoformat(),
|
||||
trajectories_total=filter_stats["total"],
|
||||
trajectories_high=filter_stats["high"],
|
||||
trajectories_medium=filter_stats["medium"],
|
||||
trajectories_low=filter_stats["low"],
|
||||
trajectories_accepted=filter_stats["accepted"],
|
||||
examples_added=append_result.new_examples,
|
||||
dataset_total=append_result.total_examples,
|
||||
train_status=train_result.status,
|
||||
train_loss=train_result.train_loss,
|
||||
train_duration_seconds=train_result.duration_seconds,
|
||||
adapter_path=train_result.adapter_path,
|
||||
model_name=train_result.model_name,
|
||||
notes=train_result.message,
|
||||
)
|
||||
self.log.record(metrics)
|
||||
|
||||
result = RetrainResult(
|
||||
iteration=iteration,
|
||||
week=week_tag,
|
||||
trajectories_exported=len(trajectories),
|
||||
trajectories_accepted=filter_stats["accepted"],
|
||||
examples_added=append_result.new_examples,
|
||||
dataset_total=append_result.total_examples,
|
||||
train_status=train_result.status,
|
||||
adapter_path=train_result.adapter_path,
|
||||
model_name=train_result.model_name,
|
||||
train_loss=train_result.train_loss,
|
||||
duration_seconds=duration,
|
||||
notes=train_result.message,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"=== Cycle %d complete: status=%s examples_added=%d total=%.1fs ===",
|
||||
iteration,
|
||||
train_result.status,
|
||||
append_result.new_examples,
|
||||
duration,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _print_result(result: RetrainResult, as_json: bool = False) -> None:
|
||||
"""Print cycle result to stdout."""
|
||||
if as_json:
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"iteration": result.iteration,
|
||||
"week": result.week,
|
||||
"trajectories_exported": result.trajectories_exported,
|
||||
"trajectories_accepted": result.trajectories_accepted,
|
||||
"examples_added": result.examples_added,
|
||||
"dataset_total": result.dataset_total,
|
||||
"train_status": result.train_status,
|
||||
"adapter_path": result.adapter_path,
|
||||
"model_name": result.model_name,
|
||||
"train_loss": result.train_loss,
|
||||
"duration_seconds": result.duration_seconds,
|
||||
"notes": result.notes,
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f" AutoLoRA Retrain — Cycle {result.iteration}")
|
||||
print(f" Week: {result.week}")
|
||||
print(f"{'='*60}")
|
||||
print(f" Trajectories: {result.trajectories_exported} exported, {result.trajectories_accepted} accepted")
|
||||
print(f" Dataset: +{result.examples_added} examples ({result.dataset_total} total)")
|
||||
print(f" Fine-tune: {result.train_status}")
|
||||
if result.train_loss is not None:
|
||||
print(f" Train loss: {result.train_loss:.4f}")
|
||||
if result.model_name:
|
||||
print(f" New model: {result.model_name}")
|
||||
if result.adapter_path:
|
||||
print(f" Adapter: {result.adapter_path}")
|
||||
print(f" Duration: {result.duration_seconds:.1f}s")
|
||||
print(f" Notes: {result.notes}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="AutoLoRA continuous improvement loop — sovereignty engine for Timmy"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weeks-ago",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Which week to process: 0=current (partial), 1=last week (default)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-model",
|
||||
default="hermes4-14b",
|
||||
help="Ollama base model name (default: hermes4-14b)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Export and filter trajectories but skip actual fine-tuning",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--json",
|
||||
action="store_true",
|
||||
dest="as_json",
|
||||
help="Output result as JSON",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
orchestrator = RetrainOrchestrator(
|
||||
base_model=args.base_model,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
result = orchestrator.run(weeks_ago=args.weeks_ago)
|
||||
_print_result(result, as_json=args.as_json)
|
||||
|
||||
# Exit 0 even on skipped/failed training — the loop must continue
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
180
timmy_automations/retrain/training_dataset.py
Normal file
180
timmy_automations/retrain/training_dataset.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""Training dataset manager — appends filtered trajectories to a JSONL training file.
|
||||
|
||||
Maintains a growing dataset of high-quality conversation examples in the
|
||||
chat-format expected by mlx-lm / HuggingFace fine-tuning pipelines.
|
||||
|
||||
Output format (one JSON object per line):
|
||||
{"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
|
||||
|
||||
Refs: #1105
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from timmy_automations.retrain.quality_filter import QualityResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_DATASET_PATH = ".loop/retrain/training_data.jsonl"
|
||||
_DEFAULT_INDEX_PATH = ".loop/retrain/dataset_index.json"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppendResult:
|
||||
"""Result of appending trajectories to the training dataset."""
|
||||
|
||||
new_examples: int
|
||||
total_examples: int
|
||||
dataset_path: str
|
||||
week_tag: str
|
||||
|
||||
|
||||
class TrainingDataset:
|
||||
"""Manages the LoRA training dataset file.
|
||||
|
||||
Each entry is a chat-format example:
|
||||
{"messages": [...], "week": "2026-W12", "quality": "high", "added_at": "..."}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_path: str | Path | None = None,
|
||||
index_path: str | Path | None = None,
|
||||
repo_root: str | Path | None = None,
|
||||
):
|
||||
if repo_root is None:
|
||||
repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
self._repo_root = Path(repo_root)
|
||||
|
||||
self._dataset_path = self._repo_root / (
|
||||
dataset_path or _DEFAULT_DATASET_PATH
|
||||
)
|
||||
self._index_path = self._repo_root / (
|
||||
index_path or _DEFAULT_INDEX_PATH
|
||||
)
|
||||
|
||||
self._dataset_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@property
|
||||
def dataset_path(self) -> Path:
|
||||
return self._dataset_path
|
||||
|
||||
def count(self) -> int:
|
||||
"""Return the number of examples currently in the dataset."""
|
||||
if not self._dataset_path.exists():
|
||||
return 0
|
||||
count = 0
|
||||
with open(self._dataset_path) as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def append(
|
||||
self, quality_results: list[QualityResult], week_tag: str
|
||||
) -> AppendResult:
|
||||
"""Append high-quality trajectories to the training dataset.
|
||||
|
||||
Deduplicates by (week_tag, session_date, started_at) so re-running
|
||||
the export for the same week is idempotent.
|
||||
|
||||
Args:
|
||||
quality_results: Filtered, trainable quality results.
|
||||
week_tag: ISO week string e.g. "2026-W12".
|
||||
|
||||
Returns:
|
||||
AppendResult with counts.
|
||||
"""
|
||||
existing_keys = self._load_existing_keys()
|
||||
new_count = 0
|
||||
added_at = datetime.now(tz=UTC).isoformat()
|
||||
|
||||
with open(self._dataset_path, "a") as f:
|
||||
for result in quality_results:
|
||||
traj = result.trajectory
|
||||
dedup_key = (
|
||||
f"{week_tag}|{traj.session_date}|{traj.started_at}"
|
||||
)
|
||||
if dedup_key in existing_keys:
|
||||
logger.debug("Skipping duplicate trajectory: %s", dedup_key)
|
||||
continue
|
||||
|
||||
chat_messages = traj.to_chat_format()
|
||||
if len(chat_messages) < 2:
|
||||
logger.debug(
|
||||
"Skipping trajectory with %d chat messages (need ≥2)",
|
||||
len(chat_messages),
|
||||
)
|
||||
continue
|
||||
|
||||
record = {
|
||||
"messages": chat_messages,
|
||||
"week": week_tag,
|
||||
"quality": result.quality.value,
|
||||
"score": result.score,
|
||||
"session_date": traj.session_date,
|
||||
"started_at": traj.started_at,
|
||||
"tool_calls": traj.tool_call_count,
|
||||
"added_at": added_at,
|
||||
}
|
||||
f.write(json.dumps(record) + "\n")
|
||||
existing_keys.add(dedup_key)
|
||||
new_count += 1
|
||||
|
||||
total = self.count()
|
||||
self._update_index(week_tag, new_count, total)
|
||||
logger.info(
|
||||
"Dataset: appended %d new examples (total=%d)", new_count, total
|
||||
)
|
||||
|
||||
return AppendResult(
|
||||
new_examples=new_count,
|
||||
total_examples=total,
|
||||
dataset_path=str(self._dataset_path),
|
||||
week_tag=week_tag,
|
||||
)
|
||||
|
||||
def _load_existing_keys(self) -> set[str]:
|
||||
"""Load deduplication keys from the existing dataset."""
|
||||
keys: set[str] = set()
|
||||
if not self._dataset_path.exists():
|
||||
return keys
|
||||
with open(self._dataset_path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
record = json.loads(line)
|
||||
week = record.get("week", "")
|
||||
session_date = record.get("session_date", "")
|
||||
started_at = record.get("started_at", "")
|
||||
keys.add(f"{week}|{session_date}|{started_at}")
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return keys
|
||||
|
||||
def _update_index(self, week_tag: str, new_count: int, total: int) -> None:
|
||||
"""Update the dataset index JSON with latest run metadata."""
|
||||
index: dict = {}
|
||||
if self._index_path.exists():
|
||||
try:
|
||||
index = json.loads(self._index_path.read_text())
|
||||
except (json.JSONDecodeError, OSError):
|
||||
index = {}
|
||||
|
||||
index.setdefault("weeks", {})
|
||||
index["weeks"][week_tag] = {
|
||||
"examples_added": new_count,
|
||||
"updated_at": datetime.now(tz=UTC).isoformat(),
|
||||
}
|
||||
index["total_examples"] = total
|
||||
index["last_updated"] = datetime.now(tz=UTC).isoformat()
|
||||
|
||||
self._index_path.write_text(json.dumps(index, indent=2))
|
||||
183
timmy_automations/retrain/training_log.py
Normal file
183
timmy_automations/retrain/training_log.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Training log — records each fine-tune cycle with metrics and skill deltas.
|
||||
|
||||
Writes to .loop/retrain/training_log.jsonl (one entry per cycle) and
|
||||
maintains a human-readable .loop/retrain/training_log.md summary.
|
||||
|
||||
Each log entry captures:
|
||||
- Iteration count
|
||||
- Week processed
|
||||
- Quality filter stats
|
||||
- Examples added to dataset
|
||||
- LoRA train result (loss, duration, adapter path)
|
||||
- Skill accuracy deltas (from smoke tests)
|
||||
|
||||
Refs: #1105
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_LOG_PATH = ".loop/retrain/training_log.jsonl"
|
||||
_DEFAULT_SUMMARY_PATH = ".loop/retrain/training_log.md"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CycleMetrics:
|
||||
"""Metrics for a single retrain cycle."""
|
||||
|
||||
iteration: int
|
||||
week: str
|
||||
ran_at: str
|
||||
|
||||
# Quality filter
|
||||
trajectories_total: int = 0
|
||||
trajectories_high: int = 0
|
||||
trajectories_medium: int = 0
|
||||
trajectories_low: int = 0
|
||||
trajectories_accepted: int = 0
|
||||
|
||||
# Dataset
|
||||
examples_added: int = 0
|
||||
dataset_total: int = 0
|
||||
|
||||
# Training
|
||||
train_status: str = "skipped"
|
||||
train_loss: float | None = None
|
||||
train_duration_seconds: float = 0.0
|
||||
adapter_path: str | None = None
|
||||
model_name: str | None = None
|
||||
|
||||
# Skill accuracy (optional, from smoke tests)
|
||||
skill_accuracy: dict[str, float] = field(default_factory=dict)
|
||||
skill_delta: dict[str, float] = field(default_factory=dict)
|
||||
|
||||
# Human-readable summary
|
||||
notes: str = ""
|
||||
|
||||
|
||||
class TrainingLog:
|
||||
"""Persistent log of all retrain cycles."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_path: str | Path | None = None,
|
||||
summary_path: str | Path | None = None,
|
||||
repo_root: str | Path | None = None,
|
||||
):
|
||||
if repo_root is None:
|
||||
repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
self._repo_root = Path(repo_root)
|
||||
|
||||
self._log_path = self._repo_root / (log_path or _DEFAULT_LOG_PATH)
|
||||
self._summary_path = self._repo_root / (summary_path or _DEFAULT_SUMMARY_PATH)
|
||||
self._log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@property
|
||||
def log_path(self) -> Path:
|
||||
return self._log_path
|
||||
|
||||
def next_iteration(self) -> int:
|
||||
"""Return the next iteration number (1-indexed)."""
|
||||
entries = self.load_all()
|
||||
if not entries:
|
||||
return 1
|
||||
return max(e.get("iteration", 0) for e in entries) + 1
|
||||
|
||||
def record(self, metrics: CycleMetrics) -> None:
|
||||
"""Append a cycle metrics record to the log."""
|
||||
entry = asdict(metrics)
|
||||
with open(self._log_path, "a") as f:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
|
||||
self._update_summary(metrics)
|
||||
logger.info(
|
||||
"Training log: iteration=%d week=%s status=%s examples_added=%d",
|
||||
metrics.iteration,
|
||||
metrics.week,
|
||||
metrics.train_status,
|
||||
metrics.examples_added,
|
||||
)
|
||||
|
||||
def load_all(self) -> list[dict[str, Any]]:
|
||||
"""Load all cycle records from the log."""
|
||||
if not self._log_path.exists():
|
||||
return []
|
||||
entries: list[dict[str, Any]] = []
|
||||
with open(self._log_path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
entries.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed log entry")
|
||||
return entries
|
||||
|
||||
def latest(self) -> dict[str, Any] | None:
|
||||
"""Return the most recent cycle record."""
|
||||
entries = self.load_all()
|
||||
return entries[-1] if entries else None
|
||||
|
||||
def _update_summary(self, metrics: CycleMetrics) -> None:
|
||||
"""Rewrite the markdown summary with all cycles."""
|
||||
all_entries = self.load_all()
|
||||
|
||||
lines = [
|
||||
"# AutoLoRA Training Log\n",
|
||||
f"*Updated: {datetime.now(tz=UTC).isoformat()}*\n",
|
||||
f"*Total iterations: {len(all_entries)}*\n",
|
||||
"",
|
||||
"## Cycles\n",
|
||||
"| # | Week | Status | Loss | Examples | Duration |",
|
||||
"|---|------|--------|------|----------|----------|",
|
||||
]
|
||||
|
||||
for entry in reversed(all_entries[-20:]): # Last 20 cycles
|
||||
loss = f"{entry.get('train_loss', 0.0) or 0.0:.4f}" if entry.get("train_loss") else "—"
|
||||
lines.append(
|
||||
f"| {entry.get('iteration', '?')} "
|
||||
f"| {entry.get('week', '?')} "
|
||||
f"| {entry.get('train_status', '?')} "
|
||||
f"| {loss} "
|
||||
f"| +{entry.get('examples_added', 0)} ({entry.get('dataset_total', 0)} total) "
|
||||
f"| {entry.get('train_duration_seconds', 0.0):.0f}s |"
|
||||
)
|
||||
|
||||
lines.append("")
|
||||
lines.append("## Skill Accuracy Over Time\n")
|
||||
|
||||
# Collect all unique skills
|
||||
all_skills: set[str] = set()
|
||||
for entry in all_entries:
|
||||
all_skills.update(entry.get("skill_accuracy", {}).keys())
|
||||
|
||||
if all_skills:
|
||||
skill_header = "| # | Week | " + " | ".join(sorted(all_skills)) + " |"
|
||||
skill_sep = "|---|------|" + "|".join("---" for _ in all_skills) + "|"
|
||||
lines.extend([skill_header, skill_sep])
|
||||
for entry in reversed(all_entries[-10:]):
|
||||
acc = entry.get("skill_accuracy", {})
|
||||
row = f"| {entry.get('iteration', '?')} | {entry.get('week', '?')} | "
|
||||
row += " | ".join(
|
||||
f"{acc.get(s, 0.0):.0%}" if s in acc else "—"
|
||||
for s in sorted(all_skills)
|
||||
)
|
||||
row += " |"
|
||||
lines.append(row)
|
||||
else:
|
||||
lines.append("*No skill accuracy data yet — run smoke tests after fine-tuning.*")
|
||||
|
||||
lines.append("")
|
||||
if metrics.notes:
|
||||
lines.append(f"## Latest Notes\n\n{metrics.notes}\n")
|
||||
|
||||
self._summary_path.write_text("\n".join(lines))
|
||||
255
timmy_automations/retrain/trajectory_exporter.py
Normal file
255
timmy_automations/retrain/trajectory_exporter.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Trajectory exporter — reads session JSONL logs and extracts conversation trajectories.
|
||||
|
||||
A trajectory is a coherent sequence of messages + tool calls that form
|
||||
a single task attempt. Each trajectory becomes one training example.
|
||||
|
||||
Refs: #1105
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_LOGS_DIR_DEFAULT = "logs"
|
||||
_SESSION_GLOB = "session_*.jsonl"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Trajectory:
|
||||
"""A single conversation trajectory extracted from session logs."""
|
||||
|
||||
session_date: str
|
||||
started_at: str
|
||||
ended_at: str
|
||||
messages: list[dict[str, Any]] = field(default_factory=list)
|
||||
tool_calls: list[dict[str, Any]] = field(default_factory=list)
|
||||
errors: list[dict[str, Any]] = field(default_factory=list)
|
||||
decisions: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def message_count(self) -> int:
|
||||
return len(self.messages)
|
||||
|
||||
@property
|
||||
def tool_call_count(self) -> int:
|
||||
return len(self.tool_calls)
|
||||
|
||||
@property
|
||||
def error_count(self) -> int:
|
||||
return len(self.errors)
|
||||
|
||||
@property
|
||||
def has_successful_tool_call(self) -> bool:
|
||||
"""True if any tool call succeeded (no error entry follows it)."""
|
||||
return self.tool_call_count > 0 and self.error_count == 0
|
||||
|
||||
@property
|
||||
def is_multi_step(self) -> bool:
|
||||
"""True if this trajectory involved multiple turns with tool use."""
|
||||
return self.message_count >= 2 and self.tool_call_count >= 1
|
||||
|
||||
def to_chat_format(self) -> list[dict[str, str]]:
|
||||
"""Convert trajectory to chat-format messages for training.
|
||||
|
||||
Interleaves messages and tool-call results as assistant/tool turns.
|
||||
"""
|
||||
chat: list[dict[str, str]] = []
|
||||
# Merge all entries by timestamp and emit in order
|
||||
all_entries = sorted(
|
||||
self.messages + self.tool_calls + self.decisions,
|
||||
key=lambda e: e.get("timestamp", ""),
|
||||
)
|
||||
for entry in all_entries:
|
||||
etype = entry.get("type")
|
||||
if etype == "message":
|
||||
role = "user" if entry.get("role") == "user" else "assistant"
|
||||
content = entry.get("content", "")
|
||||
if content:
|
||||
chat.append({"role": role, "content": content})
|
||||
elif etype == "tool_call":
|
||||
tool = entry.get("tool", "unknown")
|
||||
result = entry.get("result", "")
|
||||
chat.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"[tool:{tool}] {result}",
|
||||
}
|
||||
)
|
||||
elif etype == "decision":
|
||||
decision = entry.get("decision", "")
|
||||
if decision:
|
||||
chat.append({"role": "assistant", "content": f"[decided] {decision}"})
|
||||
return chat
|
||||
|
||||
|
||||
class TrajectoryExporter:
|
||||
"""Reads session JSONL logs and yields Trajectory objects for a date range."""
|
||||
|
||||
def __init__(self, logs_dir: str | Path | None = None, repo_root: str | Path | None = None):
|
||||
if repo_root is None:
|
||||
repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
self._repo_root = Path(repo_root)
|
||||
|
||||
if logs_dir is None:
|
||||
self._logs_dir = self._repo_root / _LOGS_DIR_DEFAULT
|
||||
else:
|
||||
self._logs_dir = Path(logs_dir)
|
||||
|
||||
def export_week(self, weeks_ago: int = 0) -> list[Trajectory]:
|
||||
"""Export all trajectories from the specified week.
|
||||
|
||||
Args:
|
||||
weeks_ago: 0 = current week, 1 = last week, etc.
|
||||
|
||||
Returns:
|
||||
List of Trajectory objects extracted from session logs.
|
||||
"""
|
||||
now = datetime.now(tz=UTC)
|
||||
# Week boundaries: Mon–Sun
|
||||
days_since_monday = now.weekday()
|
||||
week_start = (now - timedelta(days=days_since_monday + 7 * weeks_ago)).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
week_end = week_start + timedelta(days=7)
|
||||
|
||||
logger.info(
|
||||
"Exporting trajectories for week %s–%s",
|
||||
week_start.date().isoformat(),
|
||||
week_end.date().isoformat(),
|
||||
)
|
||||
|
||||
trajectories: list[Trajectory] = []
|
||||
log_files = sorted(self._logs_dir.glob(_SESSION_GLOB))
|
||||
|
||||
for log_file in log_files:
|
||||
# Parse date from filename: session_YYYY-MM-DD.jsonl
|
||||
try:
|
||||
date_str = log_file.stem.removeprefix("session_")
|
||||
file_date = datetime.strptime(date_str, "%Y-%m-%d").replace(tzinfo=UTC)
|
||||
except ValueError:
|
||||
logger.debug("Skipping non-date session file: %s", log_file.name)
|
||||
continue
|
||||
|
||||
if not (week_start <= file_date < week_end):
|
||||
continue
|
||||
|
||||
file_trajectories = self._extract_from_file(log_file)
|
||||
trajectories.extend(file_trajectories)
|
||||
logger.info(
|
||||
"Extracted %d trajectories from %s", len(file_trajectories), log_file.name
|
||||
)
|
||||
|
||||
logger.info("Total trajectories exported: %d", len(trajectories))
|
||||
return trajectories
|
||||
|
||||
def _extract_from_file(self, log_file: Path) -> list[Trajectory]:
|
||||
"""Parse a single session JSONL file into trajectories.
|
||||
|
||||
Groups entries into trajectories by finding natural conversation
|
||||
boundaries (gaps of inactivity or topic shifts in the message stream).
|
||||
"""
|
||||
entries: list[dict[str, Any]] = []
|
||||
try:
|
||||
with open(log_file) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
entries.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed JSON line in %s", log_file.name)
|
||||
except OSError as exc:
|
||||
logger.warning("Could not read %s: %s", log_file, exc)
|
||||
return []
|
||||
|
||||
if not entries:
|
||||
return []
|
||||
|
||||
date_str = log_file.stem.removeprefix("session_")
|
||||
return self._segment_trajectories(entries, date_str)
|
||||
|
||||
def _segment_trajectories(
|
||||
self, entries: list[dict[str, Any]], session_date: str
|
||||
) -> list[Trajectory]:
|
||||
"""Split a flat list of session entries into discrete trajectories.
|
||||
|
||||
Segmentation rule: start a new trajectory when:
|
||||
- A user message follows a Timmy message (new conversation turn)
|
||||
- More than 5 minutes have elapsed between entries
|
||||
|
||||
This produces training examples that are coherent task attempts.
|
||||
"""
|
||||
if not entries:
|
||||
return []
|
||||
|
||||
trajectories: list[Trajectory] = []
|
||||
current_entries: list[dict[str, Any]] = []
|
||||
prev_ts: datetime | None = None
|
||||
_SEGMENT_GAP_MINUTES = 5
|
||||
|
||||
def _flush() -> None:
|
||||
if current_entries:
|
||||
traj = _build_trajectory(current_entries, session_date)
|
||||
if traj.message_count > 0:
|
||||
trajectories.append(traj)
|
||||
|
||||
for entry in entries:
|
||||
ts_raw = entry.get("timestamp", "")
|
||||
try:
|
||||
ts = datetime.fromisoformat(ts_raw.replace("Z", "+00:00"))
|
||||
except (ValueError, AttributeError):
|
||||
ts = None
|
||||
|
||||
# Time-gap segmentation
|
||||
if ts and prev_ts and (ts - prev_ts).total_seconds() > _SEGMENT_GAP_MINUTES * 60:
|
||||
_flush()
|
||||
current_entries = []
|
||||
|
||||
# New-turn segmentation: user message after assistant turn
|
||||
etype = entry.get("type")
|
||||
erole = entry.get("role")
|
||||
if etype == "message" and erole == "user" and current_entries:
|
||||
# Check if previous non-error entry was a Timmy message
|
||||
for prev in reversed(current_entries):
|
||||
if prev.get("type") == "message":
|
||||
if prev.get("role") == "timmy":
|
||||
_flush()
|
||||
current_entries = []
|
||||
break
|
||||
|
||||
current_entries.append(entry)
|
||||
if ts:
|
||||
prev_ts = ts
|
||||
|
||||
_flush()
|
||||
return trajectories
|
||||
|
||||
|
||||
def _build_trajectory(entries: list[dict[str, Any]], session_date: str) -> Trajectory:
|
||||
"""Build a Trajectory from a flat list of entries."""
|
||||
messages = [e for e in entries if e.get("type") == "message"]
|
||||
tool_calls = [e for e in entries if e.get("type") == "tool_call"]
|
||||
errors = [e for e in entries if e.get("type") == "error"]
|
||||
decisions = [e for e in entries if e.get("type") == "decision"]
|
||||
|
||||
timestamps = [e.get("timestamp", "") for e in entries if e.get("timestamp")]
|
||||
started_at = min(timestamps) if timestamps else ""
|
||||
ended_at = max(timestamps) if timestamps else ""
|
||||
|
||||
return Trajectory(
|
||||
session_date=session_date,
|
||||
started_at=started_at,
|
||||
ended_at=ended_at,
|
||||
messages=messages,
|
||||
tool_calls=tool_calls,
|
||||
errors=errors,
|
||||
decisions=decisions,
|
||||
)
|
||||
Reference in New Issue
Block a user