Files
quartz/quartz/workers/semantic.worker.ts
Aaron Pham f533902c75 feat: semantic search (1/n)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
2025-10-05 19:50:52 -04:00

549 lines
15 KiB
TypeScript

// Unified semantic search worker: handles data loading and query execution
import { env, pipeline } from "@huggingface/transformers"
import "onnxruntime-web/webgpu"
import "onnxruntime-web/wasm"
export {}
type VectorShardMeta = {
path: string
rows: number
rowOffset: number
byteLength: number
sha256?: string
byteStride: number
}
type LevelSection = {
level: number
indptr: { offset: number; elements: number; byteLength: number }
indices: { offset: number; elements: number; byteLength: number }
}
type ChunkMetadata = {
parentSlug: string
chunkId: number
}
type Manifest = {
version: number
dims: number
dtype: string
normalized: boolean
rows: number
shardSizeRows: number
vectors: {
dtype: string
rows: number
dims: number
shards: VectorShardMeta[]
}
ids: string[]
titles?: string[]
chunkMetadata?: Record<string, ChunkMetadata>
hnsw: {
M: number
efConstruction: number
entryPoint: number
maxLevel: number
graph: {
path: string
sha256?: string
levels: LevelSection[]
}
}
}
type InitMessage = {
type: "init"
cfg: any
manifestUrl: string
baseUrl?: string
disableCache?: boolean
}
type SearchMessage = { type: "search"; text: string; k: number; seq: number }
type ResetMessage = { type: "reset" }
type WorkerMessage = InitMessage | SearchMessage | ResetMessage
type ReadyMessage = { type: "ready" }
type ProgressMessage = {
type: "progress"
loadedRows: number
totalRows: number
}
type SearchHit = { id: number; score: number }
type SearchResultMessage = {
type: "search-result"
seq: number
semantic: SearchHit[]
}
type ErrorMessage = { type: "error"; seq?: number; message: string }
type WorkerState = "idle" | "loading" | "ready" | "error"
// IndexedDB configuration
const DB_NAME = "semantic-search-cache"
const STORE_NAME = "assets"
const DB_VERSION = 1
const hasIndexedDB = typeof indexedDB !== "undefined"
const supportsSharedArrayBuffer = typeof SharedArrayBuffer !== "undefined"
// State
let state: WorkerState = "idle"
let manifest: Manifest | null = null
let cfg: any = null
let vectorsView: Float32Array | null = null
let dims = 0
let rows = 0
let classifier: any = null
let envConfigured = false
let entryPoint = -1
let maxLevel = 0
let efDefault = 128
let levelGraph: { indptr: Uint32Array; indices: Uint32Array }[] = []
let abortController: AbortController | null = null
let dbPromise: Promise<IDBDatabase> | null = null
// IndexedDB helpers
function openDatabase(): Promise<IDBDatabase> {
if (!hasIndexedDB) {
return Promise.reject(new Error("indexedDB unavailable"))
}
if (!dbPromise) {
dbPromise = new Promise((resolve, reject) => {
const req = indexedDB.open(DB_NAME, DB_VERSION)
req.onupgradeneeded = () => {
const db = req.result
if (!db.objectStoreNames.contains(STORE_NAME)) {
db.createObjectStore(STORE_NAME)
}
}
req.onsuccess = () => resolve(req.result)
req.onerror = () => reject(req.error ?? new Error("failed to open cache store"))
})
}
return dbPromise
}
async function readAsset(hash: string): Promise<ArrayBuffer | null> {
if (!hasIndexedDB) {
return null
}
const db = await openDatabase()
return new Promise((resolve, reject) => {
const tx = db.transaction(STORE_NAME, "readonly")
const store = tx.objectStore(STORE_NAME)
const req = store.get(hash)
req.onsuccess = () => {
const value = req.result
if (value instanceof ArrayBuffer) {
resolve(value)
} else if (value && value.buffer instanceof ArrayBuffer) {
resolve(value.buffer as ArrayBuffer)
} else {
resolve(null)
}
}
req.onerror = () => reject(req.error ?? new Error("failed to read cached asset"))
})
}
async function writeAsset(hash: string, buffer: ArrayBuffer): Promise<void> {
if (!hasIndexedDB) {
return
}
const db = await openDatabase()
await new Promise<void>((resolve, reject) => {
const tx = db.transaction(STORE_NAME, "readwrite")
const store = tx.objectStore(STORE_NAME)
const req = store.put(buffer, hash)
req.onsuccess = () => resolve()
req.onerror = () => reject(req.error ?? new Error("failed to cache asset"))
})
}
function toAbsolute(path: string, baseUrl?: string): string {
if (path.startsWith("http://") || path.startsWith("https://")) {
return path
}
const base = baseUrl ?? self.location.origin
return new URL(path, base).toString()
}
async function fetchBinary(
path: string,
disableCache: boolean,
sha?: string,
): Promise<ArrayBuffer> {
if (!disableCache && sha && hasIndexedDB) {
try {
const cached = await readAsset(sha)
if (cached) {
return cached
}
} catch {
// fall through to network fetch on cache errors
}
}
const res = await fetch(path, { signal: abortController?.signal ?? undefined })
if (!res.ok) {
throw new Error(`failed to fetch ${path}: ${res.status} ${res.statusText}`)
}
const payload = await res.arrayBuffer()
if (!disableCache && sha && hasIndexedDB) {
try {
await writeAsset(sha, payload)
} catch {
// ignore cache write failures
}
}
return payload
}
async function populateVectors(
manifest: Manifest,
baseUrl: string | undefined,
disableCache: boolean | undefined,
): Promise<{ buffer: Float32Array; rowsLoaded: number }> {
if (manifest.vectors.dtype !== "fp32") {
throw new Error(`unsupported embedding dtype '${manifest.vectors.dtype}', regenerate with fp32`)
}
const rows = manifest.rows
const dims = manifest.dims
const totalBytes = rows * dims * Float32Array.BYTES_PER_ELEMENT
const buffer = supportsSharedArrayBuffer
? new Float32Array(new SharedArrayBuffer(totalBytes))
: new Float32Array(totalBytes)
let loadedRows = 0
for (const shard of manifest.vectors.shards) {
const absolute = toAbsolute(shard.path, baseUrl)
const payload = await fetchBinary(absolute, Boolean(disableCache), shard.sha256)
const view = new Float32Array(payload)
if (view.length !== shard.rows * dims) {
throw new Error(
`shard ${shard.path} has mismatched length (expected ${shard.rows * dims}, got ${view.length})`,
)
}
buffer.set(view, shard.rowOffset * dims)
loadedRows = Math.min(rows, shard.rowOffset + shard.rows)
const progress: ProgressMessage = {
type: "progress",
loadedRows,
totalRows: rows,
}
self.postMessage(progress)
}
return { buffer, rowsLoaded: loadedRows }
}
async function populateGraph(
manifest: Manifest,
baseUrl: string | undefined,
disableCache: boolean | undefined,
): Promise<ArrayBuffer> {
const graphMeta = manifest.hnsw.graph
const absolute = toAbsolute(graphMeta.path, baseUrl)
return await fetchBinary(absolute, Boolean(disableCache), graphMeta.sha256)
}
function configureRuntimeEnv() {
if (envConfigured) return
env.allowLocalModels = false
env.allowRemoteModels = true
const wasmBackend = env.backends?.onnx?.wasm
if (!wasmBackend) {
throw new Error("transformers.js ONNX runtime backend unavailable")
}
const cdnBase = `https://cdn.jsdelivr.net/npm/@huggingface/transformers@${env.version}/dist/`
wasmBackend.wasmPaths = cdnBase
envConfigured = true
}
async function ensureEncoder() {
if (classifier) return
if (!cfg?.model) {
throw new Error("semantic worker missing model identifier")
}
configureRuntimeEnv()
const dtype = typeof cfg?.dtype === "string" && cfg.dtype.length > 0 ? cfg.dtype : "fp32"
const pipelineOpts: Record<string, unknown> = {
device: "wasm",
dtype,
local_files_only: false,
}
classifier = await pipeline("feature-extraction", cfg.model, pipelineOpts)
cfg.dtype = dtype
}
function vectorSlice(id: number): Float32Array {
if (!vectorsView) {
throw new Error("vector buffer not configured")
}
const start = id * dims
const end = start + dims
return vectorsView.subarray(start, end)
}
function dot(a: Float32Array, b: Float32Array): number {
let s = 0
for (let i = 0; i < dims; i++) {
s += a[i] * b[i]
}
return s
}
function neighborsFor(level: number, node: number): Uint32Array {
const meta = levelGraph[level]
if (!meta) return new Uint32Array()
const { indptr, indices } = meta
if (node < 0 || node + 1 >= indptr.length) return new Uint32Array()
const start = indptr[node]
const end = indptr[node + 1]
return indices.subarray(start, end)
}
function insertSortedDescending(arr: SearchHit[], item: SearchHit) {
let idx = arr.length
while (idx > 0 && arr[idx - 1].score < item.score) {
idx -= 1
}
arr.splice(idx, 0, item)
}
function bruteForceSearch(query: Float32Array, k: number): SearchHit[] {
if (!vectorsView) return []
const hits: SearchHit[] = []
for (let id = 0; id < rows; id++) {
const score = dot(query, vectorSlice(id))
if (hits.length < k) {
insertSortedDescending(hits, { id, score })
} else if (score > hits[hits.length - 1].score) {
insertSortedDescending(hits, { id, score })
hits.length = k
}
}
return hits
}
function hnswSearch(query: Float32Array, k: number): SearchHit[] {
if (!manifest || !vectorsView || entryPoint < 0 || levelGraph.length === 0) {
return bruteForceSearch(query, k)
}
const ef = Math.max(efDefault, k * 10)
let ep = entryPoint
let epScore = dot(query, vectorSlice(ep))
for (let level = maxLevel; level > 0; level--) {
let changed = true
while (changed) {
changed = false
const neigh = neighborsFor(level, ep)
for (let i = 0; i < neigh.length; i++) {
const candidate = neigh[i]
if (candidate >= rows) continue
const score = dot(query, vectorSlice(candidate))
if (score > epScore) {
epScore = score
ep = candidate
changed = true
}
}
}
}
const visited = new Set<number>()
const candidateQueue: SearchHit[] = []
const best: SearchHit[] = []
insertSortedDescending(candidateQueue, { id: ep, score: epScore })
insertSortedDescending(best, { id: ep, score: epScore })
visited.add(ep)
while (candidateQueue.length > 0) {
const current = candidateQueue.shift()!
const worstBest = best.length >= ef ? best[best.length - 1].score : -Infinity
if (current.score < worstBest && best.length >= ef) {
break
}
const neigh = neighborsFor(0, current.id)
for (let i = 0; i < neigh.length; i++) {
const candidate = neigh[i]
if (candidate >= rows || visited.has(candidate)) continue
visited.add(candidate)
const score = dot(query, vectorSlice(candidate))
const hit = { id: candidate, score }
insertSortedDescending(candidateQueue, hit)
if (best.length < ef || score > best[best.length - 1].score) {
insertSortedDescending(best, hit)
if (best.length > ef) {
best.pop()
}
}
}
}
best.sort((a, b) => b.score - a.score)
return best.slice(0, k)
}
async function embed(text: string, isQuery: boolean = false): Promise<Float32Array> {
await ensureEncoder()
// Apply model-specific prefixes for asymmetric search
let prefixedText = text
if (cfg?.model) {
const modelName = cfg.model.toLowerCase()
switch (true) {
case modelName.includes("e5"): {
// E5 models require query: or passage: prefix
prefixedText = isQuery ? `query: ${text}` : `passage: ${text}`
break
}
case modelName.includes("qwen") && modelName.includes("embedding"): {
// Qwen3-Embedding requires task instruction for queries only
if (isQuery) {
const task = "Given a web search query, retrieve relevant passages that answer the query"
prefixedText = `Instruct: ${task}\nQuery: ${text}`
}
// Documents use plain text (no prefix)
break
}
case modelName.includes("embeddinggemma"): {
// embeddinggemma requires specific prefixes
prefixedText = isQuery
? `task: search result | query: ${text}`
: `title: none | text: ${text}`
break
}
default:
break
}
}
const out = await classifier(prefixedText, { pooling: "mean", normalize: true })
const data = Array.from(out?.data ?? out) as number[]
const vec = new Float32Array(dims)
for (let i = 0; i < dims; i++) vec[i] = data[i] ?? 0
return vec
}
async function handleInit(msg: InitMessage) {
if (state === "loading" || state === "ready") {
throw new Error("worker already initialized or loading")
}
state = "loading"
abortController?.abort()
abortController = new AbortController()
try {
cfg = msg.cfg
const manifestUrl = toAbsolute(msg.manifestUrl, msg.baseUrl)
const response = await fetch(manifestUrl, { signal: abortController.signal })
if (!response.ok) {
throw new Error(
`failed to fetch manifest ${manifestUrl}: ${response.status} ${response.statusText}`,
)
}
manifest = (await response.json()) as Manifest
if (manifest.vectors.dtype !== "fp32") {
throw new Error(
`unsupported embedding dtype '${manifest.vectors.dtype}', regenerate with fp32`,
)
}
dims = manifest.dims
rows = manifest.rows
const { buffer: vectorBuffer } = await populateVectors(manifest, msg.baseUrl, msg.disableCache)
vectorsView = vectorBuffer
const graphBuffer = await populateGraph(manifest, msg.baseUrl, msg.disableCache)
entryPoint = manifest.hnsw.entryPoint
maxLevel = manifest.hnsw.maxLevel
efDefault = Math.max(64, manifest.hnsw.M * 4)
levelGraph = manifest.hnsw.graph.levels.map((level) => {
const indptr = new Uint32Array(graphBuffer, level.indptr.offset, level.indptr.elements)
const indices = new Uint32Array(graphBuffer, level.indices.offset, level.indices.elements)
return { indptr, indices }
})
state = "ready"
const ready: ReadyMessage = { type: "ready" }
self.postMessage(ready)
} catch (err) {
state = "error"
throw err
}
}
async function handleSearch(msg: SearchMessage) {
if (state !== "ready") {
throw new Error("worker not ready for search")
}
if (!manifest || !vectorsView) {
throw new Error("semantic worker not configured")
}
const queryVec = await embed(msg.text, true)
const semanticHits = hnswSearch(queryVec, Math.max(1, msg.k))
const message: SearchResultMessage = {
type: "search-result",
seq: msg.seq,
semantic: semanticHits,
}
self.postMessage(message)
}
function handleReset() {
abortController?.abort()
abortController = null
state = "idle"
manifest = null
cfg = null
vectorsView = null
dims = 0
rows = 0
classifier = null
envConfigured = false
levelGraph = []
entryPoint = -1
maxLevel = 0
}
self.onmessage = (event: MessageEvent<WorkerMessage>) => {
const data = event.data
if (data.type === "reset") {
handleReset()
return
}
if (data.type === "init") {
void handleInit(data).catch((err: unknown) => {
const message: ErrorMessage = {
type: "error",
message: err instanceof Error ? err.message : String(err),
}
self.postMessage(message)
})
return
}
if (data.type === "search") {
void handleSearch(data).catch((err: unknown) => {
const message: ErrorMessage = {
type: "error",
seq: data.seq,
message: err instanceof Error ? err.message : String(err),
}
self.postMessage(message)
})
}
}