#!/usr/bin/env python3 """Apple Silicon DFlash planning helpers and CLI (issue #152).""" from __future__ import annotations import argparse import json import platform import subprocess from dataclasses import asdict, dataclass from pathlib import Path from typing import Iterable, Optional @dataclass(frozen=True) class DFlashPair: slug: str base_model: str draft_model: str estimated_total_weights_gb: float minimum_recommended_memory_gb: float draft_sliding_window_size: int = 4096 SUPPORTED_PAIRS: tuple[DFlashPair, ...] = ( DFlashPair( slug="qwen35-4b", base_model="Qwen/Qwen3.5-4B", draft_model="z-lab/Qwen3.5-4B-DFlash", estimated_total_weights_gb=9.68, minimum_recommended_memory_gb=16.0, ), DFlashPair( slug="qwen35-9b", base_model="Qwen/Qwen3.5-9B", draft_model="z-lab/Qwen3.5-9B-DFlash", estimated_total_weights_gb=19.93, minimum_recommended_memory_gb=28.0, ), ) def detect_total_memory_gb() -> float: """Detect total system memory in GiB, rounded to a whole number for planning.""" system = platform.system() if system == "Darwin": mem_bytes = int(subprocess.check_output(["sysctl", "-n", "hw.memsize"]).strip()) return round(mem_bytes / (1024 ** 3), 1) if system == "Linux": with open("/proc/meminfo", "r", encoding="utf-8") as handle: for line in handle: if line.startswith("MemTotal:"): mem_kb = int(line.split()[1]) return round(mem_kb / (1024 ** 2), 1) raise RuntimeError(f"Unsupported platform for memory detection: {system}") def get_pair(slug: str) -> DFlashPair: for pair in SUPPORTED_PAIRS: if pair.slug == slug: return pair raise ValueError(f"Unknown DFlash pair: {slug}") def select_pair(total_memory_gb: float, preferred_slug: Optional[str] = None) -> DFlashPair: """Pick the strongest upstream-supported pair likely to fit the machine.""" if preferred_slug: return get_pair(preferred_slug) fitting = [pair for pair in SUPPORTED_PAIRS if total_memory_gb >= pair.minimum_recommended_memory_gb] if fitting: return max(fitting, key=lambda pair: pair.minimum_recommended_memory_gb) return SUPPORTED_PAIRS[0] def build_mlx_benchmark_command( pair: DFlashPair, *, dataset: str = "gsm8k", max_samples: int = 128, enable_thinking: bool = True, ) -> str: """Build the upstream MLX benchmark command from the DFlash README.""" parts = [ "python -m dflash.benchmark --backend mlx", f"--model {pair.base_model}", f"--draft-model {pair.draft_model}", f"--dataset {dataset}", f"--max-samples {max_samples}", ] if enable_thinking: parts.append("--enable-thinking") parts.append(f"--draft-sliding-window-size {pair.draft_sliding_window_size}") return " \\\n ".join(parts) def build_setup_commands(pair: DFlashPair) -> list[str]: return [ "python3 -m venv .venv-dflash", "source .venv-dflash/bin/activate", "git clone https://github.com/z-lab/dflash.git", "cd dflash", "pip install -e .[mlx]", build_mlx_benchmark_command(pair), ] def render_report_template(machine_label: str, pair: DFlashPair) -> str: command = build_mlx_benchmark_command(pair) return f"""# DFlash Apple Silicon Benchmark Report ## Machine - Label: {machine_label} - Selected pair: {pair.slug} - Base model: {pair.base_model} - Draft model: {pair.draft_model} - Estimated total weight footprint: {pair.estimated_total_weights_gb:.2f} GB ## Setup ```bash python3 -m venv .venv-dflash source .venv-dflash/bin/activate git clone https://github.com/z-lab/dflash.git cd dflash pip install -e .[mlx] {command} ``` ## Baseline comparison Compare against **plain MLX or llama.cpp speculative decoding** on the same prompt set. ## Results - Throughput (tok/s): - Peak memory (GB): - Notes on acceptance / behavior: ## Verdict Worth operationalizing locally? - [ ] Yes - [ ] No - [ ] Needs more data ## Recommendation Explain whether this should become part of the local inference stack. """ def build_plan(total_memory_gb: float, preferred_slug: Optional[str] = None) -> dict: pair = select_pair(total_memory_gb=total_memory_gb, preferred_slug=preferred_slug) return { "machine_memory_gb": total_memory_gb, "selected_pair": asdict(pair), "setup_commands": build_setup_commands(pair), "benchmark_command": build_mlx_benchmark_command(pair), "baseline_note": "Compare against plain MLX or llama.cpp speculative decoding on the same prompt set.", } def write_output(path: Path, content: str) -> None: path.parent.mkdir(parents=True, exist_ok=True) path.write_text(content, encoding="utf-8") def main(argv: Optional[Iterable[str]] = None) -> int: parser = argparse.ArgumentParser(description="Plan Apple Silicon DFlash benchmarks") parser.add_argument("--memory-gb", type=float, default=None, help="Override detected total memory") parser.add_argument("--pair", choices=[pair.slug for pair in SUPPORTED_PAIRS], default=None) parser.add_argument("--machine-label", default="Apple Silicon Mac") parser.add_argument("--format", choices=["json", "markdown"], default="markdown") parser.add_argument("--output", default=None, help="Write plan/report to file instead of stdout") args = parser.parse_args(list(argv) if argv is not None else None) memory_gb = args.memory_gb if args.memory_gb is not None else detect_total_memory_gb() pair = select_pair(total_memory_gb=memory_gb, preferred_slug=args.pair) if args.format == "json": content = json.dumps(build_plan(memory_gb, preferred_slug=pair.slug), indent=2) else: content = render_report_template(args.machine_label, pair) if args.output: write_output(Path(args.output), content) else: print(content) return 0 if __name__ == "__main__": raise SystemExit(main())