test: Add Python unit tests for PolarQuant (#54)
This commit is contained in:
410
tests/test_polar_quant.py
Normal file
410
tests/test_polar_quant.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""
|
||||
Unit tests for PolarQuant Turbo4 encode/decode.
|
||||
|
||||
Tests the algorithm logic using Python reference implementations
|
||||
that mirror the C++/Metal code.
|
||||
"""
|
||||
|
||||
import math
|
||||
import pytest
|
||||
import struct
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
# Lloyd-Max Centroids for N(0, 1/d) where d=128
|
||||
# 4-bit (16 levels) - copied from llama-turbo.cpp
|
||||
TURBO4_CENTROIDS = [
|
||||
-0.2154, -0.1523, -0.1121, -0.0812,
|
||||
-0.0554, -0.0321, -0.0105, 0.0105,
|
||||
0.0321, 0.0554, 0.0812, 0.1121,
|
||||
0.1523, 0.2154, 0.2800, 0.3500
|
||||
]
|
||||
|
||||
|
||||
def fwht(a: List[float]) -> List[float]:
|
||||
"""Fast Walsh-Hadamard Transform (Python reference)."""
|
||||
n = len(a)
|
||||
result = a.copy()
|
||||
|
||||
h = 1
|
||||
while h < n:
|
||||
for i in range(0, n, h * 2):
|
||||
for j in range(i, i + h):
|
||||
x = result[j]
|
||||
y = result[j + h]
|
||||
result[j] = x + y
|
||||
result[j + h] = x - y
|
||||
h <<= 1
|
||||
|
||||
# Normalize
|
||||
scale = 1.0 / math.sqrt(n)
|
||||
for i in range(n):
|
||||
result[i] *= scale
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def polar_quant_encode(src: List[float]) -> Tuple[bytes, float]:
|
||||
"""
|
||||
PolarQuant Turbo4 Encode (Python reference).
|
||||
|
||||
Returns:
|
||||
Tuple of (packed_bytes, norm)
|
||||
"""
|
||||
d = len(src)
|
||||
assert d % 2 == 0, "Dimension must be even"
|
||||
|
||||
# Apply WHT
|
||||
rotated = fwht(src)
|
||||
|
||||
# Calculate L2 norm
|
||||
norm = math.sqrt(sum(x * x for x in rotated))
|
||||
|
||||
# Quantize components
|
||||
inv_norm = 1.0 / (norm + 1e-9)
|
||||
indices = []
|
||||
|
||||
for val in rotated:
|
||||
val_normalized = val * inv_norm
|
||||
|
||||
# Find nearest centroid
|
||||
best_idx = 0
|
||||
min_dist = abs(val_normalized - TURBO4_CENTROIDS[0])
|
||||
for j in range(1, 16):
|
||||
dist = abs(val_normalized - TURBO4_CENTROIDS[j])
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
best_idx = j
|
||||
|
||||
indices.append(best_idx)
|
||||
|
||||
# Pack 4-bit indices into bytes
|
||||
packed = bytearray(d // 2)
|
||||
for i in range(d):
|
||||
if i % 2 == 0:
|
||||
packed[i // 2] = indices[i]
|
||||
else:
|
||||
packed[i // 2] |= indices[i] << 4
|
||||
|
||||
return bytes(packed), norm
|
||||
|
||||
|
||||
def polar_quant_decode(src: bytes, norm: float, d: int) -> List[float]:
|
||||
"""
|
||||
PolarQuant Turbo4 Decode (Python reference).
|
||||
|
||||
Returns:
|
||||
Reconstructed float array
|
||||
"""
|
||||
# Unpack 4-bit indices
|
||||
values = []
|
||||
for i in range(d):
|
||||
if i % 2 == 0:
|
||||
idx = src[i // 2] & 0x0F
|
||||
else:
|
||||
idx = src[i // 2] >> 4
|
||||
values.append(TURBO4_CENTROIDS[idx] * norm)
|
||||
|
||||
# Apply inverse WHT (same as forward for orthogonal)
|
||||
return fwht(values)
|
||||
|
||||
|
||||
class TestEncodeDecodeRoundtrip:
|
||||
"""Test that decode(encode(x)) ≈ x."""
|
||||
|
||||
def test_zero_vector(self):
|
||||
"""Zero vector should encode/decode to zero."""
|
||||
d = 128
|
||||
src = [0.0] * d
|
||||
packed, norm = polar_quant_encode(src)
|
||||
reconstructed = polar_quant_decode(packed, norm, d)
|
||||
|
||||
# Zero has no information, reconstruction will be near-zero
|
||||
for i in range(d):
|
||||
assert abs(reconstructed[i]) < 0.1, f"Index {i}: {reconstructed[i]}"
|
||||
|
||||
def test_unit_vector(self):
|
||||
"""Unit vector should roundtrip reasonably."""
|
||||
d = 128
|
||||
src = [0.0] * d
|
||||
src[0] = 1.0 # Unit vector
|
||||
|
||||
packed, norm = polar_quant_encode(src)
|
||||
reconstructed = polar_quant_decode(packed, norm, d)
|
||||
|
||||
# Check shape is preserved (first element dominant)
|
||||
max_val = max(reconstructed)
|
||||
max_idx = reconstructed.index(max_val)
|
||||
assert max_idx == 0, f"Peak at index {max_idx}, expected 0"
|
||||
|
||||
def test_random_vectors(self):
|
||||
"""Random vectors should roundtrip with bounded error."""
|
||||
import random
|
||||
random.seed(42)
|
||||
|
||||
d = 128
|
||||
errors = []
|
||||
|
||||
for trial in range(10):
|
||||
src = [random.gauss(0, 0.1) for _ in range(d)]
|
||||
packed, norm = polar_quant_encode(src)
|
||||
reconstructed = polar_quant_decode(packed, norm, d)
|
||||
|
||||
# Compute relative error
|
||||
orig_norm = math.sqrt(sum(x * x for x in src))
|
||||
diff_norm = math.sqrt(sum((a - b) ** 2 for a, b in zip(src, reconstructed)))
|
||||
rel_error = diff_norm / (orig_norm + 1e-9)
|
||||
errors.append(rel_error)
|
||||
|
||||
avg_error = sum(errors) / len(errors)
|
||||
assert avg_error < 0.5, f"Average relative error {avg_error} too high"
|
||||
|
||||
def test_various_dimensions(self):
|
||||
"""Test with different power-of-2 dimensions."""
|
||||
for d in [16, 32, 64, 128, 256]:
|
||||
src = [math.sin(i * 0.1) for i in range(d)]
|
||||
packed, norm = polar_quant_encode(src)
|
||||
reconstructed = polar_quant_decode(packed, norm, d)
|
||||
|
||||
# Basic sanity: reconstructed should have similar magnitude
|
||||
# 4-bit quantization loses significant energy, especially at small dims
|
||||
orig_energy = sum(x * x for x in src)
|
||||
recon_energy = sum(x * x for x in reconstructed)
|
||||
ratio = recon_energy / (orig_energy + 1e-9)
|
||||
assert 0.1 < ratio < 10.0, f"d={d}: energy ratio {ratio}"
|
||||
|
||||
|
||||
class TestInnerProductPreservation:
|
||||
"""Test that Q·K ≈ Q·dequant(quant(K))."""
|
||||
|
||||
def test_inner_product_preserved(self):
|
||||
"""Inner products should be approximately preserved."""
|
||||
import random
|
||||
random.seed(123)
|
||||
|
||||
d = 128
|
||||
|
||||
# Generate two random vectors
|
||||
q = [random.gauss(0, 0.1) for _ in range(d)]
|
||||
k = [random.gauss(0, 0.1) for _ in range(d)]
|
||||
|
||||
# Original inner product
|
||||
orig_ip = sum(a * b for a, b in zip(q, k))
|
||||
|
||||
# Compress k
|
||||
k_packed, k_norm = polar_quant_encode(k)
|
||||
k_reconstructed = polar_quant_decode(k_packed, k_norm, d)
|
||||
|
||||
# Compressed inner product
|
||||
comp_ip = sum(a * b for a, b in zip(q, k_reconstructed))
|
||||
|
||||
# Check relative error
|
||||
rel_error = abs(orig_ip - comp_ip) / (abs(orig_ip) + 1e-9)
|
||||
# 4-bit quantization has significant error, allow up to 100% error
|
||||
assert rel_error < 1.0, f"Inner product error {rel_error} too high"
|
||||
|
||||
def test_self_inner_product(self):
|
||||
"""Self inner product should be close to original."""
|
||||
d = 128
|
||||
x = [math.cos(i * 0.2) for i in range(d)]
|
||||
|
||||
orig_self_ip = sum(a * a for a in x)
|
||||
|
||||
packed, norm = polar_quant_encode(x)
|
||||
reconstructed = polar_quant_decode(packed, norm, d)
|
||||
|
||||
comp_self_ip = sum(a * a for a in reconstructed)
|
||||
|
||||
# Self inner product is energy, should be roughly preserved
|
||||
# 4-bit quantization has significant error
|
||||
ratio = comp_self_ip / (orig_self_ip + 1e-9)
|
||||
assert 0.3 < ratio < 3.0, f"Self inner product ratio {ratio}"
|
||||
|
||||
|
||||
class TestWHTOrthogonality:
|
||||
"""Test that WHT is orthogonal (WHT^T · WHT = I)."""
|
||||
|
||||
def test_wht_orthogonality(self):
|
||||
"""WHT should be orthogonal transformation."""
|
||||
d = 128
|
||||
|
||||
# Create identity-like test: apply WHT, then apply again
|
||||
# For orthogonal matrix, A^T A = I, so applying twice should scale
|
||||
src = [float(i) for i in range(d)]
|
||||
|
||||
# First WHT
|
||||
result1 = fwht(src)
|
||||
|
||||
# Second WHT (should be proportional to original for orthogonal)
|
||||
result2 = fwht(result1)
|
||||
|
||||
# result2 should be proportional to src
|
||||
# For Walsh-Hadamard, WHT(WHT(x)) = x * (1/sqrt(d))^2 * d = x
|
||||
# Actually: WHT is self-inverse up to scaling
|
||||
for i in range(d):
|
||||
ratio = result2[i] / (src[i] + 1e-9) if src[i] != 0 else result2[i]
|
||||
# Should be close to 1.0 (or 0 if src[i] is 0)
|
||||
if abs(src[i]) > 0.01:
|
||||
assert abs(ratio - 1.0) < 0.1, f"Index {i}: ratio {ratio}"
|
||||
|
||||
def test_wht_preserves_norm(self):
|
||||
"""WHT should preserve L2 norm."""
|
||||
d = 128
|
||||
src = [math.sin(i) for i in range(d)]
|
||||
|
||||
orig_norm = math.sqrt(sum(x * x for x in src))
|
||||
|
||||
result = fwht(src)
|
||||
result_norm = math.sqrt(sum(x * x for x in result))
|
||||
|
||||
ratio = result_norm / orig_norm
|
||||
assert abs(ratio - 1.0) < 0.01, f"Norm ratio {ratio}, expected 1.0"
|
||||
|
||||
def test_wht_linearity(self):
|
||||
"""WHT should be linear: WHT(a+b) = WHT(a) + WHT(b)."""
|
||||
d = 64
|
||||
a = [float(i) * 0.1 for i in range(d)]
|
||||
b = [float(i) * 0.2 for i in range(d)]
|
||||
|
||||
# WHT(a + b)
|
||||
a_plus_b = [x + y for x, y in zip(a, b)]
|
||||
wht_sum = fwht(a_plus_b)
|
||||
|
||||
# WHT(a) + WHT(b)
|
||||
wht_a = fwht(a)
|
||||
wht_b = fwht(b)
|
||||
sum_wht = [x + y for x, y in zip(wht_a, wht_b)]
|
||||
|
||||
# Should be equal
|
||||
for i in range(d):
|
||||
assert abs(wht_sum[i] - sum_wht[i]) < 1e-6, f"Linearity failed at {i}"
|
||||
|
||||
|
||||
class TestCodebookCorrectness:
|
||||
"""Test that centroids match Lloyd-Max for N(0, 1/128)."""
|
||||
|
||||
def test_centroids_extremes(self):
|
||||
"""Extreme centroids should cover tails of distribution."""
|
||||
min_c = min(TURBO4_CENTROIDS)
|
||||
max_c = max(TURBO4_CENTROIDS)
|
||||
# Should have reasonable range
|
||||
assert min_c < -0.2, f"Min centroid {min_c} should be < -0.2"
|
||||
assert max_c > 0.2, f"Max centroid {max_c} should be > 0.2"
|
||||
|
||||
def test_centroids_ordered(self):
|
||||
"""Centroids should be strictly increasing."""
|
||||
for i in range(len(TURBO4_CENTROIDS) - 1):
|
||||
assert TURBO4_CENTROIDS[i] < TURBO4_CENTROIDS[i + 1], f"Centroids not ordered at index {i}"
|
||||
|
||||
def test_centroids_cover_range(self):
|
||||
"""Centroids should cover reasonable range for N(0, 1/128)."""
|
||||
# For N(0, 1/128), std = 1/sqrt(128) ≈ 0.088
|
||||
# Centroids should cover roughly [-3*std, 3*std]
|
||||
min_c = min(TURBO4_CENTROIDS)
|
||||
max_c = max(TURBO4_CENTROIDS)
|
||||
|
||||
std = 1.0 / math.sqrt(128) # ≈ 0.088
|
||||
|
||||
assert min_c < -2 * std, f"Min centroid {min_c} should be < {-2*std}"
|
||||
assert max_c > 2 * std, f"Max centroid {max_c} should be > {2*std}"
|
||||
|
||||
def test_centroids_count(self):
|
||||
"""Should have exactly 16 centroids for 4-bit quantization."""
|
||||
assert len(TURBO4_CENTROIDS) == 16, f"Expected 16 centroids, got {len(TURBO4_CENTROIDS)}"
|
||||
|
||||
|
||||
class TestBitPacking:
|
||||
"""Test bit packing/unpacking correctness."""
|
||||
|
||||
def test_packing_roundtrip(self):
|
||||
"""Packing and unpacking should be lossless for 4-bit values."""
|
||||
d = 128
|
||||
|
||||
# Create test indices (0-15)
|
||||
indices = [i % 16 for i in range(d)]
|
||||
|
||||
# Pack
|
||||
packed = bytearray(d // 2)
|
||||
for i in range(d):
|
||||
if i % 2 == 0:
|
||||
packed[i // 2] = indices[i]
|
||||
else:
|
||||
packed[i // 2] |= indices[i] << 4
|
||||
|
||||
# Unpack
|
||||
unpacked = []
|
||||
for i in range(d):
|
||||
if i % 2 == 0:
|
||||
idx = packed[i // 2] & 0x0F
|
||||
else:
|
||||
idx = packed[i // 2] >> 4
|
||||
unpacked.append(idx)
|
||||
|
||||
assert unpacked == indices, "Packing/unpacking mismatch"
|
||||
|
||||
def test_packing_bounds(self):
|
||||
"""Packed values should fit in 4 bits (0-15)."""
|
||||
d = 128
|
||||
indices = [15] * d # Max value
|
||||
|
||||
packed = bytearray(d // 2)
|
||||
for i in range(d):
|
||||
if i % 2 == 0:
|
||||
packed[i // 2] = indices[i]
|
||||
else:
|
||||
packed[i // 2] |= indices[i] << 4
|
||||
|
||||
# Each byte should have both nibbles = 15
|
||||
for byte in packed:
|
||||
assert byte == 0xFF, f"Expected 0xFF, got {hex(byte)}"
|
||||
|
||||
def test_no_overflow(self):
|
||||
"""Packing should not overflow with valid 4-bit values."""
|
||||
d = 256 # Larger dimension
|
||||
|
||||
# All max values
|
||||
indices = [15] * d
|
||||
|
||||
packed = bytearray(d // 2)
|
||||
for i in range(d):
|
||||
if i % 2 == 0:
|
||||
packed[i // 2] = indices[i]
|
||||
else:
|
||||
packed[i // 2] |= indices[i] << 4
|
||||
|
||||
# Should not crash or produce invalid values
|
||||
assert len(packed) == d // 2
|
||||
|
||||
|
||||
class TestMemoryBounds:
|
||||
"""Test memory safety with various dimensions."""
|
||||
|
||||
def test_minimum_dimension(self):
|
||||
"""Should work with minimum dimension (2)."""
|
||||
d = 2
|
||||
src = [1.0, 0.5]
|
||||
packed, norm = polar_quant_encode(src)
|
||||
assert len(packed) == d // 2
|
||||
reconstructed = polar_quant_decode(packed, norm, d)
|
||||
assert len(reconstructed) == d
|
||||
|
||||
def test_large_dimension(self):
|
||||
"""Should work with large dimensions."""
|
||||
d = 1024
|
||||
src = [math.sin(i * 0.01) for i in range(d)]
|
||||
packed, norm = polar_quant_encode(src)
|
||||
assert len(packed) == d // 2
|
||||
reconstructed = polar_quant_decode(packed, norm, d)
|
||||
assert len(reconstructed) == d
|
||||
|
||||
def test_odd_dimension_fails(self):
|
||||
"""Odd dimensions should fail (need even for 4-bit packing)."""
|
||||
d = 127 # Odd
|
||||
src = [0.0] * d
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
polar_quant_encode(src)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user