// Unified semantic search worker: handles data loading and query execution import { env, pipeline } from "@huggingface/transformers" import "onnxruntime-web/webgpu" import "onnxruntime-web/wasm" export {} type VectorShardMeta = { path: string rows: number rowOffset: number byteLength: number sha256?: string byteStride: number } type LevelSection = { level: number indptr: { offset: number; elements: number; byteLength: number } indices: { offset: number; elements: number; byteLength: number } } type ChunkMetadata = { parentSlug: string chunkId: number } type Manifest = { version: number dims: number dtype: string normalized: boolean rows: number shardSizeRows: number vectors: { dtype: string rows: number dims: number shards: VectorShardMeta[] } ids: string[] titles?: string[] chunkMetadata?: Record hnsw: { M: number efConstruction: number entryPoint: number maxLevel: number graph: { path: string sha256?: string levels: LevelSection[] } } } type InitMessage = { type: "init" cfg: any manifestUrl: string baseUrl?: string disableCache?: boolean } type SearchMessage = { type: "search"; text: string; k: number; seq: number } type ResetMessage = { type: "reset" } type WorkerMessage = InitMessage | SearchMessage | ResetMessage type ReadyMessage = { type: "ready" } type ProgressMessage = { type: "progress" loadedRows: number totalRows: number } type SearchHit = { id: number; score: number } type SearchResultMessage = { type: "search-result" seq: number semantic: SearchHit[] } type ErrorMessage = { type: "error"; seq?: number; message: string } type WorkerState = "idle" | "loading" | "ready" | "error" // IndexedDB configuration const DB_NAME = "semantic-search-cache" const STORE_NAME = "assets" const DB_VERSION = 1 const hasIndexedDB = typeof indexedDB !== "undefined" const supportsSharedArrayBuffer = typeof SharedArrayBuffer !== "undefined" // State let state: WorkerState = "idle" let manifest: Manifest | null = null let cfg: any = null let vectorsView: Float32Array | null = null let dims = 0 let rows = 0 let classifier: any = null let envConfigured = false let entryPoint = -1 let maxLevel = 0 let efDefault = 128 let levelGraph: { indptr: Uint32Array; indices: Uint32Array }[] = [] let abortController: AbortController | null = null let dbPromise: Promise | null = null // IndexedDB helpers function openDatabase(): Promise { if (!hasIndexedDB) { return Promise.reject(new Error("indexedDB unavailable")) } if (!dbPromise) { dbPromise = new Promise((resolve, reject) => { const req = indexedDB.open(DB_NAME, DB_VERSION) req.onupgradeneeded = () => { const db = req.result if (!db.objectStoreNames.contains(STORE_NAME)) { db.createObjectStore(STORE_NAME) } } req.onsuccess = () => resolve(req.result) req.onerror = () => reject(req.error ?? new Error("failed to open cache store")) }) } return dbPromise } async function readAsset(hash: string): Promise { if (!hasIndexedDB) { return null } const db = await openDatabase() return new Promise((resolve, reject) => { const tx = db.transaction(STORE_NAME, "readonly") const store = tx.objectStore(STORE_NAME) const req = store.get(hash) req.onsuccess = () => { const value = req.result if (value instanceof ArrayBuffer) { resolve(value) } else if (value && value.buffer instanceof ArrayBuffer) { resolve(value.buffer as ArrayBuffer) } else { resolve(null) } } req.onerror = () => reject(req.error ?? new Error("failed to read cached asset")) }) } async function writeAsset(hash: string, buffer: ArrayBuffer): Promise { if (!hasIndexedDB) { return } const db = await openDatabase() await new Promise((resolve, reject) => { const tx = db.transaction(STORE_NAME, "readwrite") const store = tx.objectStore(STORE_NAME) const req = store.put(buffer, hash) req.onsuccess = () => resolve() req.onerror = () => reject(req.error ?? new Error("failed to cache asset")) }) } function toAbsolute(path: string, baseUrl?: string): string { if (path.startsWith("http://") || path.startsWith("https://")) { return path } const base = baseUrl ?? self.location.origin return new URL(path, base).toString() } async function fetchBinary( path: string, disableCache: boolean, sha?: string, ): Promise { if (!disableCache && sha && hasIndexedDB) { try { const cached = await readAsset(sha) if (cached) { return cached } } catch { // fall through to network fetch on cache errors } } const res = await fetch(path, { signal: abortController?.signal ?? undefined }) if (!res.ok) { throw new Error(`failed to fetch ${path}: ${res.status} ${res.statusText}`) } const payload = await res.arrayBuffer() if (!disableCache && sha && hasIndexedDB) { try { await writeAsset(sha, payload) } catch { // ignore cache write failures } } return payload } async function populateVectors( manifest: Manifest, baseUrl: string | undefined, disableCache: boolean | undefined, ): Promise<{ buffer: Float32Array; rowsLoaded: number }> { if (manifest.vectors.dtype !== "fp32") { throw new Error(`unsupported embedding dtype '${manifest.vectors.dtype}', regenerate with fp32`) } const rows = manifest.rows const dims = manifest.dims const totalBytes = rows * dims * Float32Array.BYTES_PER_ELEMENT const buffer = supportsSharedArrayBuffer ? new Float32Array(new SharedArrayBuffer(totalBytes)) : new Float32Array(totalBytes) let loadedRows = 0 for (const shard of manifest.vectors.shards) { const absolute = toAbsolute(shard.path, baseUrl) const payload = await fetchBinary(absolute, Boolean(disableCache), shard.sha256) const view = new Float32Array(payload) if (view.length !== shard.rows * dims) { throw new Error( `shard ${shard.path} has mismatched length (expected ${shard.rows * dims}, got ${view.length})`, ) } buffer.set(view, shard.rowOffset * dims) loadedRows = Math.min(rows, shard.rowOffset + shard.rows) const progress: ProgressMessage = { type: "progress", loadedRows, totalRows: rows, } self.postMessage(progress) } return { buffer, rowsLoaded: loadedRows } } async function populateGraph( manifest: Manifest, baseUrl: string | undefined, disableCache: boolean | undefined, ): Promise { const graphMeta = manifest.hnsw.graph const absolute = toAbsolute(graphMeta.path, baseUrl) return await fetchBinary(absolute, Boolean(disableCache), graphMeta.sha256) } function configureRuntimeEnv() { if (envConfigured) return env.allowLocalModels = false env.allowRemoteModels = true const wasmBackend = env.backends?.onnx?.wasm if (!wasmBackend) { throw new Error("transformers.js ONNX runtime backend unavailable") } const cdnBase = `https://cdn.jsdelivr.net/npm/@huggingface/transformers@${env.version}/dist/` wasmBackend.wasmPaths = cdnBase envConfigured = true } async function ensureEncoder() { if (classifier) return if (!cfg?.model) { throw new Error("semantic worker missing model identifier") } configureRuntimeEnv() const dtype = typeof cfg?.dtype === "string" && cfg.dtype.length > 0 ? cfg.dtype : "fp32" const pipelineOpts: Record = { device: "wasm", dtype, local_files_only: false, } classifier = await pipeline("feature-extraction", cfg.model, pipelineOpts) cfg.dtype = dtype } function vectorSlice(id: number): Float32Array { if (!vectorsView) { throw new Error("vector buffer not configured") } const start = id * dims const end = start + dims return vectorsView.subarray(start, end) } function dot(a: Float32Array, b: Float32Array): number { let s = 0 for (let i = 0; i < dims; i++) { s += a[i] * b[i] } return s } function neighborsFor(level: number, node: number): Uint32Array { const meta = levelGraph[level] if (!meta) return new Uint32Array() const { indptr, indices } = meta if (node < 0 || node + 1 >= indptr.length) return new Uint32Array() const start = indptr[node] const end = indptr[node + 1] return indices.subarray(start, end) } function insertSortedDescending(arr: SearchHit[], item: SearchHit) { let idx = arr.length while (idx > 0 && arr[idx - 1].score < item.score) { idx -= 1 } arr.splice(idx, 0, item) } function bruteForceSearch(query: Float32Array, k: number): SearchHit[] { if (!vectorsView) return [] const hits: SearchHit[] = [] for (let id = 0; id < rows; id++) { const score = dot(query, vectorSlice(id)) if (hits.length < k) { insertSortedDescending(hits, { id, score }) } else if (score > hits[hits.length - 1].score) { insertSortedDescending(hits, { id, score }) hits.length = k } } return hits } function hnswSearch(query: Float32Array, k: number): SearchHit[] { if (!manifest || !vectorsView || entryPoint < 0 || levelGraph.length === 0) { return bruteForceSearch(query, k) } const ef = Math.max(efDefault, k * 10) let ep = entryPoint let epScore = dot(query, vectorSlice(ep)) for (let level = maxLevel; level > 0; level--) { let changed = true while (changed) { changed = false const neigh = neighborsFor(level, ep) for (let i = 0; i < neigh.length; i++) { const candidate = neigh[i] if (candidate >= rows) continue const score = dot(query, vectorSlice(candidate)) if (score > epScore) { epScore = score ep = candidate changed = true } } } } const visited = new Set() const candidateQueue: SearchHit[] = [] const best: SearchHit[] = [] insertSortedDescending(candidateQueue, { id: ep, score: epScore }) insertSortedDescending(best, { id: ep, score: epScore }) visited.add(ep) while (candidateQueue.length > 0) { const current = candidateQueue.shift()! const worstBest = best.length >= ef ? best[best.length - 1].score : -Infinity if (current.score < worstBest && best.length >= ef) { break } const neigh = neighborsFor(0, current.id) for (let i = 0; i < neigh.length; i++) { const candidate = neigh[i] if (candidate >= rows || visited.has(candidate)) continue visited.add(candidate) const score = dot(query, vectorSlice(candidate)) const hit = { id: candidate, score } insertSortedDescending(candidateQueue, hit) if (best.length < ef || score > best[best.length - 1].score) { insertSortedDescending(best, hit) if (best.length > ef) { best.pop() } } } } best.sort((a, b) => b.score - a.score) return best.slice(0, k) } async function embed(text: string, isQuery: boolean = false): Promise { await ensureEncoder() // Apply model-specific prefixes for asymmetric search let prefixedText = text if (cfg?.model) { const modelName = cfg.model.toLowerCase() switch (true) { case modelName.includes("e5"): { // E5 models require query: or passage: prefix prefixedText = isQuery ? `query: ${text}` : `passage: ${text}` break } case modelName.includes("qwen") && modelName.includes("embedding"): { // Qwen3-Embedding requires task instruction for queries only if (isQuery) { const task = "Given a web search query, retrieve relevant passages that answer the query" prefixedText = `Instruct: ${task}\nQuery: ${text}` } // Documents use plain text (no prefix) break } case modelName.includes("embeddinggemma"): { // embeddinggemma requires specific prefixes prefixedText = isQuery ? `task: search result | query: ${text}` : `title: none | text: ${text}` break } default: break } } const out = await classifier(prefixedText, { pooling: "mean", normalize: true }) const data = Array.from(out?.data ?? out) as number[] const vec = new Float32Array(dims) for (let i = 0; i < dims; i++) vec[i] = data[i] ?? 0 return vec } async function handleInit(msg: InitMessage) { if (state === "loading" || state === "ready") { throw new Error("worker already initialized or loading") } state = "loading" abortController?.abort() abortController = new AbortController() try { cfg = msg.cfg const manifestUrl = toAbsolute(msg.manifestUrl, msg.baseUrl) const response = await fetch(manifestUrl, { signal: abortController.signal }) if (!response.ok) { throw new Error( `failed to fetch manifest ${manifestUrl}: ${response.status} ${response.statusText}`, ) } manifest = (await response.json()) as Manifest if (manifest.vectors.dtype !== "fp32") { throw new Error( `unsupported embedding dtype '${manifest.vectors.dtype}', regenerate with fp32`, ) } dims = manifest.dims rows = manifest.rows const { buffer: vectorBuffer } = await populateVectors(manifest, msg.baseUrl, msg.disableCache) vectorsView = vectorBuffer const graphBuffer = await populateGraph(manifest, msg.baseUrl, msg.disableCache) entryPoint = manifest.hnsw.entryPoint maxLevel = manifest.hnsw.maxLevel efDefault = Math.max(64, manifest.hnsw.M * 4) levelGraph = manifest.hnsw.graph.levels.map((level) => { const indptr = new Uint32Array(graphBuffer, level.indptr.offset, level.indptr.elements) const indices = new Uint32Array(graphBuffer, level.indices.offset, level.indices.elements) return { indptr, indices } }) state = "ready" const ready: ReadyMessage = { type: "ready" } self.postMessage(ready) } catch (err) { state = "error" throw err } } async function handleSearch(msg: SearchMessage) { if (state !== "ready") { throw new Error("worker not ready for search") } if (!manifest || !vectorsView) { throw new Error("semantic worker not configured") } const queryVec = await embed(msg.text, true) const semanticHits = hnswSearch(queryVec, Math.max(1, msg.k)) const message: SearchResultMessage = { type: "search-result", seq: msg.seq, semantic: semanticHits, } self.postMessage(message) } function handleReset() { abortController?.abort() abortController = null state = "idle" manifest = null cfg = null vectorsView = null dims = 0 rows = 0 classifier = null envConfigured = false levelGraph = [] entryPoint = -1 maxLevel = 0 } self.onmessage = (event: MessageEvent) => { const data = event.data if (data.type === "reset") { handleReset() return } if (data.type === "init") { void handleInit(data).catch((err: unknown) => { const message: ErrorMessage = { type: "error", message: err instanceof Error ? err.message : String(err), } self.postMessage(message) }) return } if (data.type === "search") { void handleSearch(data).catch((err: unknown) => { const message: ErrorMessage = { type: "error", seq: data.seq, message: err instanceof Error ? err.message : String(err), } self.postMessage(message) }) } }