forked from github/quartz
feat: semantic search (1/n)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
542
quartz/embed_build.py
Normal file
542
quartz/embed_build.py
Normal file
@@ -0,0 +1,542 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.11"
|
||||
# dependencies = [
|
||||
# "langchain-text-splitters",
|
||||
# "numpy",
|
||||
# "openai",
|
||||
# "sentence-transformers",
|
||||
# "tiktoken",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os, json, argparse, hashlib, math, random, logging
|
||||
|
||||
from pathlib import Path
|
||||
from functools import lru_cache
|
||||
from collections.abc import Iterable
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import tiktoken, numpy as np
|
||||
|
||||
from openai import OpenAI
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_VLLM_URL = os.environ.get("VLLM_URL") or os.environ.get("VLLM_EMBED_URL") or "http://127.0.0.1:8000/v1"
|
||||
|
||||
|
||||
def resolve_vllm_base_url(url: str) -> str:
|
||||
if not url:
|
||||
raise ValueError("vLLM URL must be non-empty")
|
||||
|
||||
trimmed = url.rstrip("/")
|
||||
if trimmed.endswith("/v1/embeddings"):
|
||||
trimmed = trimmed[: -len("/embeddings")]
|
||||
elif trimmed.endswith("/embeddings"):
|
||||
trimmed = trimmed[: trimmed.rfind("/")]
|
||||
|
||||
if not trimmed.endswith("/v1"):
|
||||
trimmed = f"{trimmed}/v1"
|
||||
|
||||
return trimmed
|
||||
|
||||
|
||||
def load_jsonl(fp: str) -> Iterable[dict]:
|
||||
with open(fp, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
yield json.loads(line)
|
||||
|
||||
|
||||
def l2_normalize_rows(x: np.ndarray) -> np.ndarray:
|
||||
# x: [N, D]
|
||||
norms = np.linalg.norm(x, ord=2, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1.0
|
||||
return x / norms
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_tiktoken_encoder():
|
||||
# Get the o200k_base tokenizer (GPT-4o) with caching
|
||||
# change this if you want something else.
|
||||
return tiktoken.get_encoding("o200k_base")
|
||||
|
||||
|
||||
def count_tokens(text: str) -> int:
|
||||
# Count tokens using o200k_base encoding
|
||||
encoder = get_tiktoken_encoder()
|
||||
return len(encoder.encode(text))
|
||||
|
||||
|
||||
def get_text_splitter(chunk_size: int, overlap: int):
|
||||
encoder = get_tiktoken_encoder()
|
||||
return RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size * 4, # character approximation
|
||||
chunk_overlap=overlap * 4,
|
||||
separators=["\n\n", "\n", ". ", " ", ""],
|
||||
length_function=lambda t: len(encoder.encode(t)),
|
||||
is_separator_regex=False,
|
||||
)
|
||||
|
||||
|
||||
def chunk_document(
|
||||
doc: dict, max_tokens: int = 512, overlap_tokens: int = 128, min_chunk_size: int = 100
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Chunk a document if it exceeds max_tokens
|
||||
|
||||
Args:
|
||||
doc: {'slug': str, 'title': str, 'text': str}
|
||||
max_tokens: Maximum tokens per chunk
|
||||
overlap_tokens: Overlap between chunks
|
||||
min_chunk_size: Minimum chunk size (avoid tiny chunks)
|
||||
|
||||
Returns:
|
||||
List of chunk dicts with metadata
|
||||
"""
|
||||
text = doc["text"]
|
||||
token_count = count_tokens(text)
|
||||
|
||||
# No chunking needed
|
||||
if token_count <= max_tokens:
|
||||
return [
|
||||
{
|
||||
"slug": doc["slug"],
|
||||
"title": doc.get("title", doc["slug"]),
|
||||
"text": text,
|
||||
"chunk_id": 0,
|
||||
"parent_slug": doc["slug"],
|
||||
"is_chunked": False,
|
||||
}
|
||||
]
|
||||
|
||||
# Apply chunking
|
||||
splitter = get_text_splitter(max_tokens, overlap_tokens)
|
||||
raw_chunks = splitter.split_text(text)
|
||||
|
||||
# Filter out tiny chunks
|
||||
valid_chunks = [c for c in raw_chunks if count_tokens(c) >= min_chunk_size]
|
||||
|
||||
return [
|
||||
{
|
||||
"slug": f"{doc['slug']}#chunk{i}",
|
||||
"title": doc.get("title", doc["slug"]),
|
||||
"text": chunk,
|
||||
"chunk_id": i,
|
||||
"parent_slug": doc["slug"],
|
||||
"is_chunked": True,
|
||||
}
|
||||
for i, chunk in enumerate(valid_chunks)
|
||||
]
|
||||
|
||||
|
||||
def write_shards(vectors: np.ndarray, shard_size: int, dtype: str, out_dir: Path) -> list[dict]:
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
rows, dims = vectors.shape
|
||||
shards_meta: list[dict] = []
|
||||
np_dtype = np.float16 if dtype == "fp16" else np.float32
|
||||
bytes_per_value = np.dtype(np_dtype).itemsize
|
||||
row_offset = 0
|
||||
for si, start in enumerate(range(0, rows, shard_size)):
|
||||
end = min(start + shard_size, rows)
|
||||
shard = vectors[start:end] # [n, dims]
|
||||
bin_path = out_dir / f"vectors-{si:03d}.bin"
|
||||
payload = shard.astype(np_dtype, copy=False).tobytes(order="C")
|
||||
digest = hashlib.sha256(payload).hexdigest()
|
||||
with open(bin_path, "wb") as f:
|
||||
f.write(payload)
|
||||
shard_rows = int(shard.shape[0])
|
||||
shards_meta.append(
|
||||
{
|
||||
"path": f"/embeddings/{bin_path.name}",
|
||||
"rows": shard_rows,
|
||||
"rowOffset": row_offset,
|
||||
"byteLength": len(payload),
|
||||
"sha256": digest,
|
||||
"byteStride": dims * bytes_per_value,
|
||||
},
|
||||
)
|
||||
row_offset += shard_rows
|
||||
return shards_meta
|
||||
|
||||
|
||||
def write_hnsw_graph(levels: list[list[list[int]]], rows: int, out_path: Path) -> tuple[list[dict], str]:
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
offset = 0
|
||||
meta: list[dict] = []
|
||||
digest = hashlib.sha256()
|
||||
with open(out_path, "wb") as f:
|
||||
for lvl in levels:
|
||||
indptr = np.zeros(rows + 1, dtype=np.uint32)
|
||||
edge_accum: list[int] = []
|
||||
for idx in range(rows):
|
||||
neighbors = lvl[idx] if idx < len(lvl) else []
|
||||
indptr[idx + 1] = indptr[idx] + len(neighbors)
|
||||
edge_accum.extend(neighbors)
|
||||
indptr_bytes = indptr.tobytes(order="C")
|
||||
indptr_offset = offset
|
||||
f.write(indptr_bytes)
|
||||
digest.update(indptr_bytes)
|
||||
offset += len(indptr_bytes)
|
||||
|
||||
if edge_accum:
|
||||
indices = np.asarray(edge_accum, dtype=np.uint32)
|
||||
indices_bytes = indices.tobytes(order="C")
|
||||
else:
|
||||
indices = np.zeros(0, dtype=np.uint32)
|
||||
indices_bytes = indices.tobytes(order="C")
|
||||
indices_offset = offset
|
||||
f.write(indices_bytes)
|
||||
digest.update(indices_bytes)
|
||||
offset += len(indices_bytes)
|
||||
|
||||
meta.append(
|
||||
{
|
||||
"level": len(meta),
|
||||
"indptr": {
|
||||
"offset": indptr_offset,
|
||||
"elements": int(indptr.shape[0]),
|
||||
"byteLength": len(indptr_bytes),
|
||||
},
|
||||
"indices": {
|
||||
"offset": indices_offset,
|
||||
"elements": int(indices.shape[0]),
|
||||
"byteLength": len(indices_bytes),
|
||||
},
|
||||
},
|
||||
)
|
||||
return meta, digest.hexdigest()
|
||||
|
||||
|
||||
|
||||
def embed_vllm(
|
||||
texts: list[str],
|
||||
model_id: str,
|
||||
vllm_url: str,
|
||||
batch_size: int = 64,
|
||||
concurrency: int = 8,
|
||||
) -> np.ndarray:
|
||||
base_url = resolve_vllm_base_url(vllm_url)
|
||||
api_key = os.environ.get("VLLM_API_KEY") or os.environ.get("OPENAI_API_KEY") or "not-set"
|
||||
client = OpenAI(base_url=base_url, api_key=api_key, timeout=300)
|
||||
|
||||
def list_available_models() -> list[str]:
|
||||
models: list[str] = []
|
||||
page = client.models.list()
|
||||
models.extend(model.id for model in page.data)
|
||||
while getattr(page, "has_more", False) and page.data:
|
||||
cursor = page.data[-1].id
|
||||
page = client.models.list(after=cursor)
|
||||
models.extend(model.id for model in page.data)
|
||||
return models
|
||||
|
||||
try:
|
||||
available_models = list_available_models()
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"failed to query {base_url}/models: {exc}") from exc
|
||||
|
||||
if model_id not in available_models:
|
||||
suggestions = ", ".join(sorted(available_models)) if available_models else "<none>"
|
||||
logger.warning(
|
||||
"model '%s' not served by vLLM at %s. Available models: %s. Use the first model, results may differ during semantic search (you can omit this message if your weights is a ONNX checkpoint of the same model.)", model_id, base_url, suggestions,
|
||||
)
|
||||
model_id = available_models[0]
|
||||
|
||||
# Apply model-specific prefixes for documents (asymmetric search)
|
||||
model_lower = model_id.lower()
|
||||
if "e5" in model_lower:
|
||||
# E5 models: use "passage:" prefix for documents
|
||||
prefixed = [f"passage: {t}" for t in texts]
|
||||
elif "qwen" in model_lower and "embedding" in model_lower:
|
||||
# Qwen3-Embedding: documents use plain text (no prefix)
|
||||
prefixed = texts
|
||||
elif "embeddinggemma" in model_lower:
|
||||
# embeddinggemma: use "title: none | text:" prefix for documents
|
||||
prefixed = [f"title: none | text: {t}" for t in texts]
|
||||
else:
|
||||
# Default: no prefix for unknown models
|
||||
prefixed = texts
|
||||
|
||||
print(
|
||||
"Embedding"
|
||||
f" {len(prefixed)} texts with vLLM"
|
||||
f" (model={model_id}, batch_size={batch_size}, concurrency={concurrency})",
|
||||
)
|
||||
|
||||
# Create batches
|
||||
batches = []
|
||||
for i in range(0, len(prefixed), batch_size):
|
||||
batch = prefixed[i : i + batch_size]
|
||||
batches.append((i, batch))
|
||||
|
||||
# Function to send a single batch request
|
||||
def send_batch(batch_info: tuple[int, list[str]]) -> tuple[int, list[np.ndarray]]:
|
||||
idx, batch = batch_info
|
||||
response = client.embeddings.create(model=model_id, input=batch)
|
||||
embeddings = [np.asarray(item.embedding, dtype=np.float32) for item in response.data]
|
||||
return (idx, embeddings)
|
||||
|
||||
# Send batches concurrently (or sequentially if only 1 batch)
|
||||
results: dict[int, list[np.ndarray]] = {}
|
||||
if len(batches) == 1:
|
||||
# Single batch - no need for threading
|
||||
idx, embeddings = send_batch(batches[0])
|
||||
results[idx] = embeddings
|
||||
else:
|
||||
# Multiple batches - use concurrent requests
|
||||
with ThreadPoolExecutor(max_workers=concurrency) as executor:
|
||||
futures = {executor.submit(send_batch, batch_info): batch_info[0] for batch_info in batches}
|
||||
completed = 0
|
||||
for future in as_completed(futures):
|
||||
idx, embeddings = future.result()
|
||||
results[idx] = embeddings
|
||||
completed += 1
|
||||
if completed % max(1, len(batches) // 10) == 0 or completed == len(batches):
|
||||
print(f" Completed {completed}/{len(batches)} batches ({completed * 100 // len(batches)}%)")
|
||||
|
||||
# Reconstruct in order
|
||||
out: list[np.ndarray] = []
|
||||
for i in sorted(results.keys()):
|
||||
out.extend(results[i])
|
||||
|
||||
return np.stack(out, axis=0)
|
||||
|
||||
|
||||
def embed_hf(texts: list[str], model_id: str, device: str) -> np.ndarray:
|
||||
# Prefer sentence-transformers for E5 and similar embed models
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
model = SentenceTransformer(model_id, device=device)
|
||||
|
||||
# Apply model-specific prefixes for documents (asymmetric search)
|
||||
model_lower = model_id.lower()
|
||||
if "e5" in model_lower:
|
||||
# E5 models: use "passage:" prefix for documents
|
||||
prefixed = [f"passage: {t}" for t in texts]
|
||||
elif "qwen" in model_lower and "embedding" in model_lower:
|
||||
# Qwen3-Embedding: documents use plain text (no prefix)
|
||||
prefixed = texts
|
||||
elif "embeddinggemma" in model_lower:
|
||||
# embeddinggemma: use "title: none | text:" prefix for documents
|
||||
prefixed = [f"title: none | text: {t}" for t in texts]
|
||||
else:
|
||||
# Default: no prefix for unknown models
|
||||
prefixed = texts
|
||||
|
||||
vecs = model.encode(
|
||||
prefixed,
|
||||
batch_size=64,
|
||||
normalize_embeddings=True,
|
||||
convert_to_numpy=True,
|
||||
show_progress_bar=True,
|
||||
)
|
||||
return vecs.astype(np.float32, copy=False)
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--jsonl", default="public/embeddings-text.jsonl")
|
||||
ap.add_argument("--model", default=os.environ.get("SEM_MODEL", "intfloat/multilingual-e5-large"))
|
||||
ap.add_argument("--dims", type=int, default=int(os.environ.get("SEM_DIMS", "1024")))
|
||||
ap.add_argument("--dtype", choices=["fp16", "fp32"], default=os.environ.get("SEM_DTYPE", "fp32"))
|
||||
ap.add_argument("--shard-size", type=int, default=int(os.environ.get("SEM_SHARD", "1024")))
|
||||
ap.add_argument("--out", default="public/embeddings")
|
||||
ap.add_argument("--use-vllm", action="store_true", default=bool(os.environ.get("USE_VLLM", "")))
|
||||
ap.add_argument(
|
||||
"--vllm-url",
|
||||
default=DEFAULT_VLLM_URL,
|
||||
help="Base URL for the vLLM OpenAI-compatible server (accepts either /v1 or /v1/embeddings)",
|
||||
)
|
||||
ap.add_argument("--chunk-size", type=int, default=512, help="Max tokens per chunk")
|
||||
ap.add_argument("--chunk-overlap", type=int, default=128, help="Overlap tokens between chunks")
|
||||
ap.add_argument("--no-chunking", action="store_true", help="Disable chunking (embed full docs)")
|
||||
ap.add_argument(
|
||||
"--concurrency",
|
||||
type=int,
|
||||
default=int(os.environ.get("VLLM_CONCURRENCY", "8")),
|
||||
help="Number of concurrent requests to vLLM (default: 8)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=int(os.environ.get("VLLM_BATCH_SIZE", "64")),
|
||||
help="Batch size for vLLM requests (default: 64)",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
recs = list(load_jsonl(args.jsonl))
|
||||
if not recs:
|
||||
print("No input found in public/embeddings-text.jsonl; run the site build first to emit JSONL.")
|
||||
return
|
||||
|
||||
# Apply chunking
|
||||
if args.no_chunking:
|
||||
chunks = recs
|
||||
chunk_metadata = {}
|
||||
print(f"Chunking disabled. Processing {len(chunks)} full documents")
|
||||
else:
|
||||
chunks = []
|
||||
chunk_metadata = {}
|
||||
for rec in recs:
|
||||
doc_chunks = chunk_document(rec, max_tokens=args.chunk_size, overlap_tokens=args.chunk_overlap)
|
||||
chunks.extend(doc_chunks)
|
||||
# Build chunk metadata map
|
||||
for chunk in doc_chunks:
|
||||
if chunk["is_chunked"]:
|
||||
chunk_metadata[chunk["slug"]] = {
|
||||
"parentSlug": chunk["parent_slug"],
|
||||
"chunkId": chunk["chunk_id"],
|
||||
}
|
||||
chunked_count = sum(1 for c in chunks if c.get("is_chunked", False))
|
||||
print(f"Chunked {len(recs)} documents into {len(chunks)} chunks ({chunked_count} chunked, {len(chunks) - chunked_count} unchanged)")
|
||||
print(f" Chunk size: {args.chunk_size} tokens, overlap: {args.chunk_overlap} tokens")
|
||||
|
||||
ids = [c["slug"] for c in chunks]
|
||||
titles = [c.get("title", c["slug"]) for c in chunks]
|
||||
texts = [c["text"] for c in chunks]
|
||||
|
||||
if args.use_vllm:
|
||||
vecs = embed_vllm(
|
||||
texts,
|
||||
args.model,
|
||||
args.vllm_url,
|
||||
batch_size=args.batch_size,
|
||||
concurrency=args.concurrency,
|
||||
)
|
||||
else:
|
||||
device = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
|
||||
vecs = embed_hf(texts, args.model, device)
|
||||
|
||||
# Coerce dims and re-normalize
|
||||
if vecs.shape[1] != args.dims:
|
||||
if vecs.shape[1] > args.dims:
|
||||
vecs = vecs[:, : args.dims]
|
||||
else:
|
||||
vecs = np.pad(vecs, ((0, 0), (0, args.dims - vecs.shape[1])))
|
||||
vecs = l2_normalize_rows(vecs.astype(np.float32, copy=False))
|
||||
|
||||
out_dir = Path(args.out)
|
||||
shards = write_shards(vecs, args.shard_size, args.dtype, out_dir)
|
||||
|
||||
# Build a lightweight HNSW graph and store it in a compact binary layout
|
||||
def hnsw_build(data: np.ndarray, M: int = 16, efC: int = 200, seed: int = 0) -> dict:
|
||||
rng = random.Random(seed)
|
||||
N, D = data.shape
|
||||
levels: list[list[list[int]]] = [] # levels[L][i] = neighbors of node i at level L
|
||||
|
||||
# random level assignment using 1/e distribution
|
||||
node_levels = []
|
||||
for _ in range(N):
|
||||
lvl = 0
|
||||
while rng.random() < 1 / math.e:
|
||||
lvl += 1
|
||||
node_levels.append(lvl)
|
||||
max_level = max(node_levels) if N > 0 else 0
|
||||
for _ in range(max_level + 1):
|
||||
levels.append([[] for _ in range(N)])
|
||||
|
||||
def sim(i: int, j: int) -> float:
|
||||
return float((data[i] * data[j]).sum())
|
||||
|
||||
entry = 0 if N > 0 else -1
|
||||
|
||||
def search_layer(q: int, ep: int, ef: int, L: int) -> list[int]:
|
||||
if ep < 0:
|
||||
return []
|
||||
visited = set()
|
||||
cand: list[tuple[float, int]] = []
|
||||
top: list[tuple[float, int]] = []
|
||||
def push(node: int):
|
||||
if node in visited:
|
||||
return
|
||||
visited.add(node)
|
||||
cand.append((sim(q, node), node))
|
||||
push(ep)
|
||||
while cand:
|
||||
cand.sort(reverse=True)
|
||||
s, v = cand.pop(0)
|
||||
if len(top) >= ef and s <= top[-1][0]:
|
||||
break
|
||||
top.append((s, v))
|
||||
for u in levels[L][v]:
|
||||
push(u)
|
||||
top.sort(reverse=True)
|
||||
return [n for _, n in top]
|
||||
|
||||
for i in range(N):
|
||||
if i == 0:
|
||||
continue
|
||||
lvl = node_levels[i]
|
||||
ep = entry
|
||||
for L in range(max_level, lvl, -1):
|
||||
c = search_layer(i, ep, 1, L)
|
||||
if c:
|
||||
ep = c[0]
|
||||
for L in range(min(max_level, lvl), -1, -1):
|
||||
W = search_layer(i, ep, efC, L)
|
||||
# Select top M by similarity
|
||||
neigh = sorted(((sim(i, j), j) for j in W if j != i), reverse=True)[:M]
|
||||
for _, e in neigh:
|
||||
if e not in levels[L][i]:
|
||||
levels[L][i].append(e)
|
||||
if i not in levels[L][e]:
|
||||
levels[L][e].append(i)
|
||||
|
||||
# trim neighbors to M
|
||||
for L in range(len(levels)):
|
||||
for i in range(N):
|
||||
if len(levels[L][i]) > M:
|
||||
# keep top M by sim
|
||||
nb = levels[L][i]
|
||||
nb = sorted(nb, key=lambda j: sim(i, j), reverse=True)[:M]
|
||||
levels[L][i] = nb
|
||||
|
||||
return {
|
||||
"M": M,
|
||||
"efConstruction": efC,
|
||||
"entryPoint": entry,
|
||||
"maxLevel": max_level,
|
||||
"levels": levels,
|
||||
}
|
||||
|
||||
hnsw = hnsw_build(vecs, M=16, efC=200)
|
||||
hnsw_meta, hnsw_sha = write_hnsw_graph(hnsw["levels"], int(vecs.shape[0]), out_dir / "hnsw.bin")
|
||||
|
||||
manifest = {
|
||||
"version": 2,
|
||||
"dims": args.dims,
|
||||
"dtype": args.dtype,
|
||||
"normalized": True,
|
||||
"rows": int(vecs.shape[0]),
|
||||
"shardSizeRows": args.shard_size,
|
||||
"vectors": {
|
||||
"dtype": args.dtype,
|
||||
"rows": int(vecs.shape[0]),
|
||||
"dims": args.dims,
|
||||
"shards": shards,
|
||||
},
|
||||
"ids": ids,
|
||||
"titles": titles,
|
||||
"chunkMetadata": chunk_metadata,
|
||||
"hnsw": {
|
||||
"M": hnsw["M"],
|
||||
"efConstruction": hnsw["efConstruction"],
|
||||
"entryPoint": hnsw["entryPoint"],
|
||||
"maxLevel": hnsw["maxLevel"],
|
||||
"graph": {
|
||||
"path": "/embeddings/hnsw.bin",
|
||||
"sha256": hnsw_sha,
|
||||
"levels": hnsw_meta,
|
||||
},
|
||||
},
|
||||
}
|
||||
(out_dir / "manifest.json").write_text(json.dumps(manifest, ensure_ascii=False), encoding="utf-8")
|
||||
print(f"Wrote {len(shards)} vector shard(s), HNSW graph, and manifest to {out_dir}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user