Compare commits

..

7 Commits

Author SHA1 Message Date
a186d35773 test: Add config template test suite (#696)
Some checks failed
Architecture Lint / Linter Tests (pull_request) Successful in 24s
Smoke Test / smoke (pull_request) Failing after 20s
Validate Config / YAML Lint (pull_request) Failing after 13s
Validate Config / JSON Validate (pull_request) Successful in 13s
PR Checklist / pr-checklist (pull_request) Failing after 3m49s
Validate Config / Shell Script Lint (pull_request) Failing after 48s
Validate Config / Python Syntax & Import Check (pull_request) Failing after 1m19s
Validate Config / Cron Syntax Check (pull_request) Successful in 8s
Validate Config / Deploy Script Dry Run (pull_request) Successful in 7s
Validate Config / Playbook Schema Validation (pull_request) Successful in 18s
Architecture Lint / Lint Repository (pull_request) Has been cancelled
Validate Config / Python Test Suite (pull_request) Has been cancelled
2026-04-15 14:57:10 +00:00
4475f82c0e feat: Add gateway overlay (#696) 2026-04-15 14:55:15 +00:00
b25f5428b3 feat: Add cron overlay (#696) 2026-04-15 14:55:12 +00:00
e0f72ffee6 feat: Add prod overlay (#696) 2026-04-15 14:55:09 +00:00
d554e6396c feat: Add dev overlay (#696) 2026-04-15 14:55:07 +00:00
aa3672d08c feat: Add base config template (#696) 2026-04-15 14:55:04 +00:00
aae1a952f4 feat: Add config template system with environment overlays (#696) 2026-04-15 14:55:01 +00:00
8 changed files with 471 additions and 568 deletions

41
config/base.yaml Normal file
View File

@@ -0,0 +1,41 @@
# Base config — shared defaults across all environments
# Overridden by {env}.overlay.yaml on merge
model:
name: "nousresearch/hermes-4-14b"
provider: "openrouter"
temperature: 0.7
max_tokens: 4096
provider:
name: "openrouter"
base_url: "https://openrouter.ai/api/v1"
cron:
enabled: false
interval_seconds: 300
max_concurrent: 3
gateway:
enabled: false
cors_origins: []
port: 8080
display:
spinner: true
colors: true
tool_progress: true
tools:
enabled: true
browser: true
web_search: true
session:
save_trajectories: false
max_iterations: 90
context_compression: true
logging:
level: "INFO"
file: null

24
config/cron.overlay.yaml Normal file
View File

@@ -0,0 +1,24 @@
# Cron/headless environment overlay
# Deterministic, no display
model:
temperature: 0.1
max_tokens: 4096
cron:
enabled: true
interval_seconds: 120
max_concurrent: 8
display:
spinner: false
colors: false
tool_progress: false
session:
save_trajectories: false
max_iterations: 60
logging:
level: "INFO"
file: "/var/log/timmy/cron.log"

20
config/dev.overlay.yaml Normal file
View File

@@ -0,0 +1,20 @@
# Dev environment overlay
# Higher verbosity, faster iteration
model:
temperature: 0.9
max_tokens: 2048
cron:
interval_seconds: 60
max_concurrent: 1
display:
tool_progress: true
session:
save_trajectories: true
max_iterations: 30
logging:
level: "DEBUG"

View File

@@ -0,0 +1,20 @@
# Gateway environment overlay
# Multi-platform messaging, no cron
model:
temperature: 0.5
cron:
enabled: false
gateway:
enabled: true
cors_origins: ["*"]
port: 8080
session:
save_trajectories: true
max_iterations: 50
logging:
level: "INFO"

22
config/prod.overlay.yaml Normal file
View File

@@ -0,0 +1,22 @@
# Prod environment overlay
# Lower temperature, stable settings
model:
temperature: 0.3
max_tokens: 4096
cron:
enabled: true
interval_seconds: 600
max_concurrent: 5
gateway:
enabled: true
port: 8080
session:
save_trajectories: false
max_iterations: 120
logging:
level: "WARNING"

View File

@@ -1,568 +0,0 @@
#!/usr/bin/env python3
"""
orchestrator.py — Shared Pipeline Orchestrator
SQLite-backed job queue with parallel workers, token budget tracking,
checkpoint resume, rate limiting, and error retry.
All 5 pipelines use this orchestrator for consistent execution.
Usage:
python3 orchestrator.py --pipeline training_factory --jobs jobs.jsonl
python3 orchestrator.py --pipeline adversary --jobs jobs.jsonl --workers 5
python3 orchestrator.py --status
python3 orchestrator.py --resume training_factory
python3 orchestrator.py --report training_factory
"""
import json
import os
import sys
import time
import sqlite3
import hashlib
import threading
import signal
from datetime import datetime, timezone
from pathlib import Path
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any, Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
DB_PATH = Path.home() / ".hermes" / "pipeline" / "orchestrator.db"
REPORT_DIR = Path.home() / ".hermes" / "pipeline" / "reports"
# ============================================================
# Data Structures
# ============================================================
@dataclass
class JobStatus:
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
RETRYING = "retrying"
SKIPPED = "skipped"
@dataclass
class PipelineStats:
"""Runtime statistics for a pipeline run."""
pipeline: str
total_jobs: int = 0
completed: int = 0
failed: int = 0
skipped: int = 0
tokens_used: int = 0
tokens_budget: int = 5_000_000
elapsed_seconds: float = 0.0
start_time: str = ""
jobs_per_minute: float = 0.0
def to_dict(self):
return {
"pipeline": self.pipeline,
"total_jobs": self.total_jobs,
"completed": self.completed,
"failed": self.failed,
"skipped": self.skipped,
"tokens_used": self.tokens_used,
"tokens_budget": self.tokens_budget,
"elapsed_seconds": round(self.elapsed_seconds, 1),
"start_time": self.start_time,
"jobs_per_minute": round(self.jobs_per_minute, 2),
}
# ============================================================
# Database
# ============================================================
def get_db():
"""Get SQLite database connection."""
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(DB_PATH), timeout=30, check_same_thread=False)
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA busy_timeout=5000")
_init_db(conn)
return conn
def _init_db(conn):
"""Initialize database schema."""
conn.executescript("""
CREATE TABLE IF NOT EXISTS jobs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
pipeline TEXT NOT NULL,
job_key TEXT NOT NULL,
payload TEXT NOT NULL,
status TEXT DEFAULT 'pending',
attempts INTEGER DEFAULT 0,
max_attempts INTEGER DEFAULT 3,
tokens_used INTEGER DEFAULT 0,
error TEXT,
result TEXT,
checkpoint TEXT,
created_at TEXT DEFAULT (datetime('now')),
started_at TEXT,
completed_at TEXT,
UNIQUE(pipeline, job_key)
);
CREATE INDEX IF NOT EXISTS idx_jobs_pipeline_status ON jobs(pipeline, status);
CREATE INDEX IF NOT EXISTS idx_jobs_pipeline_key ON jobs(pipeline, job_key);
CREATE TABLE IF NOT EXISTS pipeline_runs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
pipeline TEXT NOT NULL,
started_at TEXT DEFAULT (datetime('now')),
completed_at TEXT,
total_jobs INTEGER DEFAULT 0,
completed INTEGER DEFAULT 0,
failed INTEGER DEFAULT 0,
tokens_used INTEGER DEFAULT 0,
report TEXT
);
""")
conn.commit()
# ============================================================
# Job Queue
# ============================================================
class JobQueue:
"""SQLite-backed job queue."""
def __init__(self, pipeline: str, conn=None):
self.pipeline = pipeline
self.conn = conn or get_db()
def enqueue(self, job_key: str, payload: dict, max_attempts: int = 3):
"""Add a job to the queue (skip if already exists)."""
try:
self.conn.execute(
"INSERT INTO jobs (pipeline, job_key, payload, max_attempts) VALUES (?, ?, ?, ?)",
(self.pipeline, job_key, json.dumps(payload), max_attempts),
)
self.conn.commit()
return True
except sqlite3.IntegrityError:
# Already exists — check if it needs retry
row = self.conn.execute(
"SELECT status FROM jobs WHERE pipeline=? AND job_key=?",
(self.pipeline, job_key),
).fetchone()
if row and row[0] == "failed":
# Reset for retry
self.conn.execute(
"UPDATE jobs SET status='pending', attempts=0, error=NULL WHERE pipeline=? AND job_key=?",
(self.pipeline, job_key),
)
self.conn.commit()
return True
return False
def enqueue_batch(self, jobs: List[dict], key_field: str = "id"):
"""Enqueue multiple jobs. Returns (added, skipped) counts."""
added = 0
skipped = 0
for job in jobs:
key = str(job.get(key_field, hashlib.md5(json.dumps(job).encode()).hexdigest()[:12]))
if self.enqueue(key, job):
added += 1
else:
skipped += 1
return added, skipped
def claim_next(self) -> Optional[dict]:
"""Claim the next pending job (atomic)."""
row = self.conn.execute(
"""UPDATE jobs SET status='running', started_at=datetime('now')
WHERE id = (
SELECT id FROM jobs WHERE pipeline=? AND status IN ('pending', 'retrying')
ORDER BY attempts ASC, created_at ASC LIMIT 1
) RETURNING *""",
(self.pipeline,),
).fetchone()
if not row:
return None
cols = [d[1] for d in self.conn.execute("PRAGMA table_info(jobs)").fetchall()]
return dict(zip(cols, row))
def complete(self, job_key: str, result: dict, tokens_used: int = 0):
"""Mark a job as completed."""
self.conn.execute(
"""UPDATE jobs SET status='completed', completed_at=datetime('now'),
result=?, tokens_used=? WHERE pipeline=? AND job_key=?""",
(json.dumps(result), tokens_used, self.pipeline, job_key),
)
self.conn.commit()
def fail(self, job_key: str, error: str, retry: bool = True):
"""Mark a job as failed, optionally retry."""
row = self.conn.execute(
"SELECT attempts, max_attempts FROM jobs WHERE pipeline=? AND job_key=?",
(self.pipeline, job_key),
).fetchone()
if not row:
return
attempts, max_attempts = row
new_attempts = attempts + 1
if retry and new_attempts < max_attempts:
# Exponential backoff: 2^attempts seconds
delay = min(2 ** new_attempts, 60)
self.conn.execute(
"""UPDATE jobs SET status='retrying', attempts=?, error=?
WHERE pipeline=? AND job_key=?""",
(new_attempts, error, self.pipeline, job_key),
)
else:
self.conn.execute(
"""UPDATE jobs SET status='failed', attempts=?, error=?,
completed_at=datetime('now') WHERE pipeline=? AND job_key=?""",
(new_attempts, error, self.pipeline, job_key),
)
self.conn.commit()
def save_checkpoint(self, job_key: str, checkpoint: dict):
"""Save progress checkpoint for resume."""
self.conn.execute(
"UPDATE jobs SET checkpoint=? WHERE pipeline=? AND job_key=?",
(json.dumps(checkpoint), self.pipeline, job_key),
)
self.conn.commit()
def get_checkpoint(self, job_key: str) -> Optional[dict]:
"""Get saved checkpoint."""
row = self.conn.execute(
"SELECT checkpoint FROM jobs WHERE pipeline=? AND job_key=?",
(self.pipeline, job_key),
).fetchone()
if row and row[0]:
return json.loads(row[0])
return None
def stats(self) -> dict:
"""Get queue statistics."""
rows = self.conn.execute(
"""SELECT status, COUNT(*), COALESCE(SUM(tokens_used), 0)
FROM jobs WHERE pipeline=? GROUP BY status""",
(self.pipeline,),
).fetchall()
result = {"total": 0, "tokens_used": 0}
for status, count, tokens in rows:
result[status] = count
result["total"] += count
result["tokens_used"] += tokens
return result
# ============================================================
# Orchestrator
# ============================================================
class Orchestrator:
"""
Shared orchestrator for all pipelines.
Features:
- Parallel worker pool (configurable)
- Token budget tracking
- Checkpoint resume
- Rate limiting
- Error retry with exponential backoff
- Final report generation
"""
def __init__(self, pipeline: str, workers: int = 10, token_budget: int = 5_000_000):
self.pipeline = pipeline
self.workers = workers
self.token_budget = token_budget
self.queue = JobQueue(pipeline)
self.conn = self.queue.conn
self._shutdown = False
self._stats = PipelineStats(pipeline=pipeline, tokens_budget=token_budget)
self._rate_limit_delay = 0.1 # seconds between jobs
self._response_cache: Dict[str, dict] = {}
signal.signal(signal.SIGINT, self._handle_signal)
signal.signal(signal.SIGTERM, self._handle_signal)
def _handle_signal(self, signum, frame):
"""Graceful shutdown on signal."""
print(f"\nReceived signal {signum}. Shutting down gracefully...")
self._shutdown = True
def load_jobs(self, jobs_path: str, key_field: str = "id"):
"""Load jobs from a JSONL file into the queue."""
jobs = []
with open(jobs_path) as f:
for line in f:
line = line.strip()
if line:
jobs.append(json.loads(line))
added, skipped = self.queue.enqueue_batch(jobs, key_field)
print(f"Loaded: {added} new, {skipped} existing")
def run(self, job_handler: Callable[[dict], dict] = None):
"""
Run the orchestrator. Processes all pending jobs with parallel workers.
Args:
job_handler: function(job_payload) -> dict with 'tokens_used' key
"""
start = time.time()
self._stats.start_time = datetime.now(timezone.utc).isoformat()
# Record run
self.conn.execute(
"INSERT INTO pipeline_runs (pipeline, started_at) VALUES (?, ?)",
(self.pipeline, self._stats.start_time),
)
run_id = self.conn.execute("SELECT last_insert_rowid()").fetchone()[0]
self.conn.commit()
stats = self.queue.stats()
self._stats.total_jobs = stats.get("pending", 0) + stats.get("retrying", 0)
print(f"\nPipeline: {self.pipeline}")
print(f"Jobs: {self._stats.total_jobs} pending | Workers: {self.workers} | Budget: {self.token_budget:,} tokens")
print()
if self._stats.total_jobs == 0:
print("No jobs to process.")
return
completed = 0
failed = 0
skipped = 0
tokens_used = 0
with ThreadPoolExecutor(max_workers=self.workers) as executor:
futures = {}
while not self._shutdown:
# Check token budget
if tokens_used >= self.token_budget:
print(f"Token budget exhausted ({tokens_used:,}/{self.token_budget:,})")
break
# Fill worker pool
while len(futures) < self.workers and not self._shutdown:
job = self.queue.claim_next()
if not job:
break
# Check response cache (zero-token retries)
job_key = job["job_key"]
payload = json.loads(job["payload"])
cache_key = hashlib.md5(json.dumps(payload, sort_keys=True).encode()).hexdigest()
if cache_key in self._response_cache:
result = self._response_cache[cache_key]
self.queue.complete(job_key, result, tokens_used=0)
skipped += 1
continue
# Submit to worker
future = executor.submit(self._process_job, job, job_handler)
futures[future] = job
# Rate limiting
time.sleep(self._rate_limit_delay)
if not futures:
break
# Collect results
done = []
for future in as_completed(futures, timeout=1):
job = futures[future]
try:
result = future.result()
if result.get("success"):
tokens = result.get("tokens_used", 0)
tokens_used += tokens
self.queue.complete(job["job_key"], result, tokens_used=tokens)
completed += 1
else:
error = result.get("error", "unknown error")
self.queue.fail(job["job_key"], error, retry=True)
failed += 1
except Exception as e:
self.queue.fail(job["job_key"], str(e), retry=True)
failed += 1
done.append(future)
# Progress update
total = completed + failed + skipped
if total % 10 == 0:
elapsed = time.time() - start
rate = completed / (elapsed / 60) if elapsed > 0 else 0
print(f" Progress: {total}/{self._stats.total_jobs} | "
f"completed={completed} failed={failed} | "
f"tokens={tokens_used:,} | "
f"{rate:.1f}/min")
for f in done:
del futures[f]
# Final report
elapsed = time.time() - start
self._stats.completed = completed
self._stats.failed = failed
self._stats.skipped = skipped
self._stats.tokens_used = tokens_used
self._stats.elapsed_seconds = elapsed
self._stats.jobs_per_minute = completed / (elapsed / 60) if elapsed > 0 else 0
# Save run
self.conn.execute(
"""UPDATE pipeline_runs SET completed_at=?, total_jobs=?, completed=?,
failed=?, tokens_used=?, report=? WHERE id=?""",
(datetime.now(timezone.utc).isoformat(), self._stats.total_jobs,
completed, failed, tokens_used, json.dumps(self._stats.to_dict()), run_id),
)
self.conn.commit()
# Print report
print(f"\n{'='*50}")
print(f"Pipeline: {self.pipeline}")
print(f"Completed: {completed}/{self._stats.total_jobs}")
print(f"Failed: {failed}")
print(f"Skipped (cached): {skipped}")
print(f"Tokens: {tokens_used:,}/{self.token_budget:,}")
print(f"Time: {elapsed:.1f}s ({self._stats.jobs_per_minute:.1f}/min)")
print(f"{'='*50}")
# Save report file
self._save_report()
def _process_job(self, job: dict, handler: Callable = None) -> dict:
"""Process a single job."""
payload = json.loads(job["payload"])
job_key = job["job_key"]
checkpoint = self.queue.get_checkpoint(job_key)
if handler:
try:
result = handler(payload, checkpoint=checkpoint)
return result or {"success": True, "tokens_used": 0}
except Exception as e:
return {"success": False, "error": str(e)}
else:
# Default handler: just mark as complete
return {"success": True, "tokens_used": 0}
def _save_report(self):
"""Save pipeline run report."""
REPORT_DIR.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
path = REPORT_DIR / f"{self.pipeline}_{ts}.json"
with open(path, "w") as f:
json.dump(self._stats.to_dict(), f, indent=2)
print(f"Report: {path}")
def resume(self):
"""Resume failed/retrying jobs from a previous run."""
stats = self.queue.stats()
retrying = stats.get("retrying", 0)
failed = stats.get("failed", 0)
print(f"Resume {self.pipeline}: {retrying} retrying, {failed} failed to reset")
# Reset failed jobs to pending for retry
self.conn.execute(
"UPDATE jobs SET status='pending', attempts=0 WHERE pipeline=? AND status='failed'",
(self.pipeline,),
)
self.conn.execute(
"UPDATE jobs SET status='pending' WHERE pipeline=? AND status='retrying'",
(self.pipeline,),
)
self.conn.commit()
def status(self):
"""Show pipeline status."""
stats = self.queue.stats()
print(f"\nPipeline: {self.pipeline}")
for k, v in sorted(stats.items()):
print(f" {k}: {v}")
# ============================================================
# CLI
# ============================================================
def show_all_status():
"""Show status of all pipelines."""
conn = get_db()
pipelines = conn.execute(
"SELECT DISTINCT pipeline FROM jobs ORDER BY pipeline"
).fetchall()
if not pipelines:
print("No pipelines in database.")
return
print(f"\nAll Pipeline Status")
print(f"{'='*60}")
for (pipeline,) in pipelines:
queue = JobQueue(pipeline, conn)
stats = queue.stats()
total = stats.get("total", 0)
pending = stats.get("pending", 0)
running = stats.get("running", 0)
completed = stats.get("completed", 0)
failed = stats.get("failed", 0)
tokens = stats.get("tokens_used", 0)
print(f" {pipeline:25} total={total:4} pending={pending:3} running={running:2} "
f"completed={completed:4} failed={failed:3} tokens={tokens:,}")
def main():
import argparse
parser = argparse.ArgumentParser(description="Shared Pipeline Orchestrator")
parser.add_argument("--pipeline", "-p", help="Pipeline name")
parser.add_argument("--jobs", "-j", help="Jobs JSONL file to load")
parser.add_argument("--workers", "-w", type=int, default=10, help="Parallel workers")
parser.add_argument("--budget", "-b", type=int, default=5_000_000, help="Token budget")
parser.add_argument("--status", action="store_true", help="Show status")
parser.add_argument("--resume", action="store_true", help="Resume failed jobs")
parser.add_argument("--key-field", default="id", help="Job key field name")
args = parser.parse_args()
if args.status:
if args.pipeline:
orch = Orchestrator(args.pipeline)
orch.status()
else:
show_all_status()
return
if not args.pipeline:
parser.error("--pipeline is required")
orch = Orchestrator(args.pipeline, workers=args.workers, token_budget=args.budget)
if args.jobs:
orch.load_jobs(args.jobs, key_field=args.key_field)
if args.resume:
orch.resume()
if args.jobs or args.resume:
orch.run()
if __name__ == "__main__":
main()

211
scripts/config_template.py Normal file
View File

@@ -0,0 +1,211 @@
#!/usr/bin/env python3
"""
Config Template System — Environment-Specific Overlays (Issue #696)
Loads base.yaml + {env}.overlay.yaml with deep merge.
Overlay keys override base keys. Supports dot notation access.
Usage:
from scripts.config_template import ConfigTemplate, load_config
config = load_config("dev")
template = ConfigTemplate()
template.load("prod")
model = template.get("model.name")
CLI:
python3 scripts/config_template.py --env prod
python3 scripts/config_template.py --env dev --diff
python3 scripts/config_template.py --env prod --validate
python3 scripts/config_template.py --list-envs
"""
import argparse
import copy
import json
import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional
try:
import yaml
except ImportError:
yaml = None
CONFIG_DIR = Path(__file__).resolve().parent.parent / "config"
KNOWN_ENVS = ("dev", "prod", "cron", "gateway")
def _deep_merge(base: dict, overlay: dict) -> dict:
"""Deep merge overlay into base. Overlay values win on conflict."""
result = copy.deepcopy(base)
for key, value in overlay.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = _deep_merge(result[key], value)
else:
result[key] = copy.deepcopy(value)
return result
def _get_dotted(data: dict, key: str, default: Any = None) -> Any:
"""Get value from dict using dot notation: 'model.name' -> data['model']['name']."""
parts = key.split(".")
current = data
for part in parts:
if isinstance(current, dict) and part in current:
current = current[part]
else:
return default
return current
def _diff_dicts(base: dict, overlay: dict, prefix: str = "") -> List[dict]:
"""Compute diff between base and overlay configs."""
diffs = []
all_keys = set(list(base.keys()) + list(overlay.keys()))
for key in sorted(all_keys):
path = f"{prefix}.{key}" if prefix else key
in_base = key in base
in_overlay = key in overlay
if in_base and not in_overlay:
diffs.append({"key": path, "type": "removed_in_overlay", "base": base[key]})
elif not in_base and in_overlay:
diffs.append({"key": path, "type": "added_in_overlay", "overlay": overlay[key]})
elif isinstance(base[key], dict) and isinstance(overlay[key], dict):
diffs.extend(_diff_dicts(base[key], overlay[key], path))
elif base[key] != overlay[key]:
diffs.append({
"key": path, "type": "changed",
"base": base[key], "overlay": overlay[key]
})
return diffs
def _validate_config(config: dict) -> List[str]:
"""Validate config structure, return list of warnings."""
warnings = []
if "model" not in config:
warnings.append("Missing 'model' section")
elif "name" not in config.get("model", {}):
warnings.append("Missing 'model.name'")
if "provider" not in config:
warnings.append("Missing 'provider' section")
for key in config:
if not isinstance(key, str):
warnings.append(f"Non-string key: {key!r}")
return warnings
def _load_yaml_file(path: Path) -> dict:
"""Load a YAML file, return empty dict if missing."""
if not path.exists():
return {}
if yaml is None:
raise ImportError("PyYAML required: pip install pyyaml")
with open(path) as f:
data = yaml.safe_load(f)
return data if isinstance(data, dict) else {}
class ConfigTemplate:
"""Environment-specific config template with overlay merge."""
def __init__(self, config_dir: Optional[str] = None):
self.config_dir = Path(config_dir) if config_dir else CONFIG_DIR
self.base: Dict[str, Any] = {}
self.overlay: Dict[str, Any] = {}
self.merged: Dict[str, Any] = {}
self.env: Optional[str] = None
def load(self, env: str) -> dict:
"""Load base + overlay for the given environment."""
if env not in KNOWN_ENVS:
raise ValueError(f"Unknown environment '{env}'. Known: {', '.join(KNOWN_ENVS)}")
self.env = env
self.base = _load_yaml_file(self.config_dir / "base.yaml")
self.overlay = _load_yaml_file(self.config_dir / f"{env}.overlay.yaml")
self.merged = _deep_merge(self.base, self.overlay)
return self.merged
def get(self, key: str, default: Any = None) -> Any:
"""Get value with dot notation from merged config."""
return _get_dotted(self.merged, key, default)
def diff(self) -> List[dict]:
"""Show diff between base and current overlay."""
return _diff_dicts(self.base, self.overlay)
def validate(self) -> List[str]:
"""Validate merged config structure."""
return _validate_config(self.merged)
@staticmethod
def list_environments() -> List[str]:
"""List known environments."""
return list(KNOWN_ENVS)
def load_config(env: str, config_dir: Optional[str] = None) -> dict:
"""One-shot: load merged config for an environment."""
t = ConfigTemplate(config_dir)
return t.load(env)
def main():
parser = argparse.ArgumentParser(description="Config Template System")
parser.add_argument("--env", default="dev", help="Environment name")
parser.add_argument("--diff", action="store_true", help="Show diff between base and overlay")
parser.add_argument("--validate", action="store_true", help="Validate merged config")
parser.add_argument("--list-envs", action="store_true", help="List known environments")
parser.add_argument("--config-dir", default=None, help="Config directory path")
parser.add_argument("--json", action="store_true", help="Output as JSON")
args = parser.parse_args()
if args.list_envs:
envs = ConfigTemplate.list_environments()
if args.json:
print(json.dumps(envs, indent=2))
else:
for e in envs:
print(f" {e}")
return
template = ConfigTemplate(args.config_dir)
template.load(args.env)
if args.diff:
diffs = template.diff()
if args.json:
print(json.dumps(diffs, indent=2))
else:
if not diffs:
print(f"No differences between base and {args.env} overlay")
for d in diffs:
if d["type"] == "changed":
print(f" {d['key']}: {d['base']!r} -> {d['overlay']!r}")
elif d["type"] == "added_in_overlay":
print(f" {d['key']}: + {d['overlay']!r}")
elif d["type"] == "removed_in_overlay":
print(f" {d['key']}: - {d['base']!r}")
elif args.validate:
warnings = template.validate()
if args.json:
print(json.dumps({"valid": len(warnings) == 0, "warnings": warnings}, indent=2))
else:
if warnings:
for w in warnings:
print(f" WARNING: {w}")
else:
print(f" Config valid for {args.env}")
else:
if args.json:
print(json.dumps(template.merged, indent=2))
else:
print(f"Config for {args.env}:")
for k, v in template.merged.items():
print(f" {k}: {v!r}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,133 @@
#!/usr/bin/env python3
"""Tests for config_template.py — issue #696"""
import os
import sys
import tempfile
import pytest
from pathlib import Path
# Add parent dir for import
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "scripts"))
from config_template import ConfigTemplate, _deep_merge, _get_dotted, _diff_dicts, _validate_config
@pytest.fixture
def tmp_config_dir(tmp_path):
"""Create a temp config dir with base + overlay files."""
import yaml
base = {
"model": {"name": "base-model", "temperature": 0.7},
"cron": {"enabled": False, "interval": 300},
"display": {"colors": True},
}
overlay = {
"model": {"temperature": 0.9},
"cron": {"enabled": True},
"logging": {"level": "DEBUG"},
}
with open(tmp_path / "base.yaml", "w") as f:
yaml.dump(base, f)
with open(tmp_path / "test.overlay.yaml", "w") as f:
yaml.dump(overlay, f)
return tmp_path
class TestDeepMerge:
def test_overlay_wins(self):
base = {"a": 1, "b": 2}
overlay = {"b": 99}
result = _deep_merge(base, overlay)
assert result == {"a": 1, "b": 99}
def test_deep_merge_nested(self):
base = {"model": {"name": "x", "temp": 0.7}}
overlay = {"model": {"temp": 0.9}}
result = _deep_merge(base, overlay)
assert result["model"]["name"] == "x"
assert result["model"]["temp"] == 0.9
def test_new_keys_added(self):
base = {"a": 1}
overlay = {"b": 2}
result = _deep_merge(base, overlay)
assert result == {"a": 1, "b": 2}
def test_originals_unchanged(self):
base = {"a": {"inner": 1}}
overlay = {"a": {"inner": 99}}
_deep_merge(base, overlay)
assert base["a"]["inner"] == 1
class TestDottedAccess:
def test_simple_key(self):
assert _get_dotted({"a": 1}, "a") == 1
def test_nested_key(self):
assert _get_dotted({"a": {"b": {"c": 42}}}, "a.b.c") == 42
def test_missing_key_returns_default(self):
assert _get_dotted({"a": 1}, "x", "fallback") == "fallback"
def test_partial_path(self):
assert _get_dotted({"a": 1}, "a.b.c", None) is None
class TestDiff:
def test_no_diff(self):
assert _diff_dicts({"a": 1}, {"a": 1}) == []
def test_changed_value(self):
diffs = _diff_dicts({"a": 1}, {"a": 2})
assert len(diffs) == 1
assert diffs[0]["type"] == "changed"
def test_added_key(self):
diffs = _diff_dicts({"a": 1}, {"a": 1, "b": 2})
added = [d for d in diffs if d["type"] == "added_in_overlay"]
assert len(added) == 1
assert added[0]["key"] == "b"
class TestValidation:
def test_valid_config(self):
config = {"model": {"name": "x"}, "provider": {"name": "y"}}
assert _validate_config(config) == []
def test_missing_model(self):
warnings = _validate_config({"provider": {}})
assert any("model" in w for w in warnings)
class TestConfigTemplate:
def test_load(self, tmp_config_dir):
t = ConfigTemplate(str(tmp_config_dir))
merged = t.load("test")
assert merged["model"]["name"] == "base-model"
assert merged["model"]["temperature"] == 0.9
assert merged["cron"]["enabled"] is True
assert merged["logging"]["level"] == "DEBUG"
def test_get_dotted(self, tmp_config_dir):
t = ConfigTemplate(str(tmp_config_dir))
t.load("test")
assert t.get("model.temperature") == 0.9
assert t.get("nonexistent", "default") == "default"
def test_diff(self, tmp_config_dir):
t = ConfigTemplate(str(tmp_config_dir))
t.load("test")
diffs = t.diff()
assert len(diffs) > 0
def test_unknown_env_raises(self, tmp_config_dir):
t = ConfigTemplate(str(tmp_config_dir))
with pytest.raises(ValueError, match="Unknown environment"):
t.load("nonexistent")
def test_list_environments(self):
envs = ConfigTemplate.list_environments()
assert "dev" in envs
assert "prod" in envs
assert "cron" in envs
assert "gateway" in envs