Files
the-nexus/nexus/adaptive_calibrator.py
Alexander Whitestone 5649aeb975
Some checks failed
CI / validate (pull_request) Has been cancelled
feat: Implement AdaptiveCalibrator for local cost estimation (Refs #770)
Add nexus/adaptive_calibrator.py with the AdaptiveCalibrator class that
provides online learning (EMA) for LLM inference cost prediction.

Key features:
- Per-model ModelCalibration state tracking ms/token and base overhead
- EMA updates from observed (prompt_tokens, completion_tokens, actual_ms)
- Confidence metric grows with sample count (1 - exp(-n/10))
- Seeded priors distinguish local Ollama models from Groq cloud models
- Atomic JSON persistence to ~/.nexus/calibrator_state.json
- reset() per-model or global; autosave on every record()
- 23 unit tests covering convergence, persistence, edge cases

Exported from nexus/__init__.py as AdaptiveCalibrator and CostPrediction.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-03 21:39:28 -04:00

355 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
AdaptiveCalibrator — Online Learning for Local Cost Estimation
Tracks predicted vs actual inference costs (latency, tokens) per model
and learns correction factors using Exponential Moving Average (EMA).
Extracted from Kimi Report #2 design spec.
Usage:
calibrator = AdaptiveCalibrator()
# Before a call: get predicted cost
prediction = calibrator.predict("timmy:v0.1-q4", prompt_tokens=512)
# After a call: record what actually happened
calibrator.record(
model="timmy:v0.1-q4",
prompt_tokens=512,
completion_tokens=128,
actual_ms=3400,
)
# Get model stats
stats = calibrator.get_stats("timmy:v0.1-q4")
"""
import json
import math
import time
from pathlib import Path
from typing import Optional
DEFAULT_STATE_PATH = Path.home() / ".nexus" / "calibrator_state.json"
# EMA smoothing factor: 0.1 = slow adaptation, 0.3 = fast adaptation
DEFAULT_ALPHA = 0.15
# Seed latency estimates (ms per token) by model family
# These are rough priors; the calibrator adapts them online
_MODEL_PRIORS: dict[str, dict] = {
# Ollama local models (8B range, q4 quantized, typical CPU/GPU)
"default_local": {
"ms_per_prompt_token": 0.5,
"ms_per_completion_token": 8.0,
"base_overhead_ms": 300.0,
},
# Groq cloud (extremely fast inference)
"default_groq": {
"ms_per_prompt_token": 0.05,
"ms_per_completion_token": 0.3,
"base_overhead_ms": 150.0,
},
}
_GROQ_MODEL_PREFIXES = ("llama", "mixtral", "gemma", "whisper")
def _is_groq_model(model: str) -> bool:
"""Heuristic: is this a cloud Groq model vs a local Ollama model?"""
m = model.lower()
return any(m.startswith(p) for p in _GROQ_MODEL_PREFIXES) and ":" not in m
def _prior_for(model: str) -> dict:
"""Return a copy of the seed prior for this model."""
if _is_groq_model(model):
return dict(_MODEL_PRIORS["default_groq"])
return dict(_MODEL_PRIORS["default_local"])
class CostPrediction:
"""Result of a calibrated cost prediction."""
def __init__(
self,
model: str,
prompt_tokens: int,
predicted_ms: float,
confidence: float,
sample_count: int,
):
self.model = model
self.prompt_tokens = prompt_tokens
self.predicted_ms = predicted_ms
self.confidence = confidence # 0.0 (prior only) → 1.0 (well-calibrated)
self.sample_count = sample_count
self.predicted_at = time.time()
def __repr__(self) -> str:
return (
f"CostPrediction(model={self.model!r}, "
f"prompt_tokens={self.prompt_tokens}, "
f"predicted_ms={self.predicted_ms:.0f}, "
f"confidence={self.confidence:.2f}, "
f"n={self.sample_count})"
)
class ModelCalibration:
"""Per-model online calibration state.
Tracks EMA estimates of:
- ms_per_prompt_token
- ms_per_completion_token
- base_overhead_ms
Confidence grows with sample count (sigmoid-ish curve).
"""
def __init__(self, model: str, alpha: float = DEFAULT_ALPHA):
self.model = model
self.alpha = alpha
self.sample_count = 0
self.last_updated = time.time()
# EMA parameters (start from prior)
prior = _prior_for(model)
self.ms_per_prompt_token: float = prior["ms_per_prompt_token"]
self.ms_per_completion_token: float = prior["ms_per_completion_token"]
self.base_overhead_ms: float = prior["base_overhead_ms"]
# Tracking for error diagnostics
self.total_absolute_error_ms: float = 0.0
self.total_predicted_ms: float = 0.0
@property
def confidence(self) -> float:
"""Confidence in current estimates.
Grows from 0 (prior only) toward 1 as samples accumulate.
Uses: 1 - exp(-n/10) so confidence ~0.63 at n=10, ~0.95 at n=30.
"""
return 1.0 - math.exp(-self.sample_count / 10.0)
def predict(self, prompt_tokens: int, completion_tokens: int = 0) -> float:
"""Predict latency in milliseconds for a call with these token counts."""
return (
self.base_overhead_ms
+ self.ms_per_prompt_token * prompt_tokens
+ self.ms_per_completion_token * completion_tokens
)
def update(
self,
prompt_tokens: int,
completion_tokens: int,
actual_ms: float,
) -> float:
"""Update EMA estimates from one observed data point.
Uses a simple linear model:
actual_ms ≈ overhead + α_p * prompt_tokens + α_c * completion_tokens
We update each coefficient independently using EMA on the residuals.
Returns the prediction error (actual - predicted) in ms.
"""
predicted_ms = self.predict(prompt_tokens, completion_tokens)
error_ms = actual_ms - predicted_ms
# EMA update: new_estimate = old + alpha * error
# This is equivalent to: new = (1-alpha)*old + alpha*actual_ratio
total_tokens = prompt_tokens + completion_tokens or 1
# Attribute the error proportionally to each component
prompt_frac = prompt_tokens / total_tokens
completion_frac = completion_tokens / total_tokens
overhead_frac = 1.0 - 0.5 * (prompt_frac + completion_frac)
self.ms_per_prompt_token += self.alpha * error_ms * prompt_frac / max(prompt_tokens, 1)
self.ms_per_completion_token += self.alpha * error_ms * completion_frac / max(completion_tokens, 1)
self.base_overhead_ms += self.alpha * error_ms * overhead_frac
# Clamp to physically reasonable values
self.ms_per_prompt_token = max(0.001, self.ms_per_prompt_token)
self.ms_per_completion_token = max(0.001, self.ms_per_completion_token)
self.base_overhead_ms = max(0.0, self.base_overhead_ms)
self.sample_count += 1
self.last_updated = time.time()
self.total_absolute_error_ms += abs(error_ms)
self.total_predicted_ms += predicted_ms
return error_ms
@property
def mean_absolute_error_ms(self) -> float:
"""MAE over all recorded samples."""
if self.sample_count == 0:
return float("nan")
return self.total_absolute_error_ms / self.sample_count
def to_dict(self) -> dict:
return {
"model": self.model,
"alpha": self.alpha,
"sample_count": self.sample_count,
"last_updated": self.last_updated,
"ms_per_prompt_token": self.ms_per_prompt_token,
"ms_per_completion_token": self.ms_per_completion_token,
"base_overhead_ms": self.base_overhead_ms,
"total_absolute_error_ms": self.total_absolute_error_ms,
"total_predicted_ms": self.total_predicted_ms,
}
@classmethod
def from_dict(cls, d: dict) -> "ModelCalibration":
obj = cls(model=d["model"], alpha=d.get("alpha", DEFAULT_ALPHA))
obj.sample_count = d.get("sample_count", 0)
obj.last_updated = d.get("last_updated", time.time())
obj.ms_per_prompt_token = d["ms_per_prompt_token"]
obj.ms_per_completion_token = d["ms_per_completion_token"]
obj.base_overhead_ms = d["base_overhead_ms"]
obj.total_absolute_error_ms = d.get("total_absolute_error_ms", 0.0)
obj.total_predicted_ms = d.get("total_predicted_ms", 0.0)
return obj
class AdaptiveCalibrator:
"""Online calibrator for local LLM inference cost estimation.
Maintains per-model EMA calibration state, persisted to disk between
sessions. Requires no external dependencies — pure stdlib.
Thread safety: not thread-safe. Use one instance per process.
"""
def __init__(
self,
state_path: Optional[Path] = None,
alpha: float = DEFAULT_ALPHA,
autosave: bool = True,
):
self.state_path = state_path or DEFAULT_STATE_PATH
self.alpha = alpha
self.autosave = autosave
self._models: dict[str, ModelCalibration] = {}
self._load()
# ── Public API ───────────────────────────────────────────────────
def predict(
self,
model: str,
prompt_tokens: int,
completion_tokens: int = 0,
) -> CostPrediction:
"""Return a calibrated cost prediction for the given model and token counts.
If this model has never been seen, returns a prior-based estimate
with confidence=0.
"""
cal = self._get_or_create(model)
predicted_ms = cal.predict(prompt_tokens, completion_tokens)
return CostPrediction(
model=model,
prompt_tokens=prompt_tokens,
predicted_ms=predicted_ms,
confidence=cal.confidence,
sample_count=cal.sample_count,
)
def record(
self,
model: str,
prompt_tokens: int,
actual_ms: float,
completion_tokens: int = 0,
) -> float:
"""Record an observed inference call and update calibration.
Args:
model: Model identifier (e.g. "timmy:v0.1-q4", "llama3-8b-8192")
prompt_tokens: Number of tokens in the prompt/input
actual_ms: Observed wall-clock latency in milliseconds
completion_tokens: Number of tokens generated (optional)
Returns:
Prediction error in ms (actual - predicted) at time of recording.
"""
cal = self._get_or_create(model)
error_ms = cal.update(prompt_tokens, completion_tokens, actual_ms)
if self.autosave:
self._save()
return error_ms
def get_stats(self, model: str) -> dict:
"""Return calibration stats for a model."""
if model not in self._models:
return {
"model": model,
"sample_count": 0,
"confidence": 0.0,
"status": "uncalibrated (prior only)",
}
cal = self._models[model]
return {
"model": model,
"sample_count": cal.sample_count,
"confidence": round(cal.confidence, 3),
"ms_per_prompt_token": round(cal.ms_per_prompt_token, 4),
"ms_per_completion_token": round(cal.ms_per_completion_token, 4),
"base_overhead_ms": round(cal.base_overhead_ms, 1),
"mean_absolute_error_ms": round(cal.mean_absolute_error_ms, 1),
"last_updated": cal.last_updated,
"status": "calibrated" if cal.sample_count >= 10 else "warming up",
}
def all_stats(self) -> list[dict]:
"""Return calibration stats for all known models."""
return [self.get_stats(m) for m in sorted(self._models)]
def reset(self, model: Optional[str] = None):
"""Reset calibration for one model or all models."""
if model:
self._models.pop(model, None)
else:
self._models.clear()
if self.autosave:
self._save()
# ── Persistence ──────────────────────────────────────────────────
def _get_or_create(self, model: str) -> ModelCalibration:
if model not in self._models:
self._models[model] = ModelCalibration(model=model, alpha=self.alpha)
return self._models[model]
def _load(self):
"""Load persisted calibration state from disk."""
if not self.state_path.exists():
return
try:
with open(self.state_path) as f:
data = json.load(f)
for model_data in data.get("models", []):
cal = ModelCalibration.from_dict(model_data)
self._models[cal.model] = cal
except Exception:
# Corrupt state file — start fresh
self._models = {}
def _save(self):
"""Persist calibration state to disk."""
self.state_path.parent.mkdir(parents=True, exist_ok=True)
data = {
"version": 1,
"saved_at": time.time(),
"models": [cal.to_dict() for cal in self._models.values()],
}
# Write atomically via tmp file
tmp = self.state_path.with_suffix(".tmp")
with open(tmp, "w") as f:
json.dump(data, f, indent=2)
tmp.replace(self.state_path)