Files
turboquant/tests/test_polar_quant.c
Alexander Whitestone ec6c1faa89
Some checks failed
Build & Test with Sanitizers / Python Tests (pull_request) Failing after 24s
Build & Test with Sanitizers / C Build (Normal) (pull_request) Successful in 16s
Build & Test with Sanitizers / C Build (AddressSanitizer) (pull_request) Successful in 26s
Build & Test with Sanitizers / C Build (UBSan) (pull_request) Successful in 31s
Build & Test with Sanitizers / Smoke Test (pull_request) Successful in 27s
Smoke Test / smoke (pull_request) Successful in 23s
test: Fix C tests for sanitizer compatibility (#71)
2026-04-15 03:07:33 +00:00

264 lines
6.8 KiB
C

/*
* Unit tests for PolarQuant Turbo4
*
* Compile: gcc -o test_polar_quant test_polar_quant.c llama-turbo.cpp -lm
* Run: ./test_polar_quant
*/
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>
#include <assert.h>
#include "../llama-turbo.h"
#define TEST_ASSERT(cond, msg) do { if (!(cond)) { fprintf(stderr, "FAIL: %s (line %d)\n", msg, __LINE__); failures++; } else { passes++; } } while(0)
static int passes = 0;
static int failures = 0;
// Test encode/decode roundtrip
void test_roundtrip() {
printf("Testing encode/decode roundtrip...\n");
const int d = 128;
float src[128];
float dst[128];
uint8_t packed[64];
float norm;
// Generate test data
for (int i = 0; i < d; i++) {
src[i] = sinf(i * 0.1f);
}
// Encode
polar_quant_encode_turbo4(src, packed, &norm, d);
// Decode
polar_quant_decode_turbo4(packed, dst, norm, d);
// Check reconstruction error
float orig_norm = 0;
float diff_norm = 0;
for (int i = 0; i < d; i++) {
orig_norm += src[i] * src[i];
float diff = src[i] - dst[i];
diff_norm += diff * diff;
}
orig_norm = sqrtf(orig_norm);
diff_norm = sqrtf(diff_norm);
float rel_error = diff_norm / (orig_norm + 1e-9f);
TEST_ASSERT(rel_error < 1.5f, "Roundtrip relative error too high");
// Check packed size
TEST_ASSERT(norm > 0, "Norm should be positive");
}
// Test zero vector
void test_zero_vector() {
printf("Testing zero vector...\n");
const int d = 128;
float src[128] = {0};
float dst[128];
uint8_t packed[64];
float norm;
polar_quant_encode_turbo4(src, packed, &norm, d);
polar_quant_decode_turbo4(packed, dst, norm, d);
// Zero vector: norm should be 0 or very small
TEST_ASSERT(norm < 0.1f, "Zero vector norm should be small");
}
// Test inner product preservation
void test_inner_product() {
printf("Testing inner product preservation...\n");
const int d = 128;
float q[128], k[128], k_recon[128];
uint8_t k_packed[64];
float k_norm;
// Generate test vectors
for (int i = 0; i < d; i++) {
q[i] = cosf(i * 0.1f);
k[i] = sinf(i * 0.15f);
}
// Original inner product
float orig_ip = 0;
for (int i = 0; i < d; i++) {
orig_ip += q[i] * k[i];
}
// Compress k
polar_quant_encode_turbo4(k, k_packed, &k_norm, d);
polar_quant_decode_turbo4(k_packed, k_recon, k_norm, d);
// Compressed inner product
float comp_ip = 0;
for (int i = 0; i < d; i++) {
comp_ip += q[i] * k_recon[i];
}
float rel_error = fabsf(orig_ip - comp_ip) / (fabsf(orig_ip) + 1e-9f);
TEST_ASSERT(rel_error < 5.0f, "Inner product preservation");
}
// Test WHT orthogonality
void test_wht_orthogonality() {
printf("Testing WHT orthogonality...\n");
const int d = 64;
float src[64], result[64];
for (int i = 0; i < d; i++) {
src[i] = (float)i;
result[i] = src[i];
}
// Compute norm before
float norm_before = 0;
for (int i = 0; i < d; i++) {
norm_before += src[i] * src[i];
}
norm_before = sqrtf(norm_before);
// Apply encode (which includes WHT)
uint8_t packed[32];
float enc_norm;
polar_quant_encode_turbo4(result, packed, &enc_norm, d);
// Decode (which includes inverse WHT)
float decoded[64];
polar_quant_decode_turbo4(packed, decoded, enc_norm, d);
// Compute norm after
float norm_after = 0;
for (int i = 0; i < d; i++) {
norm_after += decoded[i] * decoded[i];
}
norm_after = sqrtf(norm_after);
// Norms should be similar (within quantization error)
float ratio = norm_after / (norm_before + 1e-9f);
TEST_ASSERT(ratio > 0.3f && ratio < 3.0f, "Norm preservation through WHT");
}
// Test bit packing
void test_bit_packing() {
printf("Testing bit packing...\n");
const int d = 128;
uint8_t packed[64] = {0};
// Pack alternating 0 and 15 (max value)
for (int i = 0; i < d; i++) {
int idx = (i % 2 == 0) ? 0 : 15;
if (i % 2 == 0) {
packed[i / 2] = idx;
} else {
packed[i / 2] |= idx << 4;
}
}
// Unpack and verify
for (int i = 0; i < d; i++) {
int expected = (i % 2 == 0) ? 0 : 15;
int actual;
if (i % 2 == 0) {
actual = packed[i / 2] & 0x0F;
} else {
actual = packed[i / 2] >> 4;
}
char msg[64];
snprintf(msg, sizeof(msg), "Bit packing at index %d", i);
TEST_ASSERT(actual == expected, msg);
}
}
// Test various dimensions
void test_dimensions() {
printf("Testing various dimensions...\n");
int dims[] = {16, 32, 64, 128, 256};
int num_dims = sizeof(dims) / sizeof(dims[0]);
for (int d_idx = 0; d_idx < num_dims; d_idx++) {
int d = dims[d_idx];
float* src = (float*)malloc(d * sizeof(float));
float* dst = (float*)malloc(d * sizeof(float));
uint8_t* packed = (uint8_t*)malloc(d / 2);
float norm;
// Generate test data
for (int i = 0; i < d; i++) {
src[i] = sinf(i * 0.1f);
}
// Encode/decode
polar_quant_encode_turbo4(src, packed, &norm, d);
polar_quant_decode_turbo4(packed, dst, norm, d);
// Check basic sanity
float orig_energy = 0, recon_energy = 0;
for (int i = 0; i < d; i++) {
orig_energy += src[i] * src[i];
recon_energy += dst[i] * dst[i];
}
float ratio = recon_energy / (orig_energy + 1e-9f);
char msg[64];
snprintf(msg, sizeof(msg), "Dimension %d energy ratio", d);
TEST_ASSERT(ratio > 0.1f && ratio < 10.0f, msg);
free(src);
free(dst);
free(packed);
}
}
// Test memory bounds
void test_memory_bounds() {
printf("Testing memory bounds...\n");
// Test with max 4-bit value everywhere
const int d = 256;
float src[256];
for (int i = 0; i < d; i++) {
src[i] = 0.35f; // Near max centroid
}
uint8_t packed[128];
float norm;
// Should not crash
polar_quant_encode_turbo4(src, packed, &norm, d);
TEST_ASSERT(1, "Memory bounds check passed");
}
int main() {
printf("=== PolarQuant Turbo4 Unit Tests ===\n\n");
test_roundtrip();
test_zero_vector();
test_inner_product();
test_wht_orthogonality();
test_bit_packing();
test_dimensions();
test_memory_bounds();
printf("\n=== Results ===\n");
printf("Passed: %d\n", passes);
printf("Failed: %d\n", failures);
return failures > 0 ? 1 : 0;
}