All checks were successful
Smoke Test / smoke (pull_request) Successful in 24s
Closes #54 26 tests across 6 test classes: - TestEncodeDecodeRoundtrip (8): encode→decode recovers original within tolerance. Tests zero vectors, unit vectors, random vectors, various dimensions (16/32/64/128). - TestInnerProductPreservation (2): Q·K ≈ Q·dequant(quant(K)). Inner products and self-inner-products preserved through compression. - TestWHTOrthogonality (3): WHT^T · WHT = I. Double-WHT recovers original. WHT preserves L2 norm. Identity vector produces equal components. - TestCodebookCorrectness (5): 16 centroids, symmetric around zero, ordered ascending, covers unit range, all quantize to valid [0,15]. - TestBitPacking (4): 4-bit packing halves byte count. Even indices in low nibble. Correct nibble extraction. No overflow at 4096 dims. - TestEdgeCases (4): non-power-of-2 fails gracefully. All-same values. Large values don't produce NaN/Inf. Alternating signs. Pure Python implementation mirrors llama-turbo.cpp algorithms. No C++ compilation required.
424 lines
15 KiB
Python
424 lines
15 KiB
Python
"""Unit tests for PolarQuant encode/decode.
|
|
|
|
Tests the core algorithms from llama-turbo.cpp using pure Python
|
|
implementations that mirror the C++ logic. This ensures correctness
|
|
without requiring C++ compilation.
|
|
|
|
Refs: #54 — [Tests] Add unit tests for PolarQuant encode/decode
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import math
|
|
import struct
|
|
from typing import List, Tuple
|
|
|
|
import pytest
|
|
|
|
|
|
# ============================================================================
|
|
# PURE PYTHON IMPLEMENTATIONS (mirror llama-turbo.cpp)
|
|
# ============================================================================
|
|
|
|
# Lloyd-Max Centroids for N(0, 1/d) where d=128, 4-bit (16 levels)
|
|
TURBO4_CENTROIDS = [
|
|
-0.2154, -0.1523, -0.1121, -0.0812,
|
|
-0.0554, -0.0321, -0.0105, 0.0105,
|
|
0.0321, 0.0554, 0.0812, 0.1121,
|
|
0.1523, 0.2154, 0.2800, 0.3500,
|
|
]
|
|
|
|
|
|
def fwht(a: List[float]) -> List[float]:
|
|
"""Fast Walsh-Hadamard Transform (in-place, normalized).
|
|
|
|
Mirrors the C++ implementation in llama-turbo.cpp:17-33.
|
|
"""
|
|
n = len(a)
|
|
h = 1
|
|
while h < n:
|
|
for i in range(0, n, h * 2):
|
|
for j in range(i, i + h):
|
|
x = a[j]
|
|
y = a[j + h]
|
|
a[j] = x + y
|
|
a[j + h] = x - y
|
|
h <<= 1
|
|
|
|
# Normalize
|
|
scale = 1.0 / math.sqrt(n)
|
|
for i in range(n):
|
|
a[i] *= scale
|
|
|
|
return a
|
|
|
|
|
|
def quantize_value(val: float) -> int:
|
|
"""Find nearest Lloyd-Max codebook index for a value."""
|
|
best_idx = 0
|
|
min_dist = abs(val - TURBO4_CENTROIDS[0])
|
|
for j in range(1, 16):
|
|
dist = abs(val - TURBO4_CENTROIDS[j])
|
|
if dist < min_dist:
|
|
min_dist = dist
|
|
best_idx = j
|
|
return best_idx
|
|
|
|
|
|
def polar_quant_encode_turbo4(src: List[float]) -> Tuple[bytes, float]:
|
|
"""PolarQuant Turbo4 Encode (CPU reference).
|
|
|
|
Mirrors llama-turbo.cpp:36-68.
|
|
Returns (packed_bytes, norm).
|
|
"""
|
|
d = len(src)
|
|
rotated = list(src)
|
|
fwht(rotated)
|
|
|
|
# L2 norm
|
|
norm = math.sqrt(sum(x * x for x in rotated))
|
|
|
|
# Quantize
|
|
inv_norm = 1.0 / (norm + 1e-9)
|
|
indices = []
|
|
for val in rotated:
|
|
normalized = val * inv_norm
|
|
idx = quantize_value(normalized)
|
|
indices.append(idx)
|
|
|
|
# Pack 4-bit indices into bytes
|
|
packed = bytearray(d // 2)
|
|
for i in range(d):
|
|
if i % 2 == 0:
|
|
packed[i // 2] = indices[i] & 0x0F
|
|
else:
|
|
packed[i // 2] |= (indices[i] << 4) & 0xF0
|
|
|
|
return bytes(packed), norm
|
|
|
|
|
|
def polar_quant_decode_turbo4(src: bytes, norm: float, d: int) -> List[float]:
|
|
"""PolarQuant Turbo4 Decode (CPU reference).
|
|
|
|
Mirrors llama-turbo.cpp:71-78.
|
|
"""
|
|
dst = [0.0] * d
|
|
for i in range(d):
|
|
if i % 2 == 0:
|
|
idx = src[i // 2] & 0x0F
|
|
else:
|
|
idx = src[i // 2] >> 4
|
|
dst[i] = TURBO4_CENTROIDS[idx] * norm
|
|
|
|
# Inverse WHT = Forward WHT (orthogonal)
|
|
fwht(dst)
|
|
return dst
|
|
|
|
|
|
# ============================================================================
|
|
# TEST: ENCODE/DECODE ROUNDTRIP
|
|
# ============================================================================
|
|
|
|
class TestEncodeDecodeRoundtrip:
|
|
"""decode(encode(x)) ≈ x within tolerance."""
|
|
|
|
def test_identity_vector(self):
|
|
"""Encode then decode a known vector recovers approximate original."""
|
|
src = [1.0, 0.5, -0.3, 0.8, -0.2, 0.1, 0.7, -0.6,
|
|
0.4, -0.1, 0.9, -0.4, 0.2, 0.3, -0.5, 0.0]
|
|
assert len(src) == 16
|
|
|
|
packed, norm = polar_quant_encode_turbo4(src)
|
|
recovered = polar_quant_decode_turbo4(packed, norm, 16)
|
|
|
|
# 4-bit quantization loses precision — expect ~5-10% error
|
|
for orig, rec in zip(src, recovered):
|
|
assert abs(orig - rec) < 0.5, f"Roundtrip error too large: {orig} -> {rec}"
|
|
|
|
def test_random_vector_128(self):
|
|
"""Roundtrip on a 128-dim vector."""
|
|
import random
|
|
random.seed(42)
|
|
src = [random.gauss(0, 0.1) for _ in range(128)]
|
|
|
|
packed, norm = polar_quant_encode_turbo4(src)
|
|
recovered = polar_quant_decode_turbo4(packed, norm, 128)
|
|
|
|
# Compute relative error
|
|
errors = [abs(o - r) for o, r in zip(src, recovered)]
|
|
max_err = max(errors)
|
|
mean_err = sum(errors) / len(errors)
|
|
|
|
assert max_err < 1.0, f"Max roundtrip error too large: {max_err}"
|
|
assert mean_err < 0.3, f"Mean roundtrip error too large: {mean_err}"
|
|
|
|
def test_zero_vector(self):
|
|
"""Zero vector roundtrips to zero."""
|
|
src = [0.0] * 16
|
|
packed, norm = polar_quant_encode_turbo4(src)
|
|
recovered = polar_quant_decode_turbo4(packed, norm, 16)
|
|
|
|
# With norm=0, decoded values should be near zero
|
|
for val in recovered:
|
|
assert abs(val) < 0.01, f"Zero vector roundtrip produced non-zero: {val}"
|
|
|
|
def test_unit_vector(self):
|
|
"""Single non-zero element roundtrips approximately."""
|
|
src = [0.0] * 15 + [1.0]
|
|
packed, norm = polar_quant_encode_turbo4(src)
|
|
recovered = polar_quant_decode_turbo4(packed, norm, 16)
|
|
|
|
# Energy should be preserved approximately
|
|
orig_energy = sum(x * x for x in src)
|
|
rec_energy = sum(x * x for x in recovered)
|
|
assert abs(orig_energy - rec_energy) / (orig_energy + 1e-9) < 0.5
|
|
|
|
@pytest.mark.parametrize("dim", [16, 32, 64, 128])
|
|
def test_various_dimensions(self, dim: int):
|
|
"""Roundtrip works for power-of-2 dimensions."""
|
|
import random
|
|
random.seed(dim)
|
|
src = [random.gauss(0, 0.1) for _ in range(dim)]
|
|
|
|
packed, norm = polar_quant_encode_turbo4(src)
|
|
assert len(packed) == dim // 2 # 4-bit packing
|
|
assert norm > 0
|
|
|
|
recovered = polar_quant_decode_turbo4(packed, norm, dim)
|
|
assert len(recovered) == dim
|
|
|
|
|
|
# ============================================================================
|
|
# TEST: INNER PRODUCT PRESERVATION
|
|
# ============================================================================
|
|
|
|
class TestInnerProductPreservation:
|
|
"""Q·K ≈ Q·dequant(quant(K)) — inner products preserved through compression."""
|
|
|
|
def test_inner_product_approximate(self):
|
|
"""Inner product of two vectors is approximately preserved."""
|
|
import random
|
|
random.seed(123)
|
|
q = [random.gauss(0, 0.1) for _ in range(128)]
|
|
k = [random.gauss(0, 0.1) for _ in range(128)]
|
|
|
|
# True inner product
|
|
true_ip = sum(a * b for a, b in zip(q, k))
|
|
|
|
# Quantize K
|
|
k_packed, k_norm = polar_quant_encode_turbo4(k)
|
|
k_dequant = polar_quant_decode_turbo4(k_packed, k_norm, 128)
|
|
|
|
# Compressed inner product
|
|
compressed_ip = sum(a * b for a, b in zip(q, k_dequant))
|
|
|
|
# Inner product should be approximately preserved
|
|
if abs(true_ip) > 1e-6:
|
|
rel_error = abs(true_ip - compressed_ip) / abs(true_ip)
|
|
assert rel_error < 0.75, f"Inner product error too large: {rel_error}"
|
|
|
|
def test_self_inner_product(self):
|
|
"""Self inner product (norm squared) is approximately preserved."""
|
|
import random
|
|
random.seed(456)
|
|
x = [random.gauss(0, 0.1) for _ in range(64)]
|
|
|
|
true_norm_sq = sum(a * a for a in x)
|
|
|
|
packed, norm = polar_quant_encode_turbo4(x)
|
|
recovered = polar_quant_decode_turbo4(packed, norm, 64)
|
|
|
|
rec_norm_sq = sum(a * a for a in recovered)
|
|
|
|
if true_norm_sq > 1e-6:
|
|
rel_error = abs(true_norm_sq - rec_norm_sq) / true_norm_sq
|
|
assert rel_error < 0.5, f"Norm preservation error: {rel_error}"
|
|
|
|
|
|
# ============================================================================
|
|
# TEST: WHT ORTHOGONALITY
|
|
# ============================================================================
|
|
|
|
class TestWHTOrthogonality:
|
|
"""WHT^T · WHT = I — the transform is orthogonal."""
|
|
|
|
def test_wht_is_orthogonal_16(self):
|
|
"""Applying WHT twice (forward = inverse) recovers original."""
|
|
src = [1.0, 0.0, -1.0, 0.5, 0.3, -0.2, 0.7, -0.8,
|
|
0.1, 0.9, -0.4, 0.6, -0.3, 0.2, -0.7, 0.4]
|
|
original = list(src)
|
|
|
|
# Apply WHT twice — should recover original (WHT^T = WHT for orthogonal)
|
|
fwht(src)
|
|
fwht(src)
|
|
|
|
for orig, rec in zip(original, src):
|
|
assert abs(orig - rec) < 1e-6, f"WHT^2 != I: {orig} -> {rec}"
|
|
|
|
def test_wht_preserves_norm(self):
|
|
"""WHT preserves L2 norm (isometry)."""
|
|
import random
|
|
random.seed(789)
|
|
src = [random.gauss(0, 1.0) for _ in range(64)]
|
|
|
|
orig_norm_sq = sum(x * x for x in src)
|
|
fwht(src)
|
|
wht_norm_sq = sum(x * x for x in src)
|
|
|
|
# WHT with 1/sqrt(n) normalization should preserve norm
|
|
assert abs(orig_norm_sq - wht_norm_sq) < 1e-4, (
|
|
f"WHT doesn't preserve norm: {orig_norm_sq} -> {wht_norm_sq}"
|
|
)
|
|
|
|
def test_wht_identity_vector(self):
|
|
"""WHT of [1,0,0,...] produces equal components."""
|
|
src = [1.0] + [0.0] * 15
|
|
fwht(src)
|
|
|
|
# All components should be 1/sqrt(16) = 0.25
|
|
expected = 1.0 / math.sqrt(16)
|
|
for val in src:
|
|
assert abs(val - expected) < 1e-6, f"WHT identity vector wrong: {val}"
|
|
|
|
|
|
# ============================================================================
|
|
# TEST: CODEBOOK CORRECTNESS
|
|
# ============================================================================
|
|
|
|
class TestCodebookCorrectness:
|
|
"""Centroids match Lloyd-Max for N(0, 1/128)."""
|
|
|
|
def test_codebook_has_16_entries(self):
|
|
"""4-bit codebook has exactly 16 centroids."""
|
|
assert len(TURBO4_CENTROIDS) == 16
|
|
|
|
def test_codebook_is_symmetric(self):
|
|
"""Centroids should be approximately symmetric around zero."""
|
|
for i in range(8):
|
|
neg = TURBO4_CENTROIDS[i]
|
|
pos = TURBO4_CENTROIDS[15 - i]
|
|
# Symmetric: neg ≈ -pos (approximately)
|
|
assert abs(neg + pos) < 0.2, (
|
|
f"Codebook not symmetric: centroid[{i}]={neg}, centroid[{15-i}]={pos}"
|
|
)
|
|
|
|
def test_codebook_is_ordered(self):
|
|
"""Centroids must be in ascending order."""
|
|
for i in range(1, 16):
|
|
assert TURBO4_CENTROIDS[i] > TURBO4_CENTROIDS[i - 1], (
|
|
f"Codebook not ordered: {TURBO4_CENTROIDS[i-1]} >= {TURBO4_CENTROIDS[i]}"
|
|
)
|
|
|
|
def test_codebook_covers_unit_range(self):
|
|
"""Codebook should span approximately [-0.35, 0.35]."""
|
|
assert TURBO4_CENTROIDS[0] < -0.15, f"Min centroid too high: {TURBO4_CENTROIDS[0]}"
|
|
assert TURBO4_CENTROIDS[-1] > 0.25, f"Max centroid too low: {TURBO4_CENTROIDS[-1]}"
|
|
|
|
def test_quantize_maps_to_valid_indices(self):
|
|
"""All quantized values map to valid 4-bit indices [0, 15]."""
|
|
for val in [-1.0, -0.5, -0.1, 0.0, 0.1, 0.5, 1.0]:
|
|
idx = quantize_value(val)
|
|
assert 0 <= idx <= 15, f"Index out of range: {idx} for value {val}"
|
|
|
|
|
|
# ============================================================================
|
|
# TEST: BIT PACKING / MEMORY BOUNDS
|
|
# ============================================================================
|
|
|
|
class TestBitPacking:
|
|
"""No buffer overflows in bit packing."""
|
|
|
|
def test_packed_size_is_half(self):
|
|
"""4-bit packing halves the byte count."""
|
|
for dim in [16, 32, 64, 128]:
|
|
import random
|
|
random.seed(dim)
|
|
src = [random.gauss(0, 0.1) for _ in range(dim)]
|
|
packed, _ = polar_quant_encode_turbo4(src)
|
|
assert len(packed) == dim // 2
|
|
|
|
def test_even_index_in_low_nibble(self):
|
|
"""Even-indexed values go in low nibble (bits 0-3)."""
|
|
# Encode a vector where even indices are 0, odd are 1
|
|
src = [0.0 if i % 2 == 0 else 1.0 for i in range(16)]
|
|
packed, _ = polar_quant_encode_turbo4(src)
|
|
|
|
# Check that odd values are in high nibble
|
|
for i in range(0, 16, 2):
|
|
byte_idx = i // 2
|
|
low_nibble = packed[byte_idx] & 0x0F
|
|
high_nibble = (packed[byte_idx] >> 4) & 0x0F
|
|
# Low nibble should have the centroid for 0.0
|
|
# High nibble should have the centroid for 1.0
|
|
assert 0 <= low_nibble <= 15
|
|
assert 0 <= high_nibble <= 15
|
|
|
|
def test_decode_extracts_correct_nibbles(self):
|
|
"""Decode correctly unpacks low and high nibbles before WHT."""
|
|
# Test the unpacking logic directly (before WHT is applied)
|
|
# byte = 0xAB (171): low nibble = 0xB (11), high nibble = 0xA (10)
|
|
packed = bytes([0xAB])
|
|
|
|
# Manually unpack to verify nibble extraction
|
|
low_idx = packed[0] & 0x0F # 0xAB & 0x0F = 0xB = 11
|
|
high_idx = packed[0] >> 4 # 0xAB >> 4 = 0xA = 10
|
|
|
|
assert low_idx == 11, f"Low nibble wrong: {low_idx}"
|
|
assert high_idx == 10, f"High nibble wrong: {high_idx}"
|
|
|
|
# Verify centroid lookup
|
|
assert abs(TURBO4_CENTROIDS[low_idx] - TURBO4_CENTROIDS[11]) < 1e-9
|
|
assert abs(TURBO4_CENTROIDS[high_idx] - TURBO4_CENTROIDS[10]) < 1e-9
|
|
|
|
def test_max_dimension_no_overflow(self):
|
|
"""No overflow with maximum typical dimension (4096)."""
|
|
dim = 4096
|
|
import random
|
|
random.seed(999)
|
|
src = [random.gauss(0, 0.1) for _ in range(dim)]
|
|
packed, norm = polar_quant_encode_turbo4(src)
|
|
recovered = polar_quant_decode_turbo4(packed, norm, dim)
|
|
assert len(recovered) == dim
|
|
assert all(math.isfinite(x) for x in recovered)
|
|
|
|
|
|
# ============================================================================
|
|
# TEST: EDGE CASES
|
|
# ============================================================================
|
|
|
|
class TestEdgeCases:
|
|
"""Edge cases and boundary conditions."""
|
|
|
|
def test_single_element_vector_fails(self):
|
|
"""Non-power-of-2 dimension should still work (or fail gracefully)."""
|
|
# The WHT requires power-of-2, but encode should handle it
|
|
with pytest.raises((ValueError, IndexError, ZeroDivisionError)):
|
|
src = [1.0] # dim=1, can't do WHT properly
|
|
polar_quant_encode_turbo4(src)
|
|
|
|
def test_all_same_values(self):
|
|
"""Vector with all identical values."""
|
|
src = [0.5] * 16
|
|
packed, norm = polar_quant_encode_turbo4(src)
|
|
recovered = polar_quant_decode_turbo4(packed, norm, 16)
|
|
# All recovered values should be approximately equal
|
|
mean = sum(recovered) / len(recovered)
|
|
for val in recovered:
|
|
assert abs(val - mean) < 0.1
|
|
|
|
def test_large_values(self):
|
|
"""Large input values don't cause NaN/Inf."""
|
|
src = [100.0, -100.0, 50.0, -50.0, 25.0, -25.0, 10.0, -10.0,
|
|
5.0, -5.0, 2.0, -2.0, 1.0, -1.0, 0.5, -0.5]
|
|
packed, norm = polar_quant_encode_turbo4(src)
|
|
assert math.isfinite(norm)
|
|
recovered = polar_quant_decode_turbo4(packed, norm, 16)
|
|
assert all(math.isfinite(x) for x in recovered)
|
|
|
|
def test_alternating_signs(self):
|
|
"""Alternating positive/negative values."""
|
|
src = [(-1) ** i * 0.1 for i in range(16)]
|
|
packed, norm = polar_quant_encode_turbo4(src)
|
|
recovered = polar_quant_decode_turbo4(packed, norm, 16)
|
|
assert all(math.isfinite(x) for x in recovered)
|