Files
turboquant/wasm/inference-worker.js

52 lines
2.2 KiB
JavaScript

/* inference-worker.js — Web Worker for non-blocking WASM inference.
*
* Messages IN: {cmd: 'load', data: ArrayBuffer} | {cmd: 'generate', prompt, maxTokens, temperature} | {cmd: 'benchmark', runs}
* Messages OUT: {event: 'loaded', ok, ms} | {event: 'generated', text, tokensPerSec} | {event: 'benchmark', msPerToken} | {event: 'error', msg}
*/
let Module = null;
let modelLoaded = false;
self.onmessage = async function(e) {
const msg = e.data;
try {
if (msg.cmd === 'init') {
importScripts('llama-turbo-wasm.js');
Module = await createModule();
self.postMessage({event: 'ready'});
}
else if (msg.cmd === 'load') {
if (!Module) throw new Error('Module not initialized');
const buf = new Uint8Array(msg.data);
const ptr = Module._malloc(buf.length);
Module.HEAPU8.set(buf, ptr);
const t0 = performance.now();
const rc = Module.ccall('model_load', 'number', ['number','number'], [ptr, buf.length]);
const ms = performance.now() - t0;
Module._free(ptr);
modelLoaded = rc === 0;
self.postMessage({event: 'loaded', ok: rc === 0, ms});
}
else if (msg.cmd === 'generate') {
if (!modelLoaded) throw new Error('Model not loaded');
const maxTok = msg.maxTokens || 64;
const temp = msg.temperature || 0.7;
const outPtr = Module._malloc(maxTok * 4);
const t0 = performance.now();
const n = Module.ccall('generate', 'number', ['string','number','number','number'], [msg.prompt, outPtr, maxTok, temp]);
const ms = performance.now() - t0;
const text = n > 0 ? Module.UTF8ToString(outPtr, n) : '';
Module._free(outPtr);
const tps = n > 0 ? (n / (ms / 1000)).toFixed(1) : 0;
self.postMessage({event: 'generated', text, tokensPerSec: parseFloat(tps), tokens: n, ms});
}
else if (msg.cmd === 'benchmark') {
if (!modelLoaded) throw new Error('Model not loaded');
const runs = msg.runs || 100;
const msPerToken = Module.ccall('benchmark', 'number', ['number'], [runs]);
self.postMessage({event: 'benchmark', msPerToken, runs});
}
} catch(err) {
self.postMessage({event: 'error', msg: err.message || String(err)});
}
};