150 lines
4.5 KiB
C++
150 lines
4.5 KiB
C++
|
|
#include "llama-turbo.h"
|
||
|
|
#include <iostream>
|
||
|
|
#include <vector>
|
||
|
|
#include <cmath>
|
||
|
|
#include <cassert>
|
||
|
|
|
||
|
|
// Simple test for encode/decode round-trip
|
||
|
|
void test_roundtrip() {
|
||
|
|
const int d = 128;
|
||
|
|
std::vector<float> original(d);
|
||
|
|
std::vector<float> decoded(d);
|
||
|
|
std::vector<uint8_t> packed(d / 2);
|
||
|
|
float norm;
|
||
|
|
|
||
|
|
// Generate random test data
|
||
|
|
for (int i = 0; i < d; i++) {
|
||
|
|
original[i] = (float)rand() / RAND_MAX * 2.0f - 1.0f;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Encode
|
||
|
|
polar_quant_encode_turbo4(original.data(), packed.data(), &norm, d);
|
||
|
|
|
||
|
|
// Decode
|
||
|
|
polar_quant_decode_turbo4(packed.data(), decoded.data(), norm, d);
|
||
|
|
|
||
|
|
// Check round-trip error
|
||
|
|
float max_error = 0.0f;
|
||
|
|
float avg_error = 0.0f;
|
||
|
|
for (int i = 0; i < d; i++) {
|
||
|
|
float error = fabsf(original[i] - decoded[i]);
|
||
|
|
max_error = fmaxf(max_error, error);
|
||
|
|
avg_error += error;
|
||
|
|
}
|
||
|
|
avg_error /= d;
|
||
|
|
|
||
|
|
std::cout << "Round-trip test:" << std::endl;
|
||
|
|
std::cout << " Max error: " << max_error << std::endl;
|
||
|
|
std::cout << " Avg error: " << avg_error << std::endl;
|
||
|
|
std::cout << " Norm: " << norm << std::endl;
|
||
|
|
|
||
|
|
// Check that error is reasonable (should be small due to quantization)
|
||
|
|
assert(max_error < 1.0f && "Round-trip error too large");
|
||
|
|
assert(avg_error < 0.5f && "Average error too large");
|
||
|
|
}
|
||
|
|
|
||
|
|
// Test with known values
|
||
|
|
void test_known_values() {
|
||
|
|
const int d = 128;
|
||
|
|
std::vector<float> zeros(d, 0.0f);
|
||
|
|
std::vector<float> ones(d, 1.0f);
|
||
|
|
std::vector<float> decoded(d);
|
||
|
|
std::vector<uint8_t> packed(d / 2);
|
||
|
|
float norm;
|
||
|
|
|
||
|
|
// Test zeros
|
||
|
|
polar_quant_encode_turbo4(zeros.data(), packed.data(), &norm, d);
|
||
|
|
polar_quant_decode_turbo4(packed.data(), decoded.data(), norm, d);
|
||
|
|
|
||
|
|
std::cout << "Zero test:" << std::endl;
|
||
|
|
std::cout << " Norm: " << norm << std::endl;
|
||
|
|
|
||
|
|
// Test ones
|
||
|
|
polar_quant_encode_turbo4(ones.data(), packed.data(), &norm, d);
|
||
|
|
polar_quant_decode_turbo4(packed.data(), decoded.data(), norm, d);
|
||
|
|
|
||
|
|
std::cout << "Ones test:" << std::endl;
|
||
|
|
std::cout << " Norm: " << norm << std::endl;
|
||
|
|
|
||
|
|
// Check that decoded values are approximately 1.0
|
||
|
|
float avg = 0.0f;
|
||
|
|
for (int i = 0; i < d; i++) {
|
||
|
|
avg += decoded[i];
|
||
|
|
}
|
||
|
|
avg /= d;
|
||
|
|
|
||
|
|
std::cout << " Average decoded value: " << avg << std::endl;
|
||
|
|
assert(fabsf(avg - 1.0f) < 0.5f && "Decoded average should be close to 1.0");
|
||
|
|
}
|
||
|
|
|
||
|
|
// Test edge cases
|
||
|
|
void test_edge_cases() {
|
||
|
|
const int d = 128;
|
||
|
|
std::vector<float> large(d);
|
||
|
|
std::vector<float> small(d);
|
||
|
|
std::vector<float> decoded(d);
|
||
|
|
std::vector<uint8_t> packed(d / 2);
|
||
|
|
float norm;
|
||
|
|
|
||
|
|
// Test large values
|
||
|
|
for (int i = 0; i < d; i++) {
|
||
|
|
large[i] = 1000.0f;
|
||
|
|
}
|
||
|
|
|
||
|
|
polar_quant_encode_turbo4(large.data(), packed.data(), &norm, d);
|
||
|
|
polar_quant_decode_turbo4(packed.data(), decoded.data(), norm, d);
|
||
|
|
|
||
|
|
std::cout << "Large values test:" << std::endl;
|
||
|
|
std::cout << " Norm: " << norm << std::endl;
|
||
|
|
|
||
|
|
// Test small values
|
||
|
|
for (int i = 0; i < d; i++) {
|
||
|
|
small[i] = 0.001f;
|
||
|
|
}
|
||
|
|
|
||
|
|
polar_quant_encode_turbo4(small.data(), packed.data(), &norm, d);
|
||
|
|
polar_quant_decode_turbo4(packed.data(), decoded.data(), norm, d);
|
||
|
|
|
||
|
|
std::cout << "Small values test:" << std::endl;
|
||
|
|
std::cout << " Norm: " << norm << std::endl;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Test error handling
|
||
|
|
void test_error_handling() {
|
||
|
|
const int d = 128;
|
||
|
|
std::vector<float> data(d, 1.0f);
|
||
|
|
std::vector<uint8_t> packed(d / 2);
|
||
|
|
std::vector<float> decoded(d);
|
||
|
|
float norm;
|
||
|
|
|
||
|
|
// Test with null pointers (should assert in debug mode)
|
||
|
|
std::cout << "Error handling tests:" << std::endl;
|
||
|
|
std::cout << " Note: These should trigger assertions in debug mode" << std::endl;
|
||
|
|
|
||
|
|
// Uncomment to test assertions:
|
||
|
|
// polar_quant_encode_turbo4(nullptr, packed.data(), &norm, d);
|
||
|
|
// polar_quant_encode_turbo4(data.data(), nullptr, &norm, d);
|
||
|
|
// polar_quant_encode_turbo4(data.data(), packed.data(), nullptr, d);
|
||
|
|
|
||
|
|
// Test with invalid d (not power of 2)
|
||
|
|
// polar_quant_encode_turbo4(data.data(), packed.data(), &norm, 127);
|
||
|
|
}
|
||
|
|
|
||
|
|
int main() {
|
||
|
|
std::cout << "TurboQuant Unit Tests" << std::endl;
|
||
|
|
std::cout << "====================" << std::endl;
|
||
|
|
|
||
|
|
try {
|
||
|
|
test_roundtrip();
|
||
|
|
test_known_values();
|
||
|
|
test_edge_cases();
|
||
|
|
test_error_handling();
|
||
|
|
|
||
|
|
std::cout << std::endl;
|
||
|
|
std::cout << "All tests passed!" << std::endl;
|
||
|
|
return 0;
|
||
|
|
} catch (const std::exception& e) {
|
||
|
|
std::cerr << "Test failed: " << e.what() << std::endl;
|
||
|
|
return 1;
|
||
|
|
}
|
||
|
|
}
|