52 lines
2.2 KiB
JavaScript
52 lines
2.2 KiB
JavaScript
/* inference-worker.js — Web Worker for non-blocking WASM inference.
|
|
*
|
|
* Messages IN: {cmd: 'load', data: ArrayBuffer} | {cmd: 'generate', prompt, maxTokens, temperature} | {cmd: 'benchmark', runs}
|
|
* Messages OUT: {event: 'loaded', ok, ms} | {event: 'generated', text, tokensPerSec} | {event: 'benchmark', msPerToken} | {event: 'error', msg}
|
|
*/
|
|
|
|
let Module = null;
|
|
let modelLoaded = false;
|
|
|
|
self.onmessage = async function(e) {
|
|
const msg = e.data;
|
|
try {
|
|
if (msg.cmd === 'init') {
|
|
importScripts('llama-turbo-wasm.js');
|
|
Module = await createModule();
|
|
self.postMessage({event: 'ready'});
|
|
}
|
|
else if (msg.cmd === 'load') {
|
|
if (!Module) throw new Error('Module not initialized');
|
|
const buf = new Uint8Array(msg.data);
|
|
const ptr = Module._malloc(buf.length);
|
|
Module.HEAPU8.set(buf, ptr);
|
|
const t0 = performance.now();
|
|
const rc = Module.ccall('model_load', 'number', ['number','number'], [ptr, buf.length]);
|
|
const ms = performance.now() - t0;
|
|
Module._free(ptr);
|
|
modelLoaded = rc === 0;
|
|
self.postMessage({event: 'loaded', ok: rc === 0, ms});
|
|
}
|
|
else if (msg.cmd === 'generate') {
|
|
if (!modelLoaded) throw new Error('Model not loaded');
|
|
const maxTok = msg.maxTokens || 64;
|
|
const temp = msg.temperature || 0.7;
|
|
const outPtr = Module._malloc(maxTok * 4);
|
|
const t0 = performance.now();
|
|
const n = Module.ccall('generate', 'number', ['string','number','number','number'], [msg.prompt, outPtr, maxTok, temp]);
|
|
const ms = performance.now() - t0;
|
|
const text = n > 0 ? Module.UTF8ToString(outPtr, n) : '';
|
|
Module._free(outPtr);
|
|
const tps = n > 0 ? (n / (ms / 1000)).toFixed(1) : 0;
|
|
self.postMessage({event: 'generated', text, tokensPerSec: parseFloat(tps), tokens: n, ms});
|
|
}
|
|
else if (msg.cmd === 'benchmark') {
|
|
if (!modelLoaded) throw new Error('Model not loaded');
|
|
const runs = msg.runs || 100;
|
|
const msPerToken = Module.ccall('benchmark', 'number', ['number'], [runs]);
|
|
self.postMessage({event: 'benchmark', msPerToken, runs});
|
|
}
|
|
} catch(err) {
|
|
self.postMessage({event: 'error', msg: err.message || String(err)});
|
|
}
|
|
}; |