From ccb041999765f9c1fb2deb28e729c4b646f46c3b Mon Sep 17 00:00:00 2001 From: Alexander Whitestone Date: Wed, 15 Apr 2026 02:15:02 +0000 Subject: [PATCH] test: Add Python unit tests for PolarQuant (#54) --- tests/test_polar_quant.py | 410 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 410 insertions(+) create mode 100644 tests/test_polar_quant.py diff --git a/tests/test_polar_quant.py b/tests/test_polar_quant.py new file mode 100644 index 00000000..53a15078 --- /dev/null +++ b/tests/test_polar_quant.py @@ -0,0 +1,410 @@ +""" +Unit tests for PolarQuant Turbo4 encode/decode. + +Tests the algorithm logic using Python reference implementations +that mirror the C++/Metal code. +""" + +import math +import pytest +import struct +from typing import List, Tuple + + +# Lloyd-Max Centroids for N(0, 1/d) where d=128 +# 4-bit (16 levels) - copied from llama-turbo.cpp +TURBO4_CENTROIDS = [ + -0.2154, -0.1523, -0.1121, -0.0812, + -0.0554, -0.0321, -0.0105, 0.0105, + 0.0321, 0.0554, 0.0812, 0.1121, + 0.1523, 0.2154, 0.2800, 0.3500 +] + + +def fwht(a: List[float]) -> List[float]: + """Fast Walsh-Hadamard Transform (Python reference).""" + n = len(a) + result = a.copy() + + h = 1 + while h < n: + for i in range(0, n, h * 2): + for j in range(i, i + h): + x = result[j] + y = result[j + h] + result[j] = x + y + result[j + h] = x - y + h <<= 1 + + # Normalize + scale = 1.0 / math.sqrt(n) + for i in range(n): + result[i] *= scale + + return result + + +def polar_quant_encode(src: List[float]) -> Tuple[bytes, float]: + """ + PolarQuant Turbo4 Encode (Python reference). + + Returns: + Tuple of (packed_bytes, norm) + """ + d = len(src) + assert d % 2 == 0, "Dimension must be even" + + # Apply WHT + rotated = fwht(src) + + # Calculate L2 norm + norm = math.sqrt(sum(x * x for x in rotated)) + + # Quantize components + inv_norm = 1.0 / (norm + 1e-9) + indices = [] + + for val in rotated: + val_normalized = val * inv_norm + + # Find nearest centroid + best_idx = 0 + min_dist = abs(val_normalized - TURBO4_CENTROIDS[0]) + for j in range(1, 16): + dist = abs(val_normalized - TURBO4_CENTROIDS[j]) + if dist < min_dist: + min_dist = dist + best_idx = j + + indices.append(best_idx) + + # Pack 4-bit indices into bytes + packed = bytearray(d // 2) + for i in range(d): + if i % 2 == 0: + packed[i // 2] = indices[i] + else: + packed[i // 2] |= indices[i] << 4 + + return bytes(packed), norm + + +def polar_quant_decode(src: bytes, norm: float, d: int) -> List[float]: + """ + PolarQuant Turbo4 Decode (Python reference). + + Returns: + Reconstructed float array + """ + # Unpack 4-bit indices + values = [] + for i in range(d): + if i % 2 == 0: + idx = src[i // 2] & 0x0F + else: + idx = src[i // 2] >> 4 + values.append(TURBO4_CENTROIDS[idx] * norm) + + # Apply inverse WHT (same as forward for orthogonal) + return fwht(values) + + +class TestEncodeDecodeRoundtrip: + """Test that decode(encode(x)) ≈ x.""" + + def test_zero_vector(self): + """Zero vector should encode/decode to zero.""" + d = 128 + src = [0.0] * d + packed, norm = polar_quant_encode(src) + reconstructed = polar_quant_decode(packed, norm, d) + + # Zero has no information, reconstruction will be near-zero + for i in range(d): + assert abs(reconstructed[i]) < 0.1, f"Index {i}: {reconstructed[i]}" + + def test_unit_vector(self): + """Unit vector should roundtrip reasonably.""" + d = 128 + src = [0.0] * d + src[0] = 1.0 # Unit vector + + packed, norm = polar_quant_encode(src) + reconstructed = polar_quant_decode(packed, norm, d) + + # Check shape is preserved (first element dominant) + max_val = max(reconstructed) + max_idx = reconstructed.index(max_val) + assert max_idx == 0, f"Peak at index {max_idx}, expected 0" + + def test_random_vectors(self): + """Random vectors should roundtrip with bounded error.""" + import random + random.seed(42) + + d = 128 + errors = [] + + for trial in range(10): + src = [random.gauss(0, 0.1) for _ in range(d)] + packed, norm = polar_quant_encode(src) + reconstructed = polar_quant_decode(packed, norm, d) + + # Compute relative error + orig_norm = math.sqrt(sum(x * x for x in src)) + diff_norm = math.sqrt(sum((a - b) ** 2 for a, b in zip(src, reconstructed))) + rel_error = diff_norm / (orig_norm + 1e-9) + errors.append(rel_error) + + avg_error = sum(errors) / len(errors) + assert avg_error < 0.5, f"Average relative error {avg_error} too high" + + def test_various_dimensions(self): + """Test with different power-of-2 dimensions.""" + for d in [16, 32, 64, 128, 256]: + src = [math.sin(i * 0.1) for i in range(d)] + packed, norm = polar_quant_encode(src) + reconstructed = polar_quant_decode(packed, norm, d) + + # Basic sanity: reconstructed should have similar magnitude + # 4-bit quantization loses significant energy, especially at small dims + orig_energy = sum(x * x for x in src) + recon_energy = sum(x * x for x in reconstructed) + ratio = recon_energy / (orig_energy + 1e-9) + assert 0.1 < ratio < 10.0, f"d={d}: energy ratio {ratio}" + + +class TestInnerProductPreservation: + """Test that Q·K ≈ Q·dequant(quant(K)).""" + + def test_inner_product_preserved(self): + """Inner products should be approximately preserved.""" + import random + random.seed(123) + + d = 128 + + # Generate two random vectors + q = [random.gauss(0, 0.1) for _ in range(d)] + k = [random.gauss(0, 0.1) for _ in range(d)] + + # Original inner product + orig_ip = sum(a * b for a, b in zip(q, k)) + + # Compress k + k_packed, k_norm = polar_quant_encode(k) + k_reconstructed = polar_quant_decode(k_packed, k_norm, d) + + # Compressed inner product + comp_ip = sum(a * b for a, b in zip(q, k_reconstructed)) + + # Check relative error + rel_error = abs(orig_ip - comp_ip) / (abs(orig_ip) + 1e-9) + # 4-bit quantization has significant error, allow up to 100% error + assert rel_error < 1.0, f"Inner product error {rel_error} too high" + + def test_self_inner_product(self): + """Self inner product should be close to original.""" + d = 128 + x = [math.cos(i * 0.2) for i in range(d)] + + orig_self_ip = sum(a * a for a in x) + + packed, norm = polar_quant_encode(x) + reconstructed = polar_quant_decode(packed, norm, d) + + comp_self_ip = sum(a * a for a in reconstructed) + + # Self inner product is energy, should be roughly preserved + # 4-bit quantization has significant error + ratio = comp_self_ip / (orig_self_ip + 1e-9) + assert 0.3 < ratio < 3.0, f"Self inner product ratio {ratio}" + + +class TestWHTOrthogonality: + """Test that WHT is orthogonal (WHT^T · WHT = I).""" + + def test_wht_orthogonality(self): + """WHT should be orthogonal transformation.""" + d = 128 + + # Create identity-like test: apply WHT, then apply again + # For orthogonal matrix, A^T A = I, so applying twice should scale + src = [float(i) for i in range(d)] + + # First WHT + result1 = fwht(src) + + # Second WHT (should be proportional to original for orthogonal) + result2 = fwht(result1) + + # result2 should be proportional to src + # For Walsh-Hadamard, WHT(WHT(x)) = x * (1/sqrt(d))^2 * d = x + # Actually: WHT is self-inverse up to scaling + for i in range(d): + ratio = result2[i] / (src[i] + 1e-9) if src[i] != 0 else result2[i] + # Should be close to 1.0 (or 0 if src[i] is 0) + if abs(src[i]) > 0.01: + assert abs(ratio - 1.0) < 0.1, f"Index {i}: ratio {ratio}" + + def test_wht_preserves_norm(self): + """WHT should preserve L2 norm.""" + d = 128 + src = [math.sin(i) for i in range(d)] + + orig_norm = math.sqrt(sum(x * x for x in src)) + + result = fwht(src) + result_norm = math.sqrt(sum(x * x for x in result)) + + ratio = result_norm / orig_norm + assert abs(ratio - 1.0) < 0.01, f"Norm ratio {ratio}, expected 1.0" + + def test_wht_linearity(self): + """WHT should be linear: WHT(a+b) = WHT(a) + WHT(b).""" + d = 64 + a = [float(i) * 0.1 for i in range(d)] + b = [float(i) * 0.2 for i in range(d)] + + # WHT(a + b) + a_plus_b = [x + y for x, y in zip(a, b)] + wht_sum = fwht(a_plus_b) + + # WHT(a) + WHT(b) + wht_a = fwht(a) + wht_b = fwht(b) + sum_wht = [x + y for x, y in zip(wht_a, wht_b)] + + # Should be equal + for i in range(d): + assert abs(wht_sum[i] - sum_wht[i]) < 1e-6, f"Linearity failed at {i}" + + +class TestCodebookCorrectness: + """Test that centroids match Lloyd-Max for N(0, 1/128).""" + + def test_centroids_extremes(self): + """Extreme centroids should cover tails of distribution.""" + min_c = min(TURBO4_CENTROIDS) + max_c = max(TURBO4_CENTROIDS) + # Should have reasonable range + assert min_c < -0.2, f"Min centroid {min_c} should be < -0.2" + assert max_c > 0.2, f"Max centroid {max_c} should be > 0.2" + + def test_centroids_ordered(self): + """Centroids should be strictly increasing.""" + for i in range(len(TURBO4_CENTROIDS) - 1): + assert TURBO4_CENTROIDS[i] < TURBO4_CENTROIDS[i + 1], f"Centroids not ordered at index {i}" + + def test_centroids_cover_range(self): + """Centroids should cover reasonable range for N(0, 1/128).""" + # For N(0, 1/128), std = 1/sqrt(128) ≈ 0.088 + # Centroids should cover roughly [-3*std, 3*std] + min_c = min(TURBO4_CENTROIDS) + max_c = max(TURBO4_CENTROIDS) + + std = 1.0 / math.sqrt(128) # ≈ 0.088 + + assert min_c < -2 * std, f"Min centroid {min_c} should be < {-2*std}" + assert max_c > 2 * std, f"Max centroid {max_c} should be > {2*std}" + + def test_centroids_count(self): + """Should have exactly 16 centroids for 4-bit quantization.""" + assert len(TURBO4_CENTROIDS) == 16, f"Expected 16 centroids, got {len(TURBO4_CENTROIDS)}" + + +class TestBitPacking: + """Test bit packing/unpacking correctness.""" + + def test_packing_roundtrip(self): + """Packing and unpacking should be lossless for 4-bit values.""" + d = 128 + + # Create test indices (0-15) + indices = [i % 16 for i in range(d)] + + # Pack + packed = bytearray(d // 2) + for i in range(d): + if i % 2 == 0: + packed[i // 2] = indices[i] + else: + packed[i // 2] |= indices[i] << 4 + + # Unpack + unpacked = [] + for i in range(d): + if i % 2 == 0: + idx = packed[i // 2] & 0x0F + else: + idx = packed[i // 2] >> 4 + unpacked.append(idx) + + assert unpacked == indices, "Packing/unpacking mismatch" + + def test_packing_bounds(self): + """Packed values should fit in 4 bits (0-15).""" + d = 128 + indices = [15] * d # Max value + + packed = bytearray(d // 2) + for i in range(d): + if i % 2 == 0: + packed[i // 2] = indices[i] + else: + packed[i // 2] |= indices[i] << 4 + + # Each byte should have both nibbles = 15 + for byte in packed: + assert byte == 0xFF, f"Expected 0xFF, got {hex(byte)}" + + def test_no_overflow(self): + """Packing should not overflow with valid 4-bit values.""" + d = 256 # Larger dimension + + # All max values + indices = [15] * d + + packed = bytearray(d // 2) + for i in range(d): + if i % 2 == 0: + packed[i // 2] = indices[i] + else: + packed[i // 2] |= indices[i] << 4 + + # Should not crash or produce invalid values + assert len(packed) == d // 2 + + +class TestMemoryBounds: + """Test memory safety with various dimensions.""" + + def test_minimum_dimension(self): + """Should work with minimum dimension (2).""" + d = 2 + src = [1.0, 0.5] + packed, norm = polar_quant_encode(src) + assert len(packed) == d // 2 + reconstructed = polar_quant_decode(packed, norm, d) + assert len(reconstructed) == d + + def test_large_dimension(self): + """Should work with large dimensions.""" + d = 1024 + src = [math.sin(i * 0.01) for i in range(d)] + packed, norm = polar_quant_encode(src) + assert len(packed) == d // 2 + reconstructed = polar_quant_decode(packed, norm, d) + assert len(reconstructed) == d + + def test_odd_dimension_fails(self): + """Odd dimensions should fail (need even for 4-bit packing).""" + d = 127 # Odd + src = [0.0] * d + + with pytest.raises(AssertionError): + polar_quant_encode(src) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])