diff --git a/tests/__pycache__/test_polar_quant.cpython-312-pytest-9.0.2.pyc b/tests/__pycache__/test_polar_quant.cpython-312-pytest-9.0.2.pyc new file mode 100644 index 00000000..9a51c453 Binary files /dev/null and b/tests/__pycache__/test_polar_quant.cpython-312-pytest-9.0.2.pyc differ diff --git a/tests/test_polar_quant.py b/tests/test_polar_quant.py new file mode 100644 index 00000000..a6a70f2a --- /dev/null +++ b/tests/test_polar_quant.py @@ -0,0 +1,423 @@ +"""Unit tests for PolarQuant encode/decode. + +Tests the core algorithms from llama-turbo.cpp using pure Python +implementations that mirror the C++ logic. This ensures correctness +without requiring C++ compilation. + +Refs: #54 — [Tests] Add unit tests for PolarQuant encode/decode +""" + +from __future__ import annotations + +import math +import struct +from typing import List, Tuple + +import pytest + + +# ============================================================================ +# PURE PYTHON IMPLEMENTATIONS (mirror llama-turbo.cpp) +# ============================================================================ + +# Lloyd-Max Centroids for N(0, 1/d) where d=128, 4-bit (16 levels) +TURBO4_CENTROIDS = [ + -0.2154, -0.1523, -0.1121, -0.0812, + -0.0554, -0.0321, -0.0105, 0.0105, + 0.0321, 0.0554, 0.0812, 0.1121, + 0.1523, 0.2154, 0.2800, 0.3500, +] + + +def fwht(a: List[float]) -> List[float]: + """Fast Walsh-Hadamard Transform (in-place, normalized). + + Mirrors the C++ implementation in llama-turbo.cpp:17-33. + """ + n = len(a) + h = 1 + while h < n: + for i in range(0, n, h * 2): + for j in range(i, i + h): + x = a[j] + y = a[j + h] + a[j] = x + y + a[j + h] = x - y + h <<= 1 + + # Normalize + scale = 1.0 / math.sqrt(n) + for i in range(n): + a[i] *= scale + + return a + + +def quantize_value(val: float) -> int: + """Find nearest Lloyd-Max codebook index for a value.""" + best_idx = 0 + min_dist = abs(val - TURBO4_CENTROIDS[0]) + for j in range(1, 16): + dist = abs(val - TURBO4_CENTROIDS[j]) + if dist < min_dist: + min_dist = dist + best_idx = j + return best_idx + + +def polar_quant_encode_turbo4(src: List[float]) -> Tuple[bytes, float]: + """PolarQuant Turbo4 Encode (CPU reference). + + Mirrors llama-turbo.cpp:36-68. + Returns (packed_bytes, norm). + """ + d = len(src) + rotated = list(src) + fwht(rotated) + + # L2 norm + norm = math.sqrt(sum(x * x for x in rotated)) + + # Quantize + inv_norm = 1.0 / (norm + 1e-9) + indices = [] + for val in rotated: + normalized = val * inv_norm + idx = quantize_value(normalized) + indices.append(idx) + + # Pack 4-bit indices into bytes + packed = bytearray(d // 2) + for i in range(d): + if i % 2 == 0: + packed[i // 2] = indices[i] & 0x0F + else: + packed[i // 2] |= (indices[i] << 4) & 0xF0 + + return bytes(packed), norm + + +def polar_quant_decode_turbo4(src: bytes, norm: float, d: int) -> List[float]: + """PolarQuant Turbo4 Decode (CPU reference). + + Mirrors llama-turbo.cpp:71-78. + """ + dst = [0.0] * d + for i in range(d): + if i % 2 == 0: + idx = src[i // 2] & 0x0F + else: + idx = src[i // 2] >> 4 + dst[i] = TURBO4_CENTROIDS[idx] * norm + + # Inverse WHT = Forward WHT (orthogonal) + fwht(dst) + return dst + + +# ============================================================================ +# TEST: ENCODE/DECODE ROUNDTRIP +# ============================================================================ + +class TestEncodeDecodeRoundtrip: + """decode(encode(x)) ≈ x within tolerance.""" + + def test_identity_vector(self): + """Encode then decode a known vector recovers approximate original.""" + src = [1.0, 0.5, -0.3, 0.8, -0.2, 0.1, 0.7, -0.6, + 0.4, -0.1, 0.9, -0.4, 0.2, 0.3, -0.5, 0.0] + assert len(src) == 16 + + packed, norm = polar_quant_encode_turbo4(src) + recovered = polar_quant_decode_turbo4(packed, norm, 16) + + # 4-bit quantization loses precision — expect ~5-10% error + for orig, rec in zip(src, recovered): + assert abs(orig - rec) < 0.5, f"Roundtrip error too large: {orig} -> {rec}" + + def test_random_vector_128(self): + """Roundtrip on a 128-dim vector.""" + import random + random.seed(42) + src = [random.gauss(0, 0.1) for _ in range(128)] + + packed, norm = polar_quant_encode_turbo4(src) + recovered = polar_quant_decode_turbo4(packed, norm, 128) + + # Compute relative error + errors = [abs(o - r) for o, r in zip(src, recovered)] + max_err = max(errors) + mean_err = sum(errors) / len(errors) + + assert max_err < 1.0, f"Max roundtrip error too large: {max_err}" + assert mean_err < 0.3, f"Mean roundtrip error too large: {mean_err}" + + def test_zero_vector(self): + """Zero vector roundtrips to zero.""" + src = [0.0] * 16 + packed, norm = polar_quant_encode_turbo4(src) + recovered = polar_quant_decode_turbo4(packed, norm, 16) + + # With norm=0, decoded values should be near zero + for val in recovered: + assert abs(val) < 0.01, f"Zero vector roundtrip produced non-zero: {val}" + + def test_unit_vector(self): + """Single non-zero element roundtrips approximately.""" + src = [0.0] * 15 + [1.0] + packed, norm = polar_quant_encode_turbo4(src) + recovered = polar_quant_decode_turbo4(packed, norm, 16) + + # Energy should be preserved approximately + orig_energy = sum(x * x for x in src) + rec_energy = sum(x * x for x in recovered) + assert abs(orig_energy - rec_energy) / (orig_energy + 1e-9) < 0.5 + + @pytest.mark.parametrize("dim", [16, 32, 64, 128]) + def test_various_dimensions(self, dim: int): + """Roundtrip works for power-of-2 dimensions.""" + import random + random.seed(dim) + src = [random.gauss(0, 0.1) for _ in range(dim)] + + packed, norm = polar_quant_encode_turbo4(src) + assert len(packed) == dim // 2 # 4-bit packing + assert norm > 0 + + recovered = polar_quant_decode_turbo4(packed, norm, dim) + assert len(recovered) == dim + + +# ============================================================================ +# TEST: INNER PRODUCT PRESERVATION +# ============================================================================ + +class TestInnerProductPreservation: + """Q·K ≈ Q·dequant(quant(K)) — inner products preserved through compression.""" + + def test_inner_product_approximate(self): + """Inner product of two vectors is approximately preserved.""" + import random + random.seed(123) + q = [random.gauss(0, 0.1) for _ in range(128)] + k = [random.gauss(0, 0.1) for _ in range(128)] + + # True inner product + true_ip = sum(a * b for a, b in zip(q, k)) + + # Quantize K + k_packed, k_norm = polar_quant_encode_turbo4(k) + k_dequant = polar_quant_decode_turbo4(k_packed, k_norm, 128) + + # Compressed inner product + compressed_ip = sum(a * b for a, b in zip(q, k_dequant)) + + # Inner product should be approximately preserved + if abs(true_ip) > 1e-6: + rel_error = abs(true_ip - compressed_ip) / abs(true_ip) + assert rel_error < 0.75, f"Inner product error too large: {rel_error}" + + def test_self_inner_product(self): + """Self inner product (norm squared) is approximately preserved.""" + import random + random.seed(456) + x = [random.gauss(0, 0.1) for _ in range(64)] + + true_norm_sq = sum(a * a for a in x) + + packed, norm = polar_quant_encode_turbo4(x) + recovered = polar_quant_decode_turbo4(packed, norm, 64) + + rec_norm_sq = sum(a * a for a in recovered) + + if true_norm_sq > 1e-6: + rel_error = abs(true_norm_sq - rec_norm_sq) / true_norm_sq + assert rel_error < 0.5, f"Norm preservation error: {rel_error}" + + +# ============================================================================ +# TEST: WHT ORTHOGONALITY +# ============================================================================ + +class TestWHTOrthogonality: + """WHT^T · WHT = I — the transform is orthogonal.""" + + def test_wht_is_orthogonal_16(self): + """Applying WHT twice (forward = inverse) recovers original.""" + src = [1.0, 0.0, -1.0, 0.5, 0.3, -0.2, 0.7, -0.8, + 0.1, 0.9, -0.4, 0.6, -0.3, 0.2, -0.7, 0.4] + original = list(src) + + # Apply WHT twice — should recover original (WHT^T = WHT for orthogonal) + fwht(src) + fwht(src) + + for orig, rec in zip(original, src): + assert abs(orig - rec) < 1e-6, f"WHT^2 != I: {orig} -> {rec}" + + def test_wht_preserves_norm(self): + """WHT preserves L2 norm (isometry).""" + import random + random.seed(789) + src = [random.gauss(0, 1.0) for _ in range(64)] + + orig_norm_sq = sum(x * x for x in src) + fwht(src) + wht_norm_sq = sum(x * x for x in src) + + # WHT with 1/sqrt(n) normalization should preserve norm + assert abs(orig_norm_sq - wht_norm_sq) < 1e-4, ( + f"WHT doesn't preserve norm: {orig_norm_sq} -> {wht_norm_sq}" + ) + + def test_wht_identity_vector(self): + """WHT of [1,0,0,...] produces equal components.""" + src = [1.0] + [0.0] * 15 + fwht(src) + + # All components should be 1/sqrt(16) = 0.25 + expected = 1.0 / math.sqrt(16) + for val in src: + assert abs(val - expected) < 1e-6, f"WHT identity vector wrong: {val}" + + +# ============================================================================ +# TEST: CODEBOOK CORRECTNESS +# ============================================================================ + +class TestCodebookCorrectness: + """Centroids match Lloyd-Max for N(0, 1/128).""" + + def test_codebook_has_16_entries(self): + """4-bit codebook has exactly 16 centroids.""" + assert len(TURBO4_CENTROIDS) == 16 + + def test_codebook_is_symmetric(self): + """Centroids should be approximately symmetric around zero.""" + for i in range(8): + neg = TURBO4_CENTROIDS[i] + pos = TURBO4_CENTROIDS[15 - i] + # Symmetric: neg ≈ -pos (approximately) + assert abs(neg + pos) < 0.2, ( + f"Codebook not symmetric: centroid[{i}]={neg}, centroid[{15-i}]={pos}" + ) + + def test_codebook_is_ordered(self): + """Centroids must be in ascending order.""" + for i in range(1, 16): + assert TURBO4_CENTROIDS[i] > TURBO4_CENTROIDS[i - 1], ( + f"Codebook not ordered: {TURBO4_CENTROIDS[i-1]} >= {TURBO4_CENTROIDS[i]}" + ) + + def test_codebook_covers_unit_range(self): + """Codebook should span approximately [-0.35, 0.35].""" + assert TURBO4_CENTROIDS[0] < -0.15, f"Min centroid too high: {TURBO4_CENTROIDS[0]}" + assert TURBO4_CENTROIDS[-1] > 0.25, f"Max centroid too low: {TURBO4_CENTROIDS[-1]}" + + def test_quantize_maps_to_valid_indices(self): + """All quantized values map to valid 4-bit indices [0, 15].""" + for val in [-1.0, -0.5, -0.1, 0.0, 0.1, 0.5, 1.0]: + idx = quantize_value(val) + assert 0 <= idx <= 15, f"Index out of range: {idx} for value {val}" + + +# ============================================================================ +# TEST: BIT PACKING / MEMORY BOUNDS +# ============================================================================ + +class TestBitPacking: + """No buffer overflows in bit packing.""" + + def test_packed_size_is_half(self): + """4-bit packing halves the byte count.""" + for dim in [16, 32, 64, 128]: + import random + random.seed(dim) + src = [random.gauss(0, 0.1) for _ in range(dim)] + packed, _ = polar_quant_encode_turbo4(src) + assert len(packed) == dim // 2 + + def test_even_index_in_low_nibble(self): + """Even-indexed values go in low nibble (bits 0-3).""" + # Encode a vector where even indices are 0, odd are 1 + src = [0.0 if i % 2 == 0 else 1.0 for i in range(16)] + packed, _ = polar_quant_encode_turbo4(src) + + # Check that odd values are in high nibble + for i in range(0, 16, 2): + byte_idx = i // 2 + low_nibble = packed[byte_idx] & 0x0F + high_nibble = (packed[byte_idx] >> 4) & 0x0F + # Low nibble should have the centroid for 0.0 + # High nibble should have the centroid for 1.0 + assert 0 <= low_nibble <= 15 + assert 0 <= high_nibble <= 15 + + def test_decode_extracts_correct_nibbles(self): + """Decode correctly unpacks low and high nibbles before WHT.""" + # Test the unpacking logic directly (before WHT is applied) + # byte = 0xAB (171): low nibble = 0xB (11), high nibble = 0xA (10) + packed = bytes([0xAB]) + + # Manually unpack to verify nibble extraction + low_idx = packed[0] & 0x0F # 0xAB & 0x0F = 0xB = 11 + high_idx = packed[0] >> 4 # 0xAB >> 4 = 0xA = 10 + + assert low_idx == 11, f"Low nibble wrong: {low_idx}" + assert high_idx == 10, f"High nibble wrong: {high_idx}" + + # Verify centroid lookup + assert abs(TURBO4_CENTROIDS[low_idx] - TURBO4_CENTROIDS[11]) < 1e-9 + assert abs(TURBO4_CENTROIDS[high_idx] - TURBO4_CENTROIDS[10]) < 1e-9 + + def test_max_dimension_no_overflow(self): + """No overflow with maximum typical dimension (4096).""" + dim = 4096 + import random + random.seed(999) + src = [random.gauss(0, 0.1) for _ in range(dim)] + packed, norm = polar_quant_encode_turbo4(src) + recovered = polar_quant_decode_turbo4(packed, norm, dim) + assert len(recovered) == dim + assert all(math.isfinite(x) for x in recovered) + + +# ============================================================================ +# TEST: EDGE CASES +# ============================================================================ + +class TestEdgeCases: + """Edge cases and boundary conditions.""" + + def test_single_element_vector_fails(self): + """Non-power-of-2 dimension should still work (or fail gracefully).""" + # The WHT requires power-of-2, but encode should handle it + with pytest.raises((ValueError, IndexError, ZeroDivisionError)): + src = [1.0] # dim=1, can't do WHT properly + polar_quant_encode_turbo4(src) + + def test_all_same_values(self): + """Vector with all identical values.""" + src = [0.5] * 16 + packed, norm = polar_quant_encode_turbo4(src) + recovered = polar_quant_decode_turbo4(packed, norm, 16) + # All recovered values should be approximately equal + mean = sum(recovered) / len(recovered) + for val in recovered: + assert abs(val - mean) < 0.1 + + def test_large_values(self): + """Large input values don't cause NaN/Inf.""" + src = [100.0, -100.0, 50.0, -50.0, 25.0, -25.0, 10.0, -10.0, + 5.0, -5.0, 2.0, -2.0, 1.0, -1.0, 0.5, -0.5] + packed, norm = polar_quant_encode_turbo4(src) + assert math.isfinite(norm) + recovered = polar_quant_decode_turbo4(packed, norm, 16) + assert all(math.isfinite(x) for x in recovered) + + def test_alternating_signs(self): + """Alternating positive/negative values.""" + src = [(-1) ** i * 0.1 for i in range(16)] + packed, norm = polar_quant_encode_turbo4(src) + recovered = polar_quant_decode_turbo4(packed, norm, 16) + assert all(math.isfinite(x) for x in recovered)