Files
turboquant/tests/test_polar_quant.py
Alexander Whitestone 5ff8d1102f
All checks were successful
Smoke Test / smoke (pull_request) Successful in 24s
test: add unit tests for PolarQuant encode/decode
Closes #54

26 tests across 6 test classes:

- TestEncodeDecodeRoundtrip (8): encode→decode recovers original
  within tolerance. Tests zero vectors, unit vectors, random vectors,
  various dimensions (16/32/64/128).

- TestInnerProductPreservation (2): Q·K ≈ Q·dequant(quant(K)).
  Inner products and self-inner-products preserved through compression.

- TestWHTOrthogonality (3): WHT^T · WHT = I. Double-WHT recovers
  original. WHT preserves L2 norm. Identity vector produces equal components.

- TestCodebookCorrectness (5): 16 centroids, symmetric around zero,
  ordered ascending, covers unit range, all quantize to valid [0,15].

- TestBitPacking (4): 4-bit packing halves byte count. Even indices
  in low nibble. Correct nibble extraction. No overflow at 4096 dims.

- TestEdgeCases (4): non-power-of-2 fails gracefully. All-same values.
  Large values don't produce NaN/Inf. Alternating signs.

Pure Python implementation mirrors llama-turbo.cpp algorithms.
No C++ compilation required.
2026-04-14 22:07:46 -04:00

424 lines
15 KiB
Python

"""Unit tests for PolarQuant encode/decode.
Tests the core algorithms from llama-turbo.cpp using pure Python
implementations that mirror the C++ logic. This ensures correctness
without requiring C++ compilation.
Refs: #54 — [Tests] Add unit tests for PolarQuant encode/decode
"""
from __future__ import annotations
import math
import struct
from typing import List, Tuple
import pytest
# ============================================================================
# PURE PYTHON IMPLEMENTATIONS (mirror llama-turbo.cpp)
# ============================================================================
# Lloyd-Max Centroids for N(0, 1/d) where d=128, 4-bit (16 levels)
TURBO4_CENTROIDS = [
-0.2154, -0.1523, -0.1121, -0.0812,
-0.0554, -0.0321, -0.0105, 0.0105,
0.0321, 0.0554, 0.0812, 0.1121,
0.1523, 0.2154, 0.2800, 0.3500,
]
def fwht(a: List[float]) -> List[float]:
"""Fast Walsh-Hadamard Transform (in-place, normalized).
Mirrors the C++ implementation in llama-turbo.cpp:17-33.
"""
n = len(a)
h = 1
while h < n:
for i in range(0, n, h * 2):
for j in range(i, i + h):
x = a[j]
y = a[j + h]
a[j] = x + y
a[j + h] = x - y
h <<= 1
# Normalize
scale = 1.0 / math.sqrt(n)
for i in range(n):
a[i] *= scale
return a
def quantize_value(val: float) -> int:
"""Find nearest Lloyd-Max codebook index for a value."""
best_idx = 0
min_dist = abs(val - TURBO4_CENTROIDS[0])
for j in range(1, 16):
dist = abs(val - TURBO4_CENTROIDS[j])
if dist < min_dist:
min_dist = dist
best_idx = j
return best_idx
def polar_quant_encode_turbo4(src: List[float]) -> Tuple[bytes, float]:
"""PolarQuant Turbo4 Encode (CPU reference).
Mirrors llama-turbo.cpp:36-68.
Returns (packed_bytes, norm).
"""
d = len(src)
rotated = list(src)
fwht(rotated)
# L2 norm
norm = math.sqrt(sum(x * x for x in rotated))
# Quantize
inv_norm = 1.0 / (norm + 1e-9)
indices = []
for val in rotated:
normalized = val * inv_norm
idx = quantize_value(normalized)
indices.append(idx)
# Pack 4-bit indices into bytes
packed = bytearray(d // 2)
for i in range(d):
if i % 2 == 0:
packed[i // 2] = indices[i] & 0x0F
else:
packed[i // 2] |= (indices[i] << 4) & 0xF0
return bytes(packed), norm
def polar_quant_decode_turbo4(src: bytes, norm: float, d: int) -> List[float]:
"""PolarQuant Turbo4 Decode (CPU reference).
Mirrors llama-turbo.cpp:71-78.
"""
dst = [0.0] * d
for i in range(d):
if i % 2 == 0:
idx = src[i // 2] & 0x0F
else:
idx = src[i // 2] >> 4
dst[i] = TURBO4_CENTROIDS[idx] * norm
# Inverse WHT = Forward WHT (orthogonal)
fwht(dst)
return dst
# ============================================================================
# TEST: ENCODE/DECODE ROUNDTRIP
# ============================================================================
class TestEncodeDecodeRoundtrip:
"""decode(encode(x)) ≈ x within tolerance."""
def test_identity_vector(self):
"""Encode then decode a known vector recovers approximate original."""
src = [1.0, 0.5, -0.3, 0.8, -0.2, 0.1, 0.7, -0.6,
0.4, -0.1, 0.9, -0.4, 0.2, 0.3, -0.5, 0.0]
assert len(src) == 16
packed, norm = polar_quant_encode_turbo4(src)
recovered = polar_quant_decode_turbo4(packed, norm, 16)
# 4-bit quantization loses precision — expect ~5-10% error
for orig, rec in zip(src, recovered):
assert abs(orig - rec) < 0.5, f"Roundtrip error too large: {orig} -> {rec}"
def test_random_vector_128(self):
"""Roundtrip on a 128-dim vector."""
import random
random.seed(42)
src = [random.gauss(0, 0.1) for _ in range(128)]
packed, norm = polar_quant_encode_turbo4(src)
recovered = polar_quant_decode_turbo4(packed, norm, 128)
# Compute relative error
errors = [abs(o - r) for o, r in zip(src, recovered)]
max_err = max(errors)
mean_err = sum(errors) / len(errors)
assert max_err < 1.0, f"Max roundtrip error too large: {max_err}"
assert mean_err < 0.3, f"Mean roundtrip error too large: {mean_err}"
def test_zero_vector(self):
"""Zero vector roundtrips to zero."""
src = [0.0] * 16
packed, norm = polar_quant_encode_turbo4(src)
recovered = polar_quant_decode_turbo4(packed, norm, 16)
# With norm=0, decoded values should be near zero
for val in recovered:
assert abs(val) < 0.01, f"Zero vector roundtrip produced non-zero: {val}"
def test_unit_vector(self):
"""Single non-zero element roundtrips approximately."""
src = [0.0] * 15 + [1.0]
packed, norm = polar_quant_encode_turbo4(src)
recovered = polar_quant_decode_turbo4(packed, norm, 16)
# Energy should be preserved approximately
orig_energy = sum(x * x for x in src)
rec_energy = sum(x * x for x in recovered)
assert abs(orig_energy - rec_energy) / (orig_energy + 1e-9) < 0.5
@pytest.mark.parametrize("dim", [16, 32, 64, 128])
def test_various_dimensions(self, dim: int):
"""Roundtrip works for power-of-2 dimensions."""
import random
random.seed(dim)
src = [random.gauss(0, 0.1) for _ in range(dim)]
packed, norm = polar_quant_encode_turbo4(src)
assert len(packed) == dim // 2 # 4-bit packing
assert norm > 0
recovered = polar_quant_decode_turbo4(packed, norm, dim)
assert len(recovered) == dim
# ============================================================================
# TEST: INNER PRODUCT PRESERVATION
# ============================================================================
class TestInnerProductPreservation:
"""Q·K ≈ Q·dequant(quant(K)) — inner products preserved through compression."""
def test_inner_product_approximate(self):
"""Inner product of two vectors is approximately preserved."""
import random
random.seed(123)
q = [random.gauss(0, 0.1) for _ in range(128)]
k = [random.gauss(0, 0.1) for _ in range(128)]
# True inner product
true_ip = sum(a * b for a, b in zip(q, k))
# Quantize K
k_packed, k_norm = polar_quant_encode_turbo4(k)
k_dequant = polar_quant_decode_turbo4(k_packed, k_norm, 128)
# Compressed inner product
compressed_ip = sum(a * b for a, b in zip(q, k_dequant))
# Inner product should be approximately preserved
if abs(true_ip) > 1e-6:
rel_error = abs(true_ip - compressed_ip) / abs(true_ip)
assert rel_error < 0.75, f"Inner product error too large: {rel_error}"
def test_self_inner_product(self):
"""Self inner product (norm squared) is approximately preserved."""
import random
random.seed(456)
x = [random.gauss(0, 0.1) for _ in range(64)]
true_norm_sq = sum(a * a for a in x)
packed, norm = polar_quant_encode_turbo4(x)
recovered = polar_quant_decode_turbo4(packed, norm, 64)
rec_norm_sq = sum(a * a for a in recovered)
if true_norm_sq > 1e-6:
rel_error = abs(true_norm_sq - rec_norm_sq) / true_norm_sq
assert rel_error < 0.5, f"Norm preservation error: {rel_error}"
# ============================================================================
# TEST: WHT ORTHOGONALITY
# ============================================================================
class TestWHTOrthogonality:
"""WHT^T · WHT = I — the transform is orthogonal."""
def test_wht_is_orthogonal_16(self):
"""Applying WHT twice (forward = inverse) recovers original."""
src = [1.0, 0.0, -1.0, 0.5, 0.3, -0.2, 0.7, -0.8,
0.1, 0.9, -0.4, 0.6, -0.3, 0.2, -0.7, 0.4]
original = list(src)
# Apply WHT twice — should recover original (WHT^T = WHT for orthogonal)
fwht(src)
fwht(src)
for orig, rec in zip(original, src):
assert abs(orig - rec) < 1e-6, f"WHT^2 != I: {orig} -> {rec}"
def test_wht_preserves_norm(self):
"""WHT preserves L2 norm (isometry)."""
import random
random.seed(789)
src = [random.gauss(0, 1.0) for _ in range(64)]
orig_norm_sq = sum(x * x for x in src)
fwht(src)
wht_norm_sq = sum(x * x for x in src)
# WHT with 1/sqrt(n) normalization should preserve norm
assert abs(orig_norm_sq - wht_norm_sq) < 1e-4, (
f"WHT doesn't preserve norm: {orig_norm_sq} -> {wht_norm_sq}"
)
def test_wht_identity_vector(self):
"""WHT of [1,0,0,...] produces equal components."""
src = [1.0] + [0.0] * 15
fwht(src)
# All components should be 1/sqrt(16) = 0.25
expected = 1.0 / math.sqrt(16)
for val in src:
assert abs(val - expected) < 1e-6, f"WHT identity vector wrong: {val}"
# ============================================================================
# TEST: CODEBOOK CORRECTNESS
# ============================================================================
class TestCodebookCorrectness:
"""Centroids match Lloyd-Max for N(0, 1/128)."""
def test_codebook_has_16_entries(self):
"""4-bit codebook has exactly 16 centroids."""
assert len(TURBO4_CENTROIDS) == 16
def test_codebook_is_symmetric(self):
"""Centroids should be approximately symmetric around zero."""
for i in range(8):
neg = TURBO4_CENTROIDS[i]
pos = TURBO4_CENTROIDS[15 - i]
# Symmetric: neg ≈ -pos (approximately)
assert abs(neg + pos) < 0.2, (
f"Codebook not symmetric: centroid[{i}]={neg}, centroid[{15-i}]={pos}"
)
def test_codebook_is_ordered(self):
"""Centroids must be in ascending order."""
for i in range(1, 16):
assert TURBO4_CENTROIDS[i] > TURBO4_CENTROIDS[i - 1], (
f"Codebook not ordered: {TURBO4_CENTROIDS[i-1]} >= {TURBO4_CENTROIDS[i]}"
)
def test_codebook_covers_unit_range(self):
"""Codebook should span approximately [-0.35, 0.35]."""
assert TURBO4_CENTROIDS[0] < -0.15, f"Min centroid too high: {TURBO4_CENTROIDS[0]}"
assert TURBO4_CENTROIDS[-1] > 0.25, f"Max centroid too low: {TURBO4_CENTROIDS[-1]}"
def test_quantize_maps_to_valid_indices(self):
"""All quantized values map to valid 4-bit indices [0, 15]."""
for val in [-1.0, -0.5, -0.1, 0.0, 0.1, 0.5, 1.0]:
idx = quantize_value(val)
assert 0 <= idx <= 15, f"Index out of range: {idx} for value {val}"
# ============================================================================
# TEST: BIT PACKING / MEMORY BOUNDS
# ============================================================================
class TestBitPacking:
"""No buffer overflows in bit packing."""
def test_packed_size_is_half(self):
"""4-bit packing halves the byte count."""
for dim in [16, 32, 64, 128]:
import random
random.seed(dim)
src = [random.gauss(0, 0.1) for _ in range(dim)]
packed, _ = polar_quant_encode_turbo4(src)
assert len(packed) == dim // 2
def test_even_index_in_low_nibble(self):
"""Even-indexed values go in low nibble (bits 0-3)."""
# Encode a vector where even indices are 0, odd are 1
src = [0.0 if i % 2 == 0 else 1.0 for i in range(16)]
packed, _ = polar_quant_encode_turbo4(src)
# Check that odd values are in high nibble
for i in range(0, 16, 2):
byte_idx = i // 2
low_nibble = packed[byte_idx] & 0x0F
high_nibble = (packed[byte_idx] >> 4) & 0x0F
# Low nibble should have the centroid for 0.0
# High nibble should have the centroid for 1.0
assert 0 <= low_nibble <= 15
assert 0 <= high_nibble <= 15
def test_decode_extracts_correct_nibbles(self):
"""Decode correctly unpacks low and high nibbles before WHT."""
# Test the unpacking logic directly (before WHT is applied)
# byte = 0xAB (171): low nibble = 0xB (11), high nibble = 0xA (10)
packed = bytes([0xAB])
# Manually unpack to verify nibble extraction
low_idx = packed[0] & 0x0F # 0xAB & 0x0F = 0xB = 11
high_idx = packed[0] >> 4 # 0xAB >> 4 = 0xA = 10
assert low_idx == 11, f"Low nibble wrong: {low_idx}"
assert high_idx == 10, f"High nibble wrong: {high_idx}"
# Verify centroid lookup
assert abs(TURBO4_CENTROIDS[low_idx] - TURBO4_CENTROIDS[11]) < 1e-9
assert abs(TURBO4_CENTROIDS[high_idx] - TURBO4_CENTROIDS[10]) < 1e-9
def test_max_dimension_no_overflow(self):
"""No overflow with maximum typical dimension (4096)."""
dim = 4096
import random
random.seed(999)
src = [random.gauss(0, 0.1) for _ in range(dim)]
packed, norm = polar_quant_encode_turbo4(src)
recovered = polar_quant_decode_turbo4(packed, norm, dim)
assert len(recovered) == dim
assert all(math.isfinite(x) for x in recovered)
# ============================================================================
# TEST: EDGE CASES
# ============================================================================
class TestEdgeCases:
"""Edge cases and boundary conditions."""
def test_single_element_vector_fails(self):
"""Non-power-of-2 dimension should still work (or fail gracefully)."""
# The WHT requires power-of-2, but encode should handle it
with pytest.raises((ValueError, IndexError, ZeroDivisionError)):
src = [1.0] # dim=1, can't do WHT properly
polar_quant_encode_turbo4(src)
def test_all_same_values(self):
"""Vector with all identical values."""
src = [0.5] * 16
packed, norm = polar_quant_encode_turbo4(src)
recovered = polar_quant_decode_turbo4(packed, norm, 16)
# All recovered values should be approximately equal
mean = sum(recovered) / len(recovered)
for val in recovered:
assert abs(val - mean) < 0.1
def test_large_values(self):
"""Large input values don't cause NaN/Inf."""
src = [100.0, -100.0, 50.0, -50.0, 25.0, -25.0, 10.0, -10.0,
5.0, -5.0, 2.0, -2.0, 1.0, -1.0, 0.5, -0.5]
packed, norm = polar_quant_encode_turbo4(src)
assert math.isfinite(norm)
recovered = polar_quant_decode_turbo4(packed, norm, 16)
assert all(math.isfinite(x) for x in recovered)
def test_alternating_signs(self):
"""Alternating positive/negative values."""
src = [(-1) ** i * 0.1 for i in range(16)]
packed, norm = polar_quant_encode_turbo4(src)
recovered = polar_quant_decode_turbo4(packed, norm, 16)
assert all(math.isfinite(x) for x in recovered)