Files
Claude (Opus 4.6) 1be1324a0d
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
[claude] Implement AutoLoRA continuous improvement loop (#1105) (#1118)
2026-03-23 18:18:32 +00:00

293 lines
9.8 KiB
Python

#!/usr/bin/env python3
"""AutoLoRA continuous improvement loop — the sovereignty retrain script.
Implements the weekly retrain cycle end-to-end:
Work → Record trajectories → Export weekly → Filter quality
→ LoRA fine-tune → Load adapter → Model improves → Repeat forever
Run:
python3 timmy_automations/retrain/retrain.py
python3 timmy_automations/retrain/retrain.py --dry-run
python3 timmy_automations/retrain/retrain.py --weeks-ago 1
Epic: #1091 — Project Bannerlord
Pipeline: AutoLoRA Sovereignty Loop (Step 6 of 7)
Refs: #1105
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
from dataclasses import dataclass
from datetime import UTC, datetime
from pathlib import Path
# Allow running directly from repo root
_REPO_ROOT = Path(__file__).resolve().parent.parent.parent
if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
from timmy_automations.retrain.lora_trainer import LoRATrainer
from timmy_automations.retrain.quality_filter import QualityFilter
from timmy_automations.retrain.training_dataset import TrainingDataset
from timmy_automations.retrain.training_log import CycleMetrics, TrainingLog
from timmy_automations.retrain.trajectory_exporter import TrajectoryExporter
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(name)s: %(message)s",
datefmt="%Y-%m-%dT%H:%M:%S",
)
logger = logging.getLogger("retrain")
@dataclass
class RetrainResult:
"""Result of a complete retrain cycle."""
iteration: int
week: str
trajectories_exported: int
trajectories_accepted: int
examples_added: int
dataset_total: int
train_status: str
adapter_path: str | None
model_name: str | None
train_loss: float | None
duration_seconds: float
notes: str
class RetrainOrchestrator:
"""Orchestrates the complete AutoLoRA continuous improvement loop.
Step 1: Export this week's conversation trajectories from session logs
Step 2: Filter for high-quality exchanges
Step 3: Append to the training dataset
Step 4: Trigger LoRA fine-tune
Step 5: Load the new adapter (via Ollama)
Step 6: Log iteration, loss, skill accuracy
"""
def __init__(
self,
base_model: str = "hermes4-14b",
repo_root: str | Path | None = None,
dry_run: bool = False,
):
if repo_root is None:
repo_root = _REPO_ROOT
self._repo_root = Path(repo_root)
self._dry_run = dry_run
self.exporter = TrajectoryExporter(repo_root=self._repo_root)
self.quality_filter = QualityFilter()
self.dataset = TrainingDataset(repo_root=self._repo_root)
self.trainer = LoRATrainer(
base_model=base_model,
repo_root=self._repo_root,
dry_run=dry_run,
)
self.log = TrainingLog(repo_root=self._repo_root)
def run(self, weeks_ago: int = 1) -> RetrainResult:
"""Execute one complete retrain cycle.
Args:
weeks_ago: Which week to process. 0 = current week (partial),
1 = last week (default, Sunday night run), etc.
Returns:
RetrainResult with full cycle summary.
"""
started = datetime.now(tz=UTC)
iteration = self.log.next_iteration()
# Determine ISO week tag
from datetime import timedelta
now = datetime.now(tz=UTC)
target_date = now - timedelta(weeks=weeks_ago)
week_tag = f"{target_date.year}-W{target_date.isocalendar().week:02d}"
logger.info(
"=== AutoLoRA Retrain Cycle %d | Week: %s | dry_run=%s ===",
iteration,
week_tag,
self._dry_run,
)
# Step 1: Export trajectories
logger.info("Step 1: Exporting trajectories for %s...", week_tag)
trajectories = self.exporter.export_week(weeks_ago=weeks_ago)
logger.info("Exported %d raw trajectories", len(trajectories))
# Step 2: Quality filter
logger.info("Step 2: Applying quality filter...")
trainable, filter_stats = self.quality_filter.filter(trajectories)
logger.info(
"Quality filter: %d/%d accepted (high=%d medium=%d low=%d)",
filter_stats["accepted"],
filter_stats["total"],
filter_stats["high"],
filter_stats["medium"],
filter_stats["low"],
)
# Step 3: Append to dataset
logger.info("Step 3: Appending to training dataset...")
append_result = self.dataset.append(trainable, week_tag)
logger.info(
"Dataset: +%d new examples (%d total)",
append_result.new_examples,
append_result.total_examples,
)
# Step 4: LoRA fine-tune
logger.info("Step 4: Triggering LoRA fine-tune (iteration=%d)...", iteration)
train_result = self.trainer.train(
dataset_path=self.dataset.dataset_path,
iteration=iteration,
)
logger.info(
"Train result: status=%s loss=%s duration=%.1fs",
train_result.status,
train_result.train_loss,
train_result.duration_seconds,
)
# Step 5 & 6: Log cycle
duration = (datetime.now(tz=UTC) - started).total_seconds()
metrics = CycleMetrics(
iteration=iteration,
week=week_tag,
ran_at=started.isoformat(),
trajectories_total=filter_stats["total"],
trajectories_high=filter_stats["high"],
trajectories_medium=filter_stats["medium"],
trajectories_low=filter_stats["low"],
trajectories_accepted=filter_stats["accepted"],
examples_added=append_result.new_examples,
dataset_total=append_result.total_examples,
train_status=train_result.status,
train_loss=train_result.train_loss,
train_duration_seconds=train_result.duration_seconds,
adapter_path=train_result.adapter_path,
model_name=train_result.model_name,
notes=train_result.message,
)
self.log.record(metrics)
result = RetrainResult(
iteration=iteration,
week=week_tag,
trajectories_exported=len(trajectories),
trajectories_accepted=filter_stats["accepted"],
examples_added=append_result.new_examples,
dataset_total=append_result.total_examples,
train_status=train_result.status,
adapter_path=train_result.adapter_path,
model_name=train_result.model_name,
train_loss=train_result.train_loss,
duration_seconds=duration,
notes=train_result.message,
)
logger.info(
"=== Cycle %d complete: status=%s examples_added=%d total=%.1fs ===",
iteration,
train_result.status,
append_result.new_examples,
duration,
)
return result
def _print_result(result: RetrainResult, as_json: bool = False) -> None:
"""Print cycle result to stdout."""
if as_json:
print(
json.dumps(
{
"iteration": result.iteration,
"week": result.week,
"trajectories_exported": result.trajectories_exported,
"trajectories_accepted": result.trajectories_accepted,
"examples_added": result.examples_added,
"dataset_total": result.dataset_total,
"train_status": result.train_status,
"adapter_path": result.adapter_path,
"model_name": result.model_name,
"train_loss": result.train_loss,
"duration_seconds": result.duration_seconds,
"notes": result.notes,
},
indent=2,
)
)
return
print(f"\n{'='*60}")
print(f" AutoLoRA Retrain — Cycle {result.iteration}")
print(f" Week: {result.week}")
print(f"{'='*60}")
print(f" Trajectories: {result.trajectories_exported} exported, {result.trajectories_accepted} accepted")
print(f" Dataset: +{result.examples_added} examples ({result.dataset_total} total)")
print(f" Fine-tune: {result.train_status}")
if result.train_loss is not None:
print(f" Train loss: {result.train_loss:.4f}")
if result.model_name:
print(f" New model: {result.model_name}")
if result.adapter_path:
print(f" Adapter: {result.adapter_path}")
print(f" Duration: {result.duration_seconds:.1f}s")
print(f" Notes: {result.notes}")
print(f"{'='*60}\n")
def main() -> int:
parser = argparse.ArgumentParser(
description="AutoLoRA continuous improvement loop — sovereignty engine for Timmy"
)
parser.add_argument(
"--weeks-ago",
type=int,
default=1,
help="Which week to process: 0=current (partial), 1=last week (default)",
)
parser.add_argument(
"--base-model",
default="hermes4-14b",
help="Ollama base model name (default: hermes4-14b)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Export and filter trajectories but skip actual fine-tuning",
)
parser.add_argument(
"--json",
action="store_true",
dest="as_json",
help="Output result as JSON",
)
args = parser.parse_args()
orchestrator = RetrainOrchestrator(
base_model=args.base_model,
dry_run=args.dry_run,
)
result = orchestrator.run(weeks_ago=args.weeks_ago)
_print_result(result, as_json=args.as_json)
# Exit 0 even on skipped/failed training — the loop must continue
return 0
if __name__ == "__main__":
sys.exit(main())