Compare commits
1 Commits
step35/104
...
burn/54-17
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4b272f2277 |
374
tests/test_polar_quant.py
Normal file
374
tests/test_polar_quant.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""Unit tests for PolarQuant encode/decode (#54).
|
||||
|
||||
Tests the core PolarQuant compression functions:
|
||||
1. Encode/Decode Roundtrip: decode(encode(x)) ≈ x
|
||||
2. Inner Product Preservation: Q·K ≈ Q·dequant(quant(K))
|
||||
3. WHT Orthogonality: WHT^T · WHT = I
|
||||
4. Codebook Correctness: Centroids match Lloyd-Max for N(0, 1/128)
|
||||
5. Memory Bounds: No buffer overflows in bit packing
|
||||
|
||||
This is a Python reference implementation for testing. The actual
|
||||
C++ implementation in llama-turbo.cpp should produce identical results.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reference implementations (Python mirrors of C++ code)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Lloyd-Max Centroids for N(0, 1/d) where d=128, 4-bit (16 levels)
|
||||
TURBO4_CENTROIDS = np.array([
|
||||
-0.2154, -0.1523, -0.1121, -0.0812,
|
||||
-0.0554, -0.0321, -0.0105, 0.0105,
|
||||
0.0321, 0.0554, 0.0812, 0.1121,
|
||||
0.1523, 0.2154, 0.2800, 0.3500,
|
||||
], dtype=np.float32)
|
||||
|
||||
|
||||
def fwht(a: np.ndarray) -> np.ndarray:
|
||||
"""Fast Walsh-Hadamard Transform (in-place clone)."""
|
||||
a = a.copy()
|
||||
n = len(a)
|
||||
h = 1
|
||||
while h < n:
|
||||
for i in range(0, n, h * 2):
|
||||
for j in range(i, i + h):
|
||||
x = a[j]
|
||||
y = a[j + h]
|
||||
a[j] = x + y
|
||||
a[j + h] = x - y
|
||||
h <<= 1
|
||||
# Normalize
|
||||
a *= 1.0 / np.sqrt(n)
|
||||
return a
|
||||
|
||||
|
||||
def polar_quant_encode(src: np.ndarray) -> tuple[np.ndarray, float]:
|
||||
"""PolarQuant encode (4-bit quantization).
|
||||
|
||||
Returns:
|
||||
(packed_bytes, norm) — packed is uint8 array of length d/2
|
||||
"""
|
||||
d = len(src)
|
||||
|
||||
# Apply WHT
|
||||
rotated = fwht(src)
|
||||
|
||||
# Compute L2 norm
|
||||
norm = np.linalg.norm(rotated)
|
||||
|
||||
# Normalize and quantize
|
||||
normalized = rotated / (norm + 1e-9)
|
||||
|
||||
# Find nearest centroid for each element
|
||||
indices = np.zeros(d, dtype=np.int32)
|
||||
for i in range(d):
|
||||
distances = np.abs(normalized[i] - TURBO4_CENTROIDS)
|
||||
indices[i] = np.argmin(distances)
|
||||
|
||||
# Pack 4-bit indices into bytes
|
||||
packed = np.zeros(d // 2, dtype=np.uint8)
|
||||
for i in range(0, d, 2):
|
||||
packed[i // 2] = indices[i] | (indices[i + 1] << 4)
|
||||
|
||||
return packed, norm
|
||||
|
||||
|
||||
def polar_quant_decode(packed: np.ndarray, norm: float, d: int) -> np.ndarray:
|
||||
"""PolarQuant decode (4-bit dequantization).
|
||||
|
||||
Args:
|
||||
packed: uint8 array of length d/2
|
||||
norm: L2 norm from encode
|
||||
d: original dimension
|
||||
|
||||
Returns:
|
||||
Reconstructed float array of length d
|
||||
"""
|
||||
# Unpack 4-bit indices
|
||||
indices = np.zeros(d, dtype=np.int32)
|
||||
for i in range(d):
|
||||
if i % 2 == 0:
|
||||
indices[i] = packed[i // 2] & 0x0F
|
||||
else:
|
||||
indices[i] = packed[i // 2] >> 4
|
||||
|
||||
# Reconstruct from centroids
|
||||
dst = TURBO4_CENTROIDS[indices] * norm
|
||||
|
||||
# Inverse WHT (same as forward for orthogonal matrices)
|
||||
dst = fwht(dst)
|
||||
|
||||
return dst
|
||||
|
||||
|
||||
def inner_product(a: np.ndarray, b: np.ndarray) -> float:
|
||||
"""Compute inner product."""
|
||||
return float(np.dot(a, b))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Encode/Decode Roundtrip
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEncodeDecodeRoundtrip:
|
||||
"""decode(encode(x)) ≈ x"""
|
||||
|
||||
def test_roundtrip_d64(self):
|
||||
np.random.seed(42)
|
||||
x = np.random.randn(64).astype(np.float32)
|
||||
packed, norm = polar_quant_encode(x)
|
||||
recovered = polar_quant_decode(packed, norm, 64)
|
||||
|
||||
# Should recover within quantization error
|
||||
error = np.max(np.abs(x - recovered))
|
||||
assert error < 1.0, f"Roundtrip error too large: {error}"
|
||||
|
||||
def test_roundtrip_d128(self):
|
||||
np.random.seed(42)
|
||||
x = np.random.randn(128).astype(np.float32)
|
||||
packed, norm = polar_quant_encode(x)
|
||||
recovered = polar_quant_decode(packed, norm, 128)
|
||||
|
||||
error = np.max(np.abs(x - recovered))
|
||||
assert error < 1.5, f"Roundtrip error too large: {error}"
|
||||
|
||||
def test_roundtrip_d256(self):
|
||||
np.random.seed(42)
|
||||
x = np.random.randn(256).astype(np.float32)
|
||||
packed, norm = polar_quant_encode(x)
|
||||
recovered = polar_quant_decode(packed, norm, 256)
|
||||
|
||||
error = np.max(np.abs(x - recovered))
|
||||
assert error < 2.0, f"Roundtrip error too large: {error}"
|
||||
|
||||
def test_roundtrip_zero_vector(self):
|
||||
x = np.zeros(128, dtype=np.float32)
|
||||
packed, norm = polar_quant_encode(x)
|
||||
recovered = polar_quant_decode(packed, norm, 128)
|
||||
|
||||
# Zero vector should recover to near-zero
|
||||
assert np.max(np.abs(recovered)) < 0.01
|
||||
|
||||
def test_roundtrip_unit_vector(self):
|
||||
x = np.zeros(128, dtype=np.float32)
|
||||
x[0] = 1.0
|
||||
packed, norm = polar_quant_encode(x)
|
||||
recovered = polar_quant_decode(packed, norm, 128)
|
||||
|
||||
error = np.max(np.abs(x - recovered))
|
||||
assert error < 1.0
|
||||
|
||||
def test_roundtrip_magnitude_preserved(self):
|
||||
"""Large values should recover with similar magnitude."""
|
||||
x = np.array([10.0, -10.0, 5.0, -5.0] * 32, dtype=np.float32)
|
||||
packed, norm = polar_quant_encode(x)
|
||||
recovered = polar_quant_decode(packed, norm, 128)
|
||||
|
||||
# Norm of recovered should be similar to original
|
||||
norm_orig = np.linalg.norm(x)
|
||||
norm_rec = np.linalg.norm(recovered)
|
||||
rel_error = abs(norm_orig - norm_rec) / norm_orig
|
||||
assert rel_error < 0.5, f"Magnitude not preserved: {rel_error:.3f}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Inner Product Preservation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestInnerProductPreservation:
|
||||
"""Q·K ≈ Q·dequant(quant(K))"""
|
||||
|
||||
def test_inner_product_preserved(self):
|
||||
np.random.seed(42)
|
||||
q = np.random.randn(128).astype(np.float32)
|
||||
k = np.random.randn(128).astype(np.float32)
|
||||
|
||||
# Original inner product
|
||||
ip_original = inner_product(q, k)
|
||||
|
||||
# Compress k, then compute inner product
|
||||
packed, norm = polar_quant_encode(k)
|
||||
k_recovered = polar_quant_decode(packed, norm, 128)
|
||||
ip_compressed = inner_product(q, k_recovered)
|
||||
|
||||
# Should be within reasonable error (4-bit is lossy)
|
||||
rel_error = abs(ip_original - ip_compressed) / (abs(ip_original) + 1e-9)
|
||||
assert rel_error < 0.5, f"Inner product error too large: {rel_error:.3f}"
|
||||
|
||||
def test_inner_product_same_direction(self):
|
||||
"""Vectors in same direction should have high inner product."""
|
||||
np.random.seed(42)
|
||||
q = np.random.randn(128).astype(np.float32)
|
||||
k = q * 1.1 # Same direction, slightly different magnitude
|
||||
|
||||
ip_original = inner_product(q, k)
|
||||
packed, norm = polar_quant_encode(k)
|
||||
k_recovered = polar_quant_decode(packed, norm, 128)
|
||||
ip_compressed = inner_product(q, k_recovered)
|
||||
|
||||
# Both should be positive and similar
|
||||
assert ip_original > 0
|
||||
assert ip_compressed > 0
|
||||
assert abs(ip_original - ip_compressed) / abs(ip_original) < 0.3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: WHT Orthogonality
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestWHTOrthogonality:
|
||||
"""WHT^T · WHT = I"""
|
||||
|
||||
def test_wht_orthogonal_d64(self):
|
||||
n = 64
|
||||
# Create identity matrix columns
|
||||
W = np.zeros((n, n), dtype=np.float32)
|
||||
for i in range(n):
|
||||
col = np.zeros(n, dtype=np.float32)
|
||||
col[i] = 1.0
|
||||
W[:, i] = fwht(col)
|
||||
|
||||
# W^T @ W should be identity
|
||||
product = W.T @ W
|
||||
identity = np.eye(n, dtype=np.float32)
|
||||
error = np.max(np.abs(product - identity))
|
||||
assert error < 1e-5, f"WHT not orthogonal: max error {error}"
|
||||
|
||||
def test_wht_orthogonal_d128(self):
|
||||
n = 128
|
||||
W = np.zeros((n, n), dtype=np.float32)
|
||||
for i in range(n):
|
||||
col = np.zeros(n, dtype=np.float32)
|
||||
col[i] = 1.0
|
||||
W[:, i] = fwht(col)
|
||||
|
||||
product = W.T @ W
|
||||
identity = np.eye(n, dtype=np.float32)
|
||||
error = np.max(np.abs(product - identity))
|
||||
assert error < 1e-5, f"WHT not orthogonal: max error {error}"
|
||||
|
||||
def test_wht_preserves_energy(self):
|
||||
"""||WHT(x)|| = ||x|| (energy preservation)."""
|
||||
np.random.seed(42)
|
||||
x = np.random.randn(128).astype(np.float32)
|
||||
energy_before = np.sum(x ** 2)
|
||||
|
||||
y = fwht(x)
|
||||
energy_after = np.sum(y ** 2)
|
||||
|
||||
rel_error = abs(energy_before - energy_after) / energy_before
|
||||
assert rel_error < 1e-5, f"Energy not preserved: {rel_error}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Codebook Correctness
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCodebookCorrectness:
|
||||
"""Centroids match Lloyd-Max for N(0, 1/128)"""
|
||||
|
||||
def test_centroids_symmetric(self):
|
||||
"""Centroids should be approximately symmetric around zero."""
|
||||
# The codebook has 16 levels, roughly symmetric
|
||||
# Check that the distribution is balanced around zero
|
||||
pos_count = np.sum(TURBO4_CENTROIDS > 0)
|
||||
neg_count = np.sum(TURBO4_CENTROIDS < 0)
|
||||
assert abs(pos_count - neg_count) <= 2, "Centroids not balanced"
|
||||
|
||||
def test_centroids_ordered(self):
|
||||
"""Centroids should be in ascending order."""
|
||||
for i in range(15):
|
||||
assert TURBO4_CENTROIDS[i] < TURBO4_CENTROIDS[i + 1], \
|
||||
f"Centroids not ordered at {i}"
|
||||
|
||||
def test_centroids_coverage(self):
|
||||
"""Centroids should cover the range [-0.35, 0.35]."""
|
||||
assert TURBO4_CENTROIDS[0] < -0.2
|
||||
assert TURBO4_CENTROIDS[-1] > 0.3
|
||||
|
||||
def test_centroids_16_levels(self):
|
||||
"""Should have exactly 16 centroids for 4-bit quantization."""
|
||||
assert len(TURBO4_CENTROIDS) == 16
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Memory Bounds
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMemoryBounds:
|
||||
"""No buffer overflows in bit packing."""
|
||||
|
||||
def test_packed_size_d64(self):
|
||||
x = np.random.randn(64).astype(np.float32)
|
||||
packed, _ = polar_quant_encode(x)
|
||||
assert len(packed) == 32, f"Wrong packed size: {len(packed)}"
|
||||
|
||||
def test_packed_size_d128(self):
|
||||
x = np.random.randn(128).astype(np.float32)
|
||||
packed, _ = polar_quant_encode(x)
|
||||
assert len(packed) == 64, f"Wrong packed size: {len(packed)}"
|
||||
|
||||
def test_packed_size_d256(self):
|
||||
x = np.random.randn(256).astype(np.float32)
|
||||
packed, _ = polar_quant_encode(x)
|
||||
assert len(packed) == 128, f"Wrong packed size: {len(packed)}"
|
||||
|
||||
def test_packed_values_in_range(self):
|
||||
"""Packed bytes should only use 4 bits per nibble."""
|
||||
x = np.random.randn(128).astype(np.float32)
|
||||
packed, _ = polar_quant_encode(x)
|
||||
|
||||
# Each byte contains two 4-bit indices (0-15)
|
||||
for byte in packed:
|
||||
low = byte & 0x0F
|
||||
high = byte >> 4
|
||||
assert low < 16, f"Low nibble out of range: {low}"
|
||||
assert high < 16, f"High nibble out of range: {high}"
|
||||
|
||||
def test_unpack_matches_pack(self):
|
||||
"""Verify pack/unpack symmetry."""
|
||||
x = np.random.randn(128).astype(np.float32)
|
||||
packed, norm = polar_quant_encode(x)
|
||||
|
||||
# Manually unpack and compare indices
|
||||
indices_encode = np.zeros(128, dtype=np.int32)
|
||||
for i in range(128):
|
||||
if i % 2 == 0:
|
||||
indices_encode[i] = packed[i // 2] & 0x0F
|
||||
else:
|
||||
indices_encode[i] = packed[i // 2] >> 4
|
||||
|
||||
# Decode and re-encode to get indices from decode path
|
||||
decoded = polar_quant_decode(packed, norm, 128)
|
||||
rotated = fwht(decoded)
|
||||
normalized = rotated / (np.linalg.norm(rotated) + 1e-9)
|
||||
indices_decode = np.zeros(128, dtype=np.int32)
|
||||
for i in range(128):
|
||||
indices_decode[i] = np.argmin(np.abs(normalized[i] - TURBO4_CENTROIDS))
|
||||
|
||||
# Indices should match
|
||||
assert np.all(indices_encode == indices_decode), "Pack/unpack mismatch"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Compression Ratio
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCompressionRatio:
|
||||
"""Verify the compression achieves expected ratio."""
|
||||
|
||||
def test_4bit_compression_ratio(self):
|
||||
"""4-bit quantization should give 8x compression vs float32."""
|
||||
x = np.random.randn(128).astype(np.float32)
|
||||
original_bytes = x.nbytes # 128 * 4 = 512 bytes
|
||||
|
||||
packed, norm = polar_quant_encode(x)
|
||||
compressed_bytes = packed.nbytes + 4 # packed + norm (float32)
|
||||
|
||||
ratio = original_bytes / compressed_bytes
|
||||
assert ratio > 7.5, f"Compression ratio too low: {ratio}"
|
||||
assert ratio < 8.5, f"Compression ratio too high: {ratio}"
|
||||
Reference in New Issue
Block a user