#!/usr/bin/env python3 """ TurboQuant Constant-Time Benchmark — Issue #72 Benchmarks constant-time (side-channel resistant) vs original quantization. Measures encode latency, decode latency, and memory bandwidth impact. Usage: python3 benchmarks/constant_time_benchmark.py --size 4096 --iterations 100 python3 benchmarks/constant_time_benchmark.py --json """ import argparse import json import os import statistics import sys import time from datetime import datetime, timezone from pathlib import Path from typing import Callable # --------------------------------------------------------------------------- # Quantization kernels (Python reference implementations) # --------------------------------------------------------------------------- import struct import math def quantize_fp16_to_q4_0_original(weights: list[float]) -> bytes: """Original quantization: FP16 → Q4_0 (block size 32). Each block: 2 bytes scale (FP16) + 16 bytes quants (4-bit packed). Non-constant-time: early exits, branching on zero detection. """ block_size = 32 n_blocks = len(weights) // block_size output = bytearray() for b in range(n_blocks): block = weights[b * block_size:(b + 1) * block_size] # Find absmax absmax = 0.0 for w in block: absmax = max(absmax, abs(w)) if absmax == 0.0: # Early exit — branch prediction leak output.extend(struct.pack(' bytes: """Constant-time quantization: FP16 → Q4_0. No early exits, no branches on data values. Same output as original but timing does not leak information about weight distribution. """ block_size = 32 n_blocks = len(weights) // block_size output = bytearray() for b in range(n_blocks): block = weights[b * block_size:(b + 1) * block_size] # Find absmax — no early exit on zero absmax = 0.0 for w in block: absval = abs(w) # Constant-time max: no branch, always compute both paths absmax = absval if absval > absmax else absmax # Constant-time scale computation — no branch on zero d = absmax / 7.0 # Constant-time inverse: compute 1/d but guard against zero d_nonzero = 1.0 if d != 0.0 else 0.0 safe_d = d if d != 0.0 else 1.0 # Avoid division by zero id_val = (1.0 / safe_d) * d_nonzero # Always compute quants (even when scale=0, producing all zeros) packed = bytearray(16) for i in range(0, block_size, 2): xi0 = int(round(block[i] * id_val)) + 8 xi1 = int(round(block[i + 1] * id_val)) + 8 if i + 1 < block_size else 8 # Constant-time clamp: no branch xi0 = max(0, min(15, xi0)) xi1 = max(0, min(15, xi1)) packed[i // 2] = xi0 | (xi1 << 4) output.extend(struct.pack(' list[float]: """Original dequantization: Q4_0 → FP32.""" block_size = 32 bytes_per_block = 18 # 2 scale + 16 quants n_blocks = n // block_size weights = [] for b in range(n_blocks): offset = b * bytes_per_block d = struct.unpack_from('> 4) & 0x0F) - 8 weights.append(xi0 * d) if len(weights) < n: weights.append(xi1 * d) return weights[:n] def dequantize_q4_0_constant_time(data: bytes, n: int) -> list[float]: """Constant-time dequantization: Q4_0 → FP32.""" block_size = 32 bytes_per_block = 18 n_blocks = n // block_size weights = [] for b in range(n_blocks): offset = b * bytes_per_block d = struct.unpack_from('> 4) & 0x0F) - 8 if len(weights) < n: weights.append(xi0 * d) if len(weights) < n: weights.append(xi1 * d) return weights[:n] # --------------------------------------------------------------------------- # Benchmark harness # --------------------------------------------------------------------------- def benchmark(fn: Callable, args: tuple, iterations: int) -> dict: """Benchmark a function over N iterations.""" # Warmup for _ in range(min(3, iterations)): fn(*args) latencies = [] for _ in range(iterations): start = time.perf_counter() fn(*args) elapsed = time.perf_counter() - start latencies.append(elapsed * 1000) # ms return { "iterations": iterations, "mean_ms": round(statistics.mean(latencies), 4), "median_ms": round(statistics.median(latencies), 4), "std_ms": round(statistics.stdev(latencies) if len(latencies) > 1 else 0, 4), "min_ms": round(min(latencies), 4), "max_ms": round(max(latencies), 4), "p95_ms": round(sorted(latencies)[int(len(latencies) * 0.95)], 4), "p99_ms": round(sorted(latencies)[int(len(latencies) * 0.99)], 4), } def generate_weights(size: int) -> list[float]: """Generate test weights.""" import random random.seed(42) return [random.gauss(0, 1) for _ in range(size)] def run_benchmarks(size: int, iterations: int) -> dict: """Run full benchmark suite.""" weights = generate_weights(size) print(f"Benchmarking {size} weights x {iterations} iterations...", file=sys.stderr) # Encode benchmarks print(" Encode original...", file=sys.stderr) encode_orig = benchmark(quantize_fp16_to_q4_0_original, (weights,), iterations) print(" Encode constant-time...", file=sys.stderr) encode_ct = benchmark(quantize_fp16_to_q4_0_constant_time, (weights,), iterations) # Decode benchmarks encoded_orig = quantize_fp16_to_q4_0_original(weights) print(" Decode original...", file=sys.stderr) decode_orig = benchmark(dequantize_q4_0_original, (encoded_orig, size), iterations) encoded_ct = quantize_fp16_to_q4_0_constant_time(weights) print(" Decode constant-time...", file=sys.stderr) decode_ct = benchmark(dequantize_q4_0_constant_time, (encoded_ct, size), iterations) # Correctness check decoded_orig = dequantize_q4_0_original(encoded_orig, size) decoded_ct = dequantize_q4_0_constant_time(encoded_ct, size) max_diff = max(abs(a - b) for a, b in zip(decoded_orig, decoded_ct)) # Overhead analysis encode_overhead = (encode_ct["mean_ms"] / max(encode_orig["mean_ms"], 0.001) - 1) * 100 decode_overhead = (decode_ct["mean_ms"] / max(decode_orig["mean_ms"], 0.001) - 1) * 100 return { "generated_at": datetime.now(timezone.utc).isoformat(), "config": {"weight_count": size, "iterations": iterations, "block_size": 32}, "encode": {"original": encode_orig, "constant_time": encode_ct}, "decode": {"original": decode_orig, "constant_time": decode_ct}, "correctness": { "max_decode_diff": round(max_diff, 10), "outputs_match": max_diff < 1e-6, }, "overhead": { "encode_pct": round(encode_overhead, 2), "decode_pct": round(decode_overhead, 2), }, "memory": { "original_bytes": len(encoded_orig), "constant_time_bytes": len(encoded_ct), "compression_ratio": round(size * 4 / len(encoded_orig), 2), }, } def to_markdown(report: dict) -> str: enc = report["encode"] dec = report["decode"] ov = report["overhead"] mem = report["memory"] cor = report["correctness"] lines = [ "# Constant-Time Benchmark Report", "", f"Generated: {report['generated_at'][:16]}", f"Config: {report['config']['weight_count']} weights, {report['config']['iterations']} iterations", "", "## Encode Latency", "", "| Impl | Mean (ms) | Median | P95 | P99 | Overhead |", "|------|-----------|--------|-----|-----|----------|", f"| Original | {enc['original']['mean_ms']:.2f} | {enc['original']['median_ms']:.2f} | {enc['original']['p95_ms']:.2f} | {enc['original']['p99_ms']:.2f} | baseline |", f"| Constant-time | {enc['constant_time']['mean_ms']:.2f} | {enc['constant_time']['median_ms']:.2f} | {enc['constant_time']['p95_ms']:.2f} | {enc['constant_time']['p99_ms']:.2f} | +{ov['encode_pct']:.1f}% |", "", "## Decode Latency", "", "| Impl | Mean (ms) | Median | P95 | P99 | Overhead |", "|------|-----------|--------|-----|-----|----------|", f"| Original | {dec['original']['mean_ms']:.2f} | {dec['original']['median_ms']:.2f} | {dec['original']['p95_ms']:.2f} | {dec['original']['p99_ms']:.2f} | baseline |", f"| Constant-time | {dec['constant_time']['mean_ms']:.2f} | {dec['constant_time']['median_ms']:.2f} | {dec['constant_time']['p95_ms']:.2f} | {dec['constant_time']['p99_ms']:.2f} | +{ov['decode_pct']:.1f}% |", "", "## Correctness", "", f"- Max decode difference: {cor['max_decode_diff']:.10f}", f"- Outputs match: {'✅ Yes' if cor['outputs_match'] else '❌ No'}", "", "## Memory", "", f"- Compressed size: {mem['original_bytes']} bytes ({mem['compression_ratio']:.1f}x compression)", f"- Constant-time size: {mem['constant_time_bytes']} bytes (same format)", "", "## Verdict", "", ] if ov['encode_pct'] < 10 and ov['decode_pct'] < 10: lines.append("**Constant-time overhead is acceptable (<10%).** Safe for production.") elif ov['encode_pct'] < 25 and ov['decode_pct'] < 25: lines.append("**Constant-time overhead is moderate (10-25%).** Acceptible for security-sensitive deployments.") else: lines.append("**Constant-time overhead is significant (>25%).** Consider optimizing or using original for non-sensitive workloads.") return "\n".join(lines) def main(): parser = argparse.ArgumentParser(description="Constant-time benchmark") parser.add_argument("--size", type=int, default=4096, help="Weight count") parser.add_argument("--iterations", type=int, default=100, help="Iterations") parser.add_argument("--json", action="store_true", help="JSON output") args = parser.parse_args() report = run_benchmarks(args.size, args.iterations) if args.json: print(json.dumps(report, indent=2)) else: print(to_markdown(report)) if __name__ == "__main__": main()