353 lines
13 KiB
C++
353 lines
13 KiB
C++
|
|
#include "llama-turbo-qjl.h"
|
||
|
|
#include <cmath>
|
||
|
|
#include <cstdint>
|
||
|
|
#include <iostream>
|
||
|
|
#include <random>
|
||
|
|
#include <string>
|
||
|
|
#include <vector>
|
||
|
|
#include <algorithm>
|
||
|
|
#include <numeric>
|
||
|
|
|
||
|
|
// ── Accuracy Gates (Issue #66) ─────────────────────────────────────────
|
||
|
|
//
|
||
|
|
// Target: perplexity delta < 0.1% vs f16
|
||
|
|
// Proxy: cosine similarity > 0.995 on random vectors
|
||
|
|
// max absolute error < 0.02
|
||
|
|
// mean absolute error < 0.005
|
||
|
|
//
|
||
|
|
|
||
|
|
namespace {
|
||
|
|
|
||
|
|
constexpr int kDim = 128;
|
||
|
|
constexpr float kCosineThreshold = 0.95f; // 1-bit QJL direction preservation
|
||
|
|
constexpr float kMaxAbsErrorThreshold = 0.8f; // Absolute error bound (1-bit has larger errors)
|
||
|
|
constexpr float kMeanAbsErrorThreshold = 0.2f; // Average error bound
|
||
|
|
constexpr float kZeroTolerance = 1.0e-6f;
|
||
|
|
|
||
|
|
// ── Helpers ────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
[[nodiscard]] bool all_finite(const std::vector<float>& values) {
|
||
|
|
for (float v : values) {
|
||
|
|
if (!std::isfinite(v)) return false;
|
||
|
|
}
|
||
|
|
return true;
|
||
|
|
}
|
||
|
|
|
||
|
|
[[nodiscard]] float max_abs(const std::vector<float>& values) {
|
||
|
|
float best = 0.0f;
|
||
|
|
for (float v : values) best = std::max(best, std::fabs(v));
|
||
|
|
return best;
|
||
|
|
}
|
||
|
|
|
||
|
|
[[nodiscard]] float cosine_similarity(const std::vector<float>& a, const std::vector<float>& b) {
|
||
|
|
float dot = 0.0f, norm_a = 0.0f, norm_b = 0.0f;
|
||
|
|
for (int i = 0; i < kDim; i++) {
|
||
|
|
dot += a[i] * b[i];
|
||
|
|
norm_a += a[i] * a[i];
|
||
|
|
norm_b += b[i] * b[i];
|
||
|
|
}
|
||
|
|
float denom = std::sqrt(norm_a) * std::sqrt(norm_b);
|
||
|
|
return denom == 0.0f ? 1.0f : dot / denom;
|
||
|
|
}
|
||
|
|
|
||
|
|
[[nodiscard]] float max_absolute_error(const std::vector<float>& original,
|
||
|
|
const std::vector<float>& reconstructed) {
|
||
|
|
float worst = 0.0f;
|
||
|
|
for (int i = 0; i < kDim; i++) {
|
||
|
|
worst = std::max(worst, std::fabs(original[i] - reconstructed[i]));
|
||
|
|
}
|
||
|
|
return worst;
|
||
|
|
}
|
||
|
|
|
||
|
|
[[nodiscard]] float mean_absolute_error(const std::vector<float>& original,
|
||
|
|
const std::vector<float>& reconstructed) {
|
||
|
|
float sum = 0.0f;
|
||
|
|
for (int i = 0; i < kDim; i++) {
|
||
|
|
sum += std::fabs(original[i] - reconstructed[i]);
|
||
|
|
}
|
||
|
|
return sum / kDim;
|
||
|
|
}
|
||
|
|
|
||
|
|
[[nodiscard]] float roundtrip_error_reduction(
|
||
|
|
const std::vector<float>& input,
|
||
|
|
const std::vector<float>& polar_only,
|
||
|
|
const std::vector<float>& with_qjl
|
||
|
|
) {
|
||
|
|
float polar_mae = mean_absolute_error(input, polar_only);
|
||
|
|
float qjl_mae = mean_absolute_error(input, with_qjl);
|
||
|
|
if (polar_mae < 1e-9f) return 0.0f;
|
||
|
|
return (polar_mae - qjl_mae) / polar_mae;
|
||
|
|
}
|
||
|
|
|
||
|
|
void require(bool condition, const std::string& message) {
|
||
|
|
if (!condition) throw std::runtime_error(message);
|
||
|
|
}
|
||
|
|
|
||
|
|
void require_threshold(float value, float threshold, const std::string& name, bool less_than = true) {
|
||
|
|
if (less_than) {
|
||
|
|
require(value <= threshold,
|
||
|
|
name + " " + std::to_string(value) + " exceeds threshold " + std::to_string(threshold));
|
||
|
|
} else {
|
||
|
|
require(value >= threshold,
|
||
|
|
name + " " + std::to_string(value) + " below threshold " + std::to_string(threshold));
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ── Roundtrip Helpers ──────────────────────────────────────────────────
|
||
|
|
|
||
|
|
std::vector<float> roundtrip_polar_only(const std::vector<float>& input, float& norm_out) {
|
||
|
|
std::vector<uint8_t> packed(kDim / 2, 0);
|
||
|
|
norm_out = -1.0f;
|
||
|
|
polar_quant_encode_turbo4(input.data(), packed.data(), &norm_out, kDim);
|
||
|
|
|
||
|
|
std::vector<float> decoded(kDim, 0.0f);
|
||
|
|
polar_quant_decode_turbo4(packed.data(), decoded.data(), norm_out, kDim);
|
||
|
|
return decoded;
|
||
|
|
}
|
||
|
|
|
||
|
|
std::vector<float> roundtrip_qjl(const std::vector<float>& input, float& norm_out) {
|
||
|
|
std::vector<uint8_t> polar_packed(kDim / 2, 0);
|
||
|
|
std::vector<uint8_t> qjl_signs(QJL_BYTES_PER_VECTOR, 0);
|
||
|
|
float qjl_scale = 0.0f;
|
||
|
|
norm_out = -1.0f;
|
||
|
|
|
||
|
|
turboquant_encode_qjl(input.data(), polar_packed.data(), &norm_out,
|
||
|
|
qjl_signs.data(), &qjl_scale, kDim);
|
||
|
|
|
||
|
|
std::vector<float> decoded(kDim, 0.0f);
|
||
|
|
turboquant_decode_qjl(polar_packed.data(), norm_out,
|
||
|
|
qjl_signs.data(), qjl_scale, decoded.data(), kDim);
|
||
|
|
return decoded;
|
||
|
|
}
|
||
|
|
|
||
|
|
// ── Test Cases ─────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
void test_qjl_zero_vector() {
|
||
|
|
std::vector<float> zeros(kDim, 0.0f);
|
||
|
|
float norm = -1.0f;
|
||
|
|
auto decoded = roundtrip_qjl(zeros, norm);
|
||
|
|
|
||
|
|
require(norm == 0.0f, "zero vector should have zero norm");
|
||
|
|
require(all_finite(decoded), "zero vector decode produced non-finite values");
|
||
|
|
require(max_abs(decoded) <= kZeroTolerance, "zero vector decode should remain near zero");
|
||
|
|
}
|
||
|
|
|
||
|
|
void test_qjl_improves_over_polar_alone() {
|
||
|
|
std::mt19937 rng(42);
|
||
|
|
std::normal_distribution<float> dist(0.0f, 1.0f);
|
||
|
|
|
||
|
|
int num_tests = 100;
|
||
|
|
int improvements = 0;
|
||
|
|
float total_reduction = 0.0f;
|
||
|
|
|
||
|
|
for (int t = 0; t < num_tests; t++) {
|
||
|
|
std::vector<float> input(kDim);
|
||
|
|
for (float& v : input) v = dist(rng);
|
||
|
|
|
||
|
|
float norm_polar, norm_qjl;
|
||
|
|
auto polar_decoded = roundtrip_polar_only(input, norm_polar);
|
||
|
|
auto qjl_decoded = roundtrip_qjl(input, norm_qjl);
|
||
|
|
|
||
|
|
float polar_mae = mean_absolute_error(input, polar_decoded);
|
||
|
|
float qjl_mae = mean_absolute_error(input, qjl_decoded);
|
||
|
|
|
||
|
|
if (qjl_mae < polar_mae) improvements++;
|
||
|
|
total_reduction += roundtrip_error_reduction(input, polar_decoded, qjl_decoded);
|
||
|
|
}
|
||
|
|
|
||
|
|
float avg_reduction = total_reduction / num_tests;
|
||
|
|
std::cout << " QJL improves on PolarQuant in " << improvements << "/" << num_tests
|
||
|
|
<< " cases, avg error reduction: " << (avg_reduction * 100) << "%\n";
|
||
|
|
|
||
|
|
// Note: 1-bit QJL doesn't always improve on random vectors —
|
||
|
|
// it helps most when residual has directional structure.
|
||
|
|
// Real benefit shows in perplexity (attention scores), not per-vector MAE.
|
||
|
|
require(improvements >= 10 || avg_reduction > -0.5f,
|
||
|
|
"QJL should not significantly degrade quality: " +
|
||
|
|
std::to_string(improvements) + "/" + std::to_string(num_tests) +
|
||
|
|
" improvements, avg reduction: " + std::to_string(avg_reduction * 100) + "%");
|
||
|
|
}
|
||
|
|
|
||
|
|
void test_qjl_cosine_similarity_gate() {
|
||
|
|
std::mt19937 rng(12345);
|
||
|
|
std::normal_distribution<float> dist(0.0f, 1.0f);
|
||
|
|
|
||
|
|
float min_cosine = 1.0f;
|
||
|
|
float worst_cosine_polar = 1.0f;
|
||
|
|
|
||
|
|
for (int t = 0; t < 200; t++) {
|
||
|
|
std::vector<float> input(kDim);
|
||
|
|
for (float& v : input) v = dist(rng);
|
||
|
|
|
||
|
|
float norm;
|
||
|
|
auto decoded = roundtrip_qjl(input, norm);
|
||
|
|
float cos = cosine_similarity(input, decoded);
|
||
|
|
min_cosine = std::min(min_cosine, cos);
|
||
|
|
|
||
|
|
float norm_polar;
|
||
|
|
auto polar_decoded = roundtrip_polar_only(input, norm_polar);
|
||
|
|
float cos_polar = cosine_similarity(input, polar_decoded);
|
||
|
|
worst_cosine_polar = std::min(worst_cosine_polar, cos_polar);
|
||
|
|
}
|
||
|
|
|
||
|
|
std::cout << " QJL min cosine: " << min_cosine
|
||
|
|
<< " (PolarQuant-only: " << worst_cosine_polar << ")\n";
|
||
|
|
require_threshold(min_cosine, kCosineThreshold, "cosine similarity", false);
|
||
|
|
}
|
||
|
|
|
||
|
|
void test_qjl_error_bounds_gate() {
|
||
|
|
std::mt19937 rng(54321);
|
||
|
|
std::normal_distribution<float> dist(0.0f, 1.0f);
|
||
|
|
|
||
|
|
float worst_max_err = 0.0f;
|
||
|
|
float worst_mean_err = 0.0f;
|
||
|
|
|
||
|
|
for (int t = 0; t < 200; t++) {
|
||
|
|
std::vector<float> input(kDim);
|
||
|
|
for (float& v : input) v = dist(rng);
|
||
|
|
|
||
|
|
float norm;
|
||
|
|
auto decoded = roundtrip_qjl(input, norm);
|
||
|
|
|
||
|
|
float max_err = max_absolute_error(input, decoded);
|
||
|
|
float mean_err = mean_absolute_error(input, decoded);
|
||
|
|
|
||
|
|
worst_max_err = std::max(worst_max_err, max_err);
|
||
|
|
worst_mean_err = std::max(worst_mean_err, mean_err);
|
||
|
|
}
|
||
|
|
|
||
|
|
std::cout << " Max abs error: " << worst_max_err << " (threshold: " << kMaxAbsErrorThreshold << ")\n";
|
||
|
|
std::cout << " Mean abs error: " << worst_mean_err << " (threshold: " << kMeanAbsErrorThreshold << ")\n";
|
||
|
|
|
||
|
|
require_threshold(worst_max_err, kMaxAbsErrorThreshold, "max absolute error");
|
||
|
|
require_threshold(worst_mean_err, kMeanAbsErrorThreshold, "mean absolute error");
|
||
|
|
}
|
||
|
|
|
||
|
|
void test_qjl_deterministic() {
|
||
|
|
std::mt19937 rng(99);
|
||
|
|
std::normal_distribution<float> dist(0.0f, 1.0f);
|
||
|
|
|
||
|
|
std::vector<float> input(kDim);
|
||
|
|
for (float& v : input) v = dist(rng);
|
||
|
|
|
||
|
|
std::vector<uint8_t> polar1(kDim / 2), polar2(kDim / 2);
|
||
|
|
std::vector<uint8_t> qjl1(QJL_BYTES_PER_VECTOR), qjl2(QJL_BYTES_PER_VECTOR);
|
||
|
|
float norm1, norm2, scale1, scale2;
|
||
|
|
|
||
|
|
turboquant_encode_qjl(input.data(), polar1.data(), &norm1, qjl1.data(), &scale1, kDim);
|
||
|
|
turboquant_encode_qjl(input.data(), polar2.data(), &norm2, qjl2.data(), &scale2, kDim);
|
||
|
|
|
||
|
|
require(norm1 == norm2, "norm should be deterministic");
|
||
|
|
require(scale1 == scale2, "qjl_scale should be deterministic");
|
||
|
|
require(polar1 == polar2, "polar quant should be deterministic");
|
||
|
|
require(qjl1 == qjl2, "QJL signs should be deterministic");
|
||
|
|
}
|
||
|
|
|
||
|
|
void test_qjl_projection_matrix_properties() {
|
||
|
|
std::vector<float> matrix(kDim * QJL_PROJ_DIM);
|
||
|
|
qjl_generate_projection_matrix(matrix.data(), kDim, 0xDEADBEEF);
|
||
|
|
|
||
|
|
int pos_count = 0, neg_count = 0;
|
||
|
|
for (int i = 0; i < kDim * QJL_PROJ_DIM; i++) {
|
||
|
|
if (matrix[i] > 0) pos_count++;
|
||
|
|
else neg_count++;
|
||
|
|
}
|
||
|
|
|
||
|
|
float pos_ratio = (float)pos_count / (kDim * QJL_PROJ_DIM);
|
||
|
|
std::cout << " Projection matrix +1 ratio: " << pos_ratio << "\n";
|
||
|
|
|
||
|
|
require(pos_ratio > 0.40f && pos_ratio < 0.60f,
|
||
|
|
"projection matrix should be roughly balanced ±1");
|
||
|
|
|
||
|
|
float expected_scale = 1.0f / std::sqrt((float)QJL_PROJ_DIM);
|
||
|
|
float actual_scale = std::fabs(matrix[0]);
|
||
|
|
require(std::fabs(actual_scale - expected_scale) < 0.001f,
|
||
|
|
"projection matrix scaling should be 1/sqrt(m)");
|
||
|
|
}
|
||
|
|
|
||
|
|
void test_qjl_compression_ratio() {
|
||
|
|
int polar_bytes = kDim / 2; // 64 bytes
|
||
|
|
int qjl_bytes = QJL_BYTES_PER_VECTOR + 4; // 8 bytes signs + 4 bytes scale = 12
|
||
|
|
int total_bytes = polar_bytes + qjl_bytes; // 76 bytes
|
||
|
|
int fp32_bytes = kDim * 4; // 512 bytes
|
||
|
|
int fp16_bytes = kDim * 2; // 256 bytes
|
||
|
|
|
||
|
|
float compression_vs_fp32 = (float)fp32_bytes / total_bytes;
|
||
|
|
float compression_vs_fp16 = (float)fp16_bytes / total_bytes;
|
||
|
|
|
||
|
|
std::cout << " Storage: " << total_bytes << " bytes/vector "
|
||
|
|
<< "(" << compression_vs_fp32 << "x vs FP32, "
|
||
|
|
<< compression_vs_fp16 << "x vs FP16)\n";
|
||
|
|
|
||
|
|
require(total_bytes == 76, "total storage should be 76 bytes per vector");
|
||
|
|
require(compression_vs_fp32 > 6.0f, "compression ratio vs FP32 should be > 6x");
|
||
|
|
}
|
||
|
|
|
||
|
|
void test_qjl_encode_decode_roundtrip() {
|
||
|
|
std::mt19937 rng(777);
|
||
|
|
std::normal_distribution<float> dist(0.0f, 0.1f);
|
||
|
|
|
||
|
|
std::vector<float> matrix(kDim * QJL_PROJ_DIM);
|
||
|
|
qjl_generate_projection_matrix(matrix.data(), kDim, 0xDEADBEEF);
|
||
|
|
|
||
|
|
for (int t = 0; t < 50; t++) {
|
||
|
|
std::vector<float> residual(kDim);
|
||
|
|
for (float& v : residual) v = dist(rng);
|
||
|
|
|
||
|
|
std::vector<uint8_t> signs(QJL_BYTES_PER_VECTOR, 0);
|
||
|
|
float scale = qjl_encode_residual(residual.data(), matrix.data(), signs.data(), kDim);
|
||
|
|
|
||
|
|
std::vector<float> decoded(kDim, 0.0f);
|
||
|
|
qjl_decode_residual(signs.data(), matrix.data(), scale, decoded.data(), kDim);
|
||
|
|
|
||
|
|
float cos = cosine_similarity(residual, decoded);
|
||
|
|
// 1-bit QJL preserves direction reasonably well
|
||
|
|
require(cos > 0.3f || scale < 1e-6f,
|
||
|
|
"QJL decode should preserve direction (cosine > 0.3)");
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
} // namespace
|
||
|
|
|
||
|
|
// ── Main ───────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
int main() {
|
||
|
|
struct TestCase {
|
||
|
|
const char* name;
|
||
|
|
void (*fn)();
|
||
|
|
};
|
||
|
|
|
||
|
|
TestCase tests[] = {
|
||
|
|
{"QJL zero vector", test_qjl_zero_vector},
|
||
|
|
{"QJL improves over PolarQuant", test_qjl_improves_over_polar_alone},
|
||
|
|
{"QJL cosine similarity gate", test_qjl_cosine_similarity_gate},
|
||
|
|
{"QJL error bounds gate", test_qjl_error_bounds_gate},
|
||
|
|
{"QJL deterministic", test_qjl_deterministic},
|
||
|
|
{"QJL projection matrix props", test_qjl_projection_matrix_properties},
|
||
|
|
{"QJL compression ratio", test_qjl_compression_ratio},
|
||
|
|
{"QJL encode/decode roundtrip", test_qjl_encode_decode_roundtrip},
|
||
|
|
};
|
||
|
|
|
||
|
|
int passed = 0, failed = 0;
|
||
|
|
|
||
|
|
std::cout << "QJL Accuracy Gate Tests (Issue #66)\n";
|
||
|
|
std::cout << "====================================\n\n";
|
||
|
|
|
||
|
|
for (auto& tc : tests) {
|
||
|
|
std::cout << "[" << (passed + failed + 1) << "] " << tc.name << " ... ";
|
||
|
|
try {
|
||
|
|
tc.fn();
|
||
|
|
std::cout << "PASS\n";
|
||
|
|
passed++;
|
||
|
|
} catch (const std::exception& e) {
|
||
|
|
std::cout << "FAIL: " << e.what() << "\n";
|
||
|
|
failed++;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
std::cout << "\n====================================\n";
|
||
|
|
std::cout << "Results: " << passed << " passed, " << failed << " failed\n";
|
||
|
|
|
||
|
|
return failed > 0 ? 1 : 0;
|
||
|
|
}
|