Files
turboquant/tests/metal_integration_test.cpp

78 lines
2.2 KiB
C++

// tests/metal_integration_test.cpp — Validate TurboQuant Metal kernel registration
//
// This test verifies:
// 1. ggml-metal-turbo.h compiles as valid C/C++
// 2. The API surface is consistent and complete
// 3. Integration header can be included alongside llama-turbo.h
//
// Note: Actual Metal GPU execution requires macOS with Metal support.
// This test runs on all platforms for API validation.
#include <cassert>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <stdexcept>
#include "../ggml-metal-turbo.h"
#include "../llama-turbo.h"
namespace {
void test_header_compiles() {
// Verify enum values are consecutive and complete
assert(GGML_METAL_TURBO_KERNEL_FWHT_128 == 0);
assert(GGML_METAL_TURBO_KERNEL_TURBO4_DEQUANT == 1);
assert(GGML_METAL_TURBO_KERNEL_ATTENTION_TURBO4 == 2);
assert(GGML_METAL_TURBO_KERNEL_COUNT == 3);
}
void test_cpu_roundtrip_still_works() {
// Verify the CPU reference implementation still functions
// alongside the Metal integration header
constexpr int d = 128;
float input[d] = {};
for (int i = 0; i < d; i++) {
input[i] = (float)(i - 64) / 64.0f;
}
uint8_t packed[d / 2] = {};
float norm = 0.0f;
polar_quant_encode_turbo4(input, packed, &norm, d);
assert(norm > 0.0f);
float decoded[d] = {};
polar_quant_decode_turbo4(packed, decoded, norm, d);
// All decoded values should be finite
for (int i = 0; i < d; i++) {
assert(std::isfinite(decoded[i]));
}
}
void test_api_null_safety() {
// API functions should handle NULL gracefully
assert(ggml_metal_turbo_get_pipeline(
static_cast<ggml_metal_turbo_kernel>(-1)) == nullptr);
assert(ggml_metal_turbo_get_pipeline(
static_cast<ggml_metal_turbo_kernel>(99)) == nullptr);
// Before registration, should report unavailable
assert(!ggml_metal_turbo_available());
}
} // namespace
int main() {
try {
test_header_compiles();
test_cpu_roundtrip_still_works();
test_api_null_safety();
std::printf("PASS: TurboQuant Metal integration tests\n");
return 0;
} catch (const std::exception & exc) {
std::fprintf(stderr, "FAIL: %s\n", exc.what());
return 1;
}
}