Compare commits
1 Commits
step35/104
...
burn/54-17
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5ff8d1102f |
BIN
tests/__pycache__/test_polar_quant.cpython-312-pytest-9.0.2.pyc
Normal file
BIN
tests/__pycache__/test_polar_quant.cpython-312-pytest-9.0.2.pyc
Normal file
Binary file not shown.
423
tests/test_polar_quant.py
Normal file
423
tests/test_polar_quant.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""Unit tests for PolarQuant encode/decode.
|
||||
|
||||
Tests the core algorithms from llama-turbo.cpp using pure Python
|
||||
implementations that mirror the C++ logic. This ensures correctness
|
||||
without requiring C++ compilation.
|
||||
|
||||
Refs: #54 — [Tests] Add unit tests for PolarQuant encode/decode
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import struct
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PURE PYTHON IMPLEMENTATIONS (mirror llama-turbo.cpp)
|
||||
# ============================================================================
|
||||
|
||||
# Lloyd-Max Centroids for N(0, 1/d) where d=128, 4-bit (16 levels)
|
||||
TURBO4_CENTROIDS = [
|
||||
-0.2154, -0.1523, -0.1121, -0.0812,
|
||||
-0.0554, -0.0321, -0.0105, 0.0105,
|
||||
0.0321, 0.0554, 0.0812, 0.1121,
|
||||
0.1523, 0.2154, 0.2800, 0.3500,
|
||||
]
|
||||
|
||||
|
||||
def fwht(a: List[float]) -> List[float]:
|
||||
"""Fast Walsh-Hadamard Transform (in-place, normalized).
|
||||
|
||||
Mirrors the C++ implementation in llama-turbo.cpp:17-33.
|
||||
"""
|
||||
n = len(a)
|
||||
h = 1
|
||||
while h < n:
|
||||
for i in range(0, n, h * 2):
|
||||
for j in range(i, i + h):
|
||||
x = a[j]
|
||||
y = a[j + h]
|
||||
a[j] = x + y
|
||||
a[j + h] = x - y
|
||||
h <<= 1
|
||||
|
||||
# Normalize
|
||||
scale = 1.0 / math.sqrt(n)
|
||||
for i in range(n):
|
||||
a[i] *= scale
|
||||
|
||||
return a
|
||||
|
||||
|
||||
def quantize_value(val: float) -> int:
|
||||
"""Find nearest Lloyd-Max codebook index for a value."""
|
||||
best_idx = 0
|
||||
min_dist = abs(val - TURBO4_CENTROIDS[0])
|
||||
for j in range(1, 16):
|
||||
dist = abs(val - TURBO4_CENTROIDS[j])
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
best_idx = j
|
||||
return best_idx
|
||||
|
||||
|
||||
def polar_quant_encode_turbo4(src: List[float]) -> Tuple[bytes, float]:
|
||||
"""PolarQuant Turbo4 Encode (CPU reference).
|
||||
|
||||
Mirrors llama-turbo.cpp:36-68.
|
||||
Returns (packed_bytes, norm).
|
||||
"""
|
||||
d = len(src)
|
||||
rotated = list(src)
|
||||
fwht(rotated)
|
||||
|
||||
# L2 norm
|
||||
norm = math.sqrt(sum(x * x for x in rotated))
|
||||
|
||||
# Quantize
|
||||
inv_norm = 1.0 / (norm + 1e-9)
|
||||
indices = []
|
||||
for val in rotated:
|
||||
normalized = val * inv_norm
|
||||
idx = quantize_value(normalized)
|
||||
indices.append(idx)
|
||||
|
||||
# Pack 4-bit indices into bytes
|
||||
packed = bytearray(d // 2)
|
||||
for i in range(d):
|
||||
if i % 2 == 0:
|
||||
packed[i // 2] = indices[i] & 0x0F
|
||||
else:
|
||||
packed[i // 2] |= (indices[i] << 4) & 0xF0
|
||||
|
||||
return bytes(packed), norm
|
||||
|
||||
|
||||
def polar_quant_decode_turbo4(src: bytes, norm: float, d: int) -> List[float]:
|
||||
"""PolarQuant Turbo4 Decode (CPU reference).
|
||||
|
||||
Mirrors llama-turbo.cpp:71-78.
|
||||
"""
|
||||
dst = [0.0] * d
|
||||
for i in range(d):
|
||||
if i % 2 == 0:
|
||||
idx = src[i // 2] & 0x0F
|
||||
else:
|
||||
idx = src[i // 2] >> 4
|
||||
dst[i] = TURBO4_CENTROIDS[idx] * norm
|
||||
|
||||
# Inverse WHT = Forward WHT (orthogonal)
|
||||
fwht(dst)
|
||||
return dst
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TEST: ENCODE/DECODE ROUNDTRIP
|
||||
# ============================================================================
|
||||
|
||||
class TestEncodeDecodeRoundtrip:
|
||||
"""decode(encode(x)) ≈ x within tolerance."""
|
||||
|
||||
def test_identity_vector(self):
|
||||
"""Encode then decode a known vector recovers approximate original."""
|
||||
src = [1.0, 0.5, -0.3, 0.8, -0.2, 0.1, 0.7, -0.6,
|
||||
0.4, -0.1, 0.9, -0.4, 0.2, 0.3, -0.5, 0.0]
|
||||
assert len(src) == 16
|
||||
|
||||
packed, norm = polar_quant_encode_turbo4(src)
|
||||
recovered = polar_quant_decode_turbo4(packed, norm, 16)
|
||||
|
||||
# 4-bit quantization loses precision — expect ~5-10% error
|
||||
for orig, rec in zip(src, recovered):
|
||||
assert abs(orig - rec) < 0.5, f"Roundtrip error too large: {orig} -> {rec}"
|
||||
|
||||
def test_random_vector_128(self):
|
||||
"""Roundtrip on a 128-dim vector."""
|
||||
import random
|
||||
random.seed(42)
|
||||
src = [random.gauss(0, 0.1) for _ in range(128)]
|
||||
|
||||
packed, norm = polar_quant_encode_turbo4(src)
|
||||
recovered = polar_quant_decode_turbo4(packed, norm, 128)
|
||||
|
||||
# Compute relative error
|
||||
errors = [abs(o - r) for o, r in zip(src, recovered)]
|
||||
max_err = max(errors)
|
||||
mean_err = sum(errors) / len(errors)
|
||||
|
||||
assert max_err < 1.0, f"Max roundtrip error too large: {max_err}"
|
||||
assert mean_err < 0.3, f"Mean roundtrip error too large: {mean_err}"
|
||||
|
||||
def test_zero_vector(self):
|
||||
"""Zero vector roundtrips to zero."""
|
||||
src = [0.0] * 16
|
||||
packed, norm = polar_quant_encode_turbo4(src)
|
||||
recovered = polar_quant_decode_turbo4(packed, norm, 16)
|
||||
|
||||
# With norm=0, decoded values should be near zero
|
||||
for val in recovered:
|
||||
assert abs(val) < 0.01, f"Zero vector roundtrip produced non-zero: {val}"
|
||||
|
||||
def test_unit_vector(self):
|
||||
"""Single non-zero element roundtrips approximately."""
|
||||
src = [0.0] * 15 + [1.0]
|
||||
packed, norm = polar_quant_encode_turbo4(src)
|
||||
recovered = polar_quant_decode_turbo4(packed, norm, 16)
|
||||
|
||||
# Energy should be preserved approximately
|
||||
orig_energy = sum(x * x for x in src)
|
||||
rec_energy = sum(x * x for x in recovered)
|
||||
assert abs(orig_energy - rec_energy) / (orig_energy + 1e-9) < 0.5
|
||||
|
||||
@pytest.mark.parametrize("dim", [16, 32, 64, 128])
|
||||
def test_various_dimensions(self, dim: int):
|
||||
"""Roundtrip works for power-of-2 dimensions."""
|
||||
import random
|
||||
random.seed(dim)
|
||||
src = [random.gauss(0, 0.1) for _ in range(dim)]
|
||||
|
||||
packed, norm = polar_quant_encode_turbo4(src)
|
||||
assert len(packed) == dim // 2 # 4-bit packing
|
||||
assert norm > 0
|
||||
|
||||
recovered = polar_quant_decode_turbo4(packed, norm, dim)
|
||||
assert len(recovered) == dim
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TEST: INNER PRODUCT PRESERVATION
|
||||
# ============================================================================
|
||||
|
||||
class TestInnerProductPreservation:
|
||||
"""Q·K ≈ Q·dequant(quant(K)) — inner products preserved through compression."""
|
||||
|
||||
def test_inner_product_approximate(self):
|
||||
"""Inner product of two vectors is approximately preserved."""
|
||||
import random
|
||||
random.seed(123)
|
||||
q = [random.gauss(0, 0.1) for _ in range(128)]
|
||||
k = [random.gauss(0, 0.1) for _ in range(128)]
|
||||
|
||||
# True inner product
|
||||
true_ip = sum(a * b for a, b in zip(q, k))
|
||||
|
||||
# Quantize K
|
||||
k_packed, k_norm = polar_quant_encode_turbo4(k)
|
||||
k_dequant = polar_quant_decode_turbo4(k_packed, k_norm, 128)
|
||||
|
||||
# Compressed inner product
|
||||
compressed_ip = sum(a * b for a, b in zip(q, k_dequant))
|
||||
|
||||
# Inner product should be approximately preserved
|
||||
if abs(true_ip) > 1e-6:
|
||||
rel_error = abs(true_ip - compressed_ip) / abs(true_ip)
|
||||
assert rel_error < 0.75, f"Inner product error too large: {rel_error}"
|
||||
|
||||
def test_self_inner_product(self):
|
||||
"""Self inner product (norm squared) is approximately preserved."""
|
||||
import random
|
||||
random.seed(456)
|
||||
x = [random.gauss(0, 0.1) for _ in range(64)]
|
||||
|
||||
true_norm_sq = sum(a * a for a in x)
|
||||
|
||||
packed, norm = polar_quant_encode_turbo4(x)
|
||||
recovered = polar_quant_decode_turbo4(packed, norm, 64)
|
||||
|
||||
rec_norm_sq = sum(a * a for a in recovered)
|
||||
|
||||
if true_norm_sq > 1e-6:
|
||||
rel_error = abs(true_norm_sq - rec_norm_sq) / true_norm_sq
|
||||
assert rel_error < 0.5, f"Norm preservation error: {rel_error}"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TEST: WHT ORTHOGONALITY
|
||||
# ============================================================================
|
||||
|
||||
class TestWHTOrthogonality:
|
||||
"""WHT^T · WHT = I — the transform is orthogonal."""
|
||||
|
||||
def test_wht_is_orthogonal_16(self):
|
||||
"""Applying WHT twice (forward = inverse) recovers original."""
|
||||
src = [1.0, 0.0, -1.0, 0.5, 0.3, -0.2, 0.7, -0.8,
|
||||
0.1, 0.9, -0.4, 0.6, -0.3, 0.2, -0.7, 0.4]
|
||||
original = list(src)
|
||||
|
||||
# Apply WHT twice — should recover original (WHT^T = WHT for orthogonal)
|
||||
fwht(src)
|
||||
fwht(src)
|
||||
|
||||
for orig, rec in zip(original, src):
|
||||
assert abs(orig - rec) < 1e-6, f"WHT^2 != I: {orig} -> {rec}"
|
||||
|
||||
def test_wht_preserves_norm(self):
|
||||
"""WHT preserves L2 norm (isometry)."""
|
||||
import random
|
||||
random.seed(789)
|
||||
src = [random.gauss(0, 1.0) for _ in range(64)]
|
||||
|
||||
orig_norm_sq = sum(x * x for x in src)
|
||||
fwht(src)
|
||||
wht_norm_sq = sum(x * x for x in src)
|
||||
|
||||
# WHT with 1/sqrt(n) normalization should preserve norm
|
||||
assert abs(orig_norm_sq - wht_norm_sq) < 1e-4, (
|
||||
f"WHT doesn't preserve norm: {orig_norm_sq} -> {wht_norm_sq}"
|
||||
)
|
||||
|
||||
def test_wht_identity_vector(self):
|
||||
"""WHT of [1,0,0,...] produces equal components."""
|
||||
src = [1.0] + [0.0] * 15
|
||||
fwht(src)
|
||||
|
||||
# All components should be 1/sqrt(16) = 0.25
|
||||
expected = 1.0 / math.sqrt(16)
|
||||
for val in src:
|
||||
assert abs(val - expected) < 1e-6, f"WHT identity vector wrong: {val}"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TEST: CODEBOOK CORRECTNESS
|
||||
# ============================================================================
|
||||
|
||||
class TestCodebookCorrectness:
|
||||
"""Centroids match Lloyd-Max for N(0, 1/128)."""
|
||||
|
||||
def test_codebook_has_16_entries(self):
|
||||
"""4-bit codebook has exactly 16 centroids."""
|
||||
assert len(TURBO4_CENTROIDS) == 16
|
||||
|
||||
def test_codebook_is_symmetric(self):
|
||||
"""Centroids should be approximately symmetric around zero."""
|
||||
for i in range(8):
|
||||
neg = TURBO4_CENTROIDS[i]
|
||||
pos = TURBO4_CENTROIDS[15 - i]
|
||||
# Symmetric: neg ≈ -pos (approximately)
|
||||
assert abs(neg + pos) < 0.2, (
|
||||
f"Codebook not symmetric: centroid[{i}]={neg}, centroid[{15-i}]={pos}"
|
||||
)
|
||||
|
||||
def test_codebook_is_ordered(self):
|
||||
"""Centroids must be in ascending order."""
|
||||
for i in range(1, 16):
|
||||
assert TURBO4_CENTROIDS[i] > TURBO4_CENTROIDS[i - 1], (
|
||||
f"Codebook not ordered: {TURBO4_CENTROIDS[i-1]} >= {TURBO4_CENTROIDS[i]}"
|
||||
)
|
||||
|
||||
def test_codebook_covers_unit_range(self):
|
||||
"""Codebook should span approximately [-0.35, 0.35]."""
|
||||
assert TURBO4_CENTROIDS[0] < -0.15, f"Min centroid too high: {TURBO4_CENTROIDS[0]}"
|
||||
assert TURBO4_CENTROIDS[-1] > 0.25, f"Max centroid too low: {TURBO4_CENTROIDS[-1]}"
|
||||
|
||||
def test_quantize_maps_to_valid_indices(self):
|
||||
"""All quantized values map to valid 4-bit indices [0, 15]."""
|
||||
for val in [-1.0, -0.5, -0.1, 0.0, 0.1, 0.5, 1.0]:
|
||||
idx = quantize_value(val)
|
||||
assert 0 <= idx <= 15, f"Index out of range: {idx} for value {val}"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TEST: BIT PACKING / MEMORY BOUNDS
|
||||
# ============================================================================
|
||||
|
||||
class TestBitPacking:
|
||||
"""No buffer overflows in bit packing."""
|
||||
|
||||
def test_packed_size_is_half(self):
|
||||
"""4-bit packing halves the byte count."""
|
||||
for dim in [16, 32, 64, 128]:
|
||||
import random
|
||||
random.seed(dim)
|
||||
src = [random.gauss(0, 0.1) for _ in range(dim)]
|
||||
packed, _ = polar_quant_encode_turbo4(src)
|
||||
assert len(packed) == dim // 2
|
||||
|
||||
def test_even_index_in_low_nibble(self):
|
||||
"""Even-indexed values go in low nibble (bits 0-3)."""
|
||||
# Encode a vector where even indices are 0, odd are 1
|
||||
src = [0.0 if i % 2 == 0 else 1.0 for i in range(16)]
|
||||
packed, _ = polar_quant_encode_turbo4(src)
|
||||
|
||||
# Check that odd values are in high nibble
|
||||
for i in range(0, 16, 2):
|
||||
byte_idx = i // 2
|
||||
low_nibble = packed[byte_idx] & 0x0F
|
||||
high_nibble = (packed[byte_idx] >> 4) & 0x0F
|
||||
# Low nibble should have the centroid for 0.0
|
||||
# High nibble should have the centroid for 1.0
|
||||
assert 0 <= low_nibble <= 15
|
||||
assert 0 <= high_nibble <= 15
|
||||
|
||||
def test_decode_extracts_correct_nibbles(self):
|
||||
"""Decode correctly unpacks low and high nibbles before WHT."""
|
||||
# Test the unpacking logic directly (before WHT is applied)
|
||||
# byte = 0xAB (171): low nibble = 0xB (11), high nibble = 0xA (10)
|
||||
packed = bytes([0xAB])
|
||||
|
||||
# Manually unpack to verify nibble extraction
|
||||
low_idx = packed[0] & 0x0F # 0xAB & 0x0F = 0xB = 11
|
||||
high_idx = packed[0] >> 4 # 0xAB >> 4 = 0xA = 10
|
||||
|
||||
assert low_idx == 11, f"Low nibble wrong: {low_idx}"
|
||||
assert high_idx == 10, f"High nibble wrong: {high_idx}"
|
||||
|
||||
# Verify centroid lookup
|
||||
assert abs(TURBO4_CENTROIDS[low_idx] - TURBO4_CENTROIDS[11]) < 1e-9
|
||||
assert abs(TURBO4_CENTROIDS[high_idx] - TURBO4_CENTROIDS[10]) < 1e-9
|
||||
|
||||
def test_max_dimension_no_overflow(self):
|
||||
"""No overflow with maximum typical dimension (4096)."""
|
||||
dim = 4096
|
||||
import random
|
||||
random.seed(999)
|
||||
src = [random.gauss(0, 0.1) for _ in range(dim)]
|
||||
packed, norm = polar_quant_encode_turbo4(src)
|
||||
recovered = polar_quant_decode_turbo4(packed, norm, dim)
|
||||
assert len(recovered) == dim
|
||||
assert all(math.isfinite(x) for x in recovered)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TEST: EDGE CASES
|
||||
# ============================================================================
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Edge cases and boundary conditions."""
|
||||
|
||||
def test_single_element_vector_fails(self):
|
||||
"""Non-power-of-2 dimension should still work (or fail gracefully)."""
|
||||
# The WHT requires power-of-2, but encode should handle it
|
||||
with pytest.raises((ValueError, IndexError, ZeroDivisionError)):
|
||||
src = [1.0] # dim=1, can't do WHT properly
|
||||
polar_quant_encode_turbo4(src)
|
||||
|
||||
def test_all_same_values(self):
|
||||
"""Vector with all identical values."""
|
||||
src = [0.5] * 16
|
||||
packed, norm = polar_quant_encode_turbo4(src)
|
||||
recovered = polar_quant_decode_turbo4(packed, norm, 16)
|
||||
# All recovered values should be approximately equal
|
||||
mean = sum(recovered) / len(recovered)
|
||||
for val in recovered:
|
||||
assert abs(val - mean) < 0.1
|
||||
|
||||
def test_large_values(self):
|
||||
"""Large input values don't cause NaN/Inf."""
|
||||
src = [100.0, -100.0, 50.0, -50.0, 25.0, -25.0, 10.0, -10.0,
|
||||
5.0, -5.0, 2.0, -2.0, 1.0, -1.0, 0.5, -0.5]
|
||||
packed, norm = polar_quant_encode_turbo4(src)
|
||||
assert math.isfinite(norm)
|
||||
recovered = polar_quant_decode_turbo4(packed, norm, 16)
|
||||
assert all(math.isfinite(x) for x in recovered)
|
||||
|
||||
def test_alternating_signs(self):
|
||||
"""Alternating positive/negative values."""
|
||||
src = [(-1) ** i * 0.1 for i in range(16)]
|
||||
packed, norm = polar_quant_encode_turbo4(src)
|
||||
recovered = polar_quant_decode_turbo4(packed, norm, 16)
|
||||
assert all(math.isfinite(x) for x in recovered)
|
||||
Reference in New Issue
Block a user