#!/usr/bin/env python3 """ Benchmark local Ollama models against the 50 tok/s UX threshold. Usage: python3 scripts/benchmark_local_models.py [--models MODEL1,MODEL2] [--prompt PROMPT] [--rounds N] python3 scripts/benchmark_local_models.py --all # test all pulled models python3 scripts/benchmark_local_models.py --json # JSON output for CI """ import argparse import json import os import sys import time import urllib.request import urllib.error from dataclasses import dataclass, asdict from typing import Optional OLLAMA_BASE = os.environ.get("OLLAMA_BASE_URL", "http://localhost:11434") THRESHOLD_TOK_S = 50.0 BENCHMARK_PROMPT = ( "Explain the difference between TCP and UDP protocols. " "Cover reliability, ordering, speed, and use cases. " "Be thorough but concise. Write at least 300 words." ) @dataclass class BenchmarkResult: model: str size_gb: float prompt_tokens: int eval_tokens: int eval_duration_s: float tokens_per_second: float total_duration_s: float rounds: int avg_tok_s: float meets_threshold: bool error: Optional[str] = None def get_models() -> list[dict]: """List all pulled Ollama models.""" url = f"{OLLAMA_BASE}/api/tags" try: req = urllib.request.Request(url) with urllib.request.urlopen(req, timeout=10) as resp: data = json.loads(resp.read()) return data.get("models", []) except Exception as e: print(f"Error connecting to Ollama at {OLLAMA_BASE}: {e}", file=sys.stderr) sys.exit(1) def benchmark_model(model: str, prompt: str, num_predict: int = 512) -> dict: """Run a single benchmark generation, return timing stats.""" url = f"{OLLAMA_BASE}/api/generate" payload = json.dumps({ "model": model, "prompt": prompt, "stream": False, "options": { "num_predict": num_predict, "temperature": 0.1, # low temp for consistent output }, }).encode() req = urllib.request.Request(url, data=payload, method="POST") req.add_header("Content-Type", "application/json") start = time.monotonic() try: with urllib.request.urlopen(req, timeout=300) as resp: data = json.loads(resp.read()) except urllib.error.HTTPError as e: body = e.read().decode() if e.fp else str(e) raise RuntimeError(f"HTTP {e.code}: {body[:200]}") except Exception as e: raise RuntimeError(str(e)) elapsed = time.monotonic() - start prompt_tokens = data.get("prompt_eval_count", 0) eval_tokens = data.get("eval_count", 0) eval_duration_ns = data.get("eval_duration", 0) total_duration_ns = data.get("total_duration", 0) eval_duration_s = eval_duration_ns / 1e9 if eval_duration_ns else elapsed total_duration_s = total_duration_ns / 1e9 if total_duration_ns else elapsed tok_s = eval_tokens / eval_duration_s if eval_duration_s > 0 else 0.0 return { "prompt_tokens": prompt_tokens, "eval_tokens": eval_tokens, "eval_duration_s": round(eval_duration_s, 2), "total_duration_s": round(total_duration_s, 2), "tokens_per_second": round(tok_s, 1), } def run_benchmark( model_name: str, model_size: float, prompt: str, rounds: int, num_predict: int, threshold: float = 50.0, ) -> BenchmarkResult: """Run multiple rounds and compute average.""" results = [] errors = [] for i in range(rounds): try: r = benchmark_model(model_name, prompt, num_predict) results.append(r) print(f" Round {i+1}/{rounds}: {r['tokens_per_second']} tok/s " f"({r['eval_tokens']} tokens in {r['eval_duration_s']}s)") except Exception as e: errors.append(str(e)) print(f" Round {i+1}/{rounds}: ERROR - {e}") if not results: return BenchmarkResult( model=model_name, size_gb=model_size, prompt_tokens=0, eval_tokens=0, eval_duration_s=0, tokens_per_second=0, total_duration_s=0, rounds=rounds, avg_tok_s=0, meets_threshold=False, error="; ".join(errors), ) avg_tok_s = sum(r["tokens_per_second"] for r in results) / len(results) avg_tok_s = round(avg_tok_s, 1) return BenchmarkResult( model=model_name, size_gb=model_size, prompt_tokens=sum(r["prompt_tokens"] for r in results) // len(results), eval_tokens=sum(r["eval_tokens"] for r in results) // len(results), eval_duration_s=round(sum(r["eval_duration_s"] for r in results) / len(results), 2), tokens_per_second=avg_tok_s, total_duration_s=round(sum(r["total_duration_s"] for r in results) / len(results), 2), rounds=len(results), avg_tok_s=avg_tok_s, meets_threshold=avg_tok_s >= threshold, ) def format_report(results: list[BenchmarkResult], threshold: float = 50.0) -> str: """Format a human-readable benchmark report.""" lines = [] lines.append("") lines.append("=" * 72) lines.append(f" LOCAL MODEL BENCHMARK — {threshold:.0f} tok/s UX Threshold") lines.append("=" * 72) lines.append("") # Summary table header = f"{'Model':<25} {'Size':>6} {'tok/s':>8} {'Threshold':>10} {'Status':>8}" lines.append(header) lines.append("-" * 72) passed = 0 failed = 0 errors = 0 for r in sorted(results, key=lambda x: x.avg_tok_s, reverse=True): size_str = f"{r.size_gb:.1f}GB" tok_s_str = f"{r.avg_tok_s:.1f}" if r.error: status = "ERROR" errors += 1 elif r.meets_threshold: status = "PASS" passed += 1 else: status = "FAIL" failed += 1 marker = ">" if r.meets_threshold else "X" if r.error else "!" thresh_str = f">= {threshold:.0f}" lines.append(f" {marker} {r.model:<23} {size_str:>6} {tok_s_str:>8} {thresh_str:>10} {status:>8}") lines.append("-" * 72) lines.append(f" Passed: {passed} | Failed: {failed} | Errors: {errors} | Total: {len(results)}") lines.append("") # Detail section for failures failures = [r for r in results if not r.meets_threshold and not r.error] if failures: lines.append(" FAILED MODELS (below threshold):") for r in sorted(failures, key=lambda x: x.avg_tok_s): gap = threshold - r.avg_tok_s lines.append(f" - {r.model}: {r.avg_tok_s:.1f} tok/s " f"({gap:.1f} tok/s short, {r.eval_tokens} avg tokens/round)") lines.append("") error_list = [r for r in results if r.error] if error_list: lines.append(" ERRORS:") for r in error_list: lines.append(f" - {r.model}: {r.error}") lines.append("") # Hardware info import platform lines.append(f" Host: {platform.node()} | {platform.system()} {platform.release()}") lines.append(f" Ollama: {OLLAMA_BASE}") lines.append("") return "\n".join(lines) def main(): parser = argparse.ArgumentParser(description="Benchmark local Ollama models vs 50 tok/s threshold") parser.add_argument("--models", help="Comma-separated model names (default: all)") parser.add_argument("--prompt", default=BENCHMARK_PROMPT, help="Benchmark prompt") parser.add_argument("--rounds", type=int, default=3, help="Rounds per model (default: 3)") parser.add_argument("--tokens", type=int, default=512, help="Max tokens to generate (default: 512)") parser.add_argument("--json", action="store_true", help="JSON output for CI") parser.add_argument("--all", action="store_true", help="Test all pulled models") parser.add_argument("--threshold", type=float, default=THRESHOLD_TOK_S, help="tok/s threshold") args = parser.parse_args() threshold = args.threshold # Get model list available = get_models() if not available: print("No models found. Pull a model first: ollama pull ", file=sys.stderr) sys.exit(1) if args.models: names = [m.strip() for m in args.models.split(",")] models = [m for m in available if m["name"] in names] missing = set(names) - set(m["name"] for m in models) if missing: print(f"Models not found: {', '.join(missing)}", file=sys.stderr) print(f"Available: {', '.join(m['name'] for m in available)}", file=sys.stderr) else: models = available print(f"Benchmarking {len(models)} model(s) against {threshold} tok/s threshold") print(f"Ollama: {OLLAMA_BASE} | Rounds: {args.rounds} | Max tokens: {args.tokens}") print() results = [] for m in models: name = m["name"] size_gb = m.get("size", 0) / (1024**3) print(f" {name} ({size_gb:.1f}GB):") result = run_benchmark(name, size_gb, args.prompt, args.rounds, args.tokens, threshold) results.append(result) # Output report = format_report(results, threshold) if args.json: output = { "threshold_tok_s": threshold, "ollama_base": OLLAMA_BASE, "rounds": args.rounds, "results": [asdict(r) for r in results], "passed": sum(1 for r in results if r.meets_threshold), "failed": sum(1 for r in results if not r.meets_threshold and not r.error), "errors": sum(1 for r in results if r.error), } print(json.dumps(output, indent=2)) else: print(report) # Exit code: 0 if all pass, 1 if any fail/error if any(not r.meets_threshold or r.error for r in results): sys.exit(1) sys.exit(0) if __name__ == "__main__": main()