Compare commits
1 Commits
improvemen
...
claude/iss
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5649aeb975 |
@@ -14,6 +14,7 @@ from nexus.perception_adapter import (
|
||||
)
|
||||
from nexus.experience_store import ExperienceStore
|
||||
from nexus.trajectory_logger import TrajectoryLogger
|
||||
from nexus.adaptive_calibrator import AdaptiveCalibrator, CostPrediction
|
||||
|
||||
try:
|
||||
from nexus.nexus_think import NexusMind
|
||||
@@ -28,5 +29,7 @@ __all__ = [
|
||||
"Action",
|
||||
"ExperienceStore",
|
||||
"TrajectoryLogger",
|
||||
"AdaptiveCalibrator",
|
||||
"CostPrediction",
|
||||
"NexusMind",
|
||||
]
|
||||
|
||||
354
nexus/adaptive_calibrator.py
Normal file
354
nexus/adaptive_calibrator.py
Normal file
@@ -0,0 +1,354 @@
|
||||
"""
|
||||
AdaptiveCalibrator — Online Learning for Local Cost Estimation
|
||||
|
||||
Tracks predicted vs actual inference costs (latency, tokens) per model
|
||||
and learns correction factors using Exponential Moving Average (EMA).
|
||||
|
||||
Extracted from Kimi Report #2 design spec.
|
||||
|
||||
Usage:
|
||||
calibrator = AdaptiveCalibrator()
|
||||
|
||||
# Before a call: get predicted cost
|
||||
prediction = calibrator.predict("timmy:v0.1-q4", prompt_tokens=512)
|
||||
|
||||
# After a call: record what actually happened
|
||||
calibrator.record(
|
||||
model="timmy:v0.1-q4",
|
||||
prompt_tokens=512,
|
||||
completion_tokens=128,
|
||||
actual_ms=3400,
|
||||
)
|
||||
|
||||
# Get model stats
|
||||
stats = calibrator.get_stats("timmy:v0.1-q4")
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
DEFAULT_STATE_PATH = Path.home() / ".nexus" / "calibrator_state.json"
|
||||
|
||||
# EMA smoothing factor: 0.1 = slow adaptation, 0.3 = fast adaptation
|
||||
DEFAULT_ALPHA = 0.15
|
||||
|
||||
# Seed latency estimates (ms per token) by model family
|
||||
# These are rough priors; the calibrator adapts them online
|
||||
_MODEL_PRIORS: dict[str, dict] = {
|
||||
# Ollama local models (8B range, q4 quantized, typical CPU/GPU)
|
||||
"default_local": {
|
||||
"ms_per_prompt_token": 0.5,
|
||||
"ms_per_completion_token": 8.0,
|
||||
"base_overhead_ms": 300.0,
|
||||
},
|
||||
# Groq cloud (extremely fast inference)
|
||||
"default_groq": {
|
||||
"ms_per_prompt_token": 0.05,
|
||||
"ms_per_completion_token": 0.3,
|
||||
"base_overhead_ms": 150.0,
|
||||
},
|
||||
}
|
||||
|
||||
_GROQ_MODEL_PREFIXES = ("llama", "mixtral", "gemma", "whisper")
|
||||
|
||||
|
||||
def _is_groq_model(model: str) -> bool:
|
||||
"""Heuristic: is this a cloud Groq model vs a local Ollama model?"""
|
||||
m = model.lower()
|
||||
return any(m.startswith(p) for p in _GROQ_MODEL_PREFIXES) and ":" not in m
|
||||
|
||||
|
||||
def _prior_for(model: str) -> dict:
|
||||
"""Return a copy of the seed prior for this model."""
|
||||
if _is_groq_model(model):
|
||||
return dict(_MODEL_PRIORS["default_groq"])
|
||||
return dict(_MODEL_PRIORS["default_local"])
|
||||
|
||||
|
||||
class CostPrediction:
|
||||
"""Result of a calibrated cost prediction."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
predicted_ms: float,
|
||||
confidence: float,
|
||||
sample_count: int,
|
||||
):
|
||||
self.model = model
|
||||
self.prompt_tokens = prompt_tokens
|
||||
self.predicted_ms = predicted_ms
|
||||
self.confidence = confidence # 0.0 (prior only) → 1.0 (well-calibrated)
|
||||
self.sample_count = sample_count
|
||||
self.predicted_at = time.time()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"CostPrediction(model={self.model!r}, "
|
||||
f"prompt_tokens={self.prompt_tokens}, "
|
||||
f"predicted_ms={self.predicted_ms:.0f}, "
|
||||
f"confidence={self.confidence:.2f}, "
|
||||
f"n={self.sample_count})"
|
||||
)
|
||||
|
||||
|
||||
class ModelCalibration:
|
||||
"""Per-model online calibration state.
|
||||
|
||||
Tracks EMA estimates of:
|
||||
- ms_per_prompt_token
|
||||
- ms_per_completion_token
|
||||
- base_overhead_ms
|
||||
|
||||
Confidence grows with sample count (sigmoid-ish curve).
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, alpha: float = DEFAULT_ALPHA):
|
||||
self.model = model
|
||||
self.alpha = alpha
|
||||
self.sample_count = 0
|
||||
self.last_updated = time.time()
|
||||
|
||||
# EMA parameters (start from prior)
|
||||
prior = _prior_for(model)
|
||||
self.ms_per_prompt_token: float = prior["ms_per_prompt_token"]
|
||||
self.ms_per_completion_token: float = prior["ms_per_completion_token"]
|
||||
self.base_overhead_ms: float = prior["base_overhead_ms"]
|
||||
|
||||
# Tracking for error diagnostics
|
||||
self.total_absolute_error_ms: float = 0.0
|
||||
self.total_predicted_ms: float = 0.0
|
||||
|
||||
@property
|
||||
def confidence(self) -> float:
|
||||
"""Confidence in current estimates.
|
||||
|
||||
Grows from 0 (prior only) toward 1 as samples accumulate.
|
||||
Uses: 1 - exp(-n/10) so confidence ~0.63 at n=10, ~0.95 at n=30.
|
||||
"""
|
||||
return 1.0 - math.exp(-self.sample_count / 10.0)
|
||||
|
||||
def predict(self, prompt_tokens: int, completion_tokens: int = 0) -> float:
|
||||
"""Predict latency in milliseconds for a call with these token counts."""
|
||||
return (
|
||||
self.base_overhead_ms
|
||||
+ self.ms_per_prompt_token * prompt_tokens
|
||||
+ self.ms_per_completion_token * completion_tokens
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
actual_ms: float,
|
||||
) -> float:
|
||||
"""Update EMA estimates from one observed data point.
|
||||
|
||||
Uses a simple linear model:
|
||||
actual_ms ≈ overhead + α_p * prompt_tokens + α_c * completion_tokens
|
||||
|
||||
We update each coefficient independently using EMA on the residuals.
|
||||
Returns the prediction error (actual - predicted) in ms.
|
||||
"""
|
||||
predicted_ms = self.predict(prompt_tokens, completion_tokens)
|
||||
error_ms = actual_ms - predicted_ms
|
||||
|
||||
# EMA update: new_estimate = old + alpha * error
|
||||
# This is equivalent to: new = (1-alpha)*old + alpha*actual_ratio
|
||||
total_tokens = prompt_tokens + completion_tokens or 1
|
||||
|
||||
# Attribute the error proportionally to each component
|
||||
prompt_frac = prompt_tokens / total_tokens
|
||||
completion_frac = completion_tokens / total_tokens
|
||||
overhead_frac = 1.0 - 0.5 * (prompt_frac + completion_frac)
|
||||
|
||||
self.ms_per_prompt_token += self.alpha * error_ms * prompt_frac / max(prompt_tokens, 1)
|
||||
self.ms_per_completion_token += self.alpha * error_ms * completion_frac / max(completion_tokens, 1)
|
||||
self.base_overhead_ms += self.alpha * error_ms * overhead_frac
|
||||
|
||||
# Clamp to physically reasonable values
|
||||
self.ms_per_prompt_token = max(0.001, self.ms_per_prompt_token)
|
||||
self.ms_per_completion_token = max(0.001, self.ms_per_completion_token)
|
||||
self.base_overhead_ms = max(0.0, self.base_overhead_ms)
|
||||
|
||||
self.sample_count += 1
|
||||
self.last_updated = time.time()
|
||||
self.total_absolute_error_ms += abs(error_ms)
|
||||
self.total_predicted_ms += predicted_ms
|
||||
|
||||
return error_ms
|
||||
|
||||
@property
|
||||
def mean_absolute_error_ms(self) -> float:
|
||||
"""MAE over all recorded samples."""
|
||||
if self.sample_count == 0:
|
||||
return float("nan")
|
||||
return self.total_absolute_error_ms / self.sample_count
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"model": self.model,
|
||||
"alpha": self.alpha,
|
||||
"sample_count": self.sample_count,
|
||||
"last_updated": self.last_updated,
|
||||
"ms_per_prompt_token": self.ms_per_prompt_token,
|
||||
"ms_per_completion_token": self.ms_per_completion_token,
|
||||
"base_overhead_ms": self.base_overhead_ms,
|
||||
"total_absolute_error_ms": self.total_absolute_error_ms,
|
||||
"total_predicted_ms": self.total_predicted_ms,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "ModelCalibration":
|
||||
obj = cls(model=d["model"], alpha=d.get("alpha", DEFAULT_ALPHA))
|
||||
obj.sample_count = d.get("sample_count", 0)
|
||||
obj.last_updated = d.get("last_updated", time.time())
|
||||
obj.ms_per_prompt_token = d["ms_per_prompt_token"]
|
||||
obj.ms_per_completion_token = d["ms_per_completion_token"]
|
||||
obj.base_overhead_ms = d["base_overhead_ms"]
|
||||
obj.total_absolute_error_ms = d.get("total_absolute_error_ms", 0.0)
|
||||
obj.total_predicted_ms = d.get("total_predicted_ms", 0.0)
|
||||
return obj
|
||||
|
||||
|
||||
class AdaptiveCalibrator:
|
||||
"""Online calibrator for local LLM inference cost estimation.
|
||||
|
||||
Maintains per-model EMA calibration state, persisted to disk between
|
||||
sessions. Requires no external dependencies — pure stdlib.
|
||||
|
||||
Thread safety: not thread-safe. Use one instance per process.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_path: Optional[Path] = None,
|
||||
alpha: float = DEFAULT_ALPHA,
|
||||
autosave: bool = True,
|
||||
):
|
||||
self.state_path = state_path or DEFAULT_STATE_PATH
|
||||
self.alpha = alpha
|
||||
self.autosave = autosave
|
||||
self._models: dict[str, ModelCalibration] = {}
|
||||
self._load()
|
||||
|
||||
# ── Public API ───────────────────────────────────────────────────
|
||||
|
||||
def predict(
|
||||
self,
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int = 0,
|
||||
) -> CostPrediction:
|
||||
"""Return a calibrated cost prediction for the given model and token counts.
|
||||
|
||||
If this model has never been seen, returns a prior-based estimate
|
||||
with confidence=0.
|
||||
"""
|
||||
cal = self._get_or_create(model)
|
||||
predicted_ms = cal.predict(prompt_tokens, completion_tokens)
|
||||
return CostPrediction(
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
predicted_ms=predicted_ms,
|
||||
confidence=cal.confidence,
|
||||
sample_count=cal.sample_count,
|
||||
)
|
||||
|
||||
def record(
|
||||
self,
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
actual_ms: float,
|
||||
completion_tokens: int = 0,
|
||||
) -> float:
|
||||
"""Record an observed inference call and update calibration.
|
||||
|
||||
Args:
|
||||
model: Model identifier (e.g. "timmy:v0.1-q4", "llama3-8b-8192")
|
||||
prompt_tokens: Number of tokens in the prompt/input
|
||||
actual_ms: Observed wall-clock latency in milliseconds
|
||||
completion_tokens: Number of tokens generated (optional)
|
||||
|
||||
Returns:
|
||||
Prediction error in ms (actual - predicted) at time of recording.
|
||||
"""
|
||||
cal = self._get_or_create(model)
|
||||
error_ms = cal.update(prompt_tokens, completion_tokens, actual_ms)
|
||||
if self.autosave:
|
||||
self._save()
|
||||
return error_ms
|
||||
|
||||
def get_stats(self, model: str) -> dict:
|
||||
"""Return calibration stats for a model."""
|
||||
if model not in self._models:
|
||||
return {
|
||||
"model": model,
|
||||
"sample_count": 0,
|
||||
"confidence": 0.0,
|
||||
"status": "uncalibrated (prior only)",
|
||||
}
|
||||
cal = self._models[model]
|
||||
return {
|
||||
"model": model,
|
||||
"sample_count": cal.sample_count,
|
||||
"confidence": round(cal.confidence, 3),
|
||||
"ms_per_prompt_token": round(cal.ms_per_prompt_token, 4),
|
||||
"ms_per_completion_token": round(cal.ms_per_completion_token, 4),
|
||||
"base_overhead_ms": round(cal.base_overhead_ms, 1),
|
||||
"mean_absolute_error_ms": round(cal.mean_absolute_error_ms, 1),
|
||||
"last_updated": cal.last_updated,
|
||||
"status": "calibrated" if cal.sample_count >= 10 else "warming up",
|
||||
}
|
||||
|
||||
def all_stats(self) -> list[dict]:
|
||||
"""Return calibration stats for all known models."""
|
||||
return [self.get_stats(m) for m in sorted(self._models)]
|
||||
|
||||
def reset(self, model: Optional[str] = None):
|
||||
"""Reset calibration for one model or all models."""
|
||||
if model:
|
||||
self._models.pop(model, None)
|
||||
else:
|
||||
self._models.clear()
|
||||
if self.autosave:
|
||||
self._save()
|
||||
|
||||
# ── Persistence ──────────────────────────────────────────────────
|
||||
|
||||
def _get_or_create(self, model: str) -> ModelCalibration:
|
||||
if model not in self._models:
|
||||
self._models[model] = ModelCalibration(model=model, alpha=self.alpha)
|
||||
return self._models[model]
|
||||
|
||||
def _load(self):
|
||||
"""Load persisted calibration state from disk."""
|
||||
if not self.state_path.exists():
|
||||
return
|
||||
try:
|
||||
with open(self.state_path) as f:
|
||||
data = json.load(f)
|
||||
for model_data in data.get("models", []):
|
||||
cal = ModelCalibration.from_dict(model_data)
|
||||
self._models[cal.model] = cal
|
||||
except Exception:
|
||||
# Corrupt state file — start fresh
|
||||
self._models = {}
|
||||
|
||||
def _save(self):
|
||||
"""Persist calibration state to disk."""
|
||||
self.state_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
data = {
|
||||
"version": 1,
|
||||
"saved_at": time.time(),
|
||||
"models": [cal.to_dict() for cal in self._models.values()],
|
||||
}
|
||||
# Write atomically via tmp file
|
||||
tmp = self.state_path.with_suffix(".tmp")
|
||||
with open(tmp, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
tmp.replace(self.state_path)
|
||||
262
tests/test_adaptive_calibrator.py
Normal file
262
tests/test_adaptive_calibrator.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""
|
||||
Tests for AdaptiveCalibrator — online learning for local cost estimation.
|
||||
|
||||
Covers:
|
||||
- Prior-based predictions for unseen models
|
||||
- EMA update convergence
|
||||
- Confidence growth with samples
|
||||
- Persistence (save/load round-trip)
|
||||
- reset() for one model and all models
|
||||
- Groq vs local model prior selection
|
||||
- get_stats() and all_stats()
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from nexus.adaptive_calibrator import (
|
||||
AdaptiveCalibrator,
|
||||
CostPrediction,
|
||||
ModelCalibration,
|
||||
_is_groq_model,
|
||||
_prior_for,
|
||||
DEFAULT_ALPHA,
|
||||
)
|
||||
|
||||
|
||||
# ═══ Helpers ═══
|
||||
|
||||
def make_calibrator(tmp_path: Path, alpha: float = DEFAULT_ALPHA) -> AdaptiveCalibrator:
|
||||
state_file = tmp_path / "calibrator_state.json"
|
||||
return AdaptiveCalibrator(state_path=state_file, alpha=alpha, autosave=True)
|
||||
|
||||
|
||||
# ═══ Model family detection ═══
|
||||
|
||||
def test_local_ollama_model_not_groq():
|
||||
assert not _is_groq_model("timmy:v0.1-q4")
|
||||
assert not _is_groq_model("mistral:7b-q4_0")
|
||||
|
||||
|
||||
def test_groq_model_detected():
|
||||
assert _is_groq_model("llama3-8b-8192")
|
||||
assert _is_groq_model("mixtral-8x7b-32768")
|
||||
|
||||
|
||||
def test_prior_local_is_slower_than_groq():
|
||||
local = _prior_for("timmy:v0.1-q4")
|
||||
groq = _prior_for("llama3-8b-8192")
|
||||
assert local["ms_per_completion_token"] > groq["ms_per_completion_token"]
|
||||
assert local["ms_per_prompt_token"] > groq["ms_per_prompt_token"]
|
||||
|
||||
|
||||
# ═══ CostPrediction ═══
|
||||
|
||||
def test_predict_returns_cost_prediction(tmp_path):
|
||||
cal = make_calibrator(tmp_path)
|
||||
pred = cal.predict("timmy:v0.1-q4", prompt_tokens=512)
|
||||
assert isinstance(pred, CostPrediction)
|
||||
assert pred.model == "timmy:v0.1-q4"
|
||||
assert pred.prompt_tokens == 512
|
||||
assert pred.predicted_ms > 0
|
||||
assert pred.sample_count == 0
|
||||
assert pred.confidence == 0.0 # No samples yet
|
||||
|
||||
|
||||
def test_predict_new_model_uses_prior(tmp_path):
|
||||
cal = make_calibrator(tmp_path)
|
||||
pred = cal.predict("unknown-model:x", prompt_tokens=100)
|
||||
assert pred.predicted_ms > 0
|
||||
assert pred.confidence == 0.0
|
||||
|
||||
|
||||
def test_predict_longer_prompt_costs_more(tmp_path):
|
||||
cal = make_calibrator(tmp_path)
|
||||
short = cal.predict("timmy:v0.1-q4", prompt_tokens=100)
|
||||
long_ = cal.predict("timmy:v0.1-q4", prompt_tokens=1000)
|
||||
assert long_.predicted_ms > short.predicted_ms
|
||||
|
||||
|
||||
# ═══ Record & EMA update ═══
|
||||
|
||||
def test_record_returns_error_ms(tmp_path):
|
||||
cal = make_calibrator(tmp_path)
|
||||
error = cal.record("timmy:v0.1-q4", prompt_tokens=512, actual_ms=5000)
|
||||
assert isinstance(error, float)
|
||||
|
||||
|
||||
def test_record_increases_sample_count(tmp_path):
|
||||
cal = make_calibrator(tmp_path)
|
||||
cal.record("timmy:v0.1-q4", prompt_tokens=512, actual_ms=5000)
|
||||
stats = cal.get_stats("timmy:v0.1-q4")
|
||||
assert stats["sample_count"] == 1
|
||||
|
||||
|
||||
def test_repeated_records_converge_prediction(tmp_path):
|
||||
"""After many samples of the same cost, prediction should converge."""
|
||||
cal = make_calibrator(tmp_path, alpha=0.3)
|
||||
TRUE_MS = 4000
|
||||
|
||||
for _ in range(40):
|
||||
cal.record("timmy:v0.1-q4", prompt_tokens=256, actual_ms=TRUE_MS)
|
||||
|
||||
pred = cal.predict("timmy:v0.1-q4", prompt_tokens=256)
|
||||
# Should be within 15% of true value after many samples
|
||||
assert abs(pred.predicted_ms - TRUE_MS) / TRUE_MS < 0.15
|
||||
|
||||
|
||||
def test_confidence_grows_with_samples(tmp_path):
|
||||
cal = make_calibrator(tmp_path)
|
||||
assert cal.predict("timmy:v0.1-q4", prompt_tokens=100).confidence == 0.0
|
||||
|
||||
for i in range(10):
|
||||
cal.record("timmy:v0.1-q4", prompt_tokens=100, actual_ms=2000)
|
||||
|
||||
pred = cal.predict("timmy:v0.1-q4", prompt_tokens=100)
|
||||
assert pred.confidence > 0.5
|
||||
assert pred.sample_count == 10
|
||||
|
||||
|
||||
def test_confidence_approaches_one(tmp_path):
|
||||
cal = make_calibrator(tmp_path)
|
||||
for _ in range(50):
|
||||
cal.record("timmy:v0.1-q4", prompt_tokens=100, actual_ms=2000)
|
||||
|
||||
pred = cal.predict("timmy:v0.1-q4", prompt_tokens=100)
|
||||
assert pred.confidence > 0.99
|
||||
|
||||
|
||||
def test_parameters_stay_non_negative(tmp_path):
|
||||
"""EMA updates should never drive parameters negative."""
|
||||
cal = make_calibrator(tmp_path)
|
||||
for _ in range(20):
|
||||
# Feed very small actual times (trying to drive params to zero)
|
||||
cal.record("timmy:v0.1-q4", prompt_tokens=512, actual_ms=1.0)
|
||||
|
||||
m = cal._models["timmy:v0.1-q4"]
|
||||
assert m.ms_per_prompt_token > 0
|
||||
assert m.ms_per_completion_token > 0
|
||||
assert m.base_overhead_ms >= 0
|
||||
|
||||
|
||||
# ═══ get_stats / all_stats ═══
|
||||
|
||||
def test_get_stats_uncalibrated(tmp_path):
|
||||
cal = make_calibrator(tmp_path)
|
||||
stats = cal.get_stats("never-seen-model")
|
||||
assert stats["sample_count"] == 0
|
||||
assert stats["confidence"] == 0.0
|
||||
assert "uncalibrated" in stats["status"]
|
||||
|
||||
|
||||
def test_get_stats_after_records(tmp_path):
|
||||
cal = make_calibrator(tmp_path)
|
||||
for _ in range(5):
|
||||
cal.record("timmy:v0.1-q4", prompt_tokens=200, actual_ms=3000)
|
||||
|
||||
stats = cal.get_stats("timmy:v0.1-q4")
|
||||
assert stats["sample_count"] == 5
|
||||
assert stats["confidence"] > 0
|
||||
assert "mean_absolute_error_ms" in stats
|
||||
|
||||
|
||||
def test_all_stats_lists_all_models(tmp_path):
|
||||
cal = make_calibrator(tmp_path)
|
||||
cal.record("model-a", prompt_tokens=100, actual_ms=1000)
|
||||
cal.record("model-b", prompt_tokens=100, actual_ms=2000)
|
||||
|
||||
stats = cal.all_stats()
|
||||
model_names = [s["model"] for s in stats]
|
||||
assert "model-a" in model_names
|
||||
assert "model-b" in model_names
|
||||
|
||||
|
||||
# ═══ Persistence ═══
|
||||
|
||||
def test_save_and_load(tmp_path):
|
||||
"""Calibration state should survive a save/load round-trip."""
|
||||
state_file = tmp_path / "state.json"
|
||||
|
||||
# Write some samples
|
||||
cal1 = AdaptiveCalibrator(state_path=state_file, autosave=True)
|
||||
for _ in range(15):
|
||||
cal1.record("timmy:v0.1-q4", prompt_tokens=300, actual_ms=3500)
|
||||
|
||||
stats_before = cal1.get_stats("timmy:v0.1-q4")
|
||||
|
||||
# Load fresh instance
|
||||
cal2 = AdaptiveCalibrator(state_path=state_file, autosave=True)
|
||||
stats_after = cal2.get_stats("timmy:v0.1-q4")
|
||||
|
||||
assert stats_after["sample_count"] == stats_before["sample_count"]
|
||||
assert abs(stats_after["ms_per_prompt_token"] - stats_before["ms_per_prompt_token"]) < 1e-6
|
||||
|
||||
|
||||
def test_load_with_missing_file(tmp_path):
|
||||
"""Missing state file should result in empty (not crashed) calibrator."""
|
||||
cal = AdaptiveCalibrator(state_path=tmp_path / "nonexistent.json", autosave=False)
|
||||
assert cal.all_stats() == []
|
||||
|
||||
|
||||
def test_load_with_corrupt_file(tmp_path):
|
||||
"""Corrupt state file should be silently ignored."""
|
||||
state_file = tmp_path / "state.json"
|
||||
state_file.write_text("not valid json {{{")
|
||||
|
||||
cal = AdaptiveCalibrator(state_path=state_file, autosave=False)
|
||||
assert cal.all_stats() == []
|
||||
|
||||
|
||||
def test_atomic_save(tmp_path):
|
||||
"""Save should write via a tmp file and replace atomically."""
|
||||
state_file = tmp_path / "state.json"
|
||||
cal = AdaptiveCalibrator(state_path=state_file, autosave=True)
|
||||
cal.record("timmy:v0.1-q4", prompt_tokens=100, actual_ms=2000)
|
||||
|
||||
assert state_file.exists()
|
||||
# No .tmp file should be left behind
|
||||
assert not (state_file.with_suffix(".tmp")).exists()
|
||||
# File should be valid JSON
|
||||
data = json.loads(state_file.read_text())
|
||||
assert data["version"] == 1
|
||||
|
||||
|
||||
# ═══ Reset ═══
|
||||
|
||||
def test_reset_single_model(tmp_path):
|
||||
cal = make_calibrator(tmp_path)
|
||||
cal.record("model-a", prompt_tokens=100, actual_ms=1000)
|
||||
cal.record("model-b", prompt_tokens=100, actual_ms=1000)
|
||||
|
||||
cal.reset("model-a")
|
||||
assert cal.get_stats("model-a")["sample_count"] == 0
|
||||
assert cal.get_stats("model-b")["sample_count"] == 1
|
||||
|
||||
|
||||
def test_reset_all_models(tmp_path):
|
||||
cal = make_calibrator(tmp_path)
|
||||
cal.record("model-a", prompt_tokens=100, actual_ms=1000)
|
||||
cal.record("model-b", prompt_tokens=100, actual_ms=1000)
|
||||
|
||||
cal.reset()
|
||||
assert cal.all_stats() == []
|
||||
|
||||
|
||||
# ═══ ModelCalibration unit tests ═══
|
||||
|
||||
def test_model_calibration_repr_roundtrip():
|
||||
m = ModelCalibration(model="test:v1")
|
||||
d = m.to_dict()
|
||||
m2 = ModelCalibration.from_dict(d)
|
||||
assert m2.model == m.model
|
||||
assert m2.alpha == m.alpha
|
||||
assert m2.ms_per_prompt_token == m.ms_per_prompt_token
|
||||
|
||||
|
||||
def test_model_calibration_mean_absolute_error_nan_when_no_samples():
|
||||
m = ModelCalibration(model="test:v1")
|
||||
assert math.isnan(m.mean_absolute_error_ms)
|
||||
Reference in New Issue
Block a user