feat: Implement AdaptiveCalibrator for local cost estimation (Refs #770)
Some checks failed
CI / validate (pull_request) Has been cancelled

Add nexus/adaptive_calibrator.py with the AdaptiveCalibrator class that
provides online learning (EMA) for LLM inference cost prediction.

Key features:
- Per-model ModelCalibration state tracking ms/token and base overhead
- EMA updates from observed (prompt_tokens, completion_tokens, actual_ms)
- Confidence metric grows with sample count (1 - exp(-n/10))
- Seeded priors distinguish local Ollama models from Groq cloud models
- Atomic JSON persistence to ~/.nexus/calibrator_state.json
- reset() per-model or global; autosave on every record()
- 23 unit tests covering convergence, persistence, edge cases

Exported from nexus/__init__.py as AdaptiveCalibrator and CostPrediction.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Alexander Whitestone
2026-04-03 21:39:28 -04:00
parent 29e64ef01f
commit 5649aeb975
3 changed files with 619 additions and 0 deletions

View File

@@ -14,6 +14,7 @@ from nexus.perception_adapter import (
) )
from nexus.experience_store import ExperienceStore from nexus.experience_store import ExperienceStore
from nexus.trajectory_logger import TrajectoryLogger from nexus.trajectory_logger import TrajectoryLogger
from nexus.adaptive_calibrator import AdaptiveCalibrator, CostPrediction
try: try:
from nexus.nexus_think import NexusMind from nexus.nexus_think import NexusMind
@@ -28,5 +29,7 @@ __all__ = [
"Action", "Action",
"ExperienceStore", "ExperienceStore",
"TrajectoryLogger", "TrajectoryLogger",
"AdaptiveCalibrator",
"CostPrediction",
"NexusMind", "NexusMind",
] ]

View File

@@ -0,0 +1,354 @@
"""
AdaptiveCalibrator — Online Learning for Local Cost Estimation
Tracks predicted vs actual inference costs (latency, tokens) per model
and learns correction factors using Exponential Moving Average (EMA).
Extracted from Kimi Report #2 design spec.
Usage:
calibrator = AdaptiveCalibrator()
# Before a call: get predicted cost
prediction = calibrator.predict("timmy:v0.1-q4", prompt_tokens=512)
# After a call: record what actually happened
calibrator.record(
model="timmy:v0.1-q4",
prompt_tokens=512,
completion_tokens=128,
actual_ms=3400,
)
# Get model stats
stats = calibrator.get_stats("timmy:v0.1-q4")
"""
import json
import math
import time
from pathlib import Path
from typing import Optional
DEFAULT_STATE_PATH = Path.home() / ".nexus" / "calibrator_state.json"
# EMA smoothing factor: 0.1 = slow adaptation, 0.3 = fast adaptation
DEFAULT_ALPHA = 0.15
# Seed latency estimates (ms per token) by model family
# These are rough priors; the calibrator adapts them online
_MODEL_PRIORS: dict[str, dict] = {
# Ollama local models (8B range, q4 quantized, typical CPU/GPU)
"default_local": {
"ms_per_prompt_token": 0.5,
"ms_per_completion_token": 8.0,
"base_overhead_ms": 300.0,
},
# Groq cloud (extremely fast inference)
"default_groq": {
"ms_per_prompt_token": 0.05,
"ms_per_completion_token": 0.3,
"base_overhead_ms": 150.0,
},
}
_GROQ_MODEL_PREFIXES = ("llama", "mixtral", "gemma", "whisper")
def _is_groq_model(model: str) -> bool:
"""Heuristic: is this a cloud Groq model vs a local Ollama model?"""
m = model.lower()
return any(m.startswith(p) for p in _GROQ_MODEL_PREFIXES) and ":" not in m
def _prior_for(model: str) -> dict:
"""Return a copy of the seed prior for this model."""
if _is_groq_model(model):
return dict(_MODEL_PRIORS["default_groq"])
return dict(_MODEL_PRIORS["default_local"])
class CostPrediction:
"""Result of a calibrated cost prediction."""
def __init__(
self,
model: str,
prompt_tokens: int,
predicted_ms: float,
confidence: float,
sample_count: int,
):
self.model = model
self.prompt_tokens = prompt_tokens
self.predicted_ms = predicted_ms
self.confidence = confidence # 0.0 (prior only) → 1.0 (well-calibrated)
self.sample_count = sample_count
self.predicted_at = time.time()
def __repr__(self) -> str:
return (
f"CostPrediction(model={self.model!r}, "
f"prompt_tokens={self.prompt_tokens}, "
f"predicted_ms={self.predicted_ms:.0f}, "
f"confidence={self.confidence:.2f}, "
f"n={self.sample_count})"
)
class ModelCalibration:
"""Per-model online calibration state.
Tracks EMA estimates of:
- ms_per_prompt_token
- ms_per_completion_token
- base_overhead_ms
Confidence grows with sample count (sigmoid-ish curve).
"""
def __init__(self, model: str, alpha: float = DEFAULT_ALPHA):
self.model = model
self.alpha = alpha
self.sample_count = 0
self.last_updated = time.time()
# EMA parameters (start from prior)
prior = _prior_for(model)
self.ms_per_prompt_token: float = prior["ms_per_prompt_token"]
self.ms_per_completion_token: float = prior["ms_per_completion_token"]
self.base_overhead_ms: float = prior["base_overhead_ms"]
# Tracking for error diagnostics
self.total_absolute_error_ms: float = 0.0
self.total_predicted_ms: float = 0.0
@property
def confidence(self) -> float:
"""Confidence in current estimates.
Grows from 0 (prior only) toward 1 as samples accumulate.
Uses: 1 - exp(-n/10) so confidence ~0.63 at n=10, ~0.95 at n=30.
"""
return 1.0 - math.exp(-self.sample_count / 10.0)
def predict(self, prompt_tokens: int, completion_tokens: int = 0) -> float:
"""Predict latency in milliseconds for a call with these token counts."""
return (
self.base_overhead_ms
+ self.ms_per_prompt_token * prompt_tokens
+ self.ms_per_completion_token * completion_tokens
)
def update(
self,
prompt_tokens: int,
completion_tokens: int,
actual_ms: float,
) -> float:
"""Update EMA estimates from one observed data point.
Uses a simple linear model:
actual_ms ≈ overhead + α_p * prompt_tokens + α_c * completion_tokens
We update each coefficient independently using EMA on the residuals.
Returns the prediction error (actual - predicted) in ms.
"""
predicted_ms = self.predict(prompt_tokens, completion_tokens)
error_ms = actual_ms - predicted_ms
# EMA update: new_estimate = old + alpha * error
# This is equivalent to: new = (1-alpha)*old + alpha*actual_ratio
total_tokens = prompt_tokens + completion_tokens or 1
# Attribute the error proportionally to each component
prompt_frac = prompt_tokens / total_tokens
completion_frac = completion_tokens / total_tokens
overhead_frac = 1.0 - 0.5 * (prompt_frac + completion_frac)
self.ms_per_prompt_token += self.alpha * error_ms * prompt_frac / max(prompt_tokens, 1)
self.ms_per_completion_token += self.alpha * error_ms * completion_frac / max(completion_tokens, 1)
self.base_overhead_ms += self.alpha * error_ms * overhead_frac
# Clamp to physically reasonable values
self.ms_per_prompt_token = max(0.001, self.ms_per_prompt_token)
self.ms_per_completion_token = max(0.001, self.ms_per_completion_token)
self.base_overhead_ms = max(0.0, self.base_overhead_ms)
self.sample_count += 1
self.last_updated = time.time()
self.total_absolute_error_ms += abs(error_ms)
self.total_predicted_ms += predicted_ms
return error_ms
@property
def mean_absolute_error_ms(self) -> float:
"""MAE over all recorded samples."""
if self.sample_count == 0:
return float("nan")
return self.total_absolute_error_ms / self.sample_count
def to_dict(self) -> dict:
return {
"model": self.model,
"alpha": self.alpha,
"sample_count": self.sample_count,
"last_updated": self.last_updated,
"ms_per_prompt_token": self.ms_per_prompt_token,
"ms_per_completion_token": self.ms_per_completion_token,
"base_overhead_ms": self.base_overhead_ms,
"total_absolute_error_ms": self.total_absolute_error_ms,
"total_predicted_ms": self.total_predicted_ms,
}
@classmethod
def from_dict(cls, d: dict) -> "ModelCalibration":
obj = cls(model=d["model"], alpha=d.get("alpha", DEFAULT_ALPHA))
obj.sample_count = d.get("sample_count", 0)
obj.last_updated = d.get("last_updated", time.time())
obj.ms_per_prompt_token = d["ms_per_prompt_token"]
obj.ms_per_completion_token = d["ms_per_completion_token"]
obj.base_overhead_ms = d["base_overhead_ms"]
obj.total_absolute_error_ms = d.get("total_absolute_error_ms", 0.0)
obj.total_predicted_ms = d.get("total_predicted_ms", 0.0)
return obj
class AdaptiveCalibrator:
"""Online calibrator for local LLM inference cost estimation.
Maintains per-model EMA calibration state, persisted to disk between
sessions. Requires no external dependencies — pure stdlib.
Thread safety: not thread-safe. Use one instance per process.
"""
def __init__(
self,
state_path: Optional[Path] = None,
alpha: float = DEFAULT_ALPHA,
autosave: bool = True,
):
self.state_path = state_path or DEFAULT_STATE_PATH
self.alpha = alpha
self.autosave = autosave
self._models: dict[str, ModelCalibration] = {}
self._load()
# ── Public API ───────────────────────────────────────────────────
def predict(
self,
model: str,
prompt_tokens: int,
completion_tokens: int = 0,
) -> CostPrediction:
"""Return a calibrated cost prediction for the given model and token counts.
If this model has never been seen, returns a prior-based estimate
with confidence=0.
"""
cal = self._get_or_create(model)
predicted_ms = cal.predict(prompt_tokens, completion_tokens)
return CostPrediction(
model=model,
prompt_tokens=prompt_tokens,
predicted_ms=predicted_ms,
confidence=cal.confidence,
sample_count=cal.sample_count,
)
def record(
self,
model: str,
prompt_tokens: int,
actual_ms: float,
completion_tokens: int = 0,
) -> float:
"""Record an observed inference call and update calibration.
Args:
model: Model identifier (e.g. "timmy:v0.1-q4", "llama3-8b-8192")
prompt_tokens: Number of tokens in the prompt/input
actual_ms: Observed wall-clock latency in milliseconds
completion_tokens: Number of tokens generated (optional)
Returns:
Prediction error in ms (actual - predicted) at time of recording.
"""
cal = self._get_or_create(model)
error_ms = cal.update(prompt_tokens, completion_tokens, actual_ms)
if self.autosave:
self._save()
return error_ms
def get_stats(self, model: str) -> dict:
"""Return calibration stats for a model."""
if model not in self._models:
return {
"model": model,
"sample_count": 0,
"confidence": 0.0,
"status": "uncalibrated (prior only)",
}
cal = self._models[model]
return {
"model": model,
"sample_count": cal.sample_count,
"confidence": round(cal.confidence, 3),
"ms_per_prompt_token": round(cal.ms_per_prompt_token, 4),
"ms_per_completion_token": round(cal.ms_per_completion_token, 4),
"base_overhead_ms": round(cal.base_overhead_ms, 1),
"mean_absolute_error_ms": round(cal.mean_absolute_error_ms, 1),
"last_updated": cal.last_updated,
"status": "calibrated" if cal.sample_count >= 10 else "warming up",
}
def all_stats(self) -> list[dict]:
"""Return calibration stats for all known models."""
return [self.get_stats(m) for m in sorted(self._models)]
def reset(self, model: Optional[str] = None):
"""Reset calibration for one model or all models."""
if model:
self._models.pop(model, None)
else:
self._models.clear()
if self.autosave:
self._save()
# ── Persistence ──────────────────────────────────────────────────
def _get_or_create(self, model: str) -> ModelCalibration:
if model not in self._models:
self._models[model] = ModelCalibration(model=model, alpha=self.alpha)
return self._models[model]
def _load(self):
"""Load persisted calibration state from disk."""
if not self.state_path.exists():
return
try:
with open(self.state_path) as f:
data = json.load(f)
for model_data in data.get("models", []):
cal = ModelCalibration.from_dict(model_data)
self._models[cal.model] = cal
except Exception:
# Corrupt state file — start fresh
self._models = {}
def _save(self):
"""Persist calibration state to disk."""
self.state_path.parent.mkdir(parents=True, exist_ok=True)
data = {
"version": 1,
"saved_at": time.time(),
"models": [cal.to_dict() for cal in self._models.values()],
}
# Write atomically via tmp file
tmp = self.state_path.with_suffix(".tmp")
with open(tmp, "w") as f:
json.dump(data, f, indent=2)
tmp.replace(self.state_path)

View File

@@ -0,0 +1,262 @@
"""
Tests for AdaptiveCalibrator — online learning for local cost estimation.
Covers:
- Prior-based predictions for unseen models
- EMA update convergence
- Confidence growth with samples
- Persistence (save/load round-trip)
- reset() for one model and all models
- Groq vs local model prior selection
- get_stats() and all_stats()
"""
import json
import math
import tempfile
from pathlib import Path
import pytest
from nexus.adaptive_calibrator import (
AdaptiveCalibrator,
CostPrediction,
ModelCalibration,
_is_groq_model,
_prior_for,
DEFAULT_ALPHA,
)
# ═══ Helpers ═══
def make_calibrator(tmp_path: Path, alpha: float = DEFAULT_ALPHA) -> AdaptiveCalibrator:
state_file = tmp_path / "calibrator_state.json"
return AdaptiveCalibrator(state_path=state_file, alpha=alpha, autosave=True)
# ═══ Model family detection ═══
def test_local_ollama_model_not_groq():
assert not _is_groq_model("timmy:v0.1-q4")
assert not _is_groq_model("mistral:7b-q4_0")
def test_groq_model_detected():
assert _is_groq_model("llama3-8b-8192")
assert _is_groq_model("mixtral-8x7b-32768")
def test_prior_local_is_slower_than_groq():
local = _prior_for("timmy:v0.1-q4")
groq = _prior_for("llama3-8b-8192")
assert local["ms_per_completion_token"] > groq["ms_per_completion_token"]
assert local["ms_per_prompt_token"] > groq["ms_per_prompt_token"]
# ═══ CostPrediction ═══
def test_predict_returns_cost_prediction(tmp_path):
cal = make_calibrator(tmp_path)
pred = cal.predict("timmy:v0.1-q4", prompt_tokens=512)
assert isinstance(pred, CostPrediction)
assert pred.model == "timmy:v0.1-q4"
assert pred.prompt_tokens == 512
assert pred.predicted_ms > 0
assert pred.sample_count == 0
assert pred.confidence == 0.0 # No samples yet
def test_predict_new_model_uses_prior(tmp_path):
cal = make_calibrator(tmp_path)
pred = cal.predict("unknown-model:x", prompt_tokens=100)
assert pred.predicted_ms > 0
assert pred.confidence == 0.0
def test_predict_longer_prompt_costs_more(tmp_path):
cal = make_calibrator(tmp_path)
short = cal.predict("timmy:v0.1-q4", prompt_tokens=100)
long_ = cal.predict("timmy:v0.1-q4", prompt_tokens=1000)
assert long_.predicted_ms > short.predicted_ms
# ═══ Record & EMA update ═══
def test_record_returns_error_ms(tmp_path):
cal = make_calibrator(tmp_path)
error = cal.record("timmy:v0.1-q4", prompt_tokens=512, actual_ms=5000)
assert isinstance(error, float)
def test_record_increases_sample_count(tmp_path):
cal = make_calibrator(tmp_path)
cal.record("timmy:v0.1-q4", prompt_tokens=512, actual_ms=5000)
stats = cal.get_stats("timmy:v0.1-q4")
assert stats["sample_count"] == 1
def test_repeated_records_converge_prediction(tmp_path):
"""After many samples of the same cost, prediction should converge."""
cal = make_calibrator(tmp_path, alpha=0.3)
TRUE_MS = 4000
for _ in range(40):
cal.record("timmy:v0.1-q4", prompt_tokens=256, actual_ms=TRUE_MS)
pred = cal.predict("timmy:v0.1-q4", prompt_tokens=256)
# Should be within 15% of true value after many samples
assert abs(pred.predicted_ms - TRUE_MS) / TRUE_MS < 0.15
def test_confidence_grows_with_samples(tmp_path):
cal = make_calibrator(tmp_path)
assert cal.predict("timmy:v0.1-q4", prompt_tokens=100).confidence == 0.0
for i in range(10):
cal.record("timmy:v0.1-q4", prompt_tokens=100, actual_ms=2000)
pred = cal.predict("timmy:v0.1-q4", prompt_tokens=100)
assert pred.confidence > 0.5
assert pred.sample_count == 10
def test_confidence_approaches_one(tmp_path):
cal = make_calibrator(tmp_path)
for _ in range(50):
cal.record("timmy:v0.1-q4", prompt_tokens=100, actual_ms=2000)
pred = cal.predict("timmy:v0.1-q4", prompt_tokens=100)
assert pred.confidence > 0.99
def test_parameters_stay_non_negative(tmp_path):
"""EMA updates should never drive parameters negative."""
cal = make_calibrator(tmp_path)
for _ in range(20):
# Feed very small actual times (trying to drive params to zero)
cal.record("timmy:v0.1-q4", prompt_tokens=512, actual_ms=1.0)
m = cal._models["timmy:v0.1-q4"]
assert m.ms_per_prompt_token > 0
assert m.ms_per_completion_token > 0
assert m.base_overhead_ms >= 0
# ═══ get_stats / all_stats ═══
def test_get_stats_uncalibrated(tmp_path):
cal = make_calibrator(tmp_path)
stats = cal.get_stats("never-seen-model")
assert stats["sample_count"] == 0
assert stats["confidence"] == 0.0
assert "uncalibrated" in stats["status"]
def test_get_stats_after_records(tmp_path):
cal = make_calibrator(tmp_path)
for _ in range(5):
cal.record("timmy:v0.1-q4", prompt_tokens=200, actual_ms=3000)
stats = cal.get_stats("timmy:v0.1-q4")
assert stats["sample_count"] == 5
assert stats["confidence"] > 0
assert "mean_absolute_error_ms" in stats
def test_all_stats_lists_all_models(tmp_path):
cal = make_calibrator(tmp_path)
cal.record("model-a", prompt_tokens=100, actual_ms=1000)
cal.record("model-b", prompt_tokens=100, actual_ms=2000)
stats = cal.all_stats()
model_names = [s["model"] for s in stats]
assert "model-a" in model_names
assert "model-b" in model_names
# ═══ Persistence ═══
def test_save_and_load(tmp_path):
"""Calibration state should survive a save/load round-trip."""
state_file = tmp_path / "state.json"
# Write some samples
cal1 = AdaptiveCalibrator(state_path=state_file, autosave=True)
for _ in range(15):
cal1.record("timmy:v0.1-q4", prompt_tokens=300, actual_ms=3500)
stats_before = cal1.get_stats("timmy:v0.1-q4")
# Load fresh instance
cal2 = AdaptiveCalibrator(state_path=state_file, autosave=True)
stats_after = cal2.get_stats("timmy:v0.1-q4")
assert stats_after["sample_count"] == stats_before["sample_count"]
assert abs(stats_after["ms_per_prompt_token"] - stats_before["ms_per_prompt_token"]) < 1e-6
def test_load_with_missing_file(tmp_path):
"""Missing state file should result in empty (not crashed) calibrator."""
cal = AdaptiveCalibrator(state_path=tmp_path / "nonexistent.json", autosave=False)
assert cal.all_stats() == []
def test_load_with_corrupt_file(tmp_path):
"""Corrupt state file should be silently ignored."""
state_file = tmp_path / "state.json"
state_file.write_text("not valid json {{{")
cal = AdaptiveCalibrator(state_path=state_file, autosave=False)
assert cal.all_stats() == []
def test_atomic_save(tmp_path):
"""Save should write via a tmp file and replace atomically."""
state_file = tmp_path / "state.json"
cal = AdaptiveCalibrator(state_path=state_file, autosave=True)
cal.record("timmy:v0.1-q4", prompt_tokens=100, actual_ms=2000)
assert state_file.exists()
# No .tmp file should be left behind
assert not (state_file.with_suffix(".tmp")).exists()
# File should be valid JSON
data = json.loads(state_file.read_text())
assert data["version"] == 1
# ═══ Reset ═══
def test_reset_single_model(tmp_path):
cal = make_calibrator(tmp_path)
cal.record("model-a", prompt_tokens=100, actual_ms=1000)
cal.record("model-b", prompt_tokens=100, actual_ms=1000)
cal.reset("model-a")
assert cal.get_stats("model-a")["sample_count"] == 0
assert cal.get_stats("model-b")["sample_count"] == 1
def test_reset_all_models(tmp_path):
cal = make_calibrator(tmp_path)
cal.record("model-a", prompt_tokens=100, actual_ms=1000)
cal.record("model-b", prompt_tokens=100, actual_ms=1000)
cal.reset()
assert cal.all_stats() == []
# ═══ ModelCalibration unit tests ═══
def test_model_calibration_repr_roundtrip():
m = ModelCalibration(model="test:v1")
d = m.to_dict()
m2 = ModelCalibration.from_dict(d)
assert m2.model == m.model
assert m2.alpha == m.alpha
assert m2.ms_per_prompt_token == m.ms_per_prompt_token
def test_model_calibration_mean_absolute_error_nan_when_no_samples():
m = ModelCalibration(model="test:v1")
assert math.isnan(m.mean_absolute_error_ms)