Compare commits

...

2 Commits

Author SHA1 Message Date
db225dab48 test: Add C unit tests for PolarQuant (#54)
All checks were successful
Smoke Test / smoke (pull_request) Successful in 16s
2026-04-15 02:15:05 +00:00
ccb0419997 test: Add Python unit tests for PolarQuant (#54) 2026-04-15 02:15:02 +00:00
2 changed files with 673 additions and 0 deletions

263
tests/test_polar_quant.c Normal file
View File

@@ -0,0 +1,263 @@
/*
* Unit tests for PolarQuant Turbo4
*
* Compile: gcc -o test_polar_quant test_polar_quant.c llama-turbo.cpp -lm
* Run: ./test_polar_quant
*/
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>
#include <assert.h>
#include "llama-turbo.h"
#define TEST_ASSERT(cond, msg) do { if (!(cond)) { fprintf(stderr, "FAIL: %s (line %d)\n", msg, __LINE__); failures++; } else { passes++; } } while(0)
static int passes = 0;
static int failures = 0;
// Test encode/decode roundtrip
void test_roundtrip() {
printf("Testing encode/decode roundtrip...\n");
const int d = 128;
float src[128];
float dst[128];
uint8_t packed[64];
float norm;
// Generate test data
for (int i = 0; i < d; i++) {
src[i] = sinf(i * 0.1f);
}
// Encode
polar_quant_encode_turbo4(src, packed, &norm, d);
// Decode
polar_quant_decode_turbo4(packed, dst, norm, d);
// Check reconstruction error
float orig_norm = 0;
float diff_norm = 0;
for (int i = 0; i < d; i++) {
orig_norm += src[i] * src[i];
float diff = src[i] - dst[i];
diff_norm += diff * diff;
}
orig_norm = sqrtf(orig_norm);
diff_norm = sqrtf(diff_norm);
float rel_error = diff_norm / (orig_norm + 1e-9f);
TEST_ASSERT(rel_error < 0.5f, "Roundtrip relative error too high");
// Check packed size
TEST_ASSERT(norm > 0, "Norm should be positive");
}
// Test zero vector
void test_zero_vector() {
printf("Testing zero vector...\n");
const int d = 128;
float src[128] = {0};
float dst[128];
uint8_t packed[64];
float norm;
polar_quant_encode_turbo4(src, packed, &norm, d);
polar_quant_decode_turbo4(packed, dst, norm, d);
// Zero vector: norm should be 0 or very small
TEST_ASSERT(norm < 0.1f, "Zero vector norm should be small");
}
// Test inner product preservation
void test_inner_product() {
printf("Testing inner product preservation...\n");
const int d = 128;
float q[128], k[128], k_recon[128];
uint8_t k_packed[64];
float k_norm;
// Generate test vectors
for (int i = 0; i < d; i++) {
q[i] = cosf(i * 0.1f);
k[i] = sinf(i * 0.15f);
}
// Original inner product
float orig_ip = 0;
for (int i = 0; i < d; i++) {
orig_ip += q[i] * k[i];
}
// Compress k
polar_quant_encode_turbo4(k, k_packed, &k_norm, d);
polar_quant_decode_turbo4(k_packed, k_recon, k_norm, d);
// Compressed inner product
float comp_ip = 0;
for (int i = 0; i < d; i++) {
comp_ip += q[i] * k_recon[i];
}
float rel_error = fabsf(orig_ip - comp_ip) / (fabsf(orig_ip) + 1e-9f);
TEST_ASSERT(rel_error < 0.5f, "Inner product preservation");
}
// Test WHT orthogonality
void test_wht_orthogonality() {
printf("Testing WHT orthogonality...\n");
const int d = 64;
float src[64], result[64];
for (int i = 0; i < d; i++) {
src[i] = (float)i;
result[i] = src[i];
}
// Compute norm before
float norm_before = 0;
for (int i = 0; i < d; i++) {
norm_before += src[i] * src[i];
}
norm_before = sqrtf(norm_before);
// Apply encode (which includes WHT)
uint8_t packed[32];
float enc_norm;
polar_quant_encode_turbo4(result, packed, &enc_norm, d);
// Decode (which includes inverse WHT)
float decoded[64];
polar_quant_decode_turbo4(packed, decoded, enc_norm, d);
// Compute norm after
float norm_after = 0;
for (int i = 0; i < d; i++) {
norm_after += decoded[i] * decoded[i];
}
norm_after = sqrtf(norm_after);
// Norms should be similar (within quantization error)
float ratio = norm_after / (norm_before + 1e-9f);
TEST_ASSERT(ratio > 0.5f && ratio < 2.0f, "Norm preservation through WHT");
}
// Test bit packing
void test_bit_packing() {
printf("Testing bit packing...\n");
const int d = 128;
uint8_t packed[64] = {0};
// Pack alternating 0 and 15 (max value)
for (int i = 0; i < d; i++) {
int idx = (i % 2 == 0) ? 0 : 15;
if (i % 2 == 0) {
packed[i / 2] = idx;
} else {
packed[i / 2] |= idx << 4;
}
}
// Unpack and verify
for (int i = 0; i < d; i++) {
int expected = (i % 2 == 0) ? 0 : 15;
int actual;
if (i % 2 == 0) {
actual = packed[i / 2] & 0x0F;
} else {
actual = packed[i / 2] >> 4;
}
char msg[64];
sprintf(msg, "Bit packing at index %d", i);
TEST_ASSERT(actual == expected, msg);
}
}
// Test various dimensions
void test_dimensions() {
printf("Testing various dimensions...\n");
int dims[] = {16, 32, 64, 128, 256};
int num_dims = sizeof(dims) / sizeof(dims[0]);
for (int d_idx = 0; d_idx < num_dims; d_idx++) {
int d = dims[d_idx];
float* src = malloc(d * sizeof(float));
float* dst = malloc(d * sizeof(float));
uint8_t* packed = malloc(d / 2);
float norm;
// Generate test data
for (int i = 0; i < d; i++) {
src[i] = sinf(i * 0.1f);
}
// Encode/decode
polar_quant_encode_turbo4(src, packed, &norm, d);
polar_quant_decode_turbo4(packed, dst, norm, d);
// Check basic sanity
float orig_energy = 0, recon_energy = 0;
for (int i = 0; i < d; i++) {
orig_energy += src[i] * src[i];
recon_energy += dst[i] * dst[i];
}
float ratio = recon_energy / (orig_energy + 1e-9f);
char msg[64];
sprintf(msg, "Dimension %d energy ratio", d);
TEST_ASSERT(ratio > 0.5f && ratio < 2.0f, msg);
free(src);
free(dst);
free(packed);
}
}
// Test memory bounds
void test_memory_bounds() {
printf("Testing memory bounds...\n");
// Test with max 4-bit value everywhere
const int d = 256;
float src[256];
for (int i = 0; i < d; i++) {
src[i] = 0.35f; // Near max centroid
}
uint8_t packed[128];
float norm;
// Should not crash
polar_quant_encode_turbo4(src, packed, &norm, d);
TEST_ASSERT(1, "Memory bounds check passed");
}
int main() {
printf("=== PolarQuant Turbo4 Unit Tests ===\n\n");
test_roundtrip();
test_zero_vector();
test_inner_product();
test_wht_orthogonality();
test_bit_packing();
test_dimensions();
test_memory_bounds();
printf("\n=== Results ===\n");
printf("Passed: %d\n", passes);
printf("Failed: %d\n", failures);
return failures > 0 ? 1 : 0;
}

410
tests/test_polar_quant.py Normal file
View File

@@ -0,0 +1,410 @@
"""
Unit tests for PolarQuant Turbo4 encode/decode.
Tests the algorithm logic using Python reference implementations
that mirror the C++/Metal code.
"""
import math
import pytest
import struct
from typing import List, Tuple
# Lloyd-Max Centroids for N(0, 1/d) where d=128
# 4-bit (16 levels) - copied from llama-turbo.cpp
TURBO4_CENTROIDS = [
-0.2154, -0.1523, -0.1121, -0.0812,
-0.0554, -0.0321, -0.0105, 0.0105,
0.0321, 0.0554, 0.0812, 0.1121,
0.1523, 0.2154, 0.2800, 0.3500
]
def fwht(a: List[float]) -> List[float]:
"""Fast Walsh-Hadamard Transform (Python reference)."""
n = len(a)
result = a.copy()
h = 1
while h < n:
for i in range(0, n, h * 2):
for j in range(i, i + h):
x = result[j]
y = result[j + h]
result[j] = x + y
result[j + h] = x - y
h <<= 1
# Normalize
scale = 1.0 / math.sqrt(n)
for i in range(n):
result[i] *= scale
return result
def polar_quant_encode(src: List[float]) -> Tuple[bytes, float]:
"""
PolarQuant Turbo4 Encode (Python reference).
Returns:
Tuple of (packed_bytes, norm)
"""
d = len(src)
assert d % 2 == 0, "Dimension must be even"
# Apply WHT
rotated = fwht(src)
# Calculate L2 norm
norm = math.sqrt(sum(x * x for x in rotated))
# Quantize components
inv_norm = 1.0 / (norm + 1e-9)
indices = []
for val in rotated:
val_normalized = val * inv_norm
# Find nearest centroid
best_idx = 0
min_dist = abs(val_normalized - TURBO4_CENTROIDS[0])
for j in range(1, 16):
dist = abs(val_normalized - TURBO4_CENTROIDS[j])
if dist < min_dist:
min_dist = dist
best_idx = j
indices.append(best_idx)
# Pack 4-bit indices into bytes
packed = bytearray(d // 2)
for i in range(d):
if i % 2 == 0:
packed[i // 2] = indices[i]
else:
packed[i // 2] |= indices[i] << 4
return bytes(packed), norm
def polar_quant_decode(src: bytes, norm: float, d: int) -> List[float]:
"""
PolarQuant Turbo4 Decode (Python reference).
Returns:
Reconstructed float array
"""
# Unpack 4-bit indices
values = []
for i in range(d):
if i % 2 == 0:
idx = src[i // 2] & 0x0F
else:
idx = src[i // 2] >> 4
values.append(TURBO4_CENTROIDS[idx] * norm)
# Apply inverse WHT (same as forward for orthogonal)
return fwht(values)
class TestEncodeDecodeRoundtrip:
"""Test that decode(encode(x)) ≈ x."""
def test_zero_vector(self):
"""Zero vector should encode/decode to zero."""
d = 128
src = [0.0] * d
packed, norm = polar_quant_encode(src)
reconstructed = polar_quant_decode(packed, norm, d)
# Zero has no information, reconstruction will be near-zero
for i in range(d):
assert abs(reconstructed[i]) < 0.1, f"Index {i}: {reconstructed[i]}"
def test_unit_vector(self):
"""Unit vector should roundtrip reasonably."""
d = 128
src = [0.0] * d
src[0] = 1.0 # Unit vector
packed, norm = polar_quant_encode(src)
reconstructed = polar_quant_decode(packed, norm, d)
# Check shape is preserved (first element dominant)
max_val = max(reconstructed)
max_idx = reconstructed.index(max_val)
assert max_idx == 0, f"Peak at index {max_idx}, expected 0"
def test_random_vectors(self):
"""Random vectors should roundtrip with bounded error."""
import random
random.seed(42)
d = 128
errors = []
for trial in range(10):
src = [random.gauss(0, 0.1) for _ in range(d)]
packed, norm = polar_quant_encode(src)
reconstructed = polar_quant_decode(packed, norm, d)
# Compute relative error
orig_norm = math.sqrt(sum(x * x for x in src))
diff_norm = math.sqrt(sum((a - b) ** 2 for a, b in zip(src, reconstructed)))
rel_error = diff_norm / (orig_norm + 1e-9)
errors.append(rel_error)
avg_error = sum(errors) / len(errors)
assert avg_error < 0.5, f"Average relative error {avg_error} too high"
def test_various_dimensions(self):
"""Test with different power-of-2 dimensions."""
for d in [16, 32, 64, 128, 256]:
src = [math.sin(i * 0.1) for i in range(d)]
packed, norm = polar_quant_encode(src)
reconstructed = polar_quant_decode(packed, norm, d)
# Basic sanity: reconstructed should have similar magnitude
# 4-bit quantization loses significant energy, especially at small dims
orig_energy = sum(x * x for x in src)
recon_energy = sum(x * x for x in reconstructed)
ratio = recon_energy / (orig_energy + 1e-9)
assert 0.1 < ratio < 10.0, f"d={d}: energy ratio {ratio}"
class TestInnerProductPreservation:
"""Test that Q·K ≈ Q·dequant(quant(K))."""
def test_inner_product_preserved(self):
"""Inner products should be approximately preserved."""
import random
random.seed(123)
d = 128
# Generate two random vectors
q = [random.gauss(0, 0.1) for _ in range(d)]
k = [random.gauss(0, 0.1) for _ in range(d)]
# Original inner product
orig_ip = sum(a * b for a, b in zip(q, k))
# Compress k
k_packed, k_norm = polar_quant_encode(k)
k_reconstructed = polar_quant_decode(k_packed, k_norm, d)
# Compressed inner product
comp_ip = sum(a * b for a, b in zip(q, k_reconstructed))
# Check relative error
rel_error = abs(orig_ip - comp_ip) / (abs(orig_ip) + 1e-9)
# 4-bit quantization has significant error, allow up to 100% error
assert rel_error < 1.0, f"Inner product error {rel_error} too high"
def test_self_inner_product(self):
"""Self inner product should be close to original."""
d = 128
x = [math.cos(i * 0.2) for i in range(d)]
orig_self_ip = sum(a * a for a in x)
packed, norm = polar_quant_encode(x)
reconstructed = polar_quant_decode(packed, norm, d)
comp_self_ip = sum(a * a for a in reconstructed)
# Self inner product is energy, should be roughly preserved
# 4-bit quantization has significant error
ratio = comp_self_ip / (orig_self_ip + 1e-9)
assert 0.3 < ratio < 3.0, f"Self inner product ratio {ratio}"
class TestWHTOrthogonality:
"""Test that WHT is orthogonal (WHT^T · WHT = I)."""
def test_wht_orthogonality(self):
"""WHT should be orthogonal transformation."""
d = 128
# Create identity-like test: apply WHT, then apply again
# For orthogonal matrix, A^T A = I, so applying twice should scale
src = [float(i) for i in range(d)]
# First WHT
result1 = fwht(src)
# Second WHT (should be proportional to original for orthogonal)
result2 = fwht(result1)
# result2 should be proportional to src
# For Walsh-Hadamard, WHT(WHT(x)) = x * (1/sqrt(d))^2 * d = x
# Actually: WHT is self-inverse up to scaling
for i in range(d):
ratio = result2[i] / (src[i] + 1e-9) if src[i] != 0 else result2[i]
# Should be close to 1.0 (or 0 if src[i] is 0)
if abs(src[i]) > 0.01:
assert abs(ratio - 1.0) < 0.1, f"Index {i}: ratio {ratio}"
def test_wht_preserves_norm(self):
"""WHT should preserve L2 norm."""
d = 128
src = [math.sin(i) for i in range(d)]
orig_norm = math.sqrt(sum(x * x for x in src))
result = fwht(src)
result_norm = math.sqrt(sum(x * x for x in result))
ratio = result_norm / orig_norm
assert abs(ratio - 1.0) < 0.01, f"Norm ratio {ratio}, expected 1.0"
def test_wht_linearity(self):
"""WHT should be linear: WHT(a+b) = WHT(a) + WHT(b)."""
d = 64
a = [float(i) * 0.1 for i in range(d)]
b = [float(i) * 0.2 for i in range(d)]
# WHT(a + b)
a_plus_b = [x + y for x, y in zip(a, b)]
wht_sum = fwht(a_plus_b)
# WHT(a) + WHT(b)
wht_a = fwht(a)
wht_b = fwht(b)
sum_wht = [x + y for x, y in zip(wht_a, wht_b)]
# Should be equal
for i in range(d):
assert abs(wht_sum[i] - sum_wht[i]) < 1e-6, f"Linearity failed at {i}"
class TestCodebookCorrectness:
"""Test that centroids match Lloyd-Max for N(0, 1/128)."""
def test_centroids_extremes(self):
"""Extreme centroids should cover tails of distribution."""
min_c = min(TURBO4_CENTROIDS)
max_c = max(TURBO4_CENTROIDS)
# Should have reasonable range
assert min_c < -0.2, f"Min centroid {min_c} should be < -0.2"
assert max_c > 0.2, f"Max centroid {max_c} should be > 0.2"
def test_centroids_ordered(self):
"""Centroids should be strictly increasing."""
for i in range(len(TURBO4_CENTROIDS) - 1):
assert TURBO4_CENTROIDS[i] < TURBO4_CENTROIDS[i + 1], f"Centroids not ordered at index {i}"
def test_centroids_cover_range(self):
"""Centroids should cover reasonable range for N(0, 1/128)."""
# For N(0, 1/128), std = 1/sqrt(128) ≈ 0.088
# Centroids should cover roughly [-3*std, 3*std]
min_c = min(TURBO4_CENTROIDS)
max_c = max(TURBO4_CENTROIDS)
std = 1.0 / math.sqrt(128) # ≈ 0.088
assert min_c < -2 * std, f"Min centroid {min_c} should be < {-2*std}"
assert max_c > 2 * std, f"Max centroid {max_c} should be > {2*std}"
def test_centroids_count(self):
"""Should have exactly 16 centroids for 4-bit quantization."""
assert len(TURBO4_CENTROIDS) == 16, f"Expected 16 centroids, got {len(TURBO4_CENTROIDS)}"
class TestBitPacking:
"""Test bit packing/unpacking correctness."""
def test_packing_roundtrip(self):
"""Packing and unpacking should be lossless for 4-bit values."""
d = 128
# Create test indices (0-15)
indices = [i % 16 for i in range(d)]
# Pack
packed = bytearray(d // 2)
for i in range(d):
if i % 2 == 0:
packed[i // 2] = indices[i]
else:
packed[i // 2] |= indices[i] << 4
# Unpack
unpacked = []
for i in range(d):
if i % 2 == 0:
idx = packed[i // 2] & 0x0F
else:
idx = packed[i // 2] >> 4
unpacked.append(idx)
assert unpacked == indices, "Packing/unpacking mismatch"
def test_packing_bounds(self):
"""Packed values should fit in 4 bits (0-15)."""
d = 128
indices = [15] * d # Max value
packed = bytearray(d // 2)
for i in range(d):
if i % 2 == 0:
packed[i // 2] = indices[i]
else:
packed[i // 2] |= indices[i] << 4
# Each byte should have both nibbles = 15
for byte in packed:
assert byte == 0xFF, f"Expected 0xFF, got {hex(byte)}"
def test_no_overflow(self):
"""Packing should not overflow with valid 4-bit values."""
d = 256 # Larger dimension
# All max values
indices = [15] * d
packed = bytearray(d // 2)
for i in range(d):
if i % 2 == 0:
packed[i // 2] = indices[i]
else:
packed[i // 2] |= indices[i] << 4
# Should not crash or produce invalid values
assert len(packed) == d // 2
class TestMemoryBounds:
"""Test memory safety with various dimensions."""
def test_minimum_dimension(self):
"""Should work with minimum dimension (2)."""
d = 2
src = [1.0, 0.5]
packed, norm = polar_quant_encode(src)
assert len(packed) == d // 2
reconstructed = polar_quant_decode(packed, norm, d)
assert len(reconstructed) == d
def test_large_dimension(self):
"""Should work with large dimensions."""
d = 1024
src = [math.sin(i * 0.01) for i in range(d)]
packed, norm = polar_quant_encode(src)
assert len(packed) == d // 2
reconstructed = polar_quant_decode(packed, norm, d)
assert len(reconstructed) == d
def test_odd_dimension_fails(self):
"""Odd dimensions should fail (need even for 4-bit packing)."""
d = 127 # Odd
src = [0.0] * d
with pytest.raises(AssertionError):
polar_quant_encode(src)
if __name__ == "__main__":
pytest.main([__file__, "-v"])