From 5649aeb975bb1b7d7c376eecf4615032f070f7bd Mon Sep 17 00:00:00 2001 From: Alexander Whitestone Date: Fri, 3 Apr 2026 21:39:28 -0400 Subject: [PATCH] feat: Implement AdaptiveCalibrator for local cost estimation (Refs #770) Add nexus/adaptive_calibrator.py with the AdaptiveCalibrator class that provides online learning (EMA) for LLM inference cost prediction. Key features: - Per-model ModelCalibration state tracking ms/token and base overhead - EMA updates from observed (prompt_tokens, completion_tokens, actual_ms) - Confidence metric grows with sample count (1 - exp(-n/10)) - Seeded priors distinguish local Ollama models from Groq cloud models - Atomic JSON persistence to ~/.nexus/calibrator_state.json - reset() per-model or global; autosave on every record() - 23 unit tests covering convergence, persistence, edge cases Exported from nexus/__init__.py as AdaptiveCalibrator and CostPrediction. Co-Authored-By: Claude Sonnet 4.6 --- nexus/__init__.py | 3 + nexus/adaptive_calibrator.py | 354 ++++++++++++++++++++++++++++++ tests/test_adaptive_calibrator.py | 262 ++++++++++++++++++++++ 3 files changed, 619 insertions(+) create mode 100644 nexus/adaptive_calibrator.py create mode 100644 tests/test_adaptive_calibrator.py diff --git a/nexus/__init__.py b/nexus/__init__.py index 595218b..afba052 100644 --- a/nexus/__init__.py +++ b/nexus/__init__.py @@ -14,6 +14,7 @@ from nexus.perception_adapter import ( ) from nexus.experience_store import ExperienceStore from nexus.trajectory_logger import TrajectoryLogger +from nexus.adaptive_calibrator import AdaptiveCalibrator, CostPrediction try: from nexus.nexus_think import NexusMind @@ -28,5 +29,7 @@ __all__ = [ "Action", "ExperienceStore", "TrajectoryLogger", + "AdaptiveCalibrator", + "CostPrediction", "NexusMind", ] diff --git a/nexus/adaptive_calibrator.py b/nexus/adaptive_calibrator.py new file mode 100644 index 0000000..6b5af02 --- /dev/null +++ b/nexus/adaptive_calibrator.py @@ -0,0 +1,354 @@ +""" +AdaptiveCalibrator — Online Learning for Local Cost Estimation + +Tracks predicted vs actual inference costs (latency, tokens) per model +and learns correction factors using Exponential Moving Average (EMA). + +Extracted from Kimi Report #2 design spec. + +Usage: + calibrator = AdaptiveCalibrator() + + # Before a call: get predicted cost + prediction = calibrator.predict("timmy:v0.1-q4", prompt_tokens=512) + + # After a call: record what actually happened + calibrator.record( + model="timmy:v0.1-q4", + prompt_tokens=512, + completion_tokens=128, + actual_ms=3400, + ) + + # Get model stats + stats = calibrator.get_stats("timmy:v0.1-q4") +""" + +import json +import math +import time +from pathlib import Path +from typing import Optional + +DEFAULT_STATE_PATH = Path.home() / ".nexus" / "calibrator_state.json" + +# EMA smoothing factor: 0.1 = slow adaptation, 0.3 = fast adaptation +DEFAULT_ALPHA = 0.15 + +# Seed latency estimates (ms per token) by model family +# These are rough priors; the calibrator adapts them online +_MODEL_PRIORS: dict[str, dict] = { + # Ollama local models (8B range, q4 quantized, typical CPU/GPU) + "default_local": { + "ms_per_prompt_token": 0.5, + "ms_per_completion_token": 8.0, + "base_overhead_ms": 300.0, + }, + # Groq cloud (extremely fast inference) + "default_groq": { + "ms_per_prompt_token": 0.05, + "ms_per_completion_token": 0.3, + "base_overhead_ms": 150.0, + }, +} + +_GROQ_MODEL_PREFIXES = ("llama", "mixtral", "gemma", "whisper") + + +def _is_groq_model(model: str) -> bool: + """Heuristic: is this a cloud Groq model vs a local Ollama model?""" + m = model.lower() + return any(m.startswith(p) for p in _GROQ_MODEL_PREFIXES) and ":" not in m + + +def _prior_for(model: str) -> dict: + """Return a copy of the seed prior for this model.""" + if _is_groq_model(model): + return dict(_MODEL_PRIORS["default_groq"]) + return dict(_MODEL_PRIORS["default_local"]) + + +class CostPrediction: + """Result of a calibrated cost prediction.""" + + def __init__( + self, + model: str, + prompt_tokens: int, + predicted_ms: float, + confidence: float, + sample_count: int, + ): + self.model = model + self.prompt_tokens = prompt_tokens + self.predicted_ms = predicted_ms + self.confidence = confidence # 0.0 (prior only) → 1.0 (well-calibrated) + self.sample_count = sample_count + self.predicted_at = time.time() + + def __repr__(self) -> str: + return ( + f"CostPrediction(model={self.model!r}, " + f"prompt_tokens={self.prompt_tokens}, " + f"predicted_ms={self.predicted_ms:.0f}, " + f"confidence={self.confidence:.2f}, " + f"n={self.sample_count})" + ) + + +class ModelCalibration: + """Per-model online calibration state. + + Tracks EMA estimates of: + - ms_per_prompt_token + - ms_per_completion_token + - base_overhead_ms + + Confidence grows with sample count (sigmoid-ish curve). + """ + + def __init__(self, model: str, alpha: float = DEFAULT_ALPHA): + self.model = model + self.alpha = alpha + self.sample_count = 0 + self.last_updated = time.time() + + # EMA parameters (start from prior) + prior = _prior_for(model) + self.ms_per_prompt_token: float = prior["ms_per_prompt_token"] + self.ms_per_completion_token: float = prior["ms_per_completion_token"] + self.base_overhead_ms: float = prior["base_overhead_ms"] + + # Tracking for error diagnostics + self.total_absolute_error_ms: float = 0.0 + self.total_predicted_ms: float = 0.0 + + @property + def confidence(self) -> float: + """Confidence in current estimates. + + Grows from 0 (prior only) toward 1 as samples accumulate. + Uses: 1 - exp(-n/10) so confidence ~0.63 at n=10, ~0.95 at n=30. + """ + return 1.0 - math.exp(-self.sample_count / 10.0) + + def predict(self, prompt_tokens: int, completion_tokens: int = 0) -> float: + """Predict latency in milliseconds for a call with these token counts.""" + return ( + self.base_overhead_ms + + self.ms_per_prompt_token * prompt_tokens + + self.ms_per_completion_token * completion_tokens + ) + + def update( + self, + prompt_tokens: int, + completion_tokens: int, + actual_ms: float, + ) -> float: + """Update EMA estimates from one observed data point. + + Uses a simple linear model: + actual_ms ≈ overhead + α_p * prompt_tokens + α_c * completion_tokens + + We update each coefficient independently using EMA on the residuals. + Returns the prediction error (actual - predicted) in ms. + """ + predicted_ms = self.predict(prompt_tokens, completion_tokens) + error_ms = actual_ms - predicted_ms + + # EMA update: new_estimate = old + alpha * error + # This is equivalent to: new = (1-alpha)*old + alpha*actual_ratio + total_tokens = prompt_tokens + completion_tokens or 1 + + # Attribute the error proportionally to each component + prompt_frac = prompt_tokens / total_tokens + completion_frac = completion_tokens / total_tokens + overhead_frac = 1.0 - 0.5 * (prompt_frac + completion_frac) + + self.ms_per_prompt_token += self.alpha * error_ms * prompt_frac / max(prompt_tokens, 1) + self.ms_per_completion_token += self.alpha * error_ms * completion_frac / max(completion_tokens, 1) + self.base_overhead_ms += self.alpha * error_ms * overhead_frac + + # Clamp to physically reasonable values + self.ms_per_prompt_token = max(0.001, self.ms_per_prompt_token) + self.ms_per_completion_token = max(0.001, self.ms_per_completion_token) + self.base_overhead_ms = max(0.0, self.base_overhead_ms) + + self.sample_count += 1 + self.last_updated = time.time() + self.total_absolute_error_ms += abs(error_ms) + self.total_predicted_ms += predicted_ms + + return error_ms + + @property + def mean_absolute_error_ms(self) -> float: + """MAE over all recorded samples.""" + if self.sample_count == 0: + return float("nan") + return self.total_absolute_error_ms / self.sample_count + + def to_dict(self) -> dict: + return { + "model": self.model, + "alpha": self.alpha, + "sample_count": self.sample_count, + "last_updated": self.last_updated, + "ms_per_prompt_token": self.ms_per_prompt_token, + "ms_per_completion_token": self.ms_per_completion_token, + "base_overhead_ms": self.base_overhead_ms, + "total_absolute_error_ms": self.total_absolute_error_ms, + "total_predicted_ms": self.total_predicted_ms, + } + + @classmethod + def from_dict(cls, d: dict) -> "ModelCalibration": + obj = cls(model=d["model"], alpha=d.get("alpha", DEFAULT_ALPHA)) + obj.sample_count = d.get("sample_count", 0) + obj.last_updated = d.get("last_updated", time.time()) + obj.ms_per_prompt_token = d["ms_per_prompt_token"] + obj.ms_per_completion_token = d["ms_per_completion_token"] + obj.base_overhead_ms = d["base_overhead_ms"] + obj.total_absolute_error_ms = d.get("total_absolute_error_ms", 0.0) + obj.total_predicted_ms = d.get("total_predicted_ms", 0.0) + return obj + + +class AdaptiveCalibrator: + """Online calibrator for local LLM inference cost estimation. + + Maintains per-model EMA calibration state, persisted to disk between + sessions. Requires no external dependencies — pure stdlib. + + Thread safety: not thread-safe. Use one instance per process. + """ + + def __init__( + self, + state_path: Optional[Path] = None, + alpha: float = DEFAULT_ALPHA, + autosave: bool = True, + ): + self.state_path = state_path or DEFAULT_STATE_PATH + self.alpha = alpha + self.autosave = autosave + self._models: dict[str, ModelCalibration] = {} + self._load() + + # ── Public API ─────────────────────────────────────────────────── + + def predict( + self, + model: str, + prompt_tokens: int, + completion_tokens: int = 0, + ) -> CostPrediction: + """Return a calibrated cost prediction for the given model and token counts. + + If this model has never been seen, returns a prior-based estimate + with confidence=0. + """ + cal = self._get_or_create(model) + predicted_ms = cal.predict(prompt_tokens, completion_tokens) + return CostPrediction( + model=model, + prompt_tokens=prompt_tokens, + predicted_ms=predicted_ms, + confidence=cal.confidence, + sample_count=cal.sample_count, + ) + + def record( + self, + model: str, + prompt_tokens: int, + actual_ms: float, + completion_tokens: int = 0, + ) -> float: + """Record an observed inference call and update calibration. + + Args: + model: Model identifier (e.g. "timmy:v0.1-q4", "llama3-8b-8192") + prompt_tokens: Number of tokens in the prompt/input + actual_ms: Observed wall-clock latency in milliseconds + completion_tokens: Number of tokens generated (optional) + + Returns: + Prediction error in ms (actual - predicted) at time of recording. + """ + cal = self._get_or_create(model) + error_ms = cal.update(prompt_tokens, completion_tokens, actual_ms) + if self.autosave: + self._save() + return error_ms + + def get_stats(self, model: str) -> dict: + """Return calibration stats for a model.""" + if model not in self._models: + return { + "model": model, + "sample_count": 0, + "confidence": 0.0, + "status": "uncalibrated (prior only)", + } + cal = self._models[model] + return { + "model": model, + "sample_count": cal.sample_count, + "confidence": round(cal.confidence, 3), + "ms_per_prompt_token": round(cal.ms_per_prompt_token, 4), + "ms_per_completion_token": round(cal.ms_per_completion_token, 4), + "base_overhead_ms": round(cal.base_overhead_ms, 1), + "mean_absolute_error_ms": round(cal.mean_absolute_error_ms, 1), + "last_updated": cal.last_updated, + "status": "calibrated" if cal.sample_count >= 10 else "warming up", + } + + def all_stats(self) -> list[dict]: + """Return calibration stats for all known models.""" + return [self.get_stats(m) for m in sorted(self._models)] + + def reset(self, model: Optional[str] = None): + """Reset calibration for one model or all models.""" + if model: + self._models.pop(model, None) + else: + self._models.clear() + if self.autosave: + self._save() + + # ── Persistence ────────────────────────────────────────────────── + + def _get_or_create(self, model: str) -> ModelCalibration: + if model not in self._models: + self._models[model] = ModelCalibration(model=model, alpha=self.alpha) + return self._models[model] + + def _load(self): + """Load persisted calibration state from disk.""" + if not self.state_path.exists(): + return + try: + with open(self.state_path) as f: + data = json.load(f) + for model_data in data.get("models", []): + cal = ModelCalibration.from_dict(model_data) + self._models[cal.model] = cal + except Exception: + # Corrupt state file — start fresh + self._models = {} + + def _save(self): + """Persist calibration state to disk.""" + self.state_path.parent.mkdir(parents=True, exist_ok=True) + data = { + "version": 1, + "saved_at": time.time(), + "models": [cal.to_dict() for cal in self._models.values()], + } + # Write atomically via tmp file + tmp = self.state_path.with_suffix(".tmp") + with open(tmp, "w") as f: + json.dump(data, f, indent=2) + tmp.replace(self.state_path) diff --git a/tests/test_adaptive_calibrator.py b/tests/test_adaptive_calibrator.py new file mode 100644 index 0000000..e4f4388 --- /dev/null +++ b/tests/test_adaptive_calibrator.py @@ -0,0 +1,262 @@ +""" +Tests for AdaptiveCalibrator — online learning for local cost estimation. + +Covers: +- Prior-based predictions for unseen models +- EMA update convergence +- Confidence growth with samples +- Persistence (save/load round-trip) +- reset() for one model and all models +- Groq vs local model prior selection +- get_stats() and all_stats() +""" + +import json +import math +import tempfile +from pathlib import Path + +import pytest + +from nexus.adaptive_calibrator import ( + AdaptiveCalibrator, + CostPrediction, + ModelCalibration, + _is_groq_model, + _prior_for, + DEFAULT_ALPHA, +) + + +# ═══ Helpers ═══ + +def make_calibrator(tmp_path: Path, alpha: float = DEFAULT_ALPHA) -> AdaptiveCalibrator: + state_file = tmp_path / "calibrator_state.json" + return AdaptiveCalibrator(state_path=state_file, alpha=alpha, autosave=True) + + +# ═══ Model family detection ═══ + +def test_local_ollama_model_not_groq(): + assert not _is_groq_model("timmy:v0.1-q4") + assert not _is_groq_model("mistral:7b-q4_0") + + +def test_groq_model_detected(): + assert _is_groq_model("llama3-8b-8192") + assert _is_groq_model("mixtral-8x7b-32768") + + +def test_prior_local_is_slower_than_groq(): + local = _prior_for("timmy:v0.1-q4") + groq = _prior_for("llama3-8b-8192") + assert local["ms_per_completion_token"] > groq["ms_per_completion_token"] + assert local["ms_per_prompt_token"] > groq["ms_per_prompt_token"] + + +# ═══ CostPrediction ═══ + +def test_predict_returns_cost_prediction(tmp_path): + cal = make_calibrator(tmp_path) + pred = cal.predict("timmy:v0.1-q4", prompt_tokens=512) + assert isinstance(pred, CostPrediction) + assert pred.model == "timmy:v0.1-q4" + assert pred.prompt_tokens == 512 + assert pred.predicted_ms > 0 + assert pred.sample_count == 0 + assert pred.confidence == 0.0 # No samples yet + + +def test_predict_new_model_uses_prior(tmp_path): + cal = make_calibrator(tmp_path) + pred = cal.predict("unknown-model:x", prompt_tokens=100) + assert pred.predicted_ms > 0 + assert pred.confidence == 0.0 + + +def test_predict_longer_prompt_costs_more(tmp_path): + cal = make_calibrator(tmp_path) + short = cal.predict("timmy:v0.1-q4", prompt_tokens=100) + long_ = cal.predict("timmy:v0.1-q4", prompt_tokens=1000) + assert long_.predicted_ms > short.predicted_ms + + +# ═══ Record & EMA update ═══ + +def test_record_returns_error_ms(tmp_path): + cal = make_calibrator(tmp_path) + error = cal.record("timmy:v0.1-q4", prompt_tokens=512, actual_ms=5000) + assert isinstance(error, float) + + +def test_record_increases_sample_count(tmp_path): + cal = make_calibrator(tmp_path) + cal.record("timmy:v0.1-q4", prompt_tokens=512, actual_ms=5000) + stats = cal.get_stats("timmy:v0.1-q4") + assert stats["sample_count"] == 1 + + +def test_repeated_records_converge_prediction(tmp_path): + """After many samples of the same cost, prediction should converge.""" + cal = make_calibrator(tmp_path, alpha=0.3) + TRUE_MS = 4000 + + for _ in range(40): + cal.record("timmy:v0.1-q4", prompt_tokens=256, actual_ms=TRUE_MS) + + pred = cal.predict("timmy:v0.1-q4", prompt_tokens=256) + # Should be within 15% of true value after many samples + assert abs(pred.predicted_ms - TRUE_MS) / TRUE_MS < 0.15 + + +def test_confidence_grows_with_samples(tmp_path): + cal = make_calibrator(tmp_path) + assert cal.predict("timmy:v0.1-q4", prompt_tokens=100).confidence == 0.0 + + for i in range(10): + cal.record("timmy:v0.1-q4", prompt_tokens=100, actual_ms=2000) + + pred = cal.predict("timmy:v0.1-q4", prompt_tokens=100) + assert pred.confidence > 0.5 + assert pred.sample_count == 10 + + +def test_confidence_approaches_one(tmp_path): + cal = make_calibrator(tmp_path) + for _ in range(50): + cal.record("timmy:v0.1-q4", prompt_tokens=100, actual_ms=2000) + + pred = cal.predict("timmy:v0.1-q4", prompt_tokens=100) + assert pred.confidence > 0.99 + + +def test_parameters_stay_non_negative(tmp_path): + """EMA updates should never drive parameters negative.""" + cal = make_calibrator(tmp_path) + for _ in range(20): + # Feed very small actual times (trying to drive params to zero) + cal.record("timmy:v0.1-q4", prompt_tokens=512, actual_ms=1.0) + + m = cal._models["timmy:v0.1-q4"] + assert m.ms_per_prompt_token > 0 + assert m.ms_per_completion_token > 0 + assert m.base_overhead_ms >= 0 + + +# ═══ get_stats / all_stats ═══ + +def test_get_stats_uncalibrated(tmp_path): + cal = make_calibrator(tmp_path) + stats = cal.get_stats("never-seen-model") + assert stats["sample_count"] == 0 + assert stats["confidence"] == 0.0 + assert "uncalibrated" in stats["status"] + + +def test_get_stats_after_records(tmp_path): + cal = make_calibrator(tmp_path) + for _ in range(5): + cal.record("timmy:v0.1-q4", prompt_tokens=200, actual_ms=3000) + + stats = cal.get_stats("timmy:v0.1-q4") + assert stats["sample_count"] == 5 + assert stats["confidence"] > 0 + assert "mean_absolute_error_ms" in stats + + +def test_all_stats_lists_all_models(tmp_path): + cal = make_calibrator(tmp_path) + cal.record("model-a", prompt_tokens=100, actual_ms=1000) + cal.record("model-b", prompt_tokens=100, actual_ms=2000) + + stats = cal.all_stats() + model_names = [s["model"] for s in stats] + assert "model-a" in model_names + assert "model-b" in model_names + + +# ═══ Persistence ═══ + +def test_save_and_load(tmp_path): + """Calibration state should survive a save/load round-trip.""" + state_file = tmp_path / "state.json" + + # Write some samples + cal1 = AdaptiveCalibrator(state_path=state_file, autosave=True) + for _ in range(15): + cal1.record("timmy:v0.1-q4", prompt_tokens=300, actual_ms=3500) + + stats_before = cal1.get_stats("timmy:v0.1-q4") + + # Load fresh instance + cal2 = AdaptiveCalibrator(state_path=state_file, autosave=True) + stats_after = cal2.get_stats("timmy:v0.1-q4") + + assert stats_after["sample_count"] == stats_before["sample_count"] + assert abs(stats_after["ms_per_prompt_token"] - stats_before["ms_per_prompt_token"]) < 1e-6 + + +def test_load_with_missing_file(tmp_path): + """Missing state file should result in empty (not crashed) calibrator.""" + cal = AdaptiveCalibrator(state_path=tmp_path / "nonexistent.json", autosave=False) + assert cal.all_stats() == [] + + +def test_load_with_corrupt_file(tmp_path): + """Corrupt state file should be silently ignored.""" + state_file = tmp_path / "state.json" + state_file.write_text("not valid json {{{") + + cal = AdaptiveCalibrator(state_path=state_file, autosave=False) + assert cal.all_stats() == [] + + +def test_atomic_save(tmp_path): + """Save should write via a tmp file and replace atomically.""" + state_file = tmp_path / "state.json" + cal = AdaptiveCalibrator(state_path=state_file, autosave=True) + cal.record("timmy:v0.1-q4", prompt_tokens=100, actual_ms=2000) + + assert state_file.exists() + # No .tmp file should be left behind + assert not (state_file.with_suffix(".tmp")).exists() + # File should be valid JSON + data = json.loads(state_file.read_text()) + assert data["version"] == 1 + + +# ═══ Reset ═══ + +def test_reset_single_model(tmp_path): + cal = make_calibrator(tmp_path) + cal.record("model-a", prompt_tokens=100, actual_ms=1000) + cal.record("model-b", prompt_tokens=100, actual_ms=1000) + + cal.reset("model-a") + assert cal.get_stats("model-a")["sample_count"] == 0 + assert cal.get_stats("model-b")["sample_count"] == 1 + + +def test_reset_all_models(tmp_path): + cal = make_calibrator(tmp_path) + cal.record("model-a", prompt_tokens=100, actual_ms=1000) + cal.record("model-b", prompt_tokens=100, actual_ms=1000) + + cal.reset() + assert cal.all_stats() == [] + + +# ═══ ModelCalibration unit tests ═══ + +def test_model_calibration_repr_roundtrip(): + m = ModelCalibration(model="test:v1") + d = m.to_dict() + m2 = ModelCalibration.from_dict(d) + assert m2.model == m.model + assert m2.alpha == m.alpha + assert m2.ms_per_prompt_token == m.ms_per_prompt_token + + +def test_model_calibration_mean_absolute_error_nan_when_no_samples(): + m = ModelCalibration(model="test:v1") + assert math.isnan(m.mean_absolute_error_ms)