Some checks failed
CI / validate (pull_request) Has been cancelled
Add nexus/adaptive_calibrator.py with the AdaptiveCalibrator class that provides online learning (EMA) for LLM inference cost prediction. Key features: - Per-model ModelCalibration state tracking ms/token and base overhead - EMA updates from observed (prompt_tokens, completion_tokens, actual_ms) - Confidence metric grows with sample count (1 - exp(-n/10)) - Seeded priors distinguish local Ollama models from Groq cloud models - Atomic JSON persistence to ~/.nexus/calibrator_state.json - reset() per-model or global; autosave on every record() - 23 unit tests covering convergence, persistence, edge cases Exported from nexus/__init__.py as AdaptiveCalibrator and CostPrediction. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
355 lines
12 KiB
Python
355 lines
12 KiB
Python
"""
|
||
AdaptiveCalibrator — Online Learning for Local Cost Estimation
|
||
|
||
Tracks predicted vs actual inference costs (latency, tokens) per model
|
||
and learns correction factors using Exponential Moving Average (EMA).
|
||
|
||
Extracted from Kimi Report #2 design spec.
|
||
|
||
Usage:
|
||
calibrator = AdaptiveCalibrator()
|
||
|
||
# Before a call: get predicted cost
|
||
prediction = calibrator.predict("timmy:v0.1-q4", prompt_tokens=512)
|
||
|
||
# After a call: record what actually happened
|
||
calibrator.record(
|
||
model="timmy:v0.1-q4",
|
||
prompt_tokens=512,
|
||
completion_tokens=128,
|
||
actual_ms=3400,
|
||
)
|
||
|
||
# Get model stats
|
||
stats = calibrator.get_stats("timmy:v0.1-q4")
|
||
"""
|
||
|
||
import json
|
||
import math
|
||
import time
|
||
from pathlib import Path
|
||
from typing import Optional
|
||
|
||
DEFAULT_STATE_PATH = Path.home() / ".nexus" / "calibrator_state.json"
|
||
|
||
# EMA smoothing factor: 0.1 = slow adaptation, 0.3 = fast adaptation
|
||
DEFAULT_ALPHA = 0.15
|
||
|
||
# Seed latency estimates (ms per token) by model family
|
||
# These are rough priors; the calibrator adapts them online
|
||
_MODEL_PRIORS: dict[str, dict] = {
|
||
# Ollama local models (8B range, q4 quantized, typical CPU/GPU)
|
||
"default_local": {
|
||
"ms_per_prompt_token": 0.5,
|
||
"ms_per_completion_token": 8.0,
|
||
"base_overhead_ms": 300.0,
|
||
},
|
||
# Groq cloud (extremely fast inference)
|
||
"default_groq": {
|
||
"ms_per_prompt_token": 0.05,
|
||
"ms_per_completion_token": 0.3,
|
||
"base_overhead_ms": 150.0,
|
||
},
|
||
}
|
||
|
||
_GROQ_MODEL_PREFIXES = ("llama", "mixtral", "gemma", "whisper")
|
||
|
||
|
||
def _is_groq_model(model: str) -> bool:
|
||
"""Heuristic: is this a cloud Groq model vs a local Ollama model?"""
|
||
m = model.lower()
|
||
return any(m.startswith(p) for p in _GROQ_MODEL_PREFIXES) and ":" not in m
|
||
|
||
|
||
def _prior_for(model: str) -> dict:
|
||
"""Return a copy of the seed prior for this model."""
|
||
if _is_groq_model(model):
|
||
return dict(_MODEL_PRIORS["default_groq"])
|
||
return dict(_MODEL_PRIORS["default_local"])
|
||
|
||
|
||
class CostPrediction:
|
||
"""Result of a calibrated cost prediction."""
|
||
|
||
def __init__(
|
||
self,
|
||
model: str,
|
||
prompt_tokens: int,
|
||
predicted_ms: float,
|
||
confidence: float,
|
||
sample_count: int,
|
||
):
|
||
self.model = model
|
||
self.prompt_tokens = prompt_tokens
|
||
self.predicted_ms = predicted_ms
|
||
self.confidence = confidence # 0.0 (prior only) → 1.0 (well-calibrated)
|
||
self.sample_count = sample_count
|
||
self.predicted_at = time.time()
|
||
|
||
def __repr__(self) -> str:
|
||
return (
|
||
f"CostPrediction(model={self.model!r}, "
|
||
f"prompt_tokens={self.prompt_tokens}, "
|
||
f"predicted_ms={self.predicted_ms:.0f}, "
|
||
f"confidence={self.confidence:.2f}, "
|
||
f"n={self.sample_count})"
|
||
)
|
||
|
||
|
||
class ModelCalibration:
|
||
"""Per-model online calibration state.
|
||
|
||
Tracks EMA estimates of:
|
||
- ms_per_prompt_token
|
||
- ms_per_completion_token
|
||
- base_overhead_ms
|
||
|
||
Confidence grows with sample count (sigmoid-ish curve).
|
||
"""
|
||
|
||
def __init__(self, model: str, alpha: float = DEFAULT_ALPHA):
|
||
self.model = model
|
||
self.alpha = alpha
|
||
self.sample_count = 0
|
||
self.last_updated = time.time()
|
||
|
||
# EMA parameters (start from prior)
|
||
prior = _prior_for(model)
|
||
self.ms_per_prompt_token: float = prior["ms_per_prompt_token"]
|
||
self.ms_per_completion_token: float = prior["ms_per_completion_token"]
|
||
self.base_overhead_ms: float = prior["base_overhead_ms"]
|
||
|
||
# Tracking for error diagnostics
|
||
self.total_absolute_error_ms: float = 0.0
|
||
self.total_predicted_ms: float = 0.0
|
||
|
||
@property
|
||
def confidence(self) -> float:
|
||
"""Confidence in current estimates.
|
||
|
||
Grows from 0 (prior only) toward 1 as samples accumulate.
|
||
Uses: 1 - exp(-n/10) so confidence ~0.63 at n=10, ~0.95 at n=30.
|
||
"""
|
||
return 1.0 - math.exp(-self.sample_count / 10.0)
|
||
|
||
def predict(self, prompt_tokens: int, completion_tokens: int = 0) -> float:
|
||
"""Predict latency in milliseconds for a call with these token counts."""
|
||
return (
|
||
self.base_overhead_ms
|
||
+ self.ms_per_prompt_token * prompt_tokens
|
||
+ self.ms_per_completion_token * completion_tokens
|
||
)
|
||
|
||
def update(
|
||
self,
|
||
prompt_tokens: int,
|
||
completion_tokens: int,
|
||
actual_ms: float,
|
||
) -> float:
|
||
"""Update EMA estimates from one observed data point.
|
||
|
||
Uses a simple linear model:
|
||
actual_ms ≈ overhead + α_p * prompt_tokens + α_c * completion_tokens
|
||
|
||
We update each coefficient independently using EMA on the residuals.
|
||
Returns the prediction error (actual - predicted) in ms.
|
||
"""
|
||
predicted_ms = self.predict(prompt_tokens, completion_tokens)
|
||
error_ms = actual_ms - predicted_ms
|
||
|
||
# EMA update: new_estimate = old + alpha * error
|
||
# This is equivalent to: new = (1-alpha)*old + alpha*actual_ratio
|
||
total_tokens = prompt_tokens + completion_tokens or 1
|
||
|
||
# Attribute the error proportionally to each component
|
||
prompt_frac = prompt_tokens / total_tokens
|
||
completion_frac = completion_tokens / total_tokens
|
||
overhead_frac = 1.0 - 0.5 * (prompt_frac + completion_frac)
|
||
|
||
self.ms_per_prompt_token += self.alpha * error_ms * prompt_frac / max(prompt_tokens, 1)
|
||
self.ms_per_completion_token += self.alpha * error_ms * completion_frac / max(completion_tokens, 1)
|
||
self.base_overhead_ms += self.alpha * error_ms * overhead_frac
|
||
|
||
# Clamp to physically reasonable values
|
||
self.ms_per_prompt_token = max(0.001, self.ms_per_prompt_token)
|
||
self.ms_per_completion_token = max(0.001, self.ms_per_completion_token)
|
||
self.base_overhead_ms = max(0.0, self.base_overhead_ms)
|
||
|
||
self.sample_count += 1
|
||
self.last_updated = time.time()
|
||
self.total_absolute_error_ms += abs(error_ms)
|
||
self.total_predicted_ms += predicted_ms
|
||
|
||
return error_ms
|
||
|
||
@property
|
||
def mean_absolute_error_ms(self) -> float:
|
||
"""MAE over all recorded samples."""
|
||
if self.sample_count == 0:
|
||
return float("nan")
|
||
return self.total_absolute_error_ms / self.sample_count
|
||
|
||
def to_dict(self) -> dict:
|
||
return {
|
||
"model": self.model,
|
||
"alpha": self.alpha,
|
||
"sample_count": self.sample_count,
|
||
"last_updated": self.last_updated,
|
||
"ms_per_prompt_token": self.ms_per_prompt_token,
|
||
"ms_per_completion_token": self.ms_per_completion_token,
|
||
"base_overhead_ms": self.base_overhead_ms,
|
||
"total_absolute_error_ms": self.total_absolute_error_ms,
|
||
"total_predicted_ms": self.total_predicted_ms,
|
||
}
|
||
|
||
@classmethod
|
||
def from_dict(cls, d: dict) -> "ModelCalibration":
|
||
obj = cls(model=d["model"], alpha=d.get("alpha", DEFAULT_ALPHA))
|
||
obj.sample_count = d.get("sample_count", 0)
|
||
obj.last_updated = d.get("last_updated", time.time())
|
||
obj.ms_per_prompt_token = d["ms_per_prompt_token"]
|
||
obj.ms_per_completion_token = d["ms_per_completion_token"]
|
||
obj.base_overhead_ms = d["base_overhead_ms"]
|
||
obj.total_absolute_error_ms = d.get("total_absolute_error_ms", 0.0)
|
||
obj.total_predicted_ms = d.get("total_predicted_ms", 0.0)
|
||
return obj
|
||
|
||
|
||
class AdaptiveCalibrator:
|
||
"""Online calibrator for local LLM inference cost estimation.
|
||
|
||
Maintains per-model EMA calibration state, persisted to disk between
|
||
sessions. Requires no external dependencies — pure stdlib.
|
||
|
||
Thread safety: not thread-safe. Use one instance per process.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
state_path: Optional[Path] = None,
|
||
alpha: float = DEFAULT_ALPHA,
|
||
autosave: bool = True,
|
||
):
|
||
self.state_path = state_path or DEFAULT_STATE_PATH
|
||
self.alpha = alpha
|
||
self.autosave = autosave
|
||
self._models: dict[str, ModelCalibration] = {}
|
||
self._load()
|
||
|
||
# ── Public API ───────────────────────────────────────────────────
|
||
|
||
def predict(
|
||
self,
|
||
model: str,
|
||
prompt_tokens: int,
|
||
completion_tokens: int = 0,
|
||
) -> CostPrediction:
|
||
"""Return a calibrated cost prediction for the given model and token counts.
|
||
|
||
If this model has never been seen, returns a prior-based estimate
|
||
with confidence=0.
|
||
"""
|
||
cal = self._get_or_create(model)
|
||
predicted_ms = cal.predict(prompt_tokens, completion_tokens)
|
||
return CostPrediction(
|
||
model=model,
|
||
prompt_tokens=prompt_tokens,
|
||
predicted_ms=predicted_ms,
|
||
confidence=cal.confidence,
|
||
sample_count=cal.sample_count,
|
||
)
|
||
|
||
def record(
|
||
self,
|
||
model: str,
|
||
prompt_tokens: int,
|
||
actual_ms: float,
|
||
completion_tokens: int = 0,
|
||
) -> float:
|
||
"""Record an observed inference call and update calibration.
|
||
|
||
Args:
|
||
model: Model identifier (e.g. "timmy:v0.1-q4", "llama3-8b-8192")
|
||
prompt_tokens: Number of tokens in the prompt/input
|
||
actual_ms: Observed wall-clock latency in milliseconds
|
||
completion_tokens: Number of tokens generated (optional)
|
||
|
||
Returns:
|
||
Prediction error in ms (actual - predicted) at time of recording.
|
||
"""
|
||
cal = self._get_or_create(model)
|
||
error_ms = cal.update(prompt_tokens, completion_tokens, actual_ms)
|
||
if self.autosave:
|
||
self._save()
|
||
return error_ms
|
||
|
||
def get_stats(self, model: str) -> dict:
|
||
"""Return calibration stats for a model."""
|
||
if model not in self._models:
|
||
return {
|
||
"model": model,
|
||
"sample_count": 0,
|
||
"confidence": 0.0,
|
||
"status": "uncalibrated (prior only)",
|
||
}
|
||
cal = self._models[model]
|
||
return {
|
||
"model": model,
|
||
"sample_count": cal.sample_count,
|
||
"confidence": round(cal.confidence, 3),
|
||
"ms_per_prompt_token": round(cal.ms_per_prompt_token, 4),
|
||
"ms_per_completion_token": round(cal.ms_per_completion_token, 4),
|
||
"base_overhead_ms": round(cal.base_overhead_ms, 1),
|
||
"mean_absolute_error_ms": round(cal.mean_absolute_error_ms, 1),
|
||
"last_updated": cal.last_updated,
|
||
"status": "calibrated" if cal.sample_count >= 10 else "warming up",
|
||
}
|
||
|
||
def all_stats(self) -> list[dict]:
|
||
"""Return calibration stats for all known models."""
|
||
return [self.get_stats(m) for m in sorted(self._models)]
|
||
|
||
def reset(self, model: Optional[str] = None):
|
||
"""Reset calibration for one model or all models."""
|
||
if model:
|
||
self._models.pop(model, None)
|
||
else:
|
||
self._models.clear()
|
||
if self.autosave:
|
||
self._save()
|
||
|
||
# ── Persistence ──────────────────────────────────────────────────
|
||
|
||
def _get_or_create(self, model: str) -> ModelCalibration:
|
||
if model not in self._models:
|
||
self._models[model] = ModelCalibration(model=model, alpha=self.alpha)
|
||
return self._models[model]
|
||
|
||
def _load(self):
|
||
"""Load persisted calibration state from disk."""
|
||
if not self.state_path.exists():
|
||
return
|
||
try:
|
||
with open(self.state_path) as f:
|
||
data = json.load(f)
|
||
for model_data in data.get("models", []):
|
||
cal = ModelCalibration.from_dict(model_data)
|
||
self._models[cal.model] = cal
|
||
except Exception:
|
||
# Corrupt state file — start fresh
|
||
self._models = {}
|
||
|
||
def _save(self):
|
||
"""Persist calibration state to disk."""
|
||
self.state_path.parent.mkdir(parents=True, exist_ok=True)
|
||
data = {
|
||
"version": 1,
|
||
"saved_at": time.time(),
|
||
"models": [cal.to_dict() for cal in self._models.values()],
|
||
}
|
||
# Write atomically via tmp file
|
||
tmp = self.state_path.with_suffix(".tmp")
|
||
with open(tmp, "w") as f:
|
||
json.dump(data, f, indent=2)
|
||
tmp.replace(self.state_path)
|