190 lines
6.0 KiB
Python
190 lines
6.0 KiB
Python
#!/usr/bin/env python3
|
|
"""Apple Silicon DFlash planning helpers and CLI (issue #152)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import platform
|
|
import subprocess
|
|
from dataclasses import asdict, dataclass
|
|
from pathlib import Path
|
|
from typing import Iterable, Optional
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class DFlashPair:
|
|
slug: str
|
|
base_model: str
|
|
draft_model: str
|
|
estimated_total_weights_gb: float
|
|
minimum_recommended_memory_gb: float
|
|
draft_sliding_window_size: int = 4096
|
|
|
|
|
|
SUPPORTED_PAIRS: tuple[DFlashPair, ...] = (
|
|
DFlashPair(
|
|
slug="qwen35-4b",
|
|
base_model="Qwen/Qwen3.5-4B",
|
|
draft_model="z-lab/Qwen3.5-4B-DFlash",
|
|
estimated_total_weights_gb=9.68,
|
|
minimum_recommended_memory_gb=16.0,
|
|
),
|
|
DFlashPair(
|
|
slug="qwen35-9b",
|
|
base_model="Qwen/Qwen3.5-9B",
|
|
draft_model="z-lab/Qwen3.5-9B-DFlash",
|
|
estimated_total_weights_gb=19.93,
|
|
minimum_recommended_memory_gb=28.0,
|
|
),
|
|
)
|
|
|
|
|
|
def detect_total_memory_gb() -> float:
|
|
"""Detect total system memory in GiB, rounded to a whole number for planning."""
|
|
system = platform.system()
|
|
if system == "Darwin":
|
|
mem_bytes = int(subprocess.check_output(["sysctl", "-n", "hw.memsize"]).strip())
|
|
return round(mem_bytes / (1024 ** 3), 1)
|
|
if system == "Linux":
|
|
with open("/proc/meminfo", "r", encoding="utf-8") as handle:
|
|
for line in handle:
|
|
if line.startswith("MemTotal:"):
|
|
mem_kb = int(line.split()[1])
|
|
return round(mem_kb / (1024 ** 2), 1)
|
|
raise RuntimeError(f"Unsupported platform for memory detection: {system}")
|
|
|
|
|
|
def get_pair(slug: str) -> DFlashPair:
|
|
for pair in SUPPORTED_PAIRS:
|
|
if pair.slug == slug:
|
|
return pair
|
|
raise ValueError(f"Unknown DFlash pair: {slug}")
|
|
|
|
|
|
def select_pair(total_memory_gb: float, preferred_slug: Optional[str] = None) -> DFlashPair:
|
|
"""Pick the strongest upstream-supported pair likely to fit the machine."""
|
|
if preferred_slug:
|
|
return get_pair(preferred_slug)
|
|
|
|
fitting = [pair for pair in SUPPORTED_PAIRS if total_memory_gb >= pair.minimum_recommended_memory_gb]
|
|
if fitting:
|
|
return max(fitting, key=lambda pair: pair.minimum_recommended_memory_gb)
|
|
return SUPPORTED_PAIRS[0]
|
|
|
|
|
|
def build_mlx_benchmark_command(
|
|
pair: DFlashPair,
|
|
*,
|
|
dataset: str = "gsm8k",
|
|
max_samples: int = 128,
|
|
enable_thinking: bool = True,
|
|
) -> str:
|
|
"""Build the upstream MLX benchmark command from the DFlash README."""
|
|
parts = [
|
|
"python -m dflash.benchmark --backend mlx",
|
|
f"--model {pair.base_model}",
|
|
f"--draft-model {pair.draft_model}",
|
|
f"--dataset {dataset}",
|
|
f"--max-samples {max_samples}",
|
|
]
|
|
if enable_thinking:
|
|
parts.append("--enable-thinking")
|
|
parts.append(f"--draft-sliding-window-size {pair.draft_sliding_window_size}")
|
|
return " \\\n ".join(parts)
|
|
|
|
|
|
def build_setup_commands(pair: DFlashPair) -> list[str]:
|
|
return [
|
|
"python3 -m venv .venv-dflash",
|
|
"source .venv-dflash/bin/activate",
|
|
"git clone https://github.com/z-lab/dflash.git",
|
|
"cd dflash",
|
|
"pip install -e .[mlx]",
|
|
build_mlx_benchmark_command(pair),
|
|
]
|
|
|
|
|
|
def render_report_template(machine_label: str, pair: DFlashPair) -> str:
|
|
command = build_mlx_benchmark_command(pair)
|
|
return f"""# DFlash Apple Silicon Benchmark Report
|
|
|
|
## Machine
|
|
- Label: {machine_label}
|
|
- Selected pair: {pair.slug}
|
|
- Base model: {pair.base_model}
|
|
- Draft model: {pair.draft_model}
|
|
- Estimated total weight footprint: {pair.estimated_total_weights_gb:.2f} GB
|
|
|
|
## Setup
|
|
```bash
|
|
python3 -m venv .venv-dflash
|
|
source .venv-dflash/bin/activate
|
|
git clone https://github.com/z-lab/dflash.git
|
|
cd dflash
|
|
pip install -e .[mlx]
|
|
{command}
|
|
```
|
|
|
|
## Baseline comparison
|
|
Compare against **plain MLX or llama.cpp speculative decoding** on the same prompt set.
|
|
|
|
## Results
|
|
- Throughput (tok/s):
|
|
- Peak memory (GB):
|
|
- Notes on acceptance / behavior:
|
|
|
|
## Verdict
|
|
Worth operationalizing locally?
|
|
- [ ] Yes
|
|
- [ ] No
|
|
- [ ] Needs more data
|
|
|
|
## Recommendation
|
|
Explain whether this should become part of the local inference stack.
|
|
"""
|
|
|
|
|
|
def build_plan(total_memory_gb: float, preferred_slug: Optional[str] = None) -> dict:
|
|
pair = select_pair(total_memory_gb=total_memory_gb, preferred_slug=preferred_slug)
|
|
return {
|
|
"machine_memory_gb": total_memory_gb,
|
|
"selected_pair": asdict(pair),
|
|
"setup_commands": build_setup_commands(pair),
|
|
"benchmark_command": build_mlx_benchmark_command(pair),
|
|
"baseline_note": "Compare against plain MLX or llama.cpp speculative decoding on the same prompt set.",
|
|
}
|
|
|
|
|
|
def write_output(path: Path, content: str) -> None:
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
path.write_text(content, encoding="utf-8")
|
|
|
|
|
|
def main(argv: Optional[Iterable[str]] = None) -> int:
|
|
parser = argparse.ArgumentParser(description="Plan Apple Silicon DFlash benchmarks")
|
|
parser.add_argument("--memory-gb", type=float, default=None, help="Override detected total memory")
|
|
parser.add_argument("--pair", choices=[pair.slug for pair in SUPPORTED_PAIRS], default=None)
|
|
parser.add_argument("--machine-label", default="Apple Silicon Mac")
|
|
parser.add_argument("--format", choices=["json", "markdown"], default="markdown")
|
|
parser.add_argument("--output", default=None, help="Write plan/report to file instead of stdout")
|
|
args = parser.parse_args(list(argv) if argv is not None else None)
|
|
|
|
memory_gb = args.memory_gb if args.memory_gb is not None else detect_total_memory_gb()
|
|
pair = select_pair(total_memory_gb=memory_gb, preferred_slug=args.pair)
|
|
|
|
if args.format == "json":
|
|
content = json.dumps(build_plan(memory_gb, preferred_slug=pair.slug), indent=2)
|
|
else:
|
|
content = render_report_template(args.machine_label, pair)
|
|
|
|
if args.output:
|
|
write_output(Path(args.output), content)
|
|
else:
|
|
print(content)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|