Coverage for mindsdb / integrations / handlers / duckdb_faiss_handler / faiss_index.py: 0%
121 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import os
2from typing import Iterable, List
3import numpy as np
4import psutil
6import portalocker
8import faiss # faiss or faiss-gpu
9from pydantic import BaseModel
12def _normalize_rows(x: np.ndarray) -> np.ndarray:
13 norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12
14 return x / norms
17class FaissParams(BaseModel):
18 metric: str | None = "cosine"
19 use_gpu: bool | None = False
20 nlist: int | None = 1024
21 nprobe: int | None = 32
22 hnsw_m: int | None = 32
23 hnsw_ef_search: int | None = 64
26class FaissIndex:
27 def __init__(self, path: str, config: dict):
28 self._normalize_vectors = False
30 self.config = FaissParams(**config)
32 metric = self.config.metric
33 if metric == "cosine":
34 self._normalize_vectors = True
35 self.metric = faiss.METRIC_INNER_PRODUCT
36 elif metric == "ip":
37 self.metric = faiss.METRIC_INNER_PRODUCT
38 elif metric == "l1":
39 self.metric = faiss.METRIC_L1
40 elif metric == "l2":
41 self.metric = faiss.METRIC_L2
42 else:
43 raise ValueError(f"Unknown metric: {metric}")
45 self.path = path
47 self._since_ram_checked = 0
49 self.index = None
50 self.dim = None
51 self.index_fd = None
52 if os.path.exists(self.path):
53 self.load_index()
55 def load_index(self):
56 # check RAM
57 index_size = os.path.getsize(self.path)
58 # according to tests faiss index occupies ~ the same amount of RAM as file size
59 # add 10% and 1Gb to it, check only if index > 1Gb
60 _1gb = 1024**3
61 required_ram = index_size * 1.1 + _1gb
62 available_ram = psutil.virtual_memory().available
63 if required_ram > _1gb and available_ram < required_ram:
64 to_free_gb = round((required_ram - available_ram) / _1gb, 2)
65 raise ValueError(f"Unable load FAISS index into RAM, free up al least : {to_free_gb} Gb")
67 if os.name != "nt":
68 self.index_fd = open(self.path, "rb")
69 try:
70 portalocker.lock(self.index_fd, portalocker.LOCK_EX | portalocker.LOCK_NB)
71 except portalocker.exceptions.AlreadyLocked:
72 raise ValueError(f"Index is already used: {self.path}")
74 self.index = faiss.read_index(self.path)
75 self.dim = self.index.d
77 def close(self):
78 if self.index_fd is not None:
79 self.index_fd.close()
80 self.index = None
82 def _build_index(self):
83 # TODO option to create hnsw
85 index = faiss.IndexFlat(self.dim, self.metric)
86 index = faiss.IndexIDMap(index)
88 if self.config.use_gpu:
89 try:
90 index = faiss.index_cpu_to_all_gpus(index)
91 except Exception:
92 pass
94 self.index = index
96 def check_ram_usage(self, count_vectors, index_type: str = "flat", m=32, nlist=4096):
97 self._since_ram_checked += count_vectors
99 # check after every 10k vectors
100 if self._since_ram_checked < 10000:
101 return
103 match index_type:
104 case "flat":
105 required = self.dim * 4 * count_vectors
106 case "hnsw":
107 required = (self.dim * 4 + m * 2 * 4) * count_vectors
108 case "ivf":
109 required = (self.dim * 4 + 8) * count_vectors + self.dim * 4 * nlist
110 case _:
111 raise ValueError(f"Unknown index type: {index_type}")
113 # check RAM usage
114 # keep extra 1Gb
115 available = psutil.virtual_memory().available - 1 * 1024**3
117 if available < required:
118 raise ValueError("Unable insert records, not enough RAM")
120 self._since_ram_checked = 0
122 def insert(
123 self,
124 vectors: Iterable[Iterable[float]],
125 ids: Iterable[float],
126 ) -> None:
127 if len(vectors) == 0:
128 return
130 vectors = np.array(vectors)
131 ids = np.array(ids)
133 if self.index is None:
134 # this if the first insert, detect dimension
135 self.dim = vectors.shape[1]
137 self._build_index()
139 self.check_ram_usage(len(vectors), "flat")
141 if vectors.shape[1] != self.dim:
142 raise ValueError(f"Dimension mismatch: expected {self.dim}, got {vectors.shape[1]}")
144 if self._normalize_vectors:
145 vectors = _normalize_rows(vectors)
147 self.index.add_with_ids(vectors, ids)
149 def delete_ids(self, ids: List[int]) -> None:
150 """Mark IDs as deleted for filtering in searches."""
151 ids = np.array(ids)
152 if self.index:
153 self.index.remove_ids(ids)
155 def dump(self):
156 # TODO to not save it every time for big files?
157 # use two indexes: main and temporal
158 # temporal is Flat and stores data that wasn't moved into main, and have limit
159 if self.index:
160 faiss.write_index(self.index, self.path)
162 def apply_index(self):
163 # TODO convert into IndexIVFFlat or IndexHNSWFlat
164 ...
166 def drop(self):
167 self.close()
168 if os.path.exists(self.path):
169 os.remove(self.path)
171 def search(
172 self,
173 query: Iterable[Iterable[float]],
174 limit: int = 10,
175 # allowed_ids: Optional[Sequence[int]] = None,
176 ):
177 if self.index is None:
178 return [], []
180 queries = np.array([query])
182 if self._normalize_vectors:
183 queries = _normalize_rows(queries)
185 ds, ids = self.index.search(queries, limit)
187 list_id = [i for i in ids[0] if i != -1]
188 list_distances = [1 - d for d in ds[0][: len(list_id)]]
190 return list_distances, list_id