""" Unit tests for PolarQuant Turbo4 encode/decode. Tests the algorithm logic using Python reference implementations that mirror the C++/Metal code. """ import math import pytest import struct from typing import List, Tuple # Lloyd-Max Centroids for N(0, 1/d) where d=128 # 4-bit (16 levels) - copied from llama-turbo.cpp TURBO4_CENTROIDS = [ -0.2154, -0.1523, -0.1121, -0.0812, -0.0554, -0.0321, -0.0105, 0.0105, 0.0321, 0.0554, 0.0812, 0.1121, 0.1523, 0.2154, 0.2800, 0.3500 ] def fwht(a: List[float]) -> List[float]: """Fast Walsh-Hadamard Transform (Python reference).""" n = len(a) result = a.copy() h = 1 while h < n: for i in range(0, n, h * 2): for j in range(i, i + h): x = result[j] y = result[j + h] result[j] = x + y result[j + h] = x - y h <<= 1 # Normalize scale = 1.0 / math.sqrt(n) for i in range(n): result[i] *= scale return result def polar_quant_encode(src: List[float]) -> Tuple[bytes, float]: """ PolarQuant Turbo4 Encode (Python reference). Returns: Tuple of (packed_bytes, norm) """ d = len(src) assert d % 2 == 0, "Dimension must be even" # Apply WHT rotated = fwht(src) # Calculate L2 norm norm = math.sqrt(sum(x * x for x in rotated)) # Quantize components inv_norm = 1.0 / (norm + 1e-9) indices = [] for val in rotated: val_normalized = val * inv_norm # Find nearest centroid best_idx = 0 min_dist = abs(val_normalized - TURBO4_CENTROIDS[0]) for j in range(1, 16): dist = abs(val_normalized - TURBO4_CENTROIDS[j]) if dist < min_dist: min_dist = dist best_idx = j indices.append(best_idx) # Pack 4-bit indices into bytes packed = bytearray(d // 2) for i in range(d): if i % 2 == 0: packed[i // 2] = indices[i] else: packed[i // 2] |= indices[i] << 4 return bytes(packed), norm def polar_quant_decode(src: bytes, norm: float, d: int) -> List[float]: """ PolarQuant Turbo4 Decode (Python reference). Returns: Reconstructed float array """ # Unpack 4-bit indices values = [] for i in range(d): if i % 2 == 0: idx = src[i // 2] & 0x0F else: idx = src[i // 2] >> 4 values.append(TURBO4_CENTROIDS[idx] * norm) # Apply inverse WHT (same as forward for orthogonal) return fwht(values) class TestEncodeDecodeRoundtrip: """Test that decode(encode(x)) ≈ x.""" def test_zero_vector(self): """Zero vector should encode/decode to zero.""" d = 128 src = [0.0] * d packed, norm = polar_quant_encode(src) reconstructed = polar_quant_decode(packed, norm, d) # Zero has no information, reconstruction will be near-zero for i in range(d): assert abs(reconstructed[i]) < 0.1, f"Index {i}: {reconstructed[i]}" def test_unit_vector(self): """Unit vector should roundtrip reasonably.""" d = 128 src = [0.0] * d src[0] = 1.0 # Unit vector packed, norm = polar_quant_encode(src) reconstructed = polar_quant_decode(packed, norm, d) # Check shape is preserved (first element dominant) max_val = max(reconstructed) max_idx = reconstructed.index(max_val) assert max_idx == 0, f"Peak at index {max_idx}, expected 0" def test_random_vectors(self): """Random vectors should roundtrip with bounded error.""" import random random.seed(42) d = 128 errors = [] for trial in range(10): src = [random.gauss(0, 0.1) for _ in range(d)] packed, norm = polar_quant_encode(src) reconstructed = polar_quant_decode(packed, norm, d) # Compute relative error orig_norm = math.sqrt(sum(x * x for x in src)) diff_norm = math.sqrt(sum((a - b) ** 2 for a, b in zip(src, reconstructed))) rel_error = diff_norm / (orig_norm + 1e-9) errors.append(rel_error) avg_error = sum(errors) / len(errors) assert avg_error < 0.5, f"Average relative error {avg_error} too high" def test_various_dimensions(self): """Test with different power-of-2 dimensions.""" for d in [16, 32, 64, 128, 256]: src = [math.sin(i * 0.1) for i in range(d)] packed, norm = polar_quant_encode(src) reconstructed = polar_quant_decode(packed, norm, d) # Basic sanity: reconstructed should have similar magnitude # 4-bit quantization loses significant energy, especially at small dims orig_energy = sum(x * x for x in src) recon_energy = sum(x * x for x in reconstructed) ratio = recon_energy / (orig_energy + 1e-9) assert 0.1 < ratio < 10.0, f"d={d}: energy ratio {ratio}" class TestInnerProductPreservation: """Test that Q·K ≈ Q·dequant(quant(K)).""" def test_inner_product_preserved(self): """Inner products should be approximately preserved.""" import random random.seed(123) d = 128 # Generate two random vectors q = [random.gauss(0, 0.1) for _ in range(d)] k = [random.gauss(0, 0.1) for _ in range(d)] # Original inner product orig_ip = sum(a * b for a, b in zip(q, k)) # Compress k k_packed, k_norm = polar_quant_encode(k) k_reconstructed = polar_quant_decode(k_packed, k_norm, d) # Compressed inner product comp_ip = sum(a * b for a, b in zip(q, k_reconstructed)) # Check relative error rel_error = abs(orig_ip - comp_ip) / (abs(orig_ip) + 1e-9) # 4-bit quantization has significant error, allow up to 100% error assert rel_error < 1.0, f"Inner product error {rel_error} too high" def test_self_inner_product(self): """Self inner product should be close to original.""" d = 128 x = [math.cos(i * 0.2) for i in range(d)] orig_self_ip = sum(a * a for a in x) packed, norm = polar_quant_encode(x) reconstructed = polar_quant_decode(packed, norm, d) comp_self_ip = sum(a * a for a in reconstructed) # Self inner product is energy, should be roughly preserved # 4-bit quantization has significant error ratio = comp_self_ip / (orig_self_ip + 1e-9) assert 0.3 < ratio < 3.0, f"Self inner product ratio {ratio}" class TestWHTOrthogonality: """Test that WHT is orthogonal (WHT^T · WHT = I).""" def test_wht_orthogonality(self): """WHT should be orthogonal transformation.""" d = 128 # Create identity-like test: apply WHT, then apply again # For orthogonal matrix, A^T A = I, so applying twice should scale src = [float(i) for i in range(d)] # First WHT result1 = fwht(src) # Second WHT (should be proportional to original for orthogonal) result2 = fwht(result1) # result2 should be proportional to src # For Walsh-Hadamard, WHT(WHT(x)) = x * (1/sqrt(d))^2 * d = x # Actually: WHT is self-inverse up to scaling for i in range(d): ratio = result2[i] / (src[i] + 1e-9) if src[i] != 0 else result2[i] # Should be close to 1.0 (or 0 if src[i] is 0) if abs(src[i]) > 0.01: assert abs(ratio - 1.0) < 0.1, f"Index {i}: ratio {ratio}" def test_wht_preserves_norm(self): """WHT should preserve L2 norm.""" d = 128 src = [math.sin(i) for i in range(d)] orig_norm = math.sqrt(sum(x * x for x in src)) result = fwht(src) result_norm = math.sqrt(sum(x * x for x in result)) ratio = result_norm / orig_norm assert abs(ratio - 1.0) < 0.01, f"Norm ratio {ratio}, expected 1.0" def test_wht_linearity(self): """WHT should be linear: WHT(a+b) = WHT(a) + WHT(b).""" d = 64 a = [float(i) * 0.1 for i in range(d)] b = [float(i) * 0.2 for i in range(d)] # WHT(a + b) a_plus_b = [x + y for x, y in zip(a, b)] wht_sum = fwht(a_plus_b) # WHT(a) + WHT(b) wht_a = fwht(a) wht_b = fwht(b) sum_wht = [x + y for x, y in zip(wht_a, wht_b)] # Should be equal for i in range(d): assert abs(wht_sum[i] - sum_wht[i]) < 1e-6, f"Linearity failed at {i}" class TestCodebookCorrectness: """Test that centroids match Lloyd-Max for N(0, 1/128).""" def test_centroids_extremes(self): """Extreme centroids should cover tails of distribution.""" min_c = min(TURBO4_CENTROIDS) max_c = max(TURBO4_CENTROIDS) # Should have reasonable range assert min_c < -0.2, f"Min centroid {min_c} should be < -0.2" assert max_c > 0.2, f"Max centroid {max_c} should be > 0.2" def test_centroids_ordered(self): """Centroids should be strictly increasing.""" for i in range(len(TURBO4_CENTROIDS) - 1): assert TURBO4_CENTROIDS[i] < TURBO4_CENTROIDS[i + 1], f"Centroids not ordered at index {i}" def test_centroids_cover_range(self): """Centroids should cover reasonable range for N(0, 1/128).""" # For N(0, 1/128), std = 1/sqrt(128) ≈ 0.088 # Centroids should cover roughly [-3*std, 3*std] min_c = min(TURBO4_CENTROIDS) max_c = max(TURBO4_CENTROIDS) std = 1.0 / math.sqrt(128) # ≈ 0.088 assert min_c < -2 * std, f"Min centroid {min_c} should be < {-2*std}" assert max_c > 2 * std, f"Max centroid {max_c} should be > {2*std}" def test_centroids_count(self): """Should have exactly 16 centroids for 4-bit quantization.""" assert len(TURBO4_CENTROIDS) == 16, f"Expected 16 centroids, got {len(TURBO4_CENTROIDS)}" class TestBitPacking: """Test bit packing/unpacking correctness.""" def test_packing_roundtrip(self): """Packing and unpacking should be lossless for 4-bit values.""" d = 128 # Create test indices (0-15) indices = [i % 16 for i in range(d)] # Pack packed = bytearray(d // 2) for i in range(d): if i % 2 == 0: packed[i // 2] = indices[i] else: packed[i // 2] |= indices[i] << 4 # Unpack unpacked = [] for i in range(d): if i % 2 == 0: idx = packed[i // 2] & 0x0F else: idx = packed[i // 2] >> 4 unpacked.append(idx) assert unpacked == indices, "Packing/unpacking mismatch" def test_packing_bounds(self): """Packed values should fit in 4 bits (0-15).""" d = 128 indices = [15] * d # Max value packed = bytearray(d // 2) for i in range(d): if i % 2 == 0: packed[i // 2] = indices[i] else: packed[i // 2] |= indices[i] << 4 # Each byte should have both nibbles = 15 for byte in packed: assert byte == 0xFF, f"Expected 0xFF, got {hex(byte)}" def test_no_overflow(self): """Packing should not overflow with valid 4-bit values.""" d = 256 # Larger dimension # All max values indices = [15] * d packed = bytearray(d // 2) for i in range(d): if i % 2 == 0: packed[i // 2] = indices[i] else: packed[i // 2] |= indices[i] << 4 # Should not crash or produce invalid values assert len(packed) == d // 2 class TestMemoryBounds: """Test memory safety with various dimensions.""" def test_minimum_dimension(self): """Should work with minimum dimension (2).""" d = 2 src = [1.0, 0.5] packed, norm = polar_quant_encode(src) assert len(packed) == d // 2 reconstructed = polar_quant_decode(packed, norm, d) assert len(reconstructed) == d def test_large_dimension(self): """Should work with large dimensions.""" d = 1024 src = [math.sin(i * 0.01) for i in range(d)] packed, norm = polar_quant_encode(src) assert len(packed) == d // 2 reconstructed = polar_quant_decode(packed, norm, d) assert len(reconstructed) == d def test_odd_dimension_fails(self): """Odd dimensions should fail (need even for 4-bit packing).""" d = 127 # Odd src = [0.0] * d with pytest.raises(AssertionError): polar_quant_encode(src) if __name__ == "__main__": pytest.main([__file__, "-v"])