78 lines
2.2 KiB
C++
78 lines
2.2 KiB
C++
// tests/metal_integration_test.cpp — Validate TurboQuant Metal kernel registration
|
|
//
|
|
// This test verifies:
|
|
// 1. ggml-metal-turbo.h compiles as valid C/C++
|
|
// 2. The API surface is consistent and complete
|
|
// 3. Integration header can be included alongside llama-turbo.h
|
|
//
|
|
// Note: Actual Metal GPU execution requires macOS with Metal support.
|
|
// This test runs on all platforms for API validation.
|
|
|
|
#include <cassert>
|
|
#include <cmath>
|
|
#include <cstdio>
|
|
#include <cstring>
|
|
#include <stdexcept>
|
|
|
|
#include "../ggml-metal-turbo.h"
|
|
#include "../llama-turbo.h"
|
|
|
|
namespace {
|
|
|
|
void test_header_compiles() {
|
|
// Verify enum values are consecutive and complete
|
|
assert(GGML_METAL_TURBO_KERNEL_FWHT_128 == 0);
|
|
assert(GGML_METAL_TURBO_KERNEL_TURBO4_DEQUANT == 1);
|
|
assert(GGML_METAL_TURBO_KERNEL_ATTENTION_TURBO4 == 2);
|
|
assert(GGML_METAL_TURBO_KERNEL_COUNT == 3);
|
|
}
|
|
|
|
void test_cpu_roundtrip_still_works() {
|
|
// Verify the CPU reference implementation still functions
|
|
// alongside the Metal integration header
|
|
constexpr int d = 128;
|
|
float input[d] = {};
|
|
for (int i = 0; i < d; i++) {
|
|
input[i] = (float)(i - 64) / 64.0f;
|
|
}
|
|
|
|
uint8_t packed[d / 2] = {};
|
|
float norm = 0.0f;
|
|
polar_quant_encode_turbo4(input, packed, &norm, d);
|
|
assert(norm > 0.0f);
|
|
|
|
float decoded[d] = {};
|
|
polar_quant_decode_turbo4(packed, decoded, norm, d);
|
|
|
|
// All decoded values should be finite
|
|
for (int i = 0; i < d; i++) {
|
|
assert(std::isfinite(decoded[i]));
|
|
}
|
|
}
|
|
|
|
void test_api_null_safety() {
|
|
// API functions should handle NULL gracefully
|
|
assert(ggml_metal_turbo_get_pipeline(
|
|
static_cast<ggml_metal_turbo_kernel>(-1)) == nullptr);
|
|
assert(ggml_metal_turbo_get_pipeline(
|
|
static_cast<ggml_metal_turbo_kernel>(99)) == nullptr);
|
|
|
|
// Before registration, should report unavailable
|
|
assert(!ggml_metal_turbo_available());
|
|
}
|
|
|
|
} // namespace
|
|
|
|
int main() {
|
|
try {
|
|
test_header_compiles();
|
|
test_cpu_roundtrip_still_works();
|
|
test_api_null_safety();
|
|
std::printf("PASS: TurboQuant Metal integration tests\n");
|
|
return 0;
|
|
} catch (const std::exception & exc) {
|
|
std::fprintf(stderr, "FAIL: %s\n", exc.what());
|
|
return 1;
|
|
}
|
|
}
|