feat: add Apple Silicon DFlash benchmark planner (refs #152)
All checks were successful
Smoke Test / smoke (pull_request) Successful in 18s
All checks were successful
Smoke Test / smoke (pull_request) Successful in 18s
This commit is contained in:
189
benchmarks/dflash_apple_silicon.py
Normal file
189
benchmarks/dflash_apple_silicon.py
Normal file
@@ -0,0 +1,189 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Apple Silicon DFlash planning helpers and CLI (issue #152)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import platform
|
||||
import subprocess
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DFlashPair:
|
||||
slug: str
|
||||
base_model: str
|
||||
draft_model: str
|
||||
estimated_total_weights_gb: float
|
||||
minimum_recommended_memory_gb: float
|
||||
draft_sliding_window_size: int = 4096
|
||||
|
||||
|
||||
SUPPORTED_PAIRS: tuple[DFlashPair, ...] = (
|
||||
DFlashPair(
|
||||
slug="qwen35-4b",
|
||||
base_model="Qwen/Qwen3.5-4B",
|
||||
draft_model="z-lab/Qwen3.5-4B-DFlash",
|
||||
estimated_total_weights_gb=9.68,
|
||||
minimum_recommended_memory_gb=16.0,
|
||||
),
|
||||
DFlashPair(
|
||||
slug="qwen35-9b",
|
||||
base_model="Qwen/Qwen3.5-9B",
|
||||
draft_model="z-lab/Qwen3.5-9B-DFlash",
|
||||
estimated_total_weights_gb=19.93,
|
||||
minimum_recommended_memory_gb=28.0,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def detect_total_memory_gb() -> float:
|
||||
"""Detect total system memory in GiB, rounded to a whole number for planning."""
|
||||
system = platform.system()
|
||||
if system == "Darwin":
|
||||
mem_bytes = int(subprocess.check_output(["sysctl", "-n", "hw.memsize"]).strip())
|
||||
return round(mem_bytes / (1024 ** 3), 1)
|
||||
if system == "Linux":
|
||||
with open("/proc/meminfo", "r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
if line.startswith("MemTotal:"):
|
||||
mem_kb = int(line.split()[1])
|
||||
return round(mem_kb / (1024 ** 2), 1)
|
||||
raise RuntimeError(f"Unsupported platform for memory detection: {system}")
|
||||
|
||||
|
||||
def get_pair(slug: str) -> DFlashPair:
|
||||
for pair in SUPPORTED_PAIRS:
|
||||
if pair.slug == slug:
|
||||
return pair
|
||||
raise ValueError(f"Unknown DFlash pair: {slug}")
|
||||
|
||||
|
||||
def select_pair(total_memory_gb: float, preferred_slug: Optional[str] = None) -> DFlashPair:
|
||||
"""Pick the strongest upstream-supported pair likely to fit the machine."""
|
||||
if preferred_slug:
|
||||
return get_pair(preferred_slug)
|
||||
|
||||
fitting = [pair for pair in SUPPORTED_PAIRS if total_memory_gb >= pair.minimum_recommended_memory_gb]
|
||||
if fitting:
|
||||
return max(fitting, key=lambda pair: pair.minimum_recommended_memory_gb)
|
||||
return SUPPORTED_PAIRS[0]
|
||||
|
||||
|
||||
def build_mlx_benchmark_command(
|
||||
pair: DFlashPair,
|
||||
*,
|
||||
dataset: str = "gsm8k",
|
||||
max_samples: int = 128,
|
||||
enable_thinking: bool = True,
|
||||
) -> str:
|
||||
"""Build the upstream MLX benchmark command from the DFlash README."""
|
||||
parts = [
|
||||
"python -m dflash.benchmark --backend mlx",
|
||||
f"--model {pair.base_model}",
|
||||
f"--draft-model {pair.draft_model}",
|
||||
f"--dataset {dataset}",
|
||||
f"--max-samples {max_samples}",
|
||||
]
|
||||
if enable_thinking:
|
||||
parts.append("--enable-thinking")
|
||||
parts.append(f"--draft-sliding-window-size {pair.draft_sliding_window_size}")
|
||||
return " \\\n ".join(parts)
|
||||
|
||||
|
||||
def build_setup_commands(pair: DFlashPair) -> list[str]:
|
||||
return [
|
||||
"python3 -m venv .venv-dflash",
|
||||
"source .venv-dflash/bin/activate",
|
||||
"git clone https://github.com/z-lab/dflash.git",
|
||||
"cd dflash",
|
||||
"pip install -e .[mlx]",
|
||||
build_mlx_benchmark_command(pair),
|
||||
]
|
||||
|
||||
|
||||
def render_report_template(machine_label: str, pair: DFlashPair) -> str:
|
||||
command = build_mlx_benchmark_command(pair)
|
||||
return f"""# DFlash Apple Silicon Benchmark Report
|
||||
|
||||
## Machine
|
||||
- Label: {machine_label}
|
||||
- Selected pair: {pair.slug}
|
||||
- Base model: {pair.base_model}
|
||||
- Draft model: {pair.draft_model}
|
||||
- Estimated total weight footprint: {pair.estimated_total_weights_gb:.2f} GB
|
||||
|
||||
## Setup
|
||||
```bash
|
||||
python3 -m venv .venv-dflash
|
||||
source .venv-dflash/bin/activate
|
||||
git clone https://github.com/z-lab/dflash.git
|
||||
cd dflash
|
||||
pip install -e .[mlx]
|
||||
{command}
|
||||
```
|
||||
|
||||
## Baseline comparison
|
||||
Compare against **plain MLX or llama.cpp speculative decoding** on the same prompt set.
|
||||
|
||||
## Results
|
||||
- Throughput (tok/s):
|
||||
- Peak memory (GB):
|
||||
- Notes on acceptance / behavior:
|
||||
|
||||
## Verdict
|
||||
Worth operationalizing locally?
|
||||
- [ ] Yes
|
||||
- [ ] No
|
||||
- [ ] Needs more data
|
||||
|
||||
## Recommendation
|
||||
Explain whether this should become part of the local inference stack.
|
||||
"""
|
||||
|
||||
|
||||
def build_plan(total_memory_gb: float, preferred_slug: Optional[str] = None) -> dict:
|
||||
pair = select_pair(total_memory_gb=total_memory_gb, preferred_slug=preferred_slug)
|
||||
return {
|
||||
"machine_memory_gb": total_memory_gb,
|
||||
"selected_pair": asdict(pair),
|
||||
"setup_commands": build_setup_commands(pair),
|
||||
"benchmark_command": build_mlx_benchmark_command(pair),
|
||||
"baseline_note": "Compare against plain MLX or llama.cpp speculative decoding on the same prompt set.",
|
||||
}
|
||||
|
||||
|
||||
def write_output(path: Path, content: str) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(content, encoding="utf-8")
|
||||
|
||||
|
||||
def main(argv: Optional[Iterable[str]] = None) -> int:
|
||||
parser = argparse.ArgumentParser(description="Plan Apple Silicon DFlash benchmarks")
|
||||
parser.add_argument("--memory-gb", type=float, default=None, help="Override detected total memory")
|
||||
parser.add_argument("--pair", choices=[pair.slug for pair in SUPPORTED_PAIRS], default=None)
|
||||
parser.add_argument("--machine-label", default="Apple Silicon Mac")
|
||||
parser.add_argument("--format", choices=["json", "markdown"], default="markdown")
|
||||
parser.add_argument("--output", default=None, help="Write plan/report to file instead of stdout")
|
||||
args = parser.parse_args(list(argv) if argv is not None else None)
|
||||
|
||||
memory_gb = args.memory_gb if args.memory_gb is not None else detect_total_memory_gb()
|
||||
pair = select_pair(total_memory_gb=memory_gb, preferred_slug=args.pair)
|
||||
|
||||
if args.format == "json":
|
||||
content = json.dumps(build_plan(memory_gb, preferred_slug=pair.slug), indent=2)
|
||||
else:
|
||||
content = render_report_template(args.machine_label, pair)
|
||||
|
||||
if args.output:
|
||||
write_output(Path(args.output), content)
|
||||
else:
|
||||
print(content)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user