Coverage for mindsdb / integrations / handlers / duckdb_faiss_handler / faiss_index.py: 0%

121 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import os 

2from typing import Iterable, List 

3import numpy as np 

4import psutil 

5 

6import portalocker 

7 

8import faiss # faiss or faiss-gpu 

9from pydantic import BaseModel 

10 

11 

12def _normalize_rows(x: np.ndarray) -> np.ndarray: 

13 norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12 

14 return x / norms 

15 

16 

17class FaissParams(BaseModel): 

18 metric: str | None = "cosine" 

19 use_gpu: bool | None = False 

20 nlist: int | None = 1024 

21 nprobe: int | None = 32 

22 hnsw_m: int | None = 32 

23 hnsw_ef_search: int | None = 64 

24 

25 

26class FaissIndex: 

27 def __init__(self, path: str, config: dict): 

28 self._normalize_vectors = False 

29 

30 self.config = FaissParams(**config) 

31 

32 metric = self.config.metric 

33 if metric == "cosine": 

34 self._normalize_vectors = True 

35 self.metric = faiss.METRIC_INNER_PRODUCT 

36 elif metric == "ip": 

37 self.metric = faiss.METRIC_INNER_PRODUCT 

38 elif metric == "l1": 

39 self.metric = faiss.METRIC_L1 

40 elif metric == "l2": 

41 self.metric = faiss.METRIC_L2 

42 else: 

43 raise ValueError(f"Unknown metric: {metric}") 

44 

45 self.path = path 

46 

47 self._since_ram_checked = 0 

48 

49 self.index = None 

50 self.dim = None 

51 self.index_fd = None 

52 if os.path.exists(self.path): 

53 self.load_index() 

54 

55 def load_index(self): 

56 # check RAM 

57 index_size = os.path.getsize(self.path) 

58 # according to tests faiss index occupies ~ the same amount of RAM as file size 

59 # add 10% and 1Gb to it, check only if index > 1Gb 

60 _1gb = 1024**3 

61 required_ram = index_size * 1.1 + _1gb 

62 available_ram = psutil.virtual_memory().available 

63 if required_ram > _1gb and available_ram < required_ram: 

64 to_free_gb = round((required_ram - available_ram) / _1gb, 2) 

65 raise ValueError(f"Unable load FAISS index into RAM, free up al least : {to_free_gb} Gb") 

66 

67 if os.name != "nt": 

68 self.index_fd = open(self.path, "rb") 

69 try: 

70 portalocker.lock(self.index_fd, portalocker.LOCK_EX | portalocker.LOCK_NB) 

71 except portalocker.exceptions.AlreadyLocked: 

72 raise ValueError(f"Index is already used: {self.path}") 

73 

74 self.index = faiss.read_index(self.path) 

75 self.dim = self.index.d 

76 

77 def close(self): 

78 if self.index_fd is not None: 

79 self.index_fd.close() 

80 self.index = None 

81 

82 def _build_index(self): 

83 # TODO option to create hnsw 

84 

85 index = faiss.IndexFlat(self.dim, self.metric) 

86 index = faiss.IndexIDMap(index) 

87 

88 if self.config.use_gpu: 

89 try: 

90 index = faiss.index_cpu_to_all_gpus(index) 

91 except Exception: 

92 pass 

93 

94 self.index = index 

95 

96 def check_ram_usage(self, count_vectors, index_type: str = "flat", m=32, nlist=4096): 

97 self._since_ram_checked += count_vectors 

98 

99 # check after every 10k vectors 

100 if self._since_ram_checked < 10000: 

101 return 

102 

103 match index_type: 

104 case "flat": 

105 required = self.dim * 4 * count_vectors 

106 case "hnsw": 

107 required = (self.dim * 4 + m * 2 * 4) * count_vectors 

108 case "ivf": 

109 required = (self.dim * 4 + 8) * count_vectors + self.dim * 4 * nlist 

110 case _: 

111 raise ValueError(f"Unknown index type: {index_type}") 

112 

113 # check RAM usage 

114 # keep extra 1Gb 

115 available = psutil.virtual_memory().available - 1 * 1024**3 

116 

117 if available < required: 

118 raise ValueError("Unable insert records, not enough RAM") 

119 

120 self._since_ram_checked = 0 

121 

122 def insert( 

123 self, 

124 vectors: Iterable[Iterable[float]], 

125 ids: Iterable[float], 

126 ) -> None: 

127 if len(vectors) == 0: 

128 return 

129 

130 vectors = np.array(vectors) 

131 ids = np.array(ids) 

132 

133 if self.index is None: 

134 # this if the first insert, detect dimension 

135 self.dim = vectors.shape[1] 

136 

137 self._build_index() 

138 

139 self.check_ram_usage(len(vectors), "flat") 

140 

141 if vectors.shape[1] != self.dim: 

142 raise ValueError(f"Dimension mismatch: expected {self.dim}, got {vectors.shape[1]}") 

143 

144 if self._normalize_vectors: 

145 vectors = _normalize_rows(vectors) 

146 

147 self.index.add_with_ids(vectors, ids) 

148 

149 def delete_ids(self, ids: List[int]) -> None: 

150 """Mark IDs as deleted for filtering in searches.""" 

151 ids = np.array(ids) 

152 if self.index: 

153 self.index.remove_ids(ids) 

154 

155 def dump(self): 

156 # TODO to not save it every time for big files? 

157 # use two indexes: main and temporal 

158 # temporal is Flat and stores data that wasn't moved into main, and have limit 

159 if self.index: 

160 faiss.write_index(self.index, self.path) 

161 

162 def apply_index(self): 

163 # TODO convert into IndexIVFFlat or IndexHNSWFlat 

164 ... 

165 

166 def drop(self): 

167 self.close() 

168 if os.path.exists(self.path): 

169 os.remove(self.path) 

170 

171 def search( 

172 self, 

173 query: Iterable[Iterable[float]], 

174 limit: int = 10, 

175 # allowed_ids: Optional[Sequence[int]] = None, 

176 ): 

177 if self.index is None: 

178 return [], [] 

179 

180 queries = np.array([query]) 

181 

182 if self._normalize_vectors: 

183 queries = _normalize_rows(queries) 

184 

185 ds, ids = self.index.search(queries, limit) 

186 

187 list_id = [i for i in ids[0] if i != -1] 

188 list_distances = [1 - d for d in ds[0][: len(list_id)]] 

189 

190 return list_distances, list_id