Files
quartz-org-roam/quartz/embed_build.py
Aaron Pham f533902c75 feat: semantic search (1/n)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
2025-10-05 19:50:52 -04:00

543 lines
17 KiB
Python

# /// script
# requires-python = ">=3.11"
# dependencies = [
# "langchain-text-splitters",
# "numpy",
# "openai",
# "sentence-transformers",
# "tiktoken",
# ]
# ///
from __future__ import annotations
import os, json, argparse, hashlib, math, random, logging
from pathlib import Path
from functools import lru_cache
from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor, as_completed
import tiktoken, numpy as np
from openai import OpenAI
from langchain_text_splitters import RecursiveCharacterTextSplitter
logger = logging.getLogger(__name__)
DEFAULT_VLLM_URL = os.environ.get("VLLM_URL") or os.environ.get("VLLM_EMBED_URL") or "http://127.0.0.1:8000/v1"
def resolve_vllm_base_url(url: str) -> str:
if not url:
raise ValueError("vLLM URL must be non-empty")
trimmed = url.rstrip("/")
if trimmed.endswith("/v1/embeddings"):
trimmed = trimmed[: -len("/embeddings")]
elif trimmed.endswith("/embeddings"):
trimmed = trimmed[: trimmed.rfind("/")]
if not trimmed.endswith("/v1"):
trimmed = f"{trimmed}/v1"
return trimmed
def load_jsonl(fp: str) -> Iterable[dict]:
with open(fp, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
yield json.loads(line)
def l2_normalize_rows(x: np.ndarray) -> np.ndarray:
# x: [N, D]
norms = np.linalg.norm(x, ord=2, axis=1, keepdims=True)
norms[norms == 0] = 1.0
return x / norms
@lru_cache(maxsize=1)
def get_tiktoken_encoder():
# Get the o200k_base tokenizer (GPT-4o) with caching
# change this if you want something else.
return tiktoken.get_encoding("o200k_base")
def count_tokens(text: str) -> int:
# Count tokens using o200k_base encoding
encoder = get_tiktoken_encoder()
return len(encoder.encode(text))
def get_text_splitter(chunk_size: int, overlap: int):
encoder = get_tiktoken_encoder()
return RecursiveCharacterTextSplitter(
chunk_size=chunk_size * 4, # character approximation
chunk_overlap=overlap * 4,
separators=["\n\n", "\n", ". ", " ", ""],
length_function=lambda t: len(encoder.encode(t)),
is_separator_regex=False,
)
def chunk_document(
doc: dict, max_tokens: int = 512, overlap_tokens: int = 128, min_chunk_size: int = 100
) -> list[dict]:
"""
Chunk a document if it exceeds max_tokens
Args:
doc: {'slug': str, 'title': str, 'text': str}
max_tokens: Maximum tokens per chunk
overlap_tokens: Overlap between chunks
min_chunk_size: Minimum chunk size (avoid tiny chunks)
Returns:
List of chunk dicts with metadata
"""
text = doc["text"]
token_count = count_tokens(text)
# No chunking needed
if token_count <= max_tokens:
return [
{
"slug": doc["slug"],
"title": doc.get("title", doc["slug"]),
"text": text,
"chunk_id": 0,
"parent_slug": doc["slug"],
"is_chunked": False,
}
]
# Apply chunking
splitter = get_text_splitter(max_tokens, overlap_tokens)
raw_chunks = splitter.split_text(text)
# Filter out tiny chunks
valid_chunks = [c for c in raw_chunks if count_tokens(c) >= min_chunk_size]
return [
{
"slug": f"{doc['slug']}#chunk{i}",
"title": doc.get("title", doc["slug"]),
"text": chunk,
"chunk_id": i,
"parent_slug": doc["slug"],
"is_chunked": True,
}
for i, chunk in enumerate(valid_chunks)
]
def write_shards(vectors: np.ndarray, shard_size: int, dtype: str, out_dir: Path) -> list[dict]:
out_dir.mkdir(parents=True, exist_ok=True)
rows, dims = vectors.shape
shards_meta: list[dict] = []
np_dtype = np.float16 if dtype == "fp16" else np.float32
bytes_per_value = np.dtype(np_dtype).itemsize
row_offset = 0
for si, start in enumerate(range(0, rows, shard_size)):
end = min(start + shard_size, rows)
shard = vectors[start:end] # [n, dims]
bin_path = out_dir / f"vectors-{si:03d}.bin"
payload = shard.astype(np_dtype, copy=False).tobytes(order="C")
digest = hashlib.sha256(payload).hexdigest()
with open(bin_path, "wb") as f:
f.write(payload)
shard_rows = int(shard.shape[0])
shards_meta.append(
{
"path": f"/embeddings/{bin_path.name}",
"rows": shard_rows,
"rowOffset": row_offset,
"byteLength": len(payload),
"sha256": digest,
"byteStride": dims * bytes_per_value,
},
)
row_offset += shard_rows
return shards_meta
def write_hnsw_graph(levels: list[list[list[int]]], rows: int, out_path: Path) -> tuple[list[dict], str]:
out_path.parent.mkdir(parents=True, exist_ok=True)
offset = 0
meta: list[dict] = []
digest = hashlib.sha256()
with open(out_path, "wb") as f:
for lvl in levels:
indptr = np.zeros(rows + 1, dtype=np.uint32)
edge_accum: list[int] = []
for idx in range(rows):
neighbors = lvl[idx] if idx < len(lvl) else []
indptr[idx + 1] = indptr[idx] + len(neighbors)
edge_accum.extend(neighbors)
indptr_bytes = indptr.tobytes(order="C")
indptr_offset = offset
f.write(indptr_bytes)
digest.update(indptr_bytes)
offset += len(indptr_bytes)
if edge_accum:
indices = np.asarray(edge_accum, dtype=np.uint32)
indices_bytes = indices.tobytes(order="C")
else:
indices = np.zeros(0, dtype=np.uint32)
indices_bytes = indices.tobytes(order="C")
indices_offset = offset
f.write(indices_bytes)
digest.update(indices_bytes)
offset += len(indices_bytes)
meta.append(
{
"level": len(meta),
"indptr": {
"offset": indptr_offset,
"elements": int(indptr.shape[0]),
"byteLength": len(indptr_bytes),
},
"indices": {
"offset": indices_offset,
"elements": int(indices.shape[0]),
"byteLength": len(indices_bytes),
},
},
)
return meta, digest.hexdigest()
def embed_vllm(
texts: list[str],
model_id: str,
vllm_url: str,
batch_size: int = 64,
concurrency: int = 8,
) -> np.ndarray:
base_url = resolve_vllm_base_url(vllm_url)
api_key = os.environ.get("VLLM_API_KEY") or os.environ.get("OPENAI_API_KEY") or "not-set"
client = OpenAI(base_url=base_url, api_key=api_key, timeout=300)
def list_available_models() -> list[str]:
models: list[str] = []
page = client.models.list()
models.extend(model.id for model in page.data)
while getattr(page, "has_more", False) and page.data:
cursor = page.data[-1].id
page = client.models.list(after=cursor)
models.extend(model.id for model in page.data)
return models
try:
available_models = list_available_models()
except Exception as exc:
raise RuntimeError(f"failed to query {base_url}/models: {exc}") from exc
if model_id not in available_models:
suggestions = ", ".join(sorted(available_models)) if available_models else "<none>"
logger.warning(
"model '%s' not served by vLLM at %s. Available models: %s. Use the first model, results may differ during semantic search (you can omit this message if your weights is a ONNX checkpoint of the same model.)", model_id, base_url, suggestions,
)
model_id = available_models[0]
# Apply model-specific prefixes for documents (asymmetric search)
model_lower = model_id.lower()
if "e5" in model_lower:
# E5 models: use "passage:" prefix for documents
prefixed = [f"passage: {t}" for t in texts]
elif "qwen" in model_lower and "embedding" in model_lower:
# Qwen3-Embedding: documents use plain text (no prefix)
prefixed = texts
elif "embeddinggemma" in model_lower:
# embeddinggemma: use "title: none | text:" prefix for documents
prefixed = [f"title: none | text: {t}" for t in texts]
else:
# Default: no prefix for unknown models
prefixed = texts
print(
"Embedding"
f" {len(prefixed)} texts with vLLM"
f" (model={model_id}, batch_size={batch_size}, concurrency={concurrency})",
)
# Create batches
batches = []
for i in range(0, len(prefixed), batch_size):
batch = prefixed[i : i + batch_size]
batches.append((i, batch))
# Function to send a single batch request
def send_batch(batch_info: tuple[int, list[str]]) -> tuple[int, list[np.ndarray]]:
idx, batch = batch_info
response = client.embeddings.create(model=model_id, input=batch)
embeddings = [np.asarray(item.embedding, dtype=np.float32) for item in response.data]
return (idx, embeddings)
# Send batches concurrently (or sequentially if only 1 batch)
results: dict[int, list[np.ndarray]] = {}
if len(batches) == 1:
# Single batch - no need for threading
idx, embeddings = send_batch(batches[0])
results[idx] = embeddings
else:
# Multiple batches - use concurrent requests
with ThreadPoolExecutor(max_workers=concurrency) as executor:
futures = {executor.submit(send_batch, batch_info): batch_info[0] for batch_info in batches}
completed = 0
for future in as_completed(futures):
idx, embeddings = future.result()
results[idx] = embeddings
completed += 1
if completed % max(1, len(batches) // 10) == 0 or completed == len(batches):
print(f" Completed {completed}/{len(batches)} batches ({completed * 100 // len(batches)}%)")
# Reconstruct in order
out: list[np.ndarray] = []
for i in sorted(results.keys()):
out.extend(results[i])
return np.stack(out, axis=0)
def embed_hf(texts: list[str], model_id: str, device: str) -> np.ndarray:
# Prefer sentence-transformers for E5 and similar embed models
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_id, device=device)
# Apply model-specific prefixes for documents (asymmetric search)
model_lower = model_id.lower()
if "e5" in model_lower:
# E5 models: use "passage:" prefix for documents
prefixed = [f"passage: {t}" for t in texts]
elif "qwen" in model_lower and "embedding" in model_lower:
# Qwen3-Embedding: documents use plain text (no prefix)
prefixed = texts
elif "embeddinggemma" in model_lower:
# embeddinggemma: use "title: none | text:" prefix for documents
prefixed = [f"title: none | text: {t}" for t in texts]
else:
# Default: no prefix for unknown models
prefixed = texts
vecs = model.encode(
prefixed,
batch_size=64,
normalize_embeddings=True,
convert_to_numpy=True,
show_progress_bar=True,
)
return vecs.astype(np.float32, copy=False)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--jsonl", default="public/embeddings-text.jsonl")
ap.add_argument("--model", default=os.environ.get("SEM_MODEL", "intfloat/multilingual-e5-large"))
ap.add_argument("--dims", type=int, default=int(os.environ.get("SEM_DIMS", "1024")))
ap.add_argument("--dtype", choices=["fp16", "fp32"], default=os.environ.get("SEM_DTYPE", "fp32"))
ap.add_argument("--shard-size", type=int, default=int(os.environ.get("SEM_SHARD", "1024")))
ap.add_argument("--out", default="public/embeddings")
ap.add_argument("--use-vllm", action="store_true", default=bool(os.environ.get("USE_VLLM", "")))
ap.add_argument(
"--vllm-url",
default=DEFAULT_VLLM_URL,
help="Base URL for the vLLM OpenAI-compatible server (accepts either /v1 or /v1/embeddings)",
)
ap.add_argument("--chunk-size", type=int, default=512, help="Max tokens per chunk")
ap.add_argument("--chunk-overlap", type=int, default=128, help="Overlap tokens between chunks")
ap.add_argument("--no-chunking", action="store_true", help="Disable chunking (embed full docs)")
ap.add_argument(
"--concurrency",
type=int,
default=int(os.environ.get("VLLM_CONCURRENCY", "8")),
help="Number of concurrent requests to vLLM (default: 8)",
)
ap.add_argument(
"--batch-size",
type=int,
default=int(os.environ.get("VLLM_BATCH_SIZE", "64")),
help="Batch size for vLLM requests (default: 64)",
)
args = ap.parse_args()
recs = list(load_jsonl(args.jsonl))
if not recs:
print("No input found in public/embeddings-text.jsonl; run the site build first to emit JSONL.")
return
# Apply chunking
if args.no_chunking:
chunks = recs
chunk_metadata = {}
print(f"Chunking disabled. Processing {len(chunks)} full documents")
else:
chunks = []
chunk_metadata = {}
for rec in recs:
doc_chunks = chunk_document(rec, max_tokens=args.chunk_size, overlap_tokens=args.chunk_overlap)
chunks.extend(doc_chunks)
# Build chunk metadata map
for chunk in doc_chunks:
if chunk["is_chunked"]:
chunk_metadata[chunk["slug"]] = {
"parentSlug": chunk["parent_slug"],
"chunkId": chunk["chunk_id"],
}
chunked_count = sum(1 for c in chunks if c.get("is_chunked", False))
print(f"Chunked {len(recs)} documents into {len(chunks)} chunks ({chunked_count} chunked, {len(chunks) - chunked_count} unchanged)")
print(f" Chunk size: {args.chunk_size} tokens, overlap: {args.chunk_overlap} tokens")
ids = [c["slug"] for c in chunks]
titles = [c.get("title", c["slug"]) for c in chunks]
texts = [c["text"] for c in chunks]
if args.use_vllm:
vecs = embed_vllm(
texts,
args.model,
args.vllm_url,
batch_size=args.batch_size,
concurrency=args.concurrency,
)
else:
device = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
vecs = embed_hf(texts, args.model, device)
# Coerce dims and re-normalize
if vecs.shape[1] != args.dims:
if vecs.shape[1] > args.dims:
vecs = vecs[:, : args.dims]
else:
vecs = np.pad(vecs, ((0, 0), (0, args.dims - vecs.shape[1])))
vecs = l2_normalize_rows(vecs.astype(np.float32, copy=False))
out_dir = Path(args.out)
shards = write_shards(vecs, args.shard_size, args.dtype, out_dir)
# Build a lightweight HNSW graph and store it in a compact binary layout
def hnsw_build(data: np.ndarray, M: int = 16, efC: int = 200, seed: int = 0) -> dict:
rng = random.Random(seed)
N, D = data.shape
levels: list[list[list[int]]] = [] # levels[L][i] = neighbors of node i at level L
# random level assignment using 1/e distribution
node_levels = []
for _ in range(N):
lvl = 0
while rng.random() < 1 / math.e:
lvl += 1
node_levels.append(lvl)
max_level = max(node_levels) if N > 0 else 0
for _ in range(max_level + 1):
levels.append([[] for _ in range(N)])
def sim(i: int, j: int) -> float:
return float((data[i] * data[j]).sum())
entry = 0 if N > 0 else -1
def search_layer(q: int, ep: int, ef: int, L: int) -> list[int]:
if ep < 0:
return []
visited = set()
cand: list[tuple[float, int]] = []
top: list[tuple[float, int]] = []
def push(node: int):
if node in visited:
return
visited.add(node)
cand.append((sim(q, node), node))
push(ep)
while cand:
cand.sort(reverse=True)
s, v = cand.pop(0)
if len(top) >= ef and s <= top[-1][0]:
break
top.append((s, v))
for u in levels[L][v]:
push(u)
top.sort(reverse=True)
return [n for _, n in top]
for i in range(N):
if i == 0:
continue
lvl = node_levels[i]
ep = entry
for L in range(max_level, lvl, -1):
c = search_layer(i, ep, 1, L)
if c:
ep = c[0]
for L in range(min(max_level, lvl), -1, -1):
W = search_layer(i, ep, efC, L)
# Select top M by similarity
neigh = sorted(((sim(i, j), j) for j in W if j != i), reverse=True)[:M]
for _, e in neigh:
if e not in levels[L][i]:
levels[L][i].append(e)
if i not in levels[L][e]:
levels[L][e].append(i)
# trim neighbors to M
for L in range(len(levels)):
for i in range(N):
if len(levels[L][i]) > M:
# keep top M by sim
nb = levels[L][i]
nb = sorted(nb, key=lambda j: sim(i, j), reverse=True)[:M]
levels[L][i] = nb
return {
"M": M,
"efConstruction": efC,
"entryPoint": entry,
"maxLevel": max_level,
"levels": levels,
}
hnsw = hnsw_build(vecs, M=16, efC=200)
hnsw_meta, hnsw_sha = write_hnsw_graph(hnsw["levels"], int(vecs.shape[0]), out_dir / "hnsw.bin")
manifest = {
"version": 2,
"dims": args.dims,
"dtype": args.dtype,
"normalized": True,
"rows": int(vecs.shape[0]),
"shardSizeRows": args.shard_size,
"vectors": {
"dtype": args.dtype,
"rows": int(vecs.shape[0]),
"dims": args.dims,
"shards": shards,
},
"ids": ids,
"titles": titles,
"chunkMetadata": chunk_metadata,
"hnsw": {
"M": hnsw["M"],
"efConstruction": hnsw["efConstruction"],
"entryPoint": hnsw["entryPoint"],
"maxLevel": hnsw["maxLevel"],
"graph": {
"path": "/embeddings/hnsw.bin",
"sha256": hnsw_sha,
"levels": hnsw_meta,
},
},
}
(out_dir / "manifest.json").write_text(json.dumps(manifest, ensure_ascii=False), encoding="utf-8")
print(f"Wrote {len(shards)} vector shard(s), HNSW graph, and manifest to {out_dir}")
if __name__ == "__main__":
main()