feat: Implement AdaptiveCalibrator for local cost estimation (Refs #770)
Some checks failed
CI / validate (pull_request) Has been cancelled
Some checks failed
CI / validate (pull_request) Has been cancelled
Add nexus/adaptive_calibrator.py with the AdaptiveCalibrator class that provides online learning (EMA) for LLM inference cost prediction. Key features: - Per-model ModelCalibration state tracking ms/token and base overhead - EMA updates from observed (prompt_tokens, completion_tokens, actual_ms) - Confidence metric grows with sample count (1 - exp(-n/10)) - Seeded priors distinguish local Ollama models from Groq cloud models - Atomic JSON persistence to ~/.nexus/calibrator_state.json - reset() per-model or global; autosave on every record() - 23 unit tests covering convergence, persistence, edge cases Exported from nexus/__init__.py as AdaptiveCalibrator and CostPrediction. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -14,6 +14,7 @@ from nexus.perception_adapter import (
|
|||||||
)
|
)
|
||||||
from nexus.experience_store import ExperienceStore
|
from nexus.experience_store import ExperienceStore
|
||||||
from nexus.trajectory_logger import TrajectoryLogger
|
from nexus.trajectory_logger import TrajectoryLogger
|
||||||
|
from nexus.adaptive_calibrator import AdaptiveCalibrator, CostPrediction
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from nexus.nexus_think import NexusMind
|
from nexus.nexus_think import NexusMind
|
||||||
@@ -28,5 +29,7 @@ __all__ = [
|
|||||||
"Action",
|
"Action",
|
||||||
"ExperienceStore",
|
"ExperienceStore",
|
||||||
"TrajectoryLogger",
|
"TrajectoryLogger",
|
||||||
|
"AdaptiveCalibrator",
|
||||||
|
"CostPrediction",
|
||||||
"NexusMind",
|
"NexusMind",
|
||||||
]
|
]
|
||||||
|
|||||||
354
nexus/adaptive_calibrator.py
Normal file
354
nexus/adaptive_calibrator.py
Normal file
@@ -0,0 +1,354 @@
|
|||||||
|
"""
|
||||||
|
AdaptiveCalibrator — Online Learning for Local Cost Estimation
|
||||||
|
|
||||||
|
Tracks predicted vs actual inference costs (latency, tokens) per model
|
||||||
|
and learns correction factors using Exponential Moving Average (EMA).
|
||||||
|
|
||||||
|
Extracted from Kimi Report #2 design spec.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
calibrator = AdaptiveCalibrator()
|
||||||
|
|
||||||
|
# Before a call: get predicted cost
|
||||||
|
prediction = calibrator.predict("timmy:v0.1-q4", prompt_tokens=512)
|
||||||
|
|
||||||
|
# After a call: record what actually happened
|
||||||
|
calibrator.record(
|
||||||
|
model="timmy:v0.1-q4",
|
||||||
|
prompt_tokens=512,
|
||||||
|
completion_tokens=128,
|
||||||
|
actual_ms=3400,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get model stats
|
||||||
|
stats = calibrator.get_stats("timmy:v0.1-q4")
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
DEFAULT_STATE_PATH = Path.home() / ".nexus" / "calibrator_state.json"
|
||||||
|
|
||||||
|
# EMA smoothing factor: 0.1 = slow adaptation, 0.3 = fast adaptation
|
||||||
|
DEFAULT_ALPHA = 0.15
|
||||||
|
|
||||||
|
# Seed latency estimates (ms per token) by model family
|
||||||
|
# These are rough priors; the calibrator adapts them online
|
||||||
|
_MODEL_PRIORS: dict[str, dict] = {
|
||||||
|
# Ollama local models (8B range, q4 quantized, typical CPU/GPU)
|
||||||
|
"default_local": {
|
||||||
|
"ms_per_prompt_token": 0.5,
|
||||||
|
"ms_per_completion_token": 8.0,
|
||||||
|
"base_overhead_ms": 300.0,
|
||||||
|
},
|
||||||
|
# Groq cloud (extremely fast inference)
|
||||||
|
"default_groq": {
|
||||||
|
"ms_per_prompt_token": 0.05,
|
||||||
|
"ms_per_completion_token": 0.3,
|
||||||
|
"base_overhead_ms": 150.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_GROQ_MODEL_PREFIXES = ("llama", "mixtral", "gemma", "whisper")
|
||||||
|
|
||||||
|
|
||||||
|
def _is_groq_model(model: str) -> bool:
|
||||||
|
"""Heuristic: is this a cloud Groq model vs a local Ollama model?"""
|
||||||
|
m = model.lower()
|
||||||
|
return any(m.startswith(p) for p in _GROQ_MODEL_PREFIXES) and ":" not in m
|
||||||
|
|
||||||
|
|
||||||
|
def _prior_for(model: str) -> dict:
|
||||||
|
"""Return a copy of the seed prior for this model."""
|
||||||
|
if _is_groq_model(model):
|
||||||
|
return dict(_MODEL_PRIORS["default_groq"])
|
||||||
|
return dict(_MODEL_PRIORS["default_local"])
|
||||||
|
|
||||||
|
|
||||||
|
class CostPrediction:
|
||||||
|
"""Result of a calibrated cost prediction."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt_tokens: int,
|
||||||
|
predicted_ms: float,
|
||||||
|
confidence: float,
|
||||||
|
sample_count: int,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.prompt_tokens = prompt_tokens
|
||||||
|
self.predicted_ms = predicted_ms
|
||||||
|
self.confidence = confidence # 0.0 (prior only) → 1.0 (well-calibrated)
|
||||||
|
self.sample_count = sample_count
|
||||||
|
self.predicted_at = time.time()
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"CostPrediction(model={self.model!r}, "
|
||||||
|
f"prompt_tokens={self.prompt_tokens}, "
|
||||||
|
f"predicted_ms={self.predicted_ms:.0f}, "
|
||||||
|
f"confidence={self.confidence:.2f}, "
|
||||||
|
f"n={self.sample_count})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCalibration:
|
||||||
|
"""Per-model online calibration state.
|
||||||
|
|
||||||
|
Tracks EMA estimates of:
|
||||||
|
- ms_per_prompt_token
|
||||||
|
- ms_per_completion_token
|
||||||
|
- base_overhead_ms
|
||||||
|
|
||||||
|
Confidence grows with sample count (sigmoid-ish curve).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model: str, alpha: float = DEFAULT_ALPHA):
|
||||||
|
self.model = model
|
||||||
|
self.alpha = alpha
|
||||||
|
self.sample_count = 0
|
||||||
|
self.last_updated = time.time()
|
||||||
|
|
||||||
|
# EMA parameters (start from prior)
|
||||||
|
prior = _prior_for(model)
|
||||||
|
self.ms_per_prompt_token: float = prior["ms_per_prompt_token"]
|
||||||
|
self.ms_per_completion_token: float = prior["ms_per_completion_token"]
|
||||||
|
self.base_overhead_ms: float = prior["base_overhead_ms"]
|
||||||
|
|
||||||
|
# Tracking for error diagnostics
|
||||||
|
self.total_absolute_error_ms: float = 0.0
|
||||||
|
self.total_predicted_ms: float = 0.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def confidence(self) -> float:
|
||||||
|
"""Confidence in current estimates.
|
||||||
|
|
||||||
|
Grows from 0 (prior only) toward 1 as samples accumulate.
|
||||||
|
Uses: 1 - exp(-n/10) so confidence ~0.63 at n=10, ~0.95 at n=30.
|
||||||
|
"""
|
||||||
|
return 1.0 - math.exp(-self.sample_count / 10.0)
|
||||||
|
|
||||||
|
def predict(self, prompt_tokens: int, completion_tokens: int = 0) -> float:
|
||||||
|
"""Predict latency in milliseconds for a call with these token counts."""
|
||||||
|
return (
|
||||||
|
self.base_overhead_ms
|
||||||
|
+ self.ms_per_prompt_token * prompt_tokens
|
||||||
|
+ self.ms_per_completion_token * completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int,
|
||||||
|
actual_ms: float,
|
||||||
|
) -> float:
|
||||||
|
"""Update EMA estimates from one observed data point.
|
||||||
|
|
||||||
|
Uses a simple linear model:
|
||||||
|
actual_ms ≈ overhead + α_p * prompt_tokens + α_c * completion_tokens
|
||||||
|
|
||||||
|
We update each coefficient independently using EMA on the residuals.
|
||||||
|
Returns the prediction error (actual - predicted) in ms.
|
||||||
|
"""
|
||||||
|
predicted_ms = self.predict(prompt_tokens, completion_tokens)
|
||||||
|
error_ms = actual_ms - predicted_ms
|
||||||
|
|
||||||
|
# EMA update: new_estimate = old + alpha * error
|
||||||
|
# This is equivalent to: new = (1-alpha)*old + alpha*actual_ratio
|
||||||
|
total_tokens = prompt_tokens + completion_tokens or 1
|
||||||
|
|
||||||
|
# Attribute the error proportionally to each component
|
||||||
|
prompt_frac = prompt_tokens / total_tokens
|
||||||
|
completion_frac = completion_tokens / total_tokens
|
||||||
|
overhead_frac = 1.0 - 0.5 * (prompt_frac + completion_frac)
|
||||||
|
|
||||||
|
self.ms_per_prompt_token += self.alpha * error_ms * prompt_frac / max(prompt_tokens, 1)
|
||||||
|
self.ms_per_completion_token += self.alpha * error_ms * completion_frac / max(completion_tokens, 1)
|
||||||
|
self.base_overhead_ms += self.alpha * error_ms * overhead_frac
|
||||||
|
|
||||||
|
# Clamp to physically reasonable values
|
||||||
|
self.ms_per_prompt_token = max(0.001, self.ms_per_prompt_token)
|
||||||
|
self.ms_per_completion_token = max(0.001, self.ms_per_completion_token)
|
||||||
|
self.base_overhead_ms = max(0.0, self.base_overhead_ms)
|
||||||
|
|
||||||
|
self.sample_count += 1
|
||||||
|
self.last_updated = time.time()
|
||||||
|
self.total_absolute_error_ms += abs(error_ms)
|
||||||
|
self.total_predicted_ms += predicted_ms
|
||||||
|
|
||||||
|
return error_ms
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mean_absolute_error_ms(self) -> float:
|
||||||
|
"""MAE over all recorded samples."""
|
||||||
|
if self.sample_count == 0:
|
||||||
|
return float("nan")
|
||||||
|
return self.total_absolute_error_ms / self.sample_count
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"model": self.model,
|
||||||
|
"alpha": self.alpha,
|
||||||
|
"sample_count": self.sample_count,
|
||||||
|
"last_updated": self.last_updated,
|
||||||
|
"ms_per_prompt_token": self.ms_per_prompt_token,
|
||||||
|
"ms_per_completion_token": self.ms_per_completion_token,
|
||||||
|
"base_overhead_ms": self.base_overhead_ms,
|
||||||
|
"total_absolute_error_ms": self.total_absolute_error_ms,
|
||||||
|
"total_predicted_ms": self.total_predicted_ms,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d: dict) -> "ModelCalibration":
|
||||||
|
obj = cls(model=d["model"], alpha=d.get("alpha", DEFAULT_ALPHA))
|
||||||
|
obj.sample_count = d.get("sample_count", 0)
|
||||||
|
obj.last_updated = d.get("last_updated", time.time())
|
||||||
|
obj.ms_per_prompt_token = d["ms_per_prompt_token"]
|
||||||
|
obj.ms_per_completion_token = d["ms_per_completion_token"]
|
||||||
|
obj.base_overhead_ms = d["base_overhead_ms"]
|
||||||
|
obj.total_absolute_error_ms = d.get("total_absolute_error_ms", 0.0)
|
||||||
|
obj.total_predicted_ms = d.get("total_predicted_ms", 0.0)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptiveCalibrator:
|
||||||
|
"""Online calibrator for local LLM inference cost estimation.
|
||||||
|
|
||||||
|
Maintains per-model EMA calibration state, persisted to disk between
|
||||||
|
sessions. Requires no external dependencies — pure stdlib.
|
||||||
|
|
||||||
|
Thread safety: not thread-safe. Use one instance per process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
state_path: Optional[Path] = None,
|
||||||
|
alpha: float = DEFAULT_ALPHA,
|
||||||
|
autosave: bool = True,
|
||||||
|
):
|
||||||
|
self.state_path = state_path or DEFAULT_STATE_PATH
|
||||||
|
self.alpha = alpha
|
||||||
|
self.autosave = autosave
|
||||||
|
self._models: dict[str, ModelCalibration] = {}
|
||||||
|
self._load()
|
||||||
|
|
||||||
|
# ── Public API ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int = 0,
|
||||||
|
) -> CostPrediction:
|
||||||
|
"""Return a calibrated cost prediction for the given model and token counts.
|
||||||
|
|
||||||
|
If this model has never been seen, returns a prior-based estimate
|
||||||
|
with confidence=0.
|
||||||
|
"""
|
||||||
|
cal = self._get_or_create(model)
|
||||||
|
predicted_ms = cal.predict(prompt_tokens, completion_tokens)
|
||||||
|
return CostPrediction(
|
||||||
|
model=model,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
predicted_ms=predicted_ms,
|
||||||
|
confidence=cal.confidence,
|
||||||
|
sample_count=cal.sample_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
def record(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt_tokens: int,
|
||||||
|
actual_ms: float,
|
||||||
|
completion_tokens: int = 0,
|
||||||
|
) -> float:
|
||||||
|
"""Record an observed inference call and update calibration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model identifier (e.g. "timmy:v0.1-q4", "llama3-8b-8192")
|
||||||
|
prompt_tokens: Number of tokens in the prompt/input
|
||||||
|
actual_ms: Observed wall-clock latency in milliseconds
|
||||||
|
completion_tokens: Number of tokens generated (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Prediction error in ms (actual - predicted) at time of recording.
|
||||||
|
"""
|
||||||
|
cal = self._get_or_create(model)
|
||||||
|
error_ms = cal.update(prompt_tokens, completion_tokens, actual_ms)
|
||||||
|
if self.autosave:
|
||||||
|
self._save()
|
||||||
|
return error_ms
|
||||||
|
|
||||||
|
def get_stats(self, model: str) -> dict:
|
||||||
|
"""Return calibration stats for a model."""
|
||||||
|
if model not in self._models:
|
||||||
|
return {
|
||||||
|
"model": model,
|
||||||
|
"sample_count": 0,
|
||||||
|
"confidence": 0.0,
|
||||||
|
"status": "uncalibrated (prior only)",
|
||||||
|
}
|
||||||
|
cal = self._models[model]
|
||||||
|
return {
|
||||||
|
"model": model,
|
||||||
|
"sample_count": cal.sample_count,
|
||||||
|
"confidence": round(cal.confidence, 3),
|
||||||
|
"ms_per_prompt_token": round(cal.ms_per_prompt_token, 4),
|
||||||
|
"ms_per_completion_token": round(cal.ms_per_completion_token, 4),
|
||||||
|
"base_overhead_ms": round(cal.base_overhead_ms, 1),
|
||||||
|
"mean_absolute_error_ms": round(cal.mean_absolute_error_ms, 1),
|
||||||
|
"last_updated": cal.last_updated,
|
||||||
|
"status": "calibrated" if cal.sample_count >= 10 else "warming up",
|
||||||
|
}
|
||||||
|
|
||||||
|
def all_stats(self) -> list[dict]:
|
||||||
|
"""Return calibration stats for all known models."""
|
||||||
|
return [self.get_stats(m) for m in sorted(self._models)]
|
||||||
|
|
||||||
|
def reset(self, model: Optional[str] = None):
|
||||||
|
"""Reset calibration for one model or all models."""
|
||||||
|
if model:
|
||||||
|
self._models.pop(model, None)
|
||||||
|
else:
|
||||||
|
self._models.clear()
|
||||||
|
if self.autosave:
|
||||||
|
self._save()
|
||||||
|
|
||||||
|
# ── Persistence ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _get_or_create(self, model: str) -> ModelCalibration:
|
||||||
|
if model not in self._models:
|
||||||
|
self._models[model] = ModelCalibration(model=model, alpha=self.alpha)
|
||||||
|
return self._models[model]
|
||||||
|
|
||||||
|
def _load(self):
|
||||||
|
"""Load persisted calibration state from disk."""
|
||||||
|
if not self.state_path.exists():
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
with open(self.state_path) as f:
|
||||||
|
data = json.load(f)
|
||||||
|
for model_data in data.get("models", []):
|
||||||
|
cal = ModelCalibration.from_dict(model_data)
|
||||||
|
self._models[cal.model] = cal
|
||||||
|
except Exception:
|
||||||
|
# Corrupt state file — start fresh
|
||||||
|
self._models = {}
|
||||||
|
|
||||||
|
def _save(self):
|
||||||
|
"""Persist calibration state to disk."""
|
||||||
|
self.state_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
data = {
|
||||||
|
"version": 1,
|
||||||
|
"saved_at": time.time(),
|
||||||
|
"models": [cal.to_dict() for cal in self._models.values()],
|
||||||
|
}
|
||||||
|
# Write atomically via tmp file
|
||||||
|
tmp = self.state_path.with_suffix(".tmp")
|
||||||
|
with open(tmp, "w") as f:
|
||||||
|
json.dump(data, f, indent=2)
|
||||||
|
tmp.replace(self.state_path)
|
||||||
262
tests/test_adaptive_calibrator.py
Normal file
262
tests/test_adaptive_calibrator.py
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
"""
|
||||||
|
Tests for AdaptiveCalibrator — online learning for local cost estimation.
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- Prior-based predictions for unseen models
|
||||||
|
- EMA update convergence
|
||||||
|
- Confidence growth with samples
|
||||||
|
- Persistence (save/load round-trip)
|
||||||
|
- reset() for one model and all models
|
||||||
|
- Groq vs local model prior selection
|
||||||
|
- get_stats() and all_stats()
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nexus.adaptive_calibrator import (
|
||||||
|
AdaptiveCalibrator,
|
||||||
|
CostPrediction,
|
||||||
|
ModelCalibration,
|
||||||
|
_is_groq_model,
|
||||||
|
_prior_for,
|
||||||
|
DEFAULT_ALPHA,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ═══ Helpers ═══
|
||||||
|
|
||||||
|
def make_calibrator(tmp_path: Path, alpha: float = DEFAULT_ALPHA) -> AdaptiveCalibrator:
|
||||||
|
state_file = tmp_path / "calibrator_state.json"
|
||||||
|
return AdaptiveCalibrator(state_path=state_file, alpha=alpha, autosave=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ═══ Model family detection ═══
|
||||||
|
|
||||||
|
def test_local_ollama_model_not_groq():
|
||||||
|
assert not _is_groq_model("timmy:v0.1-q4")
|
||||||
|
assert not _is_groq_model("mistral:7b-q4_0")
|
||||||
|
|
||||||
|
|
||||||
|
def test_groq_model_detected():
|
||||||
|
assert _is_groq_model("llama3-8b-8192")
|
||||||
|
assert _is_groq_model("mixtral-8x7b-32768")
|
||||||
|
|
||||||
|
|
||||||
|
def test_prior_local_is_slower_than_groq():
|
||||||
|
local = _prior_for("timmy:v0.1-q4")
|
||||||
|
groq = _prior_for("llama3-8b-8192")
|
||||||
|
assert local["ms_per_completion_token"] > groq["ms_per_completion_token"]
|
||||||
|
assert local["ms_per_prompt_token"] > groq["ms_per_prompt_token"]
|
||||||
|
|
||||||
|
|
||||||
|
# ═══ CostPrediction ═══
|
||||||
|
|
||||||
|
def test_predict_returns_cost_prediction(tmp_path):
|
||||||
|
cal = make_calibrator(tmp_path)
|
||||||
|
pred = cal.predict("timmy:v0.1-q4", prompt_tokens=512)
|
||||||
|
assert isinstance(pred, CostPrediction)
|
||||||
|
assert pred.model == "timmy:v0.1-q4"
|
||||||
|
assert pred.prompt_tokens == 512
|
||||||
|
assert pred.predicted_ms > 0
|
||||||
|
assert pred.sample_count == 0
|
||||||
|
assert pred.confidence == 0.0 # No samples yet
|
||||||
|
|
||||||
|
|
||||||
|
def test_predict_new_model_uses_prior(tmp_path):
|
||||||
|
cal = make_calibrator(tmp_path)
|
||||||
|
pred = cal.predict("unknown-model:x", prompt_tokens=100)
|
||||||
|
assert pred.predicted_ms > 0
|
||||||
|
assert pred.confidence == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_predict_longer_prompt_costs_more(tmp_path):
|
||||||
|
cal = make_calibrator(tmp_path)
|
||||||
|
short = cal.predict("timmy:v0.1-q4", prompt_tokens=100)
|
||||||
|
long_ = cal.predict("timmy:v0.1-q4", prompt_tokens=1000)
|
||||||
|
assert long_.predicted_ms > short.predicted_ms
|
||||||
|
|
||||||
|
|
||||||
|
# ═══ Record & EMA update ═══
|
||||||
|
|
||||||
|
def test_record_returns_error_ms(tmp_path):
|
||||||
|
cal = make_calibrator(tmp_path)
|
||||||
|
error = cal.record("timmy:v0.1-q4", prompt_tokens=512, actual_ms=5000)
|
||||||
|
assert isinstance(error, float)
|
||||||
|
|
||||||
|
|
||||||
|
def test_record_increases_sample_count(tmp_path):
|
||||||
|
cal = make_calibrator(tmp_path)
|
||||||
|
cal.record("timmy:v0.1-q4", prompt_tokens=512, actual_ms=5000)
|
||||||
|
stats = cal.get_stats("timmy:v0.1-q4")
|
||||||
|
assert stats["sample_count"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_repeated_records_converge_prediction(tmp_path):
|
||||||
|
"""After many samples of the same cost, prediction should converge."""
|
||||||
|
cal = make_calibrator(tmp_path, alpha=0.3)
|
||||||
|
TRUE_MS = 4000
|
||||||
|
|
||||||
|
for _ in range(40):
|
||||||
|
cal.record("timmy:v0.1-q4", prompt_tokens=256, actual_ms=TRUE_MS)
|
||||||
|
|
||||||
|
pred = cal.predict("timmy:v0.1-q4", prompt_tokens=256)
|
||||||
|
# Should be within 15% of true value after many samples
|
||||||
|
assert abs(pred.predicted_ms - TRUE_MS) / TRUE_MS < 0.15
|
||||||
|
|
||||||
|
|
||||||
|
def test_confidence_grows_with_samples(tmp_path):
|
||||||
|
cal = make_calibrator(tmp_path)
|
||||||
|
assert cal.predict("timmy:v0.1-q4", prompt_tokens=100).confidence == 0.0
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
cal.record("timmy:v0.1-q4", prompt_tokens=100, actual_ms=2000)
|
||||||
|
|
||||||
|
pred = cal.predict("timmy:v0.1-q4", prompt_tokens=100)
|
||||||
|
assert pred.confidence > 0.5
|
||||||
|
assert pred.sample_count == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_confidence_approaches_one(tmp_path):
|
||||||
|
cal = make_calibrator(tmp_path)
|
||||||
|
for _ in range(50):
|
||||||
|
cal.record("timmy:v0.1-q4", prompt_tokens=100, actual_ms=2000)
|
||||||
|
|
||||||
|
pred = cal.predict("timmy:v0.1-q4", prompt_tokens=100)
|
||||||
|
assert pred.confidence > 0.99
|
||||||
|
|
||||||
|
|
||||||
|
def test_parameters_stay_non_negative(tmp_path):
|
||||||
|
"""EMA updates should never drive parameters negative."""
|
||||||
|
cal = make_calibrator(tmp_path)
|
||||||
|
for _ in range(20):
|
||||||
|
# Feed very small actual times (trying to drive params to zero)
|
||||||
|
cal.record("timmy:v0.1-q4", prompt_tokens=512, actual_ms=1.0)
|
||||||
|
|
||||||
|
m = cal._models["timmy:v0.1-q4"]
|
||||||
|
assert m.ms_per_prompt_token > 0
|
||||||
|
assert m.ms_per_completion_token > 0
|
||||||
|
assert m.base_overhead_ms >= 0
|
||||||
|
|
||||||
|
|
||||||
|
# ═══ get_stats / all_stats ═══
|
||||||
|
|
||||||
|
def test_get_stats_uncalibrated(tmp_path):
|
||||||
|
cal = make_calibrator(tmp_path)
|
||||||
|
stats = cal.get_stats("never-seen-model")
|
||||||
|
assert stats["sample_count"] == 0
|
||||||
|
assert stats["confidence"] == 0.0
|
||||||
|
assert "uncalibrated" in stats["status"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_stats_after_records(tmp_path):
|
||||||
|
cal = make_calibrator(tmp_path)
|
||||||
|
for _ in range(5):
|
||||||
|
cal.record("timmy:v0.1-q4", prompt_tokens=200, actual_ms=3000)
|
||||||
|
|
||||||
|
stats = cal.get_stats("timmy:v0.1-q4")
|
||||||
|
assert stats["sample_count"] == 5
|
||||||
|
assert stats["confidence"] > 0
|
||||||
|
assert "mean_absolute_error_ms" in stats
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_stats_lists_all_models(tmp_path):
|
||||||
|
cal = make_calibrator(tmp_path)
|
||||||
|
cal.record("model-a", prompt_tokens=100, actual_ms=1000)
|
||||||
|
cal.record("model-b", prompt_tokens=100, actual_ms=2000)
|
||||||
|
|
||||||
|
stats = cal.all_stats()
|
||||||
|
model_names = [s["model"] for s in stats]
|
||||||
|
assert "model-a" in model_names
|
||||||
|
assert "model-b" in model_names
|
||||||
|
|
||||||
|
|
||||||
|
# ═══ Persistence ═══
|
||||||
|
|
||||||
|
def test_save_and_load(tmp_path):
|
||||||
|
"""Calibration state should survive a save/load round-trip."""
|
||||||
|
state_file = tmp_path / "state.json"
|
||||||
|
|
||||||
|
# Write some samples
|
||||||
|
cal1 = AdaptiveCalibrator(state_path=state_file, autosave=True)
|
||||||
|
for _ in range(15):
|
||||||
|
cal1.record("timmy:v0.1-q4", prompt_tokens=300, actual_ms=3500)
|
||||||
|
|
||||||
|
stats_before = cal1.get_stats("timmy:v0.1-q4")
|
||||||
|
|
||||||
|
# Load fresh instance
|
||||||
|
cal2 = AdaptiveCalibrator(state_path=state_file, autosave=True)
|
||||||
|
stats_after = cal2.get_stats("timmy:v0.1-q4")
|
||||||
|
|
||||||
|
assert stats_after["sample_count"] == stats_before["sample_count"]
|
||||||
|
assert abs(stats_after["ms_per_prompt_token"] - stats_before["ms_per_prompt_token"]) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_with_missing_file(tmp_path):
|
||||||
|
"""Missing state file should result in empty (not crashed) calibrator."""
|
||||||
|
cal = AdaptiveCalibrator(state_path=tmp_path / "nonexistent.json", autosave=False)
|
||||||
|
assert cal.all_stats() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_with_corrupt_file(tmp_path):
|
||||||
|
"""Corrupt state file should be silently ignored."""
|
||||||
|
state_file = tmp_path / "state.json"
|
||||||
|
state_file.write_text("not valid json {{{")
|
||||||
|
|
||||||
|
cal = AdaptiveCalibrator(state_path=state_file, autosave=False)
|
||||||
|
assert cal.all_stats() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_atomic_save(tmp_path):
|
||||||
|
"""Save should write via a tmp file and replace atomically."""
|
||||||
|
state_file = tmp_path / "state.json"
|
||||||
|
cal = AdaptiveCalibrator(state_path=state_file, autosave=True)
|
||||||
|
cal.record("timmy:v0.1-q4", prompt_tokens=100, actual_ms=2000)
|
||||||
|
|
||||||
|
assert state_file.exists()
|
||||||
|
# No .tmp file should be left behind
|
||||||
|
assert not (state_file.with_suffix(".tmp")).exists()
|
||||||
|
# File should be valid JSON
|
||||||
|
data = json.loads(state_file.read_text())
|
||||||
|
assert data["version"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ═══ Reset ═══
|
||||||
|
|
||||||
|
def test_reset_single_model(tmp_path):
|
||||||
|
cal = make_calibrator(tmp_path)
|
||||||
|
cal.record("model-a", prompt_tokens=100, actual_ms=1000)
|
||||||
|
cal.record("model-b", prompt_tokens=100, actual_ms=1000)
|
||||||
|
|
||||||
|
cal.reset("model-a")
|
||||||
|
assert cal.get_stats("model-a")["sample_count"] == 0
|
||||||
|
assert cal.get_stats("model-b")["sample_count"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset_all_models(tmp_path):
|
||||||
|
cal = make_calibrator(tmp_path)
|
||||||
|
cal.record("model-a", prompt_tokens=100, actual_ms=1000)
|
||||||
|
cal.record("model-b", prompt_tokens=100, actual_ms=1000)
|
||||||
|
|
||||||
|
cal.reset()
|
||||||
|
assert cal.all_stats() == []
|
||||||
|
|
||||||
|
|
||||||
|
# ═══ ModelCalibration unit tests ═══
|
||||||
|
|
||||||
|
def test_model_calibration_repr_roundtrip():
|
||||||
|
m = ModelCalibration(model="test:v1")
|
||||||
|
d = m.to_dict()
|
||||||
|
m2 = ModelCalibration.from_dict(d)
|
||||||
|
assert m2.model == m.model
|
||||||
|
assert m2.alpha == m.alpha
|
||||||
|
assert m2.ms_per_prompt_token == m.ms_per_prompt_token
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_calibration_mean_absolute_error_nan_when_no_samples():
|
||||||
|
m = ModelCalibration(model="test:v1")
|
||||||
|
assert math.isnan(m.mean_absolute_error_ms)
|
||||||
Reference in New Issue
Block a user