Coverage for mindsdb / interfaces / knowledge_base / controller.py: 34%

900 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import os 

2import copy 

3from typing import Dict, List, Optional, Any, Text, Tuple, Union 

4import json 

5import decimal 

6 

7import pandas as pd 

8import numpy as np 

9from pydantic import BaseModel, ValidationError 

10from sqlalchemy.orm.attributes import flag_modified 

11 

12from mindsdb_sql_parser.ast import BinaryOperation, Constant, Identifier, Select, Update, Delete, Star 

13from mindsdb_sql_parser import parse_sql 

14 

15from mindsdb.integrations.libs.keyword_search_base import KeywordSearchBase 

16from mindsdb.integrations.utilities.query_traversal import query_traversal 

17 

18import mindsdb.interfaces.storage.db as db 

19from mindsdb.integrations.libs.vectordatabase_handler import ( 

20 DistanceFunction, 

21 TableField, 

22 VectorStoreHandler, 

23) 

24from mindsdb.integrations.utilities.handler_utils import get_api_key 

25from mindsdb.integrations.utilities.handlers.auth_utilities.snowflake import get_validated_jwt 

26 

27from mindsdb.integrations.utilities.rag.settings import RerankerMode 

28 

29from mindsdb.interfaces.agents.constants import get_default_embeddings_model_class, MAX_INSERT_BATCH_SIZE 

30from mindsdb.interfaces.agents.provider_utils import get_llm_provider 

31 

32try: 

33 from mindsdb.interfaces.agents.langchain_agent import create_chat_model 

34except ModuleNotFoundError as exc: # pragma: no cover - optional dependency 

35 if getattr(exc, "name", "") and "langchain" in exc.name: 

36 create_chat_model = None 

37 _LANGCHAIN_IMPORT_ERROR = exc 

38 else: # Unknown import error, surface it 

39 raise 

40else: 

41 _LANGCHAIN_IMPORT_ERROR = None 

42from mindsdb.interfaces.database.projects import ProjectController 

43from mindsdb.interfaces.knowledge_base.preprocessing.models import PreprocessingConfig, Document 

44from mindsdb.interfaces.knowledge_base.preprocessing.document_preprocessor import PreprocessorFactory 

45from mindsdb.interfaces.knowledge_base.evaluate import EvaluateBase 

46from mindsdb.interfaces.knowledge_base.executor import KnowledgeBaseQueryExecutor 

47from mindsdb.interfaces.model.functions import PredictorRecordNotFound 

48from mindsdb.utilities.exception import EntityExistsError, EntityNotExistsError 

49from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator, KeywordSearchArgs 

50from mindsdb.utilities.config import config 

51from mindsdb.utilities.context import context as ctx 

52 

53from mindsdb.api.executor.command_executor import ExecuteCommands 

54from mindsdb.api.executor.utilities.sql import query_df 

55from mindsdb.utilities import log 

56from mindsdb.integrations.utilities.rag.rerankers.base_reranker import BaseLLMReranker, ListwiseLLMReranker 

57from mindsdb.interfaces.knowledge_base.llm_client import LLMClient 

58 

59logger = log.getLogger(__name__) 

60 

61 

62def _require_agent_extra(feature: str): 

63 if create_chat_model is None: 

64 raise ImportError( 

65 f"{feature} requires the optional agent dependencies. Install them via `pip install mindsdb[kb]`." 

66 ) from _LANGCHAIN_IMPORT_ERROR 

67 

68 

69class KnowledgeBaseInputParams(BaseModel): 

70 metadata_columns: List[str] | None = None 

71 content_columns: List[str] | None = None 

72 id_column: str | None = None 

73 kb_no_upsert: bool = False 

74 kb_skip_existing: bool = False 

75 embedding_model: Dict[Text, Any] | None = None 

76 is_sparse: bool = False 

77 vector_size: int | None = None 

78 reranking_model: Union[Dict[Text, Any], bool] | None = None 

79 preprocessing: Dict[Text, Any] | None = None 

80 

81 class Config: 

82 extra = "forbid" 

83 

84 

85def get_model_params(model_params: dict, default_config_key: str): 

86 """ 

87 Get model parameters by combining default config with user provided parameters. 

88 """ 

89 combined_model_params = copy.deepcopy(config.get(default_config_key, {})) 

90 

91 if model_params: 

92 if not isinstance(model_params, dict): 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true

93 raise ValueError("Model parameters must be passed as a JSON object") 

94 

95 # if provider mismatches - don't use default values 

96 if "provider" in model_params and model_params["provider"] != combined_model_params.get("provider"): 96 ↛ 99line 96 didn't jump to line 99 because the condition on line 96 was always true

97 return model_params 

98 

99 combined_model_params.update(model_params) 

100 

101 combined_model_params.pop("use_default_llm", None) 

102 

103 return combined_model_params 

104 

105 

106def adapt_embedding_model_params(embedding_model_params: dict): 

107 """ 

108 Prepare parameters for embedding model. 

109 """ 

110 params_copy = copy.deepcopy(embedding_model_params) 

111 provider = params_copy.pop("provider", None).lower() 

112 api_key = get_api_key(provider, params_copy, strict=False) or params_copy.get("api_key") 

113 # Underscores are replaced because the provider name ultimately gets mapped to a class name. 

114 # This is mostly to support Azure OpenAI (azure_openai); the mapped class name is 'AzureOpenAIEmbeddings'. 

115 params_copy["class"] = provider.replace("_", "") 

116 if provider == "azure_openai": 

117 # Azure OpenAI expects the api_key to be passed as 'openai_api_key'. 

118 params_copy["openai_api_key"] = api_key 

119 params_copy["azure_endpoint"] = params_copy.pop("base_url") 

120 if "chunk_size" not in params_copy: 

121 params_copy["chunk_size"] = 2048 

122 if "api_version" in params_copy: 

123 params_copy["openai_api_version"] = params_copy["api_version"] 

124 else: 

125 params_copy[f"{provider}_api_key"] = api_key 

126 params_copy.pop("api_key", None) 

127 params_copy["model"] = params_copy.pop("model_name", None) 

128 

129 return params_copy 

130 

131 

132def get_reranking_model_from_params(reranking_model_params: dict): 

133 """ 

134 Create reranking model from parameters. 

135 """ 

136 from mindsdb.integrations.utilities.rag.settings import RerankerConfig 

137 

138 # Work on a copy; do not mutate caller's dict 

139 params_copy = copy.deepcopy(reranking_model_params) 

140 

141 # Handle API key if not provided 

142 provider = params_copy.get("provider", "openai").lower() 

143 if "api_key" not in params_copy: 

144 params_copy["api_key"] = get_api_key(provider, params_copy, strict=False) 

145 

146 # Handle model_name -> model alias for backward compatibility 

147 if "model_name" in params_copy and "model" not in params_copy: 

148 params_copy["model"] = params_copy.pop("model_name") 

149 

150 # Validate core fields (e.g. mode) via Pydantic 

151 try: 

152 cfg = RerankerConfig(**params_copy) 

153 except ValueError as e: 

154 raise ValueError(f"Invalid reranker configuration: {str(e)}") 

155 

156 # Merge validated fields back, preserving any extra user fields 

157 validated = cfg.model_dump() 

158 reranker_params = {**params_copy, **validated} 

159 

160 # Choose reranker class based on validated mode 

161 if cfg.mode == RerankerMode.LISTWISE: 

162 return ListwiseLLMReranker(**reranker_params) 

163 return BaseLLMReranker(**reranker_params) 

164 

165 

166def safe_pandas_is_datetime(value: str) -> bool: 

167 """ 

168 Check if the value can be parsed as a datetime. 

169 """ 

170 try: 

171 result = pd.api.types.is_datetime64_any_dtype(value) 

172 return result 

173 except ValueError: 

174 return False 

175 

176 

177def to_json(obj): 

178 if obj is None: 

179 return None 

180 try: 

181 return json.dumps(obj) 

182 except TypeError: 

183 return obj 

184 

185 

186def rotate_provider_api_key(params): 

187 """ 

188 Check api key for specific providers. At the moment it checks and updated jwt token of snowflake provider 

189 :param params: input params, can be modified by this function 

190 :return: a new api key if it is refreshed 

191 """ 

192 provider = params.get("provider").lower() 

193 

194 if provider == "snowflake": 194 ↛ 195line 194 didn't jump to line 195 because the condition on line 194 was never true

195 if "snowflake_account_id" in params: 

196 # `snowflake_account_id` is the old name 

197 params["account_id"] = params.pop("snowflake_account_id") 

198 

199 if "private_key" not in params: 

200 return 

201 

202 api_key = params.get("api_key") 

203 api_key2 = get_validated_jwt( 

204 api_key, 

205 account=params.get("account_id"), 

206 user=params.get("user"), 

207 private_key=params["private_key"], 

208 ) 

209 if api_key2 != api_key: 

210 # update keys 

211 params["api_key"] = api_key2 

212 return api_key2 

213 

214 

215class KnowledgeBaseTable: 

216 """ 

217 Knowledge base table interface 

218 Handlers requests to KB table and modifies data in linked vector db table 

219 """ 

220 

221 def __init__(self, kb: db.KnowledgeBase, session): 

222 self._kb = kb 

223 self._vector_db = None 

224 self.session = session 

225 self.document_preprocessor = None 

226 self.document_loader = None 

227 self.model_params = None 

228 

229 self.kb_to_vector_columns = {"id": "_original_doc_id", "chunk_id": "id", "chunk_content": "content"} 

230 if self._kb.params.get("version", 0) < 2: 

231 self.kb_to_vector_columns["id"] = "original_doc_id" 

232 

233 def configure_preprocessing(self, config: Optional[dict] = None): 

234 """Configure preprocessing for the knowledge base table""" 

235 logger.debug(f"Configuring preprocessing with config: {config}") 

236 self.document_preprocessor = None # Reset existing preprocessor 

237 if config is None: 237 ↛ 241line 237 didn't jump to line 241 because the condition on line 237 was always true

238 config = {} 

239 

240 # Ensure content_column is set for JSON chunking if not already specified 

241 if config.get("type") == "json_chunking" and config.get("json_chunking_config"): 241 ↛ 242line 241 didn't jump to line 242 because the condition on line 241 was never true

242 if "content_column" not in config["json_chunking_config"]: 

243 config["json_chunking_config"]["content_column"] = "content" 

244 

245 preprocessing_config = PreprocessingConfig(**config) 

246 self.document_preprocessor = PreprocessorFactory.create_preprocessor(preprocessing_config) 

247 

248 # set doc_id column name 

249 self.document_preprocessor.config.doc_id_column_name = self.kb_to_vector_columns["id"] 

250 

251 logger.debug(f"Created preprocessor of type: {type(self.document_preprocessor)}") 

252 

253 def select_query(self, query: Select) -> pd.DataFrame: 

254 """ 

255 Handles select from KB table. 

256 Replaces content values with embeddings in where clause. Sends query to vector db 

257 :param query: query to KB table 

258 :return: dataframe with the result table 

259 """ 

260 

261 # Copy query for complex execution via DuckDB: DISTINCT, GROUP BY etc. 

262 query_copy = copy.deepcopy(query) 

263 

264 executor = KnowledgeBaseQueryExecutor(self) 

265 df = executor.run(query) 

266 

267 # copy metadata to columns 

268 if "metadata" in df.columns: 268 ↛ 283line 268 didn't jump to line 283 because the condition on line 268 was always true

269 meta_columns = self._get_allowed_metadata_columns() 

270 if meta_columns: 270 ↛ 271line 270 didn't jump to line 271 because the condition on line 270 was never true

271 meta_data = pd.json_normalize(df["metadata"]) 

272 # exclude absent columns and used colunns 

273 df_columns = list(df.columns) 

274 meta_columns = list(set(meta_columns).intersection(meta_data.columns).difference(df_columns)) 

275 

276 # add columns 

277 df = df.join(meta_data[meta_columns]) 

278 

279 # put metadata in the end 

280 df_columns.remove("metadata") 

281 df = df[df_columns + meta_columns + ["metadata"]] 

282 

283 if ( 283 ↛ 291line 283 didn't jump to line 291 because the condition on line 283 was never true

284 query_copy.group_by is not None 

285 or query_copy.order_by is not None 

286 or query_copy.having is not None 

287 or query_copy.distinct is True 

288 or len(query_copy.targets) != 1 

289 or not isinstance(query_copy.targets[0], Star) 

290 ): 

291 query_copy.where = None 

292 if "metadata" in df.columns: 

293 df["metadata"] = df["metadata"].apply(to_json) 

294 

295 if query_copy.from_table is None: 

296 query_copy.from_table = Identifier(parts=[self._kb.name]) 

297 

298 df = query_df(df, query_copy, session=self.session) 

299 

300 return df 

301 

302 def select(self, query, disable_reranking=False): 

303 logger.debug(f"Processing select query: {query}") 

304 

305 # Extract the content query text for potential reranking 

306 

307 db_handler = self.get_vector_db() 

308 

309 logger.debug("Replaced content with embeddings in where clause") 

310 # set table name 

311 query.from_table = Identifier(parts=[self._kb.vector_database_table]) 

312 logger.debug(f"Set table name to: {self._kb.vector_database_table}") 

313 

314 query.targets = [ 

315 Identifier(TableField.ID.value), 

316 Identifier(TableField.CONTENT.value), 

317 Identifier(TableField.METADATA.value), 

318 Identifier(TableField.DISTANCE.value), 

319 ] 

320 

321 # Get response from vector db 

322 logger.debug(f"Using vector db handler: {type(db_handler)}") 

323 

324 # extract values from conditions and prepare for vectordb 

325 conditions = [] 

326 keyword_search_conditions = [] 

327 keyword_search_cols_and_values = [] 

328 query_text = None 

329 relevance_threshold = None 

330 relevance_threshold_allowed_operators = [ 

331 FilterOperator.GREATER_THAN_OR_EQUAL.value, 

332 FilterOperator.GREATER_THAN.value, 

333 ] 

334 gt_filtering = False 

335 query_conditions = db_handler.extract_conditions(query.where) 

336 hybrid_search_alpha = None 

337 if query_conditions is not None: 337 ↛ 388line 337 didn't jump to line 388 because the condition on line 337 was always true

338 for item in query_conditions: 

339 if (item.column == "relevance") and (item.op.value in relevance_threshold_allowed_operators): 339 ↛ 340line 339 didn't jump to line 340 because the condition on line 339 was never true

340 try: 

341 relevance_threshold = float(item.value) 

342 # Validate range: must be between 0 and 1 

343 if not (0 <= relevance_threshold <= 1): 

344 raise ValueError(f"relevance_threshold must be between 0 and 1, got: {relevance_threshold}") 

345 if item.op.value == FilterOperator.GREATER_THAN.value: 

346 gt_filtering = True 

347 logger.debug(f"Found relevance_threshold in query: {relevance_threshold}") 

348 except (ValueError, TypeError) as e: 

349 error_msg = f"Invalid relevance_threshold value: {item.value}. {e}" 

350 logger.error(error_msg) 

351 raise ValueError(error_msg) from e 

352 elif (item.column == "relevance") and (item.op.value not in relevance_threshold_allowed_operators): 352 ↛ 353line 352 didn't jump to line 353 because the condition on line 352 was never true

353 raise ValueError( 

354 f"Invalid operator for relevance: {item.op.value}. Only the following operators are allowed: " 

355 f"{','.join(relevance_threshold_allowed_operators)}." 

356 ) 

357 elif item.column == "reranking": 357 ↛ 358line 357 didn't jump to line 358 because the condition on line 357 was never true

358 if item.value is False or (isinstance(item.value, str) and item.value.lower() == "false"): 

359 disable_reranking = True 

360 elif item.column == "hybrid_search": 

361 if item.value: 361 ↛ 338line 361 didn't jump to line 338 because the condition on line 361 was always true

362 if hybrid_search_alpha is None: 362 ↛ 338line 362 didn't jump to line 338 because the condition on line 362 was always true

363 hybrid_search_alpha = 0.5 

364 elif item.column == "hybrid_search_alpha": 364 ↛ 366line 364 didn't jump to line 366 because the condition on line 364 was never true

365 # validate item.value is a float 

366 if not isinstance(item.value, (float, int)): 

367 raise ValueError(f"Invalid hybrid_search_alpha value: {item.value}. Must be a float or int.") 

368 # validate hybrid search alpha is between 0 and 1 

369 if not (0 <= item.value <= 1): 

370 raise ValueError(f"Invalid hybrid_search_alpha value: {item.value}. Must be between 0 and 1.") 

371 hybrid_search_alpha = item.value 

372 elif item.column == TableField.CONTENT.value: 372 ↛ 385line 372 didn't jump to line 385 because the condition on line 372 was always true

373 query_text = item.value 

374 

375 # replace content with embeddings 

376 conditions.append( 

377 FilterCondition( 

378 column=TableField.EMBEDDINGS.value, 

379 value=self._content_to_embeddings(item.value), 

380 op=FilterOperator.EQUAL, 

381 ) 

382 ) 

383 keyword_search_cols_and_values.append((TableField.CONTENT.value, item.value)) 

384 else: 

385 conditions.append(item) 

386 keyword_search_conditions.append(item) # keyword search conditions do not use embeddings 

387 

388 if len(keyword_search_cols_and_values) > 1: 388 ↛ 389line 388 didn't jump to line 389 because the condition on line 388 was never true

389 raise ValueError( 

390 "Multiple content columns found in query conditions. " 

391 "Only one content column is allowed for keyword search." 

392 ) 

393 

394 logger.debug(f"Extracted query text: {query_text}") 

395 

396 self.addapt_conditions_columns(conditions) 

397 

398 # Set default limit if query is present 

399 limit = query.limit.value if query.limit is not None else None 

400 if query_text is not None: 400 ↛ 414line 400 didn't jump to line 414 because the condition on line 400 was always true

401 if limit is None: 401 ↛ 403line 401 didn't jump to line 403 because the condition on line 401 was always true

402 limit = 10 

403 elif limit > 100: 

404 limit = 100 

405 

406 if not disable_reranking: 406 ↛ 411line 406 didn't jump to line 411 because the condition on line 406 was always true

407 # expand limit, get more records before reranking usage: 

408 # get twice size of input but not greater than 30 

409 query_limit = min(limit * 2, limit + 30) 

410 else: 

411 query_limit = limit 

412 query.limit = Constant(query_limit) 

413 

414 allowed_metadata_columns = self._get_allowed_metadata_columns() 

415 

416 if hybrid_search_alpha is None: 416 ↛ 417line 416 didn't jump to line 417 because the condition on line 416 was never true

417 hybrid_search_alpha = 1 

418 

419 if hybrid_search_alpha > 0: 419 ↛ 425line 419 didn't jump to line 425 because the condition on line 419 was always true

420 df = db_handler.dispatch_select(query, conditions, allowed_metadata_columns=allowed_metadata_columns) 

421 df = self.addapt_result_columns(df) 

422 logger.debug(f"Query returned {len(df)} rows") 

423 logger.debug(f"Columns in response: {df.columns.tolist()}") 

424 else: 

425 df = pd.DataFrame([], columns=["id", "chunk_id", "chunk_content", "metadata", "distance"]) 

426 

427 # check if db_handler inherits from KeywordSearchBase 

428 if hybrid_search_alpha < 1: 428 ↛ 471line 428 didn't jump to line 471 because the condition on line 428 was always true

429 if not isinstance(db_handler, KeywordSearchBase): 429 ↛ 430line 429 didn't jump to line 430 because the condition on line 429 was never true

430 raise ValueError( 

431 f"Hybrid search is enabled but the db_handler {type(db_handler)} does not support it. " 

432 ) 

433 

434 # If query_text is present, use it for keyword search 

435 logger.debug(f"Performing keyword search with query text: {query_text}") 

436 keyword_search_args = KeywordSearchArgs(query=query_text, column=TableField.CONTENT.value) 

437 keyword_query_obj = copy.deepcopy(query) 

438 

439 keyword_query_obj.targets = [ 

440 Identifier(TableField.ID.value), 

441 Identifier(TableField.CONTENT.value), 

442 Identifier(TableField.METADATA.value), 

443 ] 

444 

445 df_keyword = db_handler.dispatch_select( 

446 keyword_query_obj, 

447 keyword_search_conditions, 

448 allowed_metadata_columns=allowed_metadata_columns, 

449 keyword_search_args=keyword_search_args, 

450 ) 

451 df_keyword = self.addapt_result_columns(df_keyword) 

452 logger.debug(f"Keyword search returned {len(df_keyword)} rows") 

453 logger.debug(f"Columns in keyword search response: {df_keyword.columns.tolist()}") 

454 # ensure df and df_keyword_select have exactly the same columns 

455 if not df_keyword.empty: 455 ↛ 471line 455 didn't jump to line 471 because the condition on line 455 was always true

456 if df.empty: 456 ↛ 457line 456 didn't jump to line 457 because the condition on line 456 was never true

457 df = df_keyword 

458 else: 

459 df_keyword[TableField.DISTANCE.value] = hybrid_search_alpha * df_keyword[TableField.DISTANCE.value] 

460 df[TableField.DISTANCE.value] = (1 - hybrid_search_alpha) * df[TableField.DISTANCE.value] 

461 

462 df = pd.concat([df, df_keyword], ignore_index=True) 

463 # sort by distance if distance column exists 

464 if TableField.DISTANCE.value in df.columns: 464 ↛ 467line 464 didn't jump to line 467 because the condition on line 464 was always true

465 df = df.sort_values(by=TableField.DISTANCE.value, ascending=True) 

466 # if chunk_id column exists remove duplicates based on chunk_id 

467 if "chunk_id" in df.columns: 467 ↛ 471line 467 didn't jump to line 471 because the condition on line 467 was always true

468 df = df.drop_duplicates(subset=["chunk_id"]) 

469 

470 # Check if we have a rerank_model configured in KB params 

471 df = self.add_relevance(df, query_text, relevance_threshold, disable_reranking) 

472 if limit is not None: 472 ↛ 476line 472 didn't jump to line 476 because the condition on line 472 was always true

473 df = df[:limit] 

474 

475 # if relevance filtering method is strictly GREATER THAN we filter the df 

476 if gt_filtering: 476 ↛ 477line 476 didn't jump to line 477 because the condition on line 476 was never true

477 relevance_scores = TableField.RELEVANCE.value 

478 df = df[df[relevance_scores] > relevance_threshold] 

479 

480 return df 

481 

482 def _get_allowed_metadata_columns(self) -> List[str] | None: 

483 # Return list of KB columns to restrict querying, if None: no restrictions 

484 

485 if self._kb.params.get("version", 0) < 2: 485 ↛ 489line 485 didn't jump to line 489 because the condition on line 485 was always true

486 # disable for old version KBs 

487 return None 

488 

489 user_columns = self._kb.params.get("metadata_columns", []) 

490 dynamic_columns = self._kb.params.get("inserted_metadata", []) 

491 

492 columns = set(user_columns) | set(dynamic_columns) 

493 return [col.lower() for col in columns] 

494 

495 def score_documents(self, query_text, documents, reranking_model_params): 

496 rotate_provider_api_key(reranking_model_params) 

497 reranker = get_reranking_model_from_params(reranking_model_params) 

498 return reranker.get_scores(query_text, documents) 

499 

500 def add_relevance(self, df, query_text, relevance_threshold=None, disable_reranking=False): 

501 relevance_column = TableField.RELEVANCE.value 

502 

503 reranking_model_params = get_model_params(self._kb.params.get("reranking_model"), "default_reranking_model") 

504 if reranking_model_params and query_text and len(df) > 0 and not disable_reranking: 504 ↛ 507line 504 didn't jump to line 507 because the condition on line 504 was never true

505 # Use reranker for relevance score 

506 

507 new_api_key = rotate_provider_api_key(reranking_model_params) 

508 if new_api_key: 

509 # update key 

510 if "reranking_model" not in self._kb.params: 

511 self._kb.params["reranking_model"] = {} 

512 self._kb.params["reranking_model"]["api_key"] = new_api_key 

513 flag_modified(self._kb, "params") 

514 db.session.commit() 

515 

516 # Apply custom filtering threshold if provided 

517 if relevance_threshold is not None: 

518 reranking_model_params["filtering_threshold"] = relevance_threshold 

519 logger.info(f"Using custom filtering threshold: {relevance_threshold}") 

520 

521 reranker = get_reranking_model_from_params(reranking_model_params) 

522 # Get documents to rerank 

523 documents = df["chunk_content"].tolist() 

524 # Use the get_scores method with disable_events=True 

525 scores = reranker.get_scores(query_text, documents) 

526 # Add scores as the relevance column 

527 df[relevance_column] = scores 

528 

529 # Filter by threshold 

530 scores_array = np.array(scores) 

531 df = df[scores_array >= reranker.filtering_threshold] 

532 

533 elif "distance" in df.columns: 533 ↛ 541line 533 didn't jump to line 541 because the condition on line 533 was always true

534 # Calculate relevance from distance 

535 logger.info("Calculating relevance from vector distance") 

536 df[relevance_column] = 1 / (1 + df["distance"]) 

537 if relevance_threshold is not None: 537 ↛ 538line 537 didn't jump to line 538 because the condition on line 537 was never true

538 df = df[df[relevance_column] > relevance_threshold] 

539 

540 else: 

541 df[relevance_column] = None 

542 df["distance"] = None 

543 # Sort by relevance 

544 df = df.sort_values(by=relevance_column, ascending=False) 

545 return df 

546 

547 def addapt_conditions_columns(self, conditions): 

548 if conditions is None: 548 ↛ 549line 548 didn't jump to line 549 because the condition on line 548 was never true

549 return 

550 for condition in conditions: 

551 if condition.column in self.kb_to_vector_columns: 551 ↛ 552line 551 didn't jump to line 552 because the condition on line 551 was never true

552 condition.column = self.kb_to_vector_columns[condition.column] 

553 

554 def addapt_result_columns(self, df): 

555 col_update = {} 

556 for kb_col, vec_col in self.kb_to_vector_columns.items(): 

557 if vec_col in df.columns: 

558 col_update[vec_col] = kb_col 

559 

560 df = df.rename(columns=col_update) 

561 

562 columns = list(df.columns) 

563 # update id, get from metadata 

564 df[TableField.ID.value] = df[TableField.METADATA.value].apply( 

565 lambda m: None if m is None else m.get(self.kb_to_vector_columns["id"]) 

566 ) 

567 

568 # id on first place 

569 return df[[TableField.ID.value] + columns] 

570 

571 def insert_files(self, file_names: List[str]): 

572 """Process and insert files""" 

573 if not self.document_loader: 

574 raise ValueError("Document loader not configured") 

575 

576 documents = list(self.document_loader.load_files(file_names)) 

577 if documents: 

578 self.insert_documents(documents) 

579 

580 def insert_web_pages(self, urls: List[str], crawl_depth: int, limit: int, filters: List[str] = None): 

581 """Process and insert web pages""" 

582 if not self.document_loader: 

583 raise ValueError("Document loader not configured") 

584 

585 documents = list( 

586 self.document_loader.load_web_pages(urls, limit=limit, crawl_depth=crawl_depth, filters=filters) 

587 ) 

588 if documents: 

589 self.insert_documents(documents) 

590 

591 def insert_query_result(self, query: str, project_name: str): 

592 """Process and insert SQL query results""" 

593 ast_query = parse_sql(query) 

594 

595 command_executor = ExecuteCommands(self.session) 

596 response = command_executor.execute_command(ast_query, project_name) 

597 

598 if response.error_code is not None: 

599 raise ValueError(f"Error executing query: {response.error_message}") 

600 

601 if response.data is None: 

602 raise ValueError("Query returned no data") 

603 

604 records = response.data.records 

605 df = pd.DataFrame(records) 

606 

607 self.insert(df) 

608 

609 def insert_rows(self, rows: List[Dict]): 

610 """Process and insert raw data rows""" 

611 if not rows: 

612 return 

613 

614 df = pd.DataFrame(rows) 

615 

616 self.insert(df) 

617 

618 def insert_documents(self, documents: List[Document]): 

619 """Process and insert documents with preprocessing if configured""" 

620 df = pd.DataFrame([doc.model_dump() for doc in documents]) 

621 

622 self.insert(df) 

623 

624 def update_query(self, query: Update): 

625 # add embeddings to content in updated collumns 

626 query = copy.deepcopy(query) 

627 

628 emb_col = TableField.EMBEDDINGS.value 

629 cont_col = TableField.CONTENT.value 

630 

631 db_handler = self.get_vector_db() 

632 conditions = db_handler.extract_conditions(query.where) 

633 doc_id = None 

634 for condition in conditions: 

635 if condition.column == "chunk_id" and condition.op == FilterOperator.EQUAL: 

636 doc_id = condition.value 

637 

638 if cont_col in query.update_columns: 

639 content = query.update_columns[cont_col] 

640 

641 # Apply preprocessing to content if configured 

642 if self.document_preprocessor: 

643 doc = Document( 

644 id=doc_id, 

645 content=content.value, 

646 metadata={}, # Empty metadata for content-only updates 

647 ) 

648 processed_chunks = self.document_preprocessor.process_documents([doc]) 

649 if processed_chunks: 

650 content.value = processed_chunks[0].content 

651 

652 query.update_columns[emb_col] = Constant(self._content_to_embeddings(content.value)) 

653 

654 if "metadata" not in query.update_columns: 

655 query.update_columns["metadata"] = Constant({}) 

656 

657 # TODO search content in where clause? 

658 

659 # set table name 

660 query.table = Identifier(parts=[self._kb.vector_database_table]) 

661 

662 # send to vectordb 

663 self.addapt_conditions_columns(conditions) 

664 db_handler.dispatch_update(query, conditions) 

665 

666 def delete_query(self, query: Delete): 

667 """ 

668 Handles delete query to KB table. 

669 Replaces content values with embeddings in WHERE clause. Sends query to vector db 

670 :param query: query to KB table 

671 """ 

672 query_traversal(query.where, self._replace_query_content) 

673 

674 # set table name 

675 query.table = Identifier(parts=[self._kb.vector_database_table]) 

676 

677 # send to vectordb 

678 db_handler = self.get_vector_db() 

679 conditions = db_handler.extract_conditions(query.where) 

680 self.addapt_conditions_columns(conditions) 

681 db_handler.dispatch_delete(query, conditions) 

682 

683 def hybrid_search( 

684 self, 

685 query: str, 

686 keywords: List[str] = None, 

687 metadata: Dict[str, str] = None, 

688 distance_function=DistanceFunction.COSINE_DISTANCE, 

689 ) -> pd.DataFrame: 

690 query_df = pd.DataFrame.from_records([{TableField.CONTENT.value: query}]) 

691 embeddings_df = self._df_to_embeddings(query_df) 

692 if embeddings_df.empty: 

693 return pd.DataFrame([]) 

694 embeddings = embeddings_df.iloc[0][TableField.EMBEDDINGS.value] 

695 keywords_query = None 

696 if keywords is not None: 

697 keywords_query = " ".join(keywords) 

698 db_handler = self.get_vector_db() 

699 return db_handler.hybrid_search( 

700 self._kb.vector_database_table, 

701 embeddings, 

702 query=keywords_query, 

703 metadata=metadata, 

704 distance_function=distance_function, 

705 ) 

706 

707 def clear(self): 

708 """ 

709 Clear data in KB table 

710 Sends delete to vector db table 

711 """ 

712 db_handler = self.get_vector_db() 

713 db_handler.delete(self._kb.vector_database_table) 

714 

715 def insert(self, df: pd.DataFrame, params: dict = None): 

716 """Insert dataframe to KB table. 

717 

718 Args: 

719 df: DataFrame to insert 

720 params: User parameters of insert 

721 """ 

722 if df.empty: 

723 return 

724 

725 if len(df) > MAX_INSERT_BATCH_SIZE: 

726 # auto-batching 

727 batch_size = MAX_INSERT_BATCH_SIZE 

728 

729 chunk_num = 0 

730 while chunk_num * batch_size < len(df): 

731 df2 = df[chunk_num * batch_size : (chunk_num + 1) * batch_size] 

732 self.insert(df2, params=params) 

733 chunk_num += 1 

734 return 

735 

736 try: 

737 run_query_id = ctx.run_query_id 

738 # Link current KB to running query (where KB is used to insert data) 

739 if run_query_id is not None: 

740 self._kb.query_id = run_query_id 

741 db.session.commit() 

742 

743 except AttributeError: 

744 ... 

745 

746 df.replace({np.nan: None}, inplace=True) 

747 

748 # First adapt column names to identify content and metadata columns 

749 adapted_df, normalized_columns = self._adapt_column_names(df) 

750 content_columns = normalized_columns["content_columns"] 

751 

752 # Convert DataFrame rows to documents, creating separate documents for each content column 

753 raw_documents = [] 

754 for idx, row in adapted_df.iterrows(): 

755 base_metadata = self._parse_metadata(row.get(TableField.METADATA.value, {})) 

756 provided_id = row.get(TableField.ID.value) 

757 

758 for col in content_columns: 

759 content = row.get(col) 

760 if content and str(content).strip(): 

761 content_str = str(content) 

762 

763 # Use provided_id directly if it exists, otherwise generate one 

764 doc_id = self._generate_document_id(content_str, col, provided_id) 

765 

766 metadata = { 

767 **base_metadata, 

768 "_original_row_index": str(idx), # provide link to original row index 

769 "_content_column": col, 

770 } 

771 

772 raw_documents.append(Document(content=content_str, id=doc_id, metadata=metadata)) 

773 

774 # Apply preprocessing to all documents if preprocessor exists 

775 if self.document_preprocessor: 

776 processed_chunks = self.document_preprocessor.process_documents(raw_documents) 

777 else: 

778 processed_chunks = raw_documents # Use raw documents if no preprocessing 

779 

780 # Convert processed chunks back to DataFrame with standard structure 

781 df = pd.DataFrame( 

782 [ 

783 { 

784 TableField.CONTENT.value: chunk.content, 

785 TableField.ID.value: chunk.id, 

786 TableField.METADATA.value: chunk.metadata, 

787 } 

788 for chunk in processed_chunks 

789 ] 

790 ) 

791 

792 if df.empty: 

793 logger.warning("No valid content found in any content columns") 

794 return 

795 

796 # Check if we should skip existing items (before calculating embeddings) 

797 if params is not None and params.get("kb_skip_existing", False): 

798 logger.debug(f"Checking for existing items to skip before processing {len(df)} items") 

799 db_handler = self.get_vector_db() 

800 

801 # Get list of IDs from current batch 

802 current_ids = df[TableField.ID.value].dropna().astype(str).tolist() 

803 if current_ids: 

804 # Check which IDs already exist 

805 existing_ids = db_handler.check_existing_ids(self._kb.vector_database_table, current_ids) 

806 if existing_ids: 

807 # Filter out existing items 

808 df = df[~df[TableField.ID.value].astype(str).isin(existing_ids)] 

809 logger.info(f"Skipped {len(existing_ids)} existing items, processing {len(df)} new items") 

810 

811 if df.empty: 

812 logger.info("All items already exist, nothing to insert") 

813 return 

814 

815 # add embeddings and send to vector db 

816 df_emb = self._df_to_embeddings(df) 

817 df = pd.concat([df, df_emb], axis=1) 

818 db_handler = self.get_vector_db() 

819 

820 if params is not None and params.get("kb_no_upsert", False): 

821 # speed up inserting by disable checking existing records 

822 db_handler.insert(self._kb.vector_database_table, df) 

823 else: 

824 db_handler.do_upsert(self._kb.vector_database_table, df) 

825 

826 def _adapt_column_names(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, List[str]]]: 

827 """ 

828 Convert input columns for vector db input 

829 - id, content and metadata 

830 """ 

831 # Debug incoming data 

832 logger.debug(f"Input DataFrame columns: {df.columns}") 

833 logger.debug(f"Input DataFrame first row: {df.iloc[0].to_dict()}") 

834 

835 params = self._kb.params 

836 columns = list(df.columns) 

837 

838 # -- prepare id -- 

839 id_column = params.get("id_column") 

840 if id_column is not None and id_column not in columns: 

841 id_column = None 

842 

843 if id_column is None and TableField.ID.value in columns: 

844 id_column = TableField.ID.value 

845 

846 # Also check for case-insensitive 'id' column 

847 if id_column is None: 

848 column_map = {col.lower(): col for col in columns} 

849 if "id" in column_map: 

850 id_column = column_map["id"] 

851 

852 if id_column is not None: 

853 columns.remove(id_column) 

854 logger.debug(f"Using ID column: {id_column}") 

855 

856 # Create output dataframe 

857 df_out = pd.DataFrame() 

858 

859 # Add ID if present 

860 if id_column is not None: 

861 df_out[TableField.ID.value] = df[id_column] 

862 logger.debug(f"Added IDs: {df_out[TableField.ID.value].tolist()}") 

863 

864 # -- prepare content and metadata -- 

865 content_columns = params.get("content_columns", [TableField.CONTENT.value]) 

866 metadata_columns = params.get("metadata_columns") 

867 

868 logger.debug(f"Processing with: content_columns={content_columns}, metadata_columns={metadata_columns}") 

869 

870 # Handle SQL query result columns 

871 if content_columns: 

872 # Ensure content columns are case-insensitive 

873 column_map = {col.lower(): col for col in columns} 

874 content_columns = [column_map.get(col.lower(), col) for col in content_columns] 

875 logger.debug(f"Mapped content columns: {content_columns}") 

876 

877 if metadata_columns: 

878 # Ensure metadata columns are case-insensitive 

879 column_map = {col.lower(): col for col in columns} 

880 metadata_columns = [column_map.get(col.lower(), col) for col in metadata_columns] 

881 logger.debug(f"Mapped metadata columns: {metadata_columns}") 

882 

883 content_columns = list(set(content_columns).intersection(columns)) 

884 if len(content_columns) == 0: 

885 raise ValueError(f"Content columns {params.get('content_columns')} not found in dataset: {columns}") 

886 

887 if metadata_columns is not None: 

888 metadata_columns = list(set(metadata_columns).intersection(columns)) 

889 else: 

890 # all the rest columns 

891 metadata_columns = list(set(columns).difference(content_columns)) 

892 

893 # update list of used columns 

894 inserted_metadata = set(self._kb.params.get("inserted_metadata", [])) 

895 inserted_metadata.update(metadata_columns) 

896 self._kb.params["inserted_metadata"] = list(inserted_metadata) 

897 flag_modified(self._kb, "params") 

898 db.session.commit() 

899 

900 # Add content columns directly (don't combine them) 

901 for col in content_columns: 

902 df_out[col] = df[col] 

903 

904 # Add metadata 

905 if metadata_columns and len(metadata_columns) > 0: 

906 

907 def convert_row_to_metadata(row): 

908 metadata = {} 

909 for col in metadata_columns: 

910 value = row[col] 

911 value_type = type(value) 

912 # Convert numpy/pandas types to Python native types 

913 if safe_pandas_is_datetime(value) or isinstance(value, pd.Timestamp): 

914 value = str(value) 

915 elif pd.api.types.is_integer_dtype(value_type): 

916 value = int(value) 

917 elif pd.api.types.is_float_dtype(value_type) or isinstance(value, decimal.Decimal): 

918 value = float(value) 

919 elif pd.api.types.is_bool_dtype(value_type): 

920 value = bool(value) 

921 elif isinstance(value, dict): 

922 metadata.update(value) 

923 continue 

924 elif value is not None: 

925 value = str(value) 

926 metadata[col] = value 

927 return metadata 

928 

929 metadata_dict = df[metadata_columns].apply(convert_row_to_metadata, axis=1) 

930 df_out[TableField.METADATA.value] = metadata_dict 

931 

932 logger.debug(f"Output DataFrame columns: {df_out.columns}") 

933 logger.debug(f"Output DataFrame first row: {df_out.iloc[0].to_dict() if not df_out.empty else 'Empty'}") 

934 

935 return df_out, {"content_columns": content_columns, "metadata_columns": metadata_columns} 

936 

937 def _replace_query_content(self, node, **kwargs): 

938 if isinstance(node, BinaryOperation): 

939 if isinstance(node.args[0], Identifier) and isinstance(node.args[1], Constant): 

940 col_name = node.args[0].parts[-1] 

941 if col_name.lower() == TableField.CONTENT.value: 

942 # replace 

943 node.args[0].parts = [TableField.EMBEDDINGS.value] 

944 node.args[1].value = [self._content_to_embeddings(node.args[1].value)] 

945 

946 def get_vector_db(self) -> VectorStoreHandler: 

947 """ 

948 helper to get vector db handler 

949 """ 

950 if self._vector_db is None: 950 ↛ 951line 950 didn't jump to line 951 because the condition on line 950 was never true

951 database = db.Integration.query.get(self._kb.vector_database_id) 

952 if database is None: 

953 raise ValueError("Vector database not found. Is it deleted?") 

954 database_name = database.name 

955 self._vector_db = self.session.integration_controller.get_data_handler(database_name) 

956 return self._vector_db 

957 

958 def get_vector_db_table_name(self) -> str: 

959 """ 

960 helper to get underlying table name used for embeddings 

961 """ 

962 return self._kb.vector_database_table 

963 

964 def _df_to_embeddings(self, df: pd.DataFrame) -> pd.DataFrame: 

965 """ 

966 Returns embeddings for input dataframe. 

967 Uses model embedding model to convert content to embeddings. 

968 Automatically detects input and output of model using model description 

969 :param df: 

970 :return: dataframe with embeddings 

971 """ 

972 

973 if df.empty: 

974 return pd.DataFrame([], columns=[TableField.EMBEDDINGS.value]) 

975 

976 model_id = self._kb.embedding_model_id 

977 

978 if model_id is None: 

979 messages = list(df[TableField.CONTENT.value]) 

980 embedding_params = get_model_params(self._kb.params.get("embedding_model", {}), "default_embedding_model") 

981 new_api_key = rotate_provider_api_key(embedding_params) 

982 if new_api_key: 

983 # update key 

984 if "embedding_model" not in self._kb.params: 

985 self._kb.params["embedding_model"] = {} 

986 self._kb.params["embedding_model"]["api_key"] = new_api_key 

987 flag_modified(self._kb, "params") 

988 db.session.commit() 

989 

990 llm_client = LLMClient(embedding_params, session=self.session) 

991 results = llm_client.embeddings(messages) 

992 

993 results = [[val] for val in results] 

994 return pd.DataFrame(results, columns=[TableField.EMBEDDINGS.value]) 

995 

996 # get the input columns 

997 model_rec = db.session.query(db.Predictor).filter_by(id=model_id).first() 

998 

999 assert model_rec is not None, f"Model not found: {model_id}" 

1000 model_project = db.session.query(db.Project).filter_by(id=model_rec.project_id).first() 

1001 

1002 project_datanode = self.session.datahub.get(model_project.name) 

1003 

1004 model_using = model_rec.learn_args.get("using", {}) 

1005 input_col = model_using.get("question_column") 

1006 if input_col is None: 

1007 input_col = model_using.get("input_column") 

1008 

1009 if input_col is not None and input_col != TableField.CONTENT.value: 

1010 df = df.rename(columns={TableField.CONTENT.value: input_col}) 

1011 

1012 df_out = project_datanode.predict(model_name=model_rec.name, df=df, params=self.model_params) 

1013 

1014 target = model_rec.to_predict[0] 

1015 if target != TableField.EMBEDDINGS.value: 

1016 # adapt output for vectordb 

1017 df_out = df_out.rename(columns={target: TableField.EMBEDDINGS.value}) 

1018 

1019 df_out = df_out[[TableField.EMBEDDINGS.value]] 

1020 

1021 return df_out 

1022 

1023 def _content_to_embeddings(self, content: str) -> List[float]: 

1024 """ 

1025 Converts string to embeddings 

1026 :param content: input string 

1027 :return: embeddings 

1028 """ 

1029 df = pd.DataFrame([[content]], columns=[TableField.CONTENT.value]) 

1030 res = self._df_to_embeddings(df) 

1031 return res[TableField.EMBEDDINGS.value][0] 

1032 

1033 @staticmethod 

1034 def call_litellm_embedding(session, model_params, messages): 

1035 args = copy.deepcopy(model_params) 

1036 

1037 if "model_name" not in args: 

1038 raise ValueError("'model_name' must be provided for embedding model") 

1039 

1040 llm_model = args.pop("model_name") 

1041 engine = args.pop("provider") 

1042 

1043 module = session.integration_controller.get_handler_module("litellm") 

1044 if module is None or module.Handler is None: 

1045 raise ValueError(f'Unable to use "{engine}" provider. Litellm handler is not installed') 

1046 return module.Handler.embeddings(engine, llm_model, messages, args) 

1047 

1048 def build_rag_pipeline(self, retrieval_config: dict): 

1049 """ 

1050 Builds a RAG pipeline with returned sources 

1051 

1052 Args: 

1053 retrieval_config: dict with retrieval config 

1054 

1055 Returns: 

1056 RAG: Configured RAG pipeline instance 

1057 

1058 Raises: 

1059 ValueError: If the configuration is invalid or required components are missing 

1060 """ 

1061 # Get embedding model from knowledge base 

1062 from mindsdb.integrations.handlers.langchain_embedding_handler.langchain_embedding_handler import ( 

1063 construct_model_from_args, 

1064 ) 

1065 from mindsdb.integrations.utilities.rag.rag_pipeline_builder import RAG 

1066 from mindsdb.integrations.utilities.rag.config_loader import load_rag_config 

1067 

1068 embedding_model_params = get_model_params(self._kb.params.get("embedding_model", {}), "default_embedding_model") 

1069 if self._kb.embedding_model: 

1070 # Extract embedding model args from knowledge base table 

1071 embedding_args = self._kb.embedding_model.learn_args.get("using", {}) 

1072 # Construct the embedding model directly 

1073 embeddings_model = construct_model_from_args(embedding_args) 

1074 logger.debug(f"Using knowledge base embedding model with args: {embedding_args}") 

1075 elif embedding_model_params: 

1076 embeddings_model = construct_model_from_args(adapt_embedding_model_params(embedding_model_params)) 

1077 logger.debug(f"Using knowledge base embedding model from params: {self._kb.params['embedding_model']}") 

1078 else: 

1079 embeddings_model_class = get_default_embeddings_model_class() 

1080 embeddings_model = embeddings_model_class() 

1081 logger.debug("Using default embedding model as knowledge base has no embedding model") 

1082 

1083 # Update retrieval config with knowledge base parameters 

1084 kb_params = {"vector_store_config": {"kb_table": self}} 

1085 

1086 # Load and validate config 

1087 try: 

1088 rag_config = load_rag_config(retrieval_config, kb_params, embeddings_model) 

1089 

1090 # Build LLM if specified 

1091 if "llm_model_name" in rag_config: 

1092 llm_args = {"model_name": rag_config.llm_model_name} 

1093 if not rag_config.llm_provider: 

1094 llm_args["provider"] = get_llm_provider(llm_args) 

1095 else: 

1096 llm_args["provider"] = rag_config.llm_provider 

1097 _require_agent_extra("Building knowledge base retrieval pipelines") 

1098 rag_config.llm = create_chat_model(llm_args) 

1099 

1100 # Create RAG pipeline 

1101 rag = RAG(rag_config) 

1102 logger.debug(f"RAG pipeline created with config: {rag_config}") 

1103 return rag 

1104 

1105 except Exception as e: 

1106 logger.exception("Error building RAG pipeline:") 

1107 raise ValueError(f"Failed to build RAG pipeline: {str(e)}") from e 

1108 

1109 def _parse_metadata(self, base_metadata): 

1110 """Helper function to robustly parse metadata string to dict""" 

1111 if isinstance(base_metadata, dict): 

1112 return base_metadata 

1113 if isinstance(base_metadata, str): 

1114 try: 

1115 import ast 

1116 

1117 return ast.literal_eval(base_metadata) 

1118 except (SyntaxError, ValueError): 

1119 logger.warning(f"Could not parse metadata: {base_metadata}. Using empty dict.") 

1120 return {} 

1121 return {} 

1122 

1123 def _generate_document_id(self, content: str, content_column: str, provided_id: str = None) -> str: 

1124 """Generate a deterministic document ID using the utility function.""" 

1125 from mindsdb.interfaces.knowledge_base.utils import generate_document_id 

1126 

1127 return generate_document_id(content=content, provided_id=provided_id) 

1128 

1129 def _convert_metadata_value(self, value): 

1130 """ 

1131 Convert metadata value to appropriate Python type. 

1132 

1133 Args: 

1134 value: The value to convert 

1135 

1136 Returns: 

1137 Converted value in appropriate Python type 

1138 """ 

1139 if pd.isna(value): 

1140 return None 

1141 

1142 # Handle pandas/numpy types 

1143 if pd.api.types.is_datetime64_any_dtype(value) or isinstance(value, pd.Timestamp): 

1144 return str(value) 

1145 elif pd.api.types.is_integer_dtype(type(value)): 

1146 return int(value) 

1147 elif pd.api.types.is_float_dtype(type(value)): 

1148 return float(value) 

1149 elif pd.api.types.is_bool_dtype(type(value)): 

1150 return bool(value) 

1151 

1152 # Handle basic Python types 

1153 if isinstance(value, (int, float, bool)): 

1154 return value 

1155 

1156 # Convert everything else to string 

1157 return str(value) 

1158 

1159 def create_index(self): 

1160 """ 

1161 Create an index on the knowledge base table 

1162 :param index_name: name of the index 

1163 :param params: parameters for the index 

1164 """ 

1165 db_handler = self.get_vector_db() 

1166 db_handler.create_index(self._kb.vector_database_table) 

1167 

1168 

1169class KnowledgeBaseController: 

1170 """ 

1171 Knowledge base controller handles all 

1172 manages knowledge bases 

1173 """ 

1174 

1175 KB_VERSION = 2 

1176 

1177 def __init__(self, session) -> None: 

1178 self.session = session 

1179 

1180 def _check_kb_input_params(self, params): 

1181 # check names and types KB params 

1182 try: 

1183 KnowledgeBaseInputParams.model_validate(params) 

1184 except ValidationError as e: 

1185 problems = [] 

1186 for error in e.errors(): 

1187 parameter = ".".join([str(i) for i in error["loc"]]) 

1188 param_type = error["type"] 

1189 if param_type == "extra_forbidden": 

1190 msg = f"Parameter '{parameter}' is not allowed" 

1191 else: 

1192 msg = f"Error in '{parameter}' (type: {param_type}): {error['msg']}. Input: {repr(error['input'])}" 

1193 problems.append(msg) 

1194 

1195 msg = "\n".join(problems) 

1196 if len(problems) > 1: 

1197 msg = "\n" + msg 

1198 raise ValueError(f"Problem with knowledge base parameters: {msg}") from e 

1199 

1200 def add( 

1201 self, 

1202 name: str, 

1203 project_name: str, 

1204 storage: Identifier, 

1205 params: dict, 

1206 preprocessing_config: Optional[dict] = None, 

1207 if_not_exists: bool = False, 

1208 keyword_search_enabled: bool = False, 

1209 # embedding_model: Identifier = None, # Legacy: Allow MindsDB models to be passed as embedding_model. 

1210 ) -> db.KnowledgeBase: 

1211 """ 

1212 Add a new knowledge base to the database 

1213 :param preprocessing_config: Optional preprocessing configuration to validate and store 

1214 :param is_sparse: Whether to use sparse vectors for embeddings 

1215 :param vector_size: Optional size specification for vectors, required when is_sparse=True 

1216 """ 

1217 

1218 # Validate preprocessing config first if provided 

1219 if preprocessing_config is not None: 1219 ↛ 1220line 1219 didn't jump to line 1220 because the condition on line 1219 was never true

1220 PreprocessingConfig(**preprocessing_config) # Validate before storing 

1221 params = params or {} 

1222 params["preprocessing"] = preprocessing_config 

1223 

1224 self._check_kb_input_params(params) 

1225 

1226 # Check if vector_size is provided when using sparse vectors 

1227 is_sparse = params.get("is_sparse") 

1228 vector_size = params.get("vector_size") 

1229 if is_sparse and vector_size is None: 1229 ↛ 1230line 1229 didn't jump to line 1230 because the condition on line 1229 was never true

1230 raise ValueError("vector_size is required when is_sparse=True") 

1231 

1232 # get project id 

1233 project = self.session.database_controller.get_project(project_name) 

1234 project_id = project.id 

1235 

1236 # check if knowledge base already exists 

1237 kb = self.get(name, project_id) 

1238 if kb is not None: 1238 ↛ 1239line 1238 didn't jump to line 1239 because the condition on line 1238 was never true

1239 if if_not_exists: 

1240 return kb 

1241 raise EntityExistsError("Knowledge base already exists", name) 

1242 

1243 embedding_params = get_model_params(params.get("embedding_model", {}), "default_embedding_model") 

1244 params["embedding_model"] = embedding_params 

1245 rotate_provider_api_key(embedding_params) 

1246 

1247 # if model_name is None: # Legacy 

1248 embed_info = self._check_embedding_model( 

1249 project.name, 

1250 params=embedding_params, 

1251 kb_name=name, 

1252 ) 

1253 

1254 # if params.get("reranking_model", {}) is bool and False we evaluate it to empty dictionary 

1255 reranking_model_params = params.get("reranking_model", {}) 

1256 

1257 if isinstance(reranking_model_params, bool) and not reranking_model_params: 1257 ↛ 1258line 1257 didn't jump to line 1258 because the condition on line 1257 was never true

1258 params["reranking_model"] = {} 

1259 else: 

1260 reranking_model_params = get_model_params(reranking_model_params, "default_reranking_model") 

1261 

1262 params["reranking_model"] = reranking_model_params 

1263 if reranking_model_params: 1263 ↛ 1266line 1263 didn't jump to line 1266 because the condition on line 1263 was never true

1264 # Get reranking model from params. 

1265 # This is called here to check validaity of the parameters. 

1266 rotate_provider_api_key(reranking_model_params) 

1267 self._test_reranking(reranking_model_params) 

1268 

1269 # search for the vector database table 

1270 if storage is None: 1270 ↛ 1289line 1270 didn't jump to line 1289 because the condition on line 1270 was always true

1271 cloud_pg_vector = os.environ.get("KB_PGVECTOR_URL") 

1272 if cloud_pg_vector: 1272 ↛ 1273line 1272 didn't jump to line 1273 because the condition on line 1272 was never true

1273 vector_table_name = name 

1274 # Add sparse vector support for pgvector 

1275 vector_db_params = {} 

1276 # Check both explicit parameter and model configuration 

1277 if is_sparse: 

1278 vector_db_params["is_sparse"] = True 

1279 if vector_size is not None: 

1280 vector_db_params["vector_size"] = vector_size 

1281 vector_db_name = self._create_persistent_pgvector(vector_db_params) 

1282 params["default_vector_storage"] = vector_db_name 

1283 else: 

1284 # create chroma db with same name 

1285 vector_table_name = "default_collection" 

1286 vector_db_name = self._create_persistent_chroma(name) 

1287 # memorize to remove it later 

1288 params["default_vector_storage"] = vector_db_name 

1289 elif len(storage.parts) != 2: 

1290 raise ValueError("Storage param has to be vector db with table") 

1291 else: 

1292 vector_db_name, vector_table_name = storage.parts 

1293 

1294 data_node = self.session.datahub.get(vector_db_name) 

1295 if data_node: 1295 ↛ 1298line 1295 didn't jump to line 1298 because the condition on line 1295 was always true

1296 vector_store_handler = data_node.integration_handler 

1297 else: 

1298 raise ValueError( 

1299 f"Unable to find database named {vector_db_name}, please make sure {vector_db_name} is defined" 

1300 ) 

1301 # create table in vectordb before creating KB 

1302 if "default_vector_storage" in params: 1302 ↛ 1308line 1302 didn't jump to line 1308 because the condition on line 1302 was always true

1303 # if vector db is a default - drop previous table, if exists 

1304 try: 

1305 vector_store_handler.drop_table(vector_table_name) 

1306 except Exception: 

1307 ... 

1308 vector_store_handler.create_table(vector_table_name) 

1309 self._check_vector_table(embed_info, vector_store_handler, vector_table_name) 

1310 

1311 if keyword_search_enabled: 1311 ↛ 1312line 1311 didn't jump to line 1312 because the condition on line 1311 was never true

1312 vector_store_handler.add_full_text_index(vector_table_name, TableField.CONTENT.value) 

1313 vector_database_id = self.session.integration_controller.get(vector_db_name)["id"] 

1314 

1315 # Store sparse vector settings in params if specified 

1316 if is_sparse: 1316 ↛ 1317line 1316 didn't jump to line 1317 because the condition on line 1316 was never true

1317 params = params or {} 

1318 params["vector_config"] = {"is_sparse": is_sparse} 

1319 if vector_size is not None: 

1320 params["vector_config"]["vector_size"] = vector_size 

1321 

1322 params["version"] = self.KB_VERSION 

1323 kb = db.KnowledgeBase( 

1324 name=name, 

1325 project_id=project_id, 

1326 vector_database_id=vector_database_id, 

1327 vector_database_table=vector_table_name, 

1328 embedding_model_id=None, 

1329 params=params, 

1330 ) 

1331 db.session.add(kb) 

1332 db.session.commit() 

1333 return kb 

1334 

1335 def _check_vector_table(self, embed_info, vector_store_handler, vector_table_name): 

1336 query = Select( 

1337 targets=[Identifier(TableField.EMBEDDINGS.value)], 

1338 from_table=Identifier(parts=[vector_table_name]), 

1339 limit=Constant(1), 

1340 ) 

1341 dimension = None 

1342 if hasattr(vector_store_handler, "get_dimension"): 1342 ↛ 1343line 1342 didn't jump to line 1343 because the condition on line 1342 was never true

1343 dimension = vector_store_handler.get_dimension(vector_table_name) 

1344 else: 

1345 df = vector_store_handler.dispatch_select(query, []) 

1346 if len(df) > 0: 1346 ↛ 1347line 1346 didn't jump to line 1347 because the condition on line 1346 was never true

1347 value = df[TableField.EMBEDDINGS.value][0] 

1348 if isinstance(value, str): 

1349 value = json.loads(value) 

1350 dimension = len(value) 

1351 if dimension is not None and dimension != embed_info["dimension"]: 1351 ↛ 1352line 1351 didn't jump to line 1352 because the condition on line 1351 was never true

1352 raise ValueError( 

1353 f"Dimension of embedding model doesn't match to dimension of vector table: {embed_info['dimension']} != {dimension}" 

1354 ) 

1355 

1356 def update( 

1357 self, 

1358 name: str, 

1359 project_name: str, 

1360 params: dict, 

1361 preprocessing_config: Optional[dict] = None, 

1362 ) -> db.KnowledgeBase: 

1363 """ 

1364 Update the knowledge base 

1365 :param name: The name of the knowledge base 

1366 :param project_name: Current project name 

1367 :param params: The parameters to update 

1368 :param preprocessing_config: Optional preprocessing configuration to validate and store 

1369 """ 

1370 

1371 # Validate preprocessing config first if provided 

1372 if preprocessing_config is not None: 1372 ↛ 1373line 1372 didn't jump to line 1373 because the condition on line 1372 was never true

1373 PreprocessingConfig(**preprocessing_config) # Validate before storing 

1374 params = params or {} 

1375 params["preprocessing"] = preprocessing_config 

1376 

1377 self._check_kb_input_params(params) 

1378 

1379 # get project id 

1380 project = self.session.database_controller.get_project(project_name) 

1381 project_id = project.id 

1382 

1383 # get existed KB 

1384 kb = self.get(name.lower(), project_id) 

1385 if kb is None: 1385 ↛ 1386line 1385 didn't jump to line 1386 because the condition on line 1385 was never true

1386 raise EntityNotExistsError("Knowledge base doesn't exists", name) 

1387 

1388 if "embedding_model" in params: 1388 ↛ 1410line 1388 didn't jump to line 1410 because the condition on line 1388 was always true

1389 new_config = params["embedding_model"] 

1390 # update embedding 

1391 embed_params = kb.params.get("embedding_model", {}) 

1392 if not embed_params: 1392 ↛ 1394line 1392 didn't jump to line 1394 because the condition on line 1392 was never true

1393 # maybe old version of KB 

1394 raise ValueError("No embedding config to update") 

1395 

1396 # some parameters are not allowed to update 

1397 for key in ("provider", "model_name"): 

1398 if key in new_config and new_config[key] != embed_params.get(key): 1398 ↛ 1399line 1398 didn't jump to line 1399 because the condition on line 1398 was never true

1399 raise ValueError(f"You can't update '{key}' setting") 

1400 

1401 embed_params.update(new_config) 

1402 

1403 self._check_embedding_model( 

1404 project.name, 

1405 params=embed_params, 

1406 kb_name=name, 

1407 ) 

1408 kb.params["embedding_model"] = embed_params 

1409 

1410 if "reranking_model" in params: 1410 ↛ 1411line 1410 didn't jump to line 1411 because the condition on line 1410 was never true

1411 new_config = params["reranking_model"] 

1412 # update embedding 

1413 rerank_params = kb.params.get("reranking_model", {}) 

1414 

1415 if new_config is False: 

1416 # disable reranking 

1417 rerank_params = {} 

1418 elif "provider" in new_config and new_config["provider"] != rerank_params.get("provider"): 

1419 # use new config (and include default config) 

1420 rerank_params = get_model_params(new_config, "default_reranking_model") 

1421 else: 

1422 # update current config 

1423 rerank_params.update(new_config) 

1424 

1425 if rerank_params: 

1426 self._test_reranking(rerank_params) 

1427 

1428 kb.params["reranking_model"] = rerank_params 

1429 

1430 # update other keys 

1431 for key in ["id_column", "metadata_columns", "content_columns", "preprocessing"]: 

1432 if key in params: 1432 ↛ 1433line 1432 didn't jump to line 1433 because the condition on line 1432 was never true

1433 kb.params[key] = params[key] 

1434 

1435 flag_modified(kb, "params") 

1436 db.session.commit() 

1437 

1438 return self.get(name.lower(), project_id) 

1439 

1440 def _test_reranking(self, params): 

1441 try: 

1442 reranker = get_reranking_model_from_params(params) 

1443 reranker.get_scores("test", ["test"]) 

1444 except (ValueError, RuntimeError) as e: 

1445 if params["provider"] in ("azure_openai", "openai") and params.get("method") != "no-logprobs": 

1446 # check with no-logprobs 

1447 params["method"] = "no-logprobs" 

1448 self._test_reranking(params) 

1449 logger.warning( 

1450 f"logprobs is not supported for this model: {params.get('model_name')}. using no-logprobs mode" 

1451 ) 

1452 else: 

1453 raise RuntimeError(f"Problem with reranker config: {e}") from e 

1454 

1455 def _create_persistent_pgvector(self, params=None): 

1456 """Create default vector database for knowledge base, if not specified""" 

1457 vector_store_name = "kb_pgvector_store" 

1458 

1459 # check if exists 

1460 if self.session.integration_controller.get(vector_store_name): 

1461 return vector_store_name 

1462 

1463 self.session.integration_controller.add(vector_store_name, "pgvector", params or {}) 

1464 return vector_store_name 

1465 

1466 def _create_persistent_chroma(self, kb_name, engine="chromadb"): 

1467 """Create default vector database for knowledge base, if not specified""" 

1468 

1469 vector_store_name = f"{kb_name}_{engine}" 

1470 

1471 vector_store_folder_name = f"{vector_store_name}" 

1472 connection_args = {"persist_directory": vector_store_folder_name} 

1473 

1474 # check if exists 

1475 if self.session.integration_controller.get(vector_store_name): 1475 ↛ 1476line 1475 didn't jump to line 1476 because the condition on line 1475 was never true

1476 return vector_store_name 

1477 

1478 self.session.integration_controller.add(vector_store_name, engine, connection_args) 

1479 return vector_store_name 

1480 

1481 def _check_embedding_model(self, project_name, params: dict = None, kb_name="") -> dict: 

1482 """check embedding model for knowledge base, return embedding model info""" 

1483 

1484 # if mindsdb model from old KB exists - drop it 

1485 model_name = f"kb_embedding_{kb_name}" 

1486 try: 

1487 model = self.session.model_controller.get_model(model_name, project_name=project_name) 

1488 if model is not None: 

1489 self.session.model_controller.delete_model(model_name, project_name) 

1490 except PredictorRecordNotFound: 

1491 pass 

1492 

1493 if "provider" not in params: 1493 ↛ 1494line 1493 didn't jump to line 1494 because the condition on line 1493 was never true

1494 raise ValueError("'provider' parameter is required for embedding model") 

1495 

1496 # check available providers 

1497 avail_providers = ("openai", "azure_openai", "bedrock", "gemini", "google", "ollama", "snowflake") 

1498 if params["provider"] not in avail_providers: 1498 ↛ 1499line 1498 didn't jump to line 1499 because the condition on line 1498 was never true

1499 raise ValueError( 

1500 f"Wrong embedding provider: {params['provider']}. Available providers: {', '.join(avail_providers)}" 

1501 ) 

1502 

1503 llm_client = LLMClient(params, session=self.session) 

1504 

1505 try: 

1506 resp = llm_client.embeddings(["test"]) 

1507 return {"dimension": len(resp[0])} 

1508 except Exception as e: 

1509 raise RuntimeError(f"Problem with embedding model config: {e}") from e 

1510 

1511 def delete(self, name: str, project_name: int, if_exists: bool = False) -> None: 

1512 """ 

1513 Delete a knowledge base from the database 

1514 """ 

1515 try: 

1516 project = self.session.database_controller.get_project(project_name) 

1517 except ValueError as e: 

1518 raise ValueError(f"Project not found: {project_name}") from e 

1519 project_id = project.id 

1520 

1521 # check if knowledge base exists 

1522 kb = self.get(name, project_id) 

1523 if kb is None: 

1524 # knowledge base does not exist 

1525 if if_exists: 

1526 return 

1527 else: 

1528 raise EntityNotExistsError("Knowledge base does not exist", name) 

1529 

1530 # kb exists 

1531 db.session.delete(kb) 

1532 db.session.commit() 

1533 

1534 # drop objects if they were created automatically 

1535 if "default_vector_storage" in kb.params: 

1536 try: 

1537 dn = self.session.datahub.get(kb.params["default_vector_storage"]) 

1538 dn.integration_handler.drop_table(kb.vector_database_table) 

1539 if dn.ds_type != "pgvector": 

1540 self.session.integration_controller.delete(kb.params["default_vector_storage"]) 

1541 except EntityNotExistsError: 

1542 pass 

1543 if "created_embedding_model" in kb.params: 

1544 try: 

1545 self.session.model_controller.delete_model(kb.params["created_embedding_model"], project_name) 

1546 except EntityNotExistsError: 

1547 pass 

1548 

1549 def get(self, name: str, project_id: int) -> db.KnowledgeBase: 

1550 """ 

1551 Get a knowledge base from the database 

1552 by name + project_id 

1553 """ 

1554 kb = ( 

1555 db.session.query(db.KnowledgeBase) 

1556 .filter_by( 

1557 name=name, 

1558 project_id=project_id, 

1559 ) 

1560 .first() 

1561 ) 

1562 return kb 

1563 

1564 def get_table(self, name: str, project_id: int, params: dict = None) -> KnowledgeBaseTable: 

1565 """ 

1566 Returns kb table object with properly configured preprocessing 

1567 :param name: table name 

1568 :param project_id: project id 

1569 :param params: runtime parameters for KB. Keys: 'model' - parameters for embedding model 

1570 :return: kb table object 

1571 """ 

1572 kb = self.get(name, project_id) 

1573 if kb is not None: 1573 ↛ exitline 1573 didn't return from function 'get_table' because the condition on line 1573 was always true

1574 table = KnowledgeBaseTable(kb, self.session) 

1575 if params: 1575 ↛ 1579line 1575 didn't jump to line 1579 because the condition on line 1575 was always true

1576 table.model_params = params.get("model") 

1577 

1578 # Always configure preprocessing - either from params or default 

1579 if kb.params and "preprocessing" in kb.params: 1579 ↛ 1580line 1579 didn't jump to line 1580 because the condition on line 1579 was never true

1580 table.configure_preprocessing(kb.params["preprocessing"]) 

1581 else: 

1582 table.configure_preprocessing(None) # This ensures default preprocessor is created 

1583 

1584 return table 

1585 

1586 def list(self, project_name: str = None) -> List[dict]: 

1587 """ 

1588 List all knowledge bases from the database 

1589 belonging to a project 

1590 """ 

1591 project_controller = ProjectController() 

1592 projects = project_controller.get_list() 

1593 if project_name is not None: 1593 ↛ 1596line 1593 didn't jump to line 1596 because the condition on line 1593 was always true

1594 projects = [p for p in projects if p.name == project_name] 

1595 

1596 query = db.session.query(db.KnowledgeBase).filter( 

1597 db.KnowledgeBase.project_id.in_(list([p.id for p in projects])) 

1598 ) 

1599 

1600 data = [] 

1601 project_names = {i.id: i.name for i in project_controller.get_list()} 

1602 

1603 for record in query: 1603 ↛ 1604line 1603 didn't jump to line 1604 because the loop on line 1603 never started

1604 kb = record.as_dict(with_secrets=self.session.show_secrets) 

1605 kb["project_name"] = project_names[record.project_id] 

1606 

1607 data.append(kb) 

1608 

1609 return data 

1610 

1611 def create_index(self, table_name, project_name): 

1612 project_id = self.session.database_controller.get_project(project_name).id 

1613 kb_table = self.get_table(table_name, project_id) 

1614 kb_table.create_index() 

1615 

1616 def evaluate(self, table_name: str, project_name: str, params: dict = None) -> pd.DataFrame: 

1617 """ 

1618 Run evaluate and/or create test data for evaluation 

1619 :param table_name: name of KB 

1620 :param project_name: project of KB 

1621 :param params: evaluation parameters 

1622 :return: evaluation results 

1623 """ 

1624 project_id = self.session.database_controller.get_project(project_name).id 

1625 kb_table = self.get_table(table_name, project_id) 

1626 

1627 scores = EvaluateBase.run(self.session, kb_table, params) 

1628 

1629 return scores