Coverage for mindsdb / interfaces / knowledge_base / controller.py: 34%
900 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import os
2import copy
3from typing import Dict, List, Optional, Any, Text, Tuple, Union
4import json
5import decimal
7import pandas as pd
8import numpy as np
9from pydantic import BaseModel, ValidationError
10from sqlalchemy.orm.attributes import flag_modified
12from mindsdb_sql_parser.ast import BinaryOperation, Constant, Identifier, Select, Update, Delete, Star
13from mindsdb_sql_parser import parse_sql
15from mindsdb.integrations.libs.keyword_search_base import KeywordSearchBase
16from mindsdb.integrations.utilities.query_traversal import query_traversal
18import mindsdb.interfaces.storage.db as db
19from mindsdb.integrations.libs.vectordatabase_handler import (
20 DistanceFunction,
21 TableField,
22 VectorStoreHandler,
23)
24from mindsdb.integrations.utilities.handler_utils import get_api_key
25from mindsdb.integrations.utilities.handlers.auth_utilities.snowflake import get_validated_jwt
27from mindsdb.integrations.utilities.rag.settings import RerankerMode
29from mindsdb.interfaces.agents.constants import get_default_embeddings_model_class, MAX_INSERT_BATCH_SIZE
30from mindsdb.interfaces.agents.provider_utils import get_llm_provider
32try:
33 from mindsdb.interfaces.agents.langchain_agent import create_chat_model
34except ModuleNotFoundError as exc: # pragma: no cover - optional dependency
35 if getattr(exc, "name", "") and "langchain" in exc.name:
36 create_chat_model = None
37 _LANGCHAIN_IMPORT_ERROR = exc
38 else: # Unknown import error, surface it
39 raise
40else:
41 _LANGCHAIN_IMPORT_ERROR = None
42from mindsdb.interfaces.database.projects import ProjectController
43from mindsdb.interfaces.knowledge_base.preprocessing.models import PreprocessingConfig, Document
44from mindsdb.interfaces.knowledge_base.preprocessing.document_preprocessor import PreprocessorFactory
45from mindsdb.interfaces.knowledge_base.evaluate import EvaluateBase
46from mindsdb.interfaces.knowledge_base.executor import KnowledgeBaseQueryExecutor
47from mindsdb.interfaces.model.functions import PredictorRecordNotFound
48from mindsdb.utilities.exception import EntityExistsError, EntityNotExistsError
49from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator, KeywordSearchArgs
50from mindsdb.utilities.config import config
51from mindsdb.utilities.context import context as ctx
53from mindsdb.api.executor.command_executor import ExecuteCommands
54from mindsdb.api.executor.utilities.sql import query_df
55from mindsdb.utilities import log
56from mindsdb.integrations.utilities.rag.rerankers.base_reranker import BaseLLMReranker, ListwiseLLMReranker
57from mindsdb.interfaces.knowledge_base.llm_client import LLMClient
59logger = log.getLogger(__name__)
62def _require_agent_extra(feature: str):
63 if create_chat_model is None:
64 raise ImportError(
65 f"{feature} requires the optional agent dependencies. Install them via `pip install mindsdb[kb]`."
66 ) from _LANGCHAIN_IMPORT_ERROR
69class KnowledgeBaseInputParams(BaseModel):
70 metadata_columns: List[str] | None = None
71 content_columns: List[str] | None = None
72 id_column: str | None = None
73 kb_no_upsert: bool = False
74 kb_skip_existing: bool = False
75 embedding_model: Dict[Text, Any] | None = None
76 is_sparse: bool = False
77 vector_size: int | None = None
78 reranking_model: Union[Dict[Text, Any], bool] | None = None
79 preprocessing: Dict[Text, Any] | None = None
81 class Config:
82 extra = "forbid"
85def get_model_params(model_params: dict, default_config_key: str):
86 """
87 Get model parameters by combining default config with user provided parameters.
88 """
89 combined_model_params = copy.deepcopy(config.get(default_config_key, {}))
91 if model_params:
92 if not isinstance(model_params, dict): 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true
93 raise ValueError("Model parameters must be passed as a JSON object")
95 # if provider mismatches - don't use default values
96 if "provider" in model_params and model_params["provider"] != combined_model_params.get("provider"): 96 ↛ 99line 96 didn't jump to line 99 because the condition on line 96 was always true
97 return model_params
99 combined_model_params.update(model_params)
101 combined_model_params.pop("use_default_llm", None)
103 return combined_model_params
106def adapt_embedding_model_params(embedding_model_params: dict):
107 """
108 Prepare parameters for embedding model.
109 """
110 params_copy = copy.deepcopy(embedding_model_params)
111 provider = params_copy.pop("provider", None).lower()
112 api_key = get_api_key(provider, params_copy, strict=False) or params_copy.get("api_key")
113 # Underscores are replaced because the provider name ultimately gets mapped to a class name.
114 # This is mostly to support Azure OpenAI (azure_openai); the mapped class name is 'AzureOpenAIEmbeddings'.
115 params_copy["class"] = provider.replace("_", "")
116 if provider == "azure_openai":
117 # Azure OpenAI expects the api_key to be passed as 'openai_api_key'.
118 params_copy["openai_api_key"] = api_key
119 params_copy["azure_endpoint"] = params_copy.pop("base_url")
120 if "chunk_size" not in params_copy:
121 params_copy["chunk_size"] = 2048
122 if "api_version" in params_copy:
123 params_copy["openai_api_version"] = params_copy["api_version"]
124 else:
125 params_copy[f"{provider}_api_key"] = api_key
126 params_copy.pop("api_key", None)
127 params_copy["model"] = params_copy.pop("model_name", None)
129 return params_copy
132def get_reranking_model_from_params(reranking_model_params: dict):
133 """
134 Create reranking model from parameters.
135 """
136 from mindsdb.integrations.utilities.rag.settings import RerankerConfig
138 # Work on a copy; do not mutate caller's dict
139 params_copy = copy.deepcopy(reranking_model_params)
141 # Handle API key if not provided
142 provider = params_copy.get("provider", "openai").lower()
143 if "api_key" not in params_copy:
144 params_copy["api_key"] = get_api_key(provider, params_copy, strict=False)
146 # Handle model_name -> model alias for backward compatibility
147 if "model_name" in params_copy and "model" not in params_copy:
148 params_copy["model"] = params_copy.pop("model_name")
150 # Validate core fields (e.g. mode) via Pydantic
151 try:
152 cfg = RerankerConfig(**params_copy)
153 except ValueError as e:
154 raise ValueError(f"Invalid reranker configuration: {str(e)}")
156 # Merge validated fields back, preserving any extra user fields
157 validated = cfg.model_dump()
158 reranker_params = {**params_copy, **validated}
160 # Choose reranker class based on validated mode
161 if cfg.mode == RerankerMode.LISTWISE:
162 return ListwiseLLMReranker(**reranker_params)
163 return BaseLLMReranker(**reranker_params)
166def safe_pandas_is_datetime(value: str) -> bool:
167 """
168 Check if the value can be parsed as a datetime.
169 """
170 try:
171 result = pd.api.types.is_datetime64_any_dtype(value)
172 return result
173 except ValueError:
174 return False
177def to_json(obj):
178 if obj is None:
179 return None
180 try:
181 return json.dumps(obj)
182 except TypeError:
183 return obj
186def rotate_provider_api_key(params):
187 """
188 Check api key for specific providers. At the moment it checks and updated jwt token of snowflake provider
189 :param params: input params, can be modified by this function
190 :return: a new api key if it is refreshed
191 """
192 provider = params.get("provider").lower()
194 if provider == "snowflake": 194 ↛ 195line 194 didn't jump to line 195 because the condition on line 194 was never true
195 if "snowflake_account_id" in params:
196 # `snowflake_account_id` is the old name
197 params["account_id"] = params.pop("snowflake_account_id")
199 if "private_key" not in params:
200 return
202 api_key = params.get("api_key")
203 api_key2 = get_validated_jwt(
204 api_key,
205 account=params.get("account_id"),
206 user=params.get("user"),
207 private_key=params["private_key"],
208 )
209 if api_key2 != api_key:
210 # update keys
211 params["api_key"] = api_key2
212 return api_key2
215class KnowledgeBaseTable:
216 """
217 Knowledge base table interface
218 Handlers requests to KB table and modifies data in linked vector db table
219 """
221 def __init__(self, kb: db.KnowledgeBase, session):
222 self._kb = kb
223 self._vector_db = None
224 self.session = session
225 self.document_preprocessor = None
226 self.document_loader = None
227 self.model_params = None
229 self.kb_to_vector_columns = {"id": "_original_doc_id", "chunk_id": "id", "chunk_content": "content"}
230 if self._kb.params.get("version", 0) < 2:
231 self.kb_to_vector_columns["id"] = "original_doc_id"
233 def configure_preprocessing(self, config: Optional[dict] = None):
234 """Configure preprocessing for the knowledge base table"""
235 logger.debug(f"Configuring preprocessing with config: {config}")
236 self.document_preprocessor = None # Reset existing preprocessor
237 if config is None: 237 ↛ 241line 237 didn't jump to line 241 because the condition on line 237 was always true
238 config = {}
240 # Ensure content_column is set for JSON chunking if not already specified
241 if config.get("type") == "json_chunking" and config.get("json_chunking_config"): 241 ↛ 242line 241 didn't jump to line 242 because the condition on line 241 was never true
242 if "content_column" not in config["json_chunking_config"]:
243 config["json_chunking_config"]["content_column"] = "content"
245 preprocessing_config = PreprocessingConfig(**config)
246 self.document_preprocessor = PreprocessorFactory.create_preprocessor(preprocessing_config)
248 # set doc_id column name
249 self.document_preprocessor.config.doc_id_column_name = self.kb_to_vector_columns["id"]
251 logger.debug(f"Created preprocessor of type: {type(self.document_preprocessor)}")
253 def select_query(self, query: Select) -> pd.DataFrame:
254 """
255 Handles select from KB table.
256 Replaces content values with embeddings in where clause. Sends query to vector db
257 :param query: query to KB table
258 :return: dataframe with the result table
259 """
261 # Copy query for complex execution via DuckDB: DISTINCT, GROUP BY etc.
262 query_copy = copy.deepcopy(query)
264 executor = KnowledgeBaseQueryExecutor(self)
265 df = executor.run(query)
267 # copy metadata to columns
268 if "metadata" in df.columns: 268 ↛ 283line 268 didn't jump to line 283 because the condition on line 268 was always true
269 meta_columns = self._get_allowed_metadata_columns()
270 if meta_columns: 270 ↛ 271line 270 didn't jump to line 271 because the condition on line 270 was never true
271 meta_data = pd.json_normalize(df["metadata"])
272 # exclude absent columns and used colunns
273 df_columns = list(df.columns)
274 meta_columns = list(set(meta_columns).intersection(meta_data.columns).difference(df_columns))
276 # add columns
277 df = df.join(meta_data[meta_columns])
279 # put metadata in the end
280 df_columns.remove("metadata")
281 df = df[df_columns + meta_columns + ["metadata"]]
283 if ( 283 ↛ 291line 283 didn't jump to line 291 because the condition on line 283 was never true
284 query_copy.group_by is not None
285 or query_copy.order_by is not None
286 or query_copy.having is not None
287 or query_copy.distinct is True
288 or len(query_copy.targets) != 1
289 or not isinstance(query_copy.targets[0], Star)
290 ):
291 query_copy.where = None
292 if "metadata" in df.columns:
293 df["metadata"] = df["metadata"].apply(to_json)
295 if query_copy.from_table is None:
296 query_copy.from_table = Identifier(parts=[self._kb.name])
298 df = query_df(df, query_copy, session=self.session)
300 return df
302 def select(self, query, disable_reranking=False):
303 logger.debug(f"Processing select query: {query}")
305 # Extract the content query text for potential reranking
307 db_handler = self.get_vector_db()
309 logger.debug("Replaced content with embeddings in where clause")
310 # set table name
311 query.from_table = Identifier(parts=[self._kb.vector_database_table])
312 logger.debug(f"Set table name to: {self._kb.vector_database_table}")
314 query.targets = [
315 Identifier(TableField.ID.value),
316 Identifier(TableField.CONTENT.value),
317 Identifier(TableField.METADATA.value),
318 Identifier(TableField.DISTANCE.value),
319 ]
321 # Get response from vector db
322 logger.debug(f"Using vector db handler: {type(db_handler)}")
324 # extract values from conditions and prepare for vectordb
325 conditions = []
326 keyword_search_conditions = []
327 keyword_search_cols_and_values = []
328 query_text = None
329 relevance_threshold = None
330 relevance_threshold_allowed_operators = [
331 FilterOperator.GREATER_THAN_OR_EQUAL.value,
332 FilterOperator.GREATER_THAN.value,
333 ]
334 gt_filtering = False
335 query_conditions = db_handler.extract_conditions(query.where)
336 hybrid_search_alpha = None
337 if query_conditions is not None: 337 ↛ 388line 337 didn't jump to line 388 because the condition on line 337 was always true
338 for item in query_conditions:
339 if (item.column == "relevance") and (item.op.value in relevance_threshold_allowed_operators): 339 ↛ 340line 339 didn't jump to line 340 because the condition on line 339 was never true
340 try:
341 relevance_threshold = float(item.value)
342 # Validate range: must be between 0 and 1
343 if not (0 <= relevance_threshold <= 1):
344 raise ValueError(f"relevance_threshold must be between 0 and 1, got: {relevance_threshold}")
345 if item.op.value == FilterOperator.GREATER_THAN.value:
346 gt_filtering = True
347 logger.debug(f"Found relevance_threshold in query: {relevance_threshold}")
348 except (ValueError, TypeError) as e:
349 error_msg = f"Invalid relevance_threshold value: {item.value}. {e}"
350 logger.error(error_msg)
351 raise ValueError(error_msg) from e
352 elif (item.column == "relevance") and (item.op.value not in relevance_threshold_allowed_operators): 352 ↛ 353line 352 didn't jump to line 353 because the condition on line 352 was never true
353 raise ValueError(
354 f"Invalid operator for relevance: {item.op.value}. Only the following operators are allowed: "
355 f"{','.join(relevance_threshold_allowed_operators)}."
356 )
357 elif item.column == "reranking": 357 ↛ 358line 357 didn't jump to line 358 because the condition on line 357 was never true
358 if item.value is False or (isinstance(item.value, str) and item.value.lower() == "false"):
359 disable_reranking = True
360 elif item.column == "hybrid_search":
361 if item.value: 361 ↛ 338line 361 didn't jump to line 338 because the condition on line 361 was always true
362 if hybrid_search_alpha is None: 362 ↛ 338line 362 didn't jump to line 338 because the condition on line 362 was always true
363 hybrid_search_alpha = 0.5
364 elif item.column == "hybrid_search_alpha": 364 ↛ 366line 364 didn't jump to line 366 because the condition on line 364 was never true
365 # validate item.value is a float
366 if not isinstance(item.value, (float, int)):
367 raise ValueError(f"Invalid hybrid_search_alpha value: {item.value}. Must be a float or int.")
368 # validate hybrid search alpha is between 0 and 1
369 if not (0 <= item.value <= 1):
370 raise ValueError(f"Invalid hybrid_search_alpha value: {item.value}. Must be between 0 and 1.")
371 hybrid_search_alpha = item.value
372 elif item.column == TableField.CONTENT.value: 372 ↛ 385line 372 didn't jump to line 385 because the condition on line 372 was always true
373 query_text = item.value
375 # replace content with embeddings
376 conditions.append(
377 FilterCondition(
378 column=TableField.EMBEDDINGS.value,
379 value=self._content_to_embeddings(item.value),
380 op=FilterOperator.EQUAL,
381 )
382 )
383 keyword_search_cols_and_values.append((TableField.CONTENT.value, item.value))
384 else:
385 conditions.append(item)
386 keyword_search_conditions.append(item) # keyword search conditions do not use embeddings
388 if len(keyword_search_cols_and_values) > 1: 388 ↛ 389line 388 didn't jump to line 389 because the condition on line 388 was never true
389 raise ValueError(
390 "Multiple content columns found in query conditions. "
391 "Only one content column is allowed for keyword search."
392 )
394 logger.debug(f"Extracted query text: {query_text}")
396 self.addapt_conditions_columns(conditions)
398 # Set default limit if query is present
399 limit = query.limit.value if query.limit is not None else None
400 if query_text is not None: 400 ↛ 414line 400 didn't jump to line 414 because the condition on line 400 was always true
401 if limit is None: 401 ↛ 403line 401 didn't jump to line 403 because the condition on line 401 was always true
402 limit = 10
403 elif limit > 100:
404 limit = 100
406 if not disable_reranking: 406 ↛ 411line 406 didn't jump to line 411 because the condition on line 406 was always true
407 # expand limit, get more records before reranking usage:
408 # get twice size of input but not greater than 30
409 query_limit = min(limit * 2, limit + 30)
410 else:
411 query_limit = limit
412 query.limit = Constant(query_limit)
414 allowed_metadata_columns = self._get_allowed_metadata_columns()
416 if hybrid_search_alpha is None: 416 ↛ 417line 416 didn't jump to line 417 because the condition on line 416 was never true
417 hybrid_search_alpha = 1
419 if hybrid_search_alpha > 0: 419 ↛ 425line 419 didn't jump to line 425 because the condition on line 419 was always true
420 df = db_handler.dispatch_select(query, conditions, allowed_metadata_columns=allowed_metadata_columns)
421 df = self.addapt_result_columns(df)
422 logger.debug(f"Query returned {len(df)} rows")
423 logger.debug(f"Columns in response: {df.columns.tolist()}")
424 else:
425 df = pd.DataFrame([], columns=["id", "chunk_id", "chunk_content", "metadata", "distance"])
427 # check if db_handler inherits from KeywordSearchBase
428 if hybrid_search_alpha < 1: 428 ↛ 471line 428 didn't jump to line 471 because the condition on line 428 was always true
429 if not isinstance(db_handler, KeywordSearchBase): 429 ↛ 430line 429 didn't jump to line 430 because the condition on line 429 was never true
430 raise ValueError(
431 f"Hybrid search is enabled but the db_handler {type(db_handler)} does not support it. "
432 )
434 # If query_text is present, use it for keyword search
435 logger.debug(f"Performing keyword search with query text: {query_text}")
436 keyword_search_args = KeywordSearchArgs(query=query_text, column=TableField.CONTENT.value)
437 keyword_query_obj = copy.deepcopy(query)
439 keyword_query_obj.targets = [
440 Identifier(TableField.ID.value),
441 Identifier(TableField.CONTENT.value),
442 Identifier(TableField.METADATA.value),
443 ]
445 df_keyword = db_handler.dispatch_select(
446 keyword_query_obj,
447 keyword_search_conditions,
448 allowed_metadata_columns=allowed_metadata_columns,
449 keyword_search_args=keyword_search_args,
450 )
451 df_keyword = self.addapt_result_columns(df_keyword)
452 logger.debug(f"Keyword search returned {len(df_keyword)} rows")
453 logger.debug(f"Columns in keyword search response: {df_keyword.columns.tolist()}")
454 # ensure df and df_keyword_select have exactly the same columns
455 if not df_keyword.empty: 455 ↛ 471line 455 didn't jump to line 471 because the condition on line 455 was always true
456 if df.empty: 456 ↛ 457line 456 didn't jump to line 457 because the condition on line 456 was never true
457 df = df_keyword
458 else:
459 df_keyword[TableField.DISTANCE.value] = hybrid_search_alpha * df_keyword[TableField.DISTANCE.value]
460 df[TableField.DISTANCE.value] = (1 - hybrid_search_alpha) * df[TableField.DISTANCE.value]
462 df = pd.concat([df, df_keyword], ignore_index=True)
463 # sort by distance if distance column exists
464 if TableField.DISTANCE.value in df.columns: 464 ↛ 467line 464 didn't jump to line 467 because the condition on line 464 was always true
465 df = df.sort_values(by=TableField.DISTANCE.value, ascending=True)
466 # if chunk_id column exists remove duplicates based on chunk_id
467 if "chunk_id" in df.columns: 467 ↛ 471line 467 didn't jump to line 471 because the condition on line 467 was always true
468 df = df.drop_duplicates(subset=["chunk_id"])
470 # Check if we have a rerank_model configured in KB params
471 df = self.add_relevance(df, query_text, relevance_threshold, disable_reranking)
472 if limit is not None: 472 ↛ 476line 472 didn't jump to line 476 because the condition on line 472 was always true
473 df = df[:limit]
475 # if relevance filtering method is strictly GREATER THAN we filter the df
476 if gt_filtering: 476 ↛ 477line 476 didn't jump to line 477 because the condition on line 476 was never true
477 relevance_scores = TableField.RELEVANCE.value
478 df = df[df[relevance_scores] > relevance_threshold]
480 return df
482 def _get_allowed_metadata_columns(self) -> List[str] | None:
483 # Return list of KB columns to restrict querying, if None: no restrictions
485 if self._kb.params.get("version", 0) < 2: 485 ↛ 489line 485 didn't jump to line 489 because the condition on line 485 was always true
486 # disable for old version KBs
487 return None
489 user_columns = self._kb.params.get("metadata_columns", [])
490 dynamic_columns = self._kb.params.get("inserted_metadata", [])
492 columns = set(user_columns) | set(dynamic_columns)
493 return [col.lower() for col in columns]
495 def score_documents(self, query_text, documents, reranking_model_params):
496 rotate_provider_api_key(reranking_model_params)
497 reranker = get_reranking_model_from_params(reranking_model_params)
498 return reranker.get_scores(query_text, documents)
500 def add_relevance(self, df, query_text, relevance_threshold=None, disable_reranking=False):
501 relevance_column = TableField.RELEVANCE.value
503 reranking_model_params = get_model_params(self._kb.params.get("reranking_model"), "default_reranking_model")
504 if reranking_model_params and query_text and len(df) > 0 and not disable_reranking: 504 ↛ 507line 504 didn't jump to line 507 because the condition on line 504 was never true
505 # Use reranker for relevance score
507 new_api_key = rotate_provider_api_key(reranking_model_params)
508 if new_api_key:
509 # update key
510 if "reranking_model" not in self._kb.params:
511 self._kb.params["reranking_model"] = {}
512 self._kb.params["reranking_model"]["api_key"] = new_api_key
513 flag_modified(self._kb, "params")
514 db.session.commit()
516 # Apply custom filtering threshold if provided
517 if relevance_threshold is not None:
518 reranking_model_params["filtering_threshold"] = relevance_threshold
519 logger.info(f"Using custom filtering threshold: {relevance_threshold}")
521 reranker = get_reranking_model_from_params(reranking_model_params)
522 # Get documents to rerank
523 documents = df["chunk_content"].tolist()
524 # Use the get_scores method with disable_events=True
525 scores = reranker.get_scores(query_text, documents)
526 # Add scores as the relevance column
527 df[relevance_column] = scores
529 # Filter by threshold
530 scores_array = np.array(scores)
531 df = df[scores_array >= reranker.filtering_threshold]
533 elif "distance" in df.columns: 533 ↛ 541line 533 didn't jump to line 541 because the condition on line 533 was always true
534 # Calculate relevance from distance
535 logger.info("Calculating relevance from vector distance")
536 df[relevance_column] = 1 / (1 + df["distance"])
537 if relevance_threshold is not None: 537 ↛ 538line 537 didn't jump to line 538 because the condition on line 537 was never true
538 df = df[df[relevance_column] > relevance_threshold]
540 else:
541 df[relevance_column] = None
542 df["distance"] = None
543 # Sort by relevance
544 df = df.sort_values(by=relevance_column, ascending=False)
545 return df
547 def addapt_conditions_columns(self, conditions):
548 if conditions is None: 548 ↛ 549line 548 didn't jump to line 549 because the condition on line 548 was never true
549 return
550 for condition in conditions:
551 if condition.column in self.kb_to_vector_columns: 551 ↛ 552line 551 didn't jump to line 552 because the condition on line 551 was never true
552 condition.column = self.kb_to_vector_columns[condition.column]
554 def addapt_result_columns(self, df):
555 col_update = {}
556 for kb_col, vec_col in self.kb_to_vector_columns.items():
557 if vec_col in df.columns:
558 col_update[vec_col] = kb_col
560 df = df.rename(columns=col_update)
562 columns = list(df.columns)
563 # update id, get from metadata
564 df[TableField.ID.value] = df[TableField.METADATA.value].apply(
565 lambda m: None if m is None else m.get(self.kb_to_vector_columns["id"])
566 )
568 # id on first place
569 return df[[TableField.ID.value] + columns]
571 def insert_files(self, file_names: List[str]):
572 """Process and insert files"""
573 if not self.document_loader:
574 raise ValueError("Document loader not configured")
576 documents = list(self.document_loader.load_files(file_names))
577 if documents:
578 self.insert_documents(documents)
580 def insert_web_pages(self, urls: List[str], crawl_depth: int, limit: int, filters: List[str] = None):
581 """Process and insert web pages"""
582 if not self.document_loader:
583 raise ValueError("Document loader not configured")
585 documents = list(
586 self.document_loader.load_web_pages(urls, limit=limit, crawl_depth=crawl_depth, filters=filters)
587 )
588 if documents:
589 self.insert_documents(documents)
591 def insert_query_result(self, query: str, project_name: str):
592 """Process and insert SQL query results"""
593 ast_query = parse_sql(query)
595 command_executor = ExecuteCommands(self.session)
596 response = command_executor.execute_command(ast_query, project_name)
598 if response.error_code is not None:
599 raise ValueError(f"Error executing query: {response.error_message}")
601 if response.data is None:
602 raise ValueError("Query returned no data")
604 records = response.data.records
605 df = pd.DataFrame(records)
607 self.insert(df)
609 def insert_rows(self, rows: List[Dict]):
610 """Process and insert raw data rows"""
611 if not rows:
612 return
614 df = pd.DataFrame(rows)
616 self.insert(df)
618 def insert_documents(self, documents: List[Document]):
619 """Process and insert documents with preprocessing if configured"""
620 df = pd.DataFrame([doc.model_dump() for doc in documents])
622 self.insert(df)
624 def update_query(self, query: Update):
625 # add embeddings to content in updated collumns
626 query = copy.deepcopy(query)
628 emb_col = TableField.EMBEDDINGS.value
629 cont_col = TableField.CONTENT.value
631 db_handler = self.get_vector_db()
632 conditions = db_handler.extract_conditions(query.where)
633 doc_id = None
634 for condition in conditions:
635 if condition.column == "chunk_id" and condition.op == FilterOperator.EQUAL:
636 doc_id = condition.value
638 if cont_col in query.update_columns:
639 content = query.update_columns[cont_col]
641 # Apply preprocessing to content if configured
642 if self.document_preprocessor:
643 doc = Document(
644 id=doc_id,
645 content=content.value,
646 metadata={}, # Empty metadata for content-only updates
647 )
648 processed_chunks = self.document_preprocessor.process_documents([doc])
649 if processed_chunks:
650 content.value = processed_chunks[0].content
652 query.update_columns[emb_col] = Constant(self._content_to_embeddings(content.value))
654 if "metadata" not in query.update_columns:
655 query.update_columns["metadata"] = Constant({})
657 # TODO search content in where clause?
659 # set table name
660 query.table = Identifier(parts=[self._kb.vector_database_table])
662 # send to vectordb
663 self.addapt_conditions_columns(conditions)
664 db_handler.dispatch_update(query, conditions)
666 def delete_query(self, query: Delete):
667 """
668 Handles delete query to KB table.
669 Replaces content values with embeddings in WHERE clause. Sends query to vector db
670 :param query: query to KB table
671 """
672 query_traversal(query.where, self._replace_query_content)
674 # set table name
675 query.table = Identifier(parts=[self._kb.vector_database_table])
677 # send to vectordb
678 db_handler = self.get_vector_db()
679 conditions = db_handler.extract_conditions(query.where)
680 self.addapt_conditions_columns(conditions)
681 db_handler.dispatch_delete(query, conditions)
683 def hybrid_search(
684 self,
685 query: str,
686 keywords: List[str] = None,
687 metadata: Dict[str, str] = None,
688 distance_function=DistanceFunction.COSINE_DISTANCE,
689 ) -> pd.DataFrame:
690 query_df = pd.DataFrame.from_records([{TableField.CONTENT.value: query}])
691 embeddings_df = self._df_to_embeddings(query_df)
692 if embeddings_df.empty:
693 return pd.DataFrame([])
694 embeddings = embeddings_df.iloc[0][TableField.EMBEDDINGS.value]
695 keywords_query = None
696 if keywords is not None:
697 keywords_query = " ".join(keywords)
698 db_handler = self.get_vector_db()
699 return db_handler.hybrid_search(
700 self._kb.vector_database_table,
701 embeddings,
702 query=keywords_query,
703 metadata=metadata,
704 distance_function=distance_function,
705 )
707 def clear(self):
708 """
709 Clear data in KB table
710 Sends delete to vector db table
711 """
712 db_handler = self.get_vector_db()
713 db_handler.delete(self._kb.vector_database_table)
715 def insert(self, df: pd.DataFrame, params: dict = None):
716 """Insert dataframe to KB table.
718 Args:
719 df: DataFrame to insert
720 params: User parameters of insert
721 """
722 if df.empty:
723 return
725 if len(df) > MAX_INSERT_BATCH_SIZE:
726 # auto-batching
727 batch_size = MAX_INSERT_BATCH_SIZE
729 chunk_num = 0
730 while chunk_num * batch_size < len(df):
731 df2 = df[chunk_num * batch_size : (chunk_num + 1) * batch_size]
732 self.insert(df2, params=params)
733 chunk_num += 1
734 return
736 try:
737 run_query_id = ctx.run_query_id
738 # Link current KB to running query (where KB is used to insert data)
739 if run_query_id is not None:
740 self._kb.query_id = run_query_id
741 db.session.commit()
743 except AttributeError:
744 ...
746 df.replace({np.nan: None}, inplace=True)
748 # First adapt column names to identify content and metadata columns
749 adapted_df, normalized_columns = self._adapt_column_names(df)
750 content_columns = normalized_columns["content_columns"]
752 # Convert DataFrame rows to documents, creating separate documents for each content column
753 raw_documents = []
754 for idx, row in adapted_df.iterrows():
755 base_metadata = self._parse_metadata(row.get(TableField.METADATA.value, {}))
756 provided_id = row.get(TableField.ID.value)
758 for col in content_columns:
759 content = row.get(col)
760 if content and str(content).strip():
761 content_str = str(content)
763 # Use provided_id directly if it exists, otherwise generate one
764 doc_id = self._generate_document_id(content_str, col, provided_id)
766 metadata = {
767 **base_metadata,
768 "_original_row_index": str(idx), # provide link to original row index
769 "_content_column": col,
770 }
772 raw_documents.append(Document(content=content_str, id=doc_id, metadata=metadata))
774 # Apply preprocessing to all documents if preprocessor exists
775 if self.document_preprocessor:
776 processed_chunks = self.document_preprocessor.process_documents(raw_documents)
777 else:
778 processed_chunks = raw_documents # Use raw documents if no preprocessing
780 # Convert processed chunks back to DataFrame with standard structure
781 df = pd.DataFrame(
782 [
783 {
784 TableField.CONTENT.value: chunk.content,
785 TableField.ID.value: chunk.id,
786 TableField.METADATA.value: chunk.metadata,
787 }
788 for chunk in processed_chunks
789 ]
790 )
792 if df.empty:
793 logger.warning("No valid content found in any content columns")
794 return
796 # Check if we should skip existing items (before calculating embeddings)
797 if params is not None and params.get("kb_skip_existing", False):
798 logger.debug(f"Checking for existing items to skip before processing {len(df)} items")
799 db_handler = self.get_vector_db()
801 # Get list of IDs from current batch
802 current_ids = df[TableField.ID.value].dropna().astype(str).tolist()
803 if current_ids:
804 # Check which IDs already exist
805 existing_ids = db_handler.check_existing_ids(self._kb.vector_database_table, current_ids)
806 if existing_ids:
807 # Filter out existing items
808 df = df[~df[TableField.ID.value].astype(str).isin(existing_ids)]
809 logger.info(f"Skipped {len(existing_ids)} existing items, processing {len(df)} new items")
811 if df.empty:
812 logger.info("All items already exist, nothing to insert")
813 return
815 # add embeddings and send to vector db
816 df_emb = self._df_to_embeddings(df)
817 df = pd.concat([df, df_emb], axis=1)
818 db_handler = self.get_vector_db()
820 if params is not None and params.get("kb_no_upsert", False):
821 # speed up inserting by disable checking existing records
822 db_handler.insert(self._kb.vector_database_table, df)
823 else:
824 db_handler.do_upsert(self._kb.vector_database_table, df)
826 def _adapt_column_names(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, List[str]]]:
827 """
828 Convert input columns for vector db input
829 - id, content and metadata
830 """
831 # Debug incoming data
832 logger.debug(f"Input DataFrame columns: {df.columns}")
833 logger.debug(f"Input DataFrame first row: {df.iloc[0].to_dict()}")
835 params = self._kb.params
836 columns = list(df.columns)
838 # -- prepare id --
839 id_column = params.get("id_column")
840 if id_column is not None and id_column not in columns:
841 id_column = None
843 if id_column is None and TableField.ID.value in columns:
844 id_column = TableField.ID.value
846 # Also check for case-insensitive 'id' column
847 if id_column is None:
848 column_map = {col.lower(): col for col in columns}
849 if "id" in column_map:
850 id_column = column_map["id"]
852 if id_column is not None:
853 columns.remove(id_column)
854 logger.debug(f"Using ID column: {id_column}")
856 # Create output dataframe
857 df_out = pd.DataFrame()
859 # Add ID if present
860 if id_column is not None:
861 df_out[TableField.ID.value] = df[id_column]
862 logger.debug(f"Added IDs: {df_out[TableField.ID.value].tolist()}")
864 # -- prepare content and metadata --
865 content_columns = params.get("content_columns", [TableField.CONTENT.value])
866 metadata_columns = params.get("metadata_columns")
868 logger.debug(f"Processing with: content_columns={content_columns}, metadata_columns={metadata_columns}")
870 # Handle SQL query result columns
871 if content_columns:
872 # Ensure content columns are case-insensitive
873 column_map = {col.lower(): col for col in columns}
874 content_columns = [column_map.get(col.lower(), col) for col in content_columns]
875 logger.debug(f"Mapped content columns: {content_columns}")
877 if metadata_columns:
878 # Ensure metadata columns are case-insensitive
879 column_map = {col.lower(): col for col in columns}
880 metadata_columns = [column_map.get(col.lower(), col) for col in metadata_columns]
881 logger.debug(f"Mapped metadata columns: {metadata_columns}")
883 content_columns = list(set(content_columns).intersection(columns))
884 if len(content_columns) == 0:
885 raise ValueError(f"Content columns {params.get('content_columns')} not found in dataset: {columns}")
887 if metadata_columns is not None:
888 metadata_columns = list(set(metadata_columns).intersection(columns))
889 else:
890 # all the rest columns
891 metadata_columns = list(set(columns).difference(content_columns))
893 # update list of used columns
894 inserted_metadata = set(self._kb.params.get("inserted_metadata", []))
895 inserted_metadata.update(metadata_columns)
896 self._kb.params["inserted_metadata"] = list(inserted_metadata)
897 flag_modified(self._kb, "params")
898 db.session.commit()
900 # Add content columns directly (don't combine them)
901 for col in content_columns:
902 df_out[col] = df[col]
904 # Add metadata
905 if metadata_columns and len(metadata_columns) > 0:
907 def convert_row_to_metadata(row):
908 metadata = {}
909 for col in metadata_columns:
910 value = row[col]
911 value_type = type(value)
912 # Convert numpy/pandas types to Python native types
913 if safe_pandas_is_datetime(value) or isinstance(value, pd.Timestamp):
914 value = str(value)
915 elif pd.api.types.is_integer_dtype(value_type):
916 value = int(value)
917 elif pd.api.types.is_float_dtype(value_type) or isinstance(value, decimal.Decimal):
918 value = float(value)
919 elif pd.api.types.is_bool_dtype(value_type):
920 value = bool(value)
921 elif isinstance(value, dict):
922 metadata.update(value)
923 continue
924 elif value is not None:
925 value = str(value)
926 metadata[col] = value
927 return metadata
929 metadata_dict = df[metadata_columns].apply(convert_row_to_metadata, axis=1)
930 df_out[TableField.METADATA.value] = metadata_dict
932 logger.debug(f"Output DataFrame columns: {df_out.columns}")
933 logger.debug(f"Output DataFrame first row: {df_out.iloc[0].to_dict() if not df_out.empty else 'Empty'}")
935 return df_out, {"content_columns": content_columns, "metadata_columns": metadata_columns}
937 def _replace_query_content(self, node, **kwargs):
938 if isinstance(node, BinaryOperation):
939 if isinstance(node.args[0], Identifier) and isinstance(node.args[1], Constant):
940 col_name = node.args[0].parts[-1]
941 if col_name.lower() == TableField.CONTENT.value:
942 # replace
943 node.args[0].parts = [TableField.EMBEDDINGS.value]
944 node.args[1].value = [self._content_to_embeddings(node.args[1].value)]
946 def get_vector_db(self) -> VectorStoreHandler:
947 """
948 helper to get vector db handler
949 """
950 if self._vector_db is None: 950 ↛ 951line 950 didn't jump to line 951 because the condition on line 950 was never true
951 database = db.Integration.query.get(self._kb.vector_database_id)
952 if database is None:
953 raise ValueError("Vector database not found. Is it deleted?")
954 database_name = database.name
955 self._vector_db = self.session.integration_controller.get_data_handler(database_name)
956 return self._vector_db
958 def get_vector_db_table_name(self) -> str:
959 """
960 helper to get underlying table name used for embeddings
961 """
962 return self._kb.vector_database_table
964 def _df_to_embeddings(self, df: pd.DataFrame) -> pd.DataFrame:
965 """
966 Returns embeddings for input dataframe.
967 Uses model embedding model to convert content to embeddings.
968 Automatically detects input and output of model using model description
969 :param df:
970 :return: dataframe with embeddings
971 """
973 if df.empty:
974 return pd.DataFrame([], columns=[TableField.EMBEDDINGS.value])
976 model_id = self._kb.embedding_model_id
978 if model_id is None:
979 messages = list(df[TableField.CONTENT.value])
980 embedding_params = get_model_params(self._kb.params.get("embedding_model", {}), "default_embedding_model")
981 new_api_key = rotate_provider_api_key(embedding_params)
982 if new_api_key:
983 # update key
984 if "embedding_model" not in self._kb.params:
985 self._kb.params["embedding_model"] = {}
986 self._kb.params["embedding_model"]["api_key"] = new_api_key
987 flag_modified(self._kb, "params")
988 db.session.commit()
990 llm_client = LLMClient(embedding_params, session=self.session)
991 results = llm_client.embeddings(messages)
993 results = [[val] for val in results]
994 return pd.DataFrame(results, columns=[TableField.EMBEDDINGS.value])
996 # get the input columns
997 model_rec = db.session.query(db.Predictor).filter_by(id=model_id).first()
999 assert model_rec is not None, f"Model not found: {model_id}"
1000 model_project = db.session.query(db.Project).filter_by(id=model_rec.project_id).first()
1002 project_datanode = self.session.datahub.get(model_project.name)
1004 model_using = model_rec.learn_args.get("using", {})
1005 input_col = model_using.get("question_column")
1006 if input_col is None:
1007 input_col = model_using.get("input_column")
1009 if input_col is not None and input_col != TableField.CONTENT.value:
1010 df = df.rename(columns={TableField.CONTENT.value: input_col})
1012 df_out = project_datanode.predict(model_name=model_rec.name, df=df, params=self.model_params)
1014 target = model_rec.to_predict[0]
1015 if target != TableField.EMBEDDINGS.value:
1016 # adapt output for vectordb
1017 df_out = df_out.rename(columns={target: TableField.EMBEDDINGS.value})
1019 df_out = df_out[[TableField.EMBEDDINGS.value]]
1021 return df_out
1023 def _content_to_embeddings(self, content: str) -> List[float]:
1024 """
1025 Converts string to embeddings
1026 :param content: input string
1027 :return: embeddings
1028 """
1029 df = pd.DataFrame([[content]], columns=[TableField.CONTENT.value])
1030 res = self._df_to_embeddings(df)
1031 return res[TableField.EMBEDDINGS.value][0]
1033 @staticmethod
1034 def call_litellm_embedding(session, model_params, messages):
1035 args = copy.deepcopy(model_params)
1037 if "model_name" not in args:
1038 raise ValueError("'model_name' must be provided for embedding model")
1040 llm_model = args.pop("model_name")
1041 engine = args.pop("provider")
1043 module = session.integration_controller.get_handler_module("litellm")
1044 if module is None or module.Handler is None:
1045 raise ValueError(f'Unable to use "{engine}" provider. Litellm handler is not installed')
1046 return module.Handler.embeddings(engine, llm_model, messages, args)
1048 def build_rag_pipeline(self, retrieval_config: dict):
1049 """
1050 Builds a RAG pipeline with returned sources
1052 Args:
1053 retrieval_config: dict with retrieval config
1055 Returns:
1056 RAG: Configured RAG pipeline instance
1058 Raises:
1059 ValueError: If the configuration is invalid or required components are missing
1060 """
1061 # Get embedding model from knowledge base
1062 from mindsdb.integrations.handlers.langchain_embedding_handler.langchain_embedding_handler import (
1063 construct_model_from_args,
1064 )
1065 from mindsdb.integrations.utilities.rag.rag_pipeline_builder import RAG
1066 from mindsdb.integrations.utilities.rag.config_loader import load_rag_config
1068 embedding_model_params = get_model_params(self._kb.params.get("embedding_model", {}), "default_embedding_model")
1069 if self._kb.embedding_model:
1070 # Extract embedding model args from knowledge base table
1071 embedding_args = self._kb.embedding_model.learn_args.get("using", {})
1072 # Construct the embedding model directly
1073 embeddings_model = construct_model_from_args(embedding_args)
1074 logger.debug(f"Using knowledge base embedding model with args: {embedding_args}")
1075 elif embedding_model_params:
1076 embeddings_model = construct_model_from_args(adapt_embedding_model_params(embedding_model_params))
1077 logger.debug(f"Using knowledge base embedding model from params: {self._kb.params['embedding_model']}")
1078 else:
1079 embeddings_model_class = get_default_embeddings_model_class()
1080 embeddings_model = embeddings_model_class()
1081 logger.debug("Using default embedding model as knowledge base has no embedding model")
1083 # Update retrieval config with knowledge base parameters
1084 kb_params = {"vector_store_config": {"kb_table": self}}
1086 # Load and validate config
1087 try:
1088 rag_config = load_rag_config(retrieval_config, kb_params, embeddings_model)
1090 # Build LLM if specified
1091 if "llm_model_name" in rag_config:
1092 llm_args = {"model_name": rag_config.llm_model_name}
1093 if not rag_config.llm_provider:
1094 llm_args["provider"] = get_llm_provider(llm_args)
1095 else:
1096 llm_args["provider"] = rag_config.llm_provider
1097 _require_agent_extra("Building knowledge base retrieval pipelines")
1098 rag_config.llm = create_chat_model(llm_args)
1100 # Create RAG pipeline
1101 rag = RAG(rag_config)
1102 logger.debug(f"RAG pipeline created with config: {rag_config}")
1103 return rag
1105 except Exception as e:
1106 logger.exception("Error building RAG pipeline:")
1107 raise ValueError(f"Failed to build RAG pipeline: {str(e)}") from e
1109 def _parse_metadata(self, base_metadata):
1110 """Helper function to robustly parse metadata string to dict"""
1111 if isinstance(base_metadata, dict):
1112 return base_metadata
1113 if isinstance(base_metadata, str):
1114 try:
1115 import ast
1117 return ast.literal_eval(base_metadata)
1118 except (SyntaxError, ValueError):
1119 logger.warning(f"Could not parse metadata: {base_metadata}. Using empty dict.")
1120 return {}
1121 return {}
1123 def _generate_document_id(self, content: str, content_column: str, provided_id: str = None) -> str:
1124 """Generate a deterministic document ID using the utility function."""
1125 from mindsdb.interfaces.knowledge_base.utils import generate_document_id
1127 return generate_document_id(content=content, provided_id=provided_id)
1129 def _convert_metadata_value(self, value):
1130 """
1131 Convert metadata value to appropriate Python type.
1133 Args:
1134 value: The value to convert
1136 Returns:
1137 Converted value in appropriate Python type
1138 """
1139 if pd.isna(value):
1140 return None
1142 # Handle pandas/numpy types
1143 if pd.api.types.is_datetime64_any_dtype(value) or isinstance(value, pd.Timestamp):
1144 return str(value)
1145 elif pd.api.types.is_integer_dtype(type(value)):
1146 return int(value)
1147 elif pd.api.types.is_float_dtype(type(value)):
1148 return float(value)
1149 elif pd.api.types.is_bool_dtype(type(value)):
1150 return bool(value)
1152 # Handle basic Python types
1153 if isinstance(value, (int, float, bool)):
1154 return value
1156 # Convert everything else to string
1157 return str(value)
1159 def create_index(self):
1160 """
1161 Create an index on the knowledge base table
1162 :param index_name: name of the index
1163 :param params: parameters for the index
1164 """
1165 db_handler = self.get_vector_db()
1166 db_handler.create_index(self._kb.vector_database_table)
1169class KnowledgeBaseController:
1170 """
1171 Knowledge base controller handles all
1172 manages knowledge bases
1173 """
1175 KB_VERSION = 2
1177 def __init__(self, session) -> None:
1178 self.session = session
1180 def _check_kb_input_params(self, params):
1181 # check names and types KB params
1182 try:
1183 KnowledgeBaseInputParams.model_validate(params)
1184 except ValidationError as e:
1185 problems = []
1186 for error in e.errors():
1187 parameter = ".".join([str(i) for i in error["loc"]])
1188 param_type = error["type"]
1189 if param_type == "extra_forbidden":
1190 msg = f"Parameter '{parameter}' is not allowed"
1191 else:
1192 msg = f"Error in '{parameter}' (type: {param_type}): {error['msg']}. Input: {repr(error['input'])}"
1193 problems.append(msg)
1195 msg = "\n".join(problems)
1196 if len(problems) > 1:
1197 msg = "\n" + msg
1198 raise ValueError(f"Problem with knowledge base parameters: {msg}") from e
1200 def add(
1201 self,
1202 name: str,
1203 project_name: str,
1204 storage: Identifier,
1205 params: dict,
1206 preprocessing_config: Optional[dict] = None,
1207 if_not_exists: bool = False,
1208 keyword_search_enabled: bool = False,
1209 # embedding_model: Identifier = None, # Legacy: Allow MindsDB models to be passed as embedding_model.
1210 ) -> db.KnowledgeBase:
1211 """
1212 Add a new knowledge base to the database
1213 :param preprocessing_config: Optional preprocessing configuration to validate and store
1214 :param is_sparse: Whether to use sparse vectors for embeddings
1215 :param vector_size: Optional size specification for vectors, required when is_sparse=True
1216 """
1218 # Validate preprocessing config first if provided
1219 if preprocessing_config is not None: 1219 ↛ 1220line 1219 didn't jump to line 1220 because the condition on line 1219 was never true
1220 PreprocessingConfig(**preprocessing_config) # Validate before storing
1221 params = params or {}
1222 params["preprocessing"] = preprocessing_config
1224 self._check_kb_input_params(params)
1226 # Check if vector_size is provided when using sparse vectors
1227 is_sparse = params.get("is_sparse")
1228 vector_size = params.get("vector_size")
1229 if is_sparse and vector_size is None: 1229 ↛ 1230line 1229 didn't jump to line 1230 because the condition on line 1229 was never true
1230 raise ValueError("vector_size is required when is_sparse=True")
1232 # get project id
1233 project = self.session.database_controller.get_project(project_name)
1234 project_id = project.id
1236 # check if knowledge base already exists
1237 kb = self.get(name, project_id)
1238 if kb is not None: 1238 ↛ 1239line 1238 didn't jump to line 1239 because the condition on line 1238 was never true
1239 if if_not_exists:
1240 return kb
1241 raise EntityExistsError("Knowledge base already exists", name)
1243 embedding_params = get_model_params(params.get("embedding_model", {}), "default_embedding_model")
1244 params["embedding_model"] = embedding_params
1245 rotate_provider_api_key(embedding_params)
1247 # if model_name is None: # Legacy
1248 embed_info = self._check_embedding_model(
1249 project.name,
1250 params=embedding_params,
1251 kb_name=name,
1252 )
1254 # if params.get("reranking_model", {}) is bool and False we evaluate it to empty dictionary
1255 reranking_model_params = params.get("reranking_model", {})
1257 if isinstance(reranking_model_params, bool) and not reranking_model_params: 1257 ↛ 1258line 1257 didn't jump to line 1258 because the condition on line 1257 was never true
1258 params["reranking_model"] = {}
1259 else:
1260 reranking_model_params = get_model_params(reranking_model_params, "default_reranking_model")
1262 params["reranking_model"] = reranking_model_params
1263 if reranking_model_params: 1263 ↛ 1266line 1263 didn't jump to line 1266 because the condition on line 1263 was never true
1264 # Get reranking model from params.
1265 # This is called here to check validaity of the parameters.
1266 rotate_provider_api_key(reranking_model_params)
1267 self._test_reranking(reranking_model_params)
1269 # search for the vector database table
1270 if storage is None: 1270 ↛ 1289line 1270 didn't jump to line 1289 because the condition on line 1270 was always true
1271 cloud_pg_vector = os.environ.get("KB_PGVECTOR_URL")
1272 if cloud_pg_vector: 1272 ↛ 1273line 1272 didn't jump to line 1273 because the condition on line 1272 was never true
1273 vector_table_name = name
1274 # Add sparse vector support for pgvector
1275 vector_db_params = {}
1276 # Check both explicit parameter and model configuration
1277 if is_sparse:
1278 vector_db_params["is_sparse"] = True
1279 if vector_size is not None:
1280 vector_db_params["vector_size"] = vector_size
1281 vector_db_name = self._create_persistent_pgvector(vector_db_params)
1282 params["default_vector_storage"] = vector_db_name
1283 else:
1284 # create chroma db with same name
1285 vector_table_name = "default_collection"
1286 vector_db_name = self._create_persistent_chroma(name)
1287 # memorize to remove it later
1288 params["default_vector_storage"] = vector_db_name
1289 elif len(storage.parts) != 2:
1290 raise ValueError("Storage param has to be vector db with table")
1291 else:
1292 vector_db_name, vector_table_name = storage.parts
1294 data_node = self.session.datahub.get(vector_db_name)
1295 if data_node: 1295 ↛ 1298line 1295 didn't jump to line 1298 because the condition on line 1295 was always true
1296 vector_store_handler = data_node.integration_handler
1297 else:
1298 raise ValueError(
1299 f"Unable to find database named {vector_db_name}, please make sure {vector_db_name} is defined"
1300 )
1301 # create table in vectordb before creating KB
1302 if "default_vector_storage" in params: 1302 ↛ 1308line 1302 didn't jump to line 1308 because the condition on line 1302 was always true
1303 # if vector db is a default - drop previous table, if exists
1304 try:
1305 vector_store_handler.drop_table(vector_table_name)
1306 except Exception:
1307 ...
1308 vector_store_handler.create_table(vector_table_name)
1309 self._check_vector_table(embed_info, vector_store_handler, vector_table_name)
1311 if keyword_search_enabled: 1311 ↛ 1312line 1311 didn't jump to line 1312 because the condition on line 1311 was never true
1312 vector_store_handler.add_full_text_index(vector_table_name, TableField.CONTENT.value)
1313 vector_database_id = self.session.integration_controller.get(vector_db_name)["id"]
1315 # Store sparse vector settings in params if specified
1316 if is_sparse: 1316 ↛ 1317line 1316 didn't jump to line 1317 because the condition on line 1316 was never true
1317 params = params or {}
1318 params["vector_config"] = {"is_sparse": is_sparse}
1319 if vector_size is not None:
1320 params["vector_config"]["vector_size"] = vector_size
1322 params["version"] = self.KB_VERSION
1323 kb = db.KnowledgeBase(
1324 name=name,
1325 project_id=project_id,
1326 vector_database_id=vector_database_id,
1327 vector_database_table=vector_table_name,
1328 embedding_model_id=None,
1329 params=params,
1330 )
1331 db.session.add(kb)
1332 db.session.commit()
1333 return kb
1335 def _check_vector_table(self, embed_info, vector_store_handler, vector_table_name):
1336 query = Select(
1337 targets=[Identifier(TableField.EMBEDDINGS.value)],
1338 from_table=Identifier(parts=[vector_table_name]),
1339 limit=Constant(1),
1340 )
1341 dimension = None
1342 if hasattr(vector_store_handler, "get_dimension"): 1342 ↛ 1343line 1342 didn't jump to line 1343 because the condition on line 1342 was never true
1343 dimension = vector_store_handler.get_dimension(vector_table_name)
1344 else:
1345 df = vector_store_handler.dispatch_select(query, [])
1346 if len(df) > 0: 1346 ↛ 1347line 1346 didn't jump to line 1347 because the condition on line 1346 was never true
1347 value = df[TableField.EMBEDDINGS.value][0]
1348 if isinstance(value, str):
1349 value = json.loads(value)
1350 dimension = len(value)
1351 if dimension is not None and dimension != embed_info["dimension"]: 1351 ↛ 1352line 1351 didn't jump to line 1352 because the condition on line 1351 was never true
1352 raise ValueError(
1353 f"Dimension of embedding model doesn't match to dimension of vector table: {embed_info['dimension']} != {dimension}"
1354 )
1356 def update(
1357 self,
1358 name: str,
1359 project_name: str,
1360 params: dict,
1361 preprocessing_config: Optional[dict] = None,
1362 ) -> db.KnowledgeBase:
1363 """
1364 Update the knowledge base
1365 :param name: The name of the knowledge base
1366 :param project_name: Current project name
1367 :param params: The parameters to update
1368 :param preprocessing_config: Optional preprocessing configuration to validate and store
1369 """
1371 # Validate preprocessing config first if provided
1372 if preprocessing_config is not None: 1372 ↛ 1373line 1372 didn't jump to line 1373 because the condition on line 1372 was never true
1373 PreprocessingConfig(**preprocessing_config) # Validate before storing
1374 params = params or {}
1375 params["preprocessing"] = preprocessing_config
1377 self._check_kb_input_params(params)
1379 # get project id
1380 project = self.session.database_controller.get_project(project_name)
1381 project_id = project.id
1383 # get existed KB
1384 kb = self.get(name.lower(), project_id)
1385 if kb is None: 1385 ↛ 1386line 1385 didn't jump to line 1386 because the condition on line 1385 was never true
1386 raise EntityNotExistsError("Knowledge base doesn't exists", name)
1388 if "embedding_model" in params: 1388 ↛ 1410line 1388 didn't jump to line 1410 because the condition on line 1388 was always true
1389 new_config = params["embedding_model"]
1390 # update embedding
1391 embed_params = kb.params.get("embedding_model", {})
1392 if not embed_params: 1392 ↛ 1394line 1392 didn't jump to line 1394 because the condition on line 1392 was never true
1393 # maybe old version of KB
1394 raise ValueError("No embedding config to update")
1396 # some parameters are not allowed to update
1397 for key in ("provider", "model_name"):
1398 if key in new_config and new_config[key] != embed_params.get(key): 1398 ↛ 1399line 1398 didn't jump to line 1399 because the condition on line 1398 was never true
1399 raise ValueError(f"You can't update '{key}' setting")
1401 embed_params.update(new_config)
1403 self._check_embedding_model(
1404 project.name,
1405 params=embed_params,
1406 kb_name=name,
1407 )
1408 kb.params["embedding_model"] = embed_params
1410 if "reranking_model" in params: 1410 ↛ 1411line 1410 didn't jump to line 1411 because the condition on line 1410 was never true
1411 new_config = params["reranking_model"]
1412 # update embedding
1413 rerank_params = kb.params.get("reranking_model", {})
1415 if new_config is False:
1416 # disable reranking
1417 rerank_params = {}
1418 elif "provider" in new_config and new_config["provider"] != rerank_params.get("provider"):
1419 # use new config (and include default config)
1420 rerank_params = get_model_params(new_config, "default_reranking_model")
1421 else:
1422 # update current config
1423 rerank_params.update(new_config)
1425 if rerank_params:
1426 self._test_reranking(rerank_params)
1428 kb.params["reranking_model"] = rerank_params
1430 # update other keys
1431 for key in ["id_column", "metadata_columns", "content_columns", "preprocessing"]:
1432 if key in params: 1432 ↛ 1433line 1432 didn't jump to line 1433 because the condition on line 1432 was never true
1433 kb.params[key] = params[key]
1435 flag_modified(kb, "params")
1436 db.session.commit()
1438 return self.get(name.lower(), project_id)
1440 def _test_reranking(self, params):
1441 try:
1442 reranker = get_reranking_model_from_params(params)
1443 reranker.get_scores("test", ["test"])
1444 except (ValueError, RuntimeError) as e:
1445 if params["provider"] in ("azure_openai", "openai") and params.get("method") != "no-logprobs":
1446 # check with no-logprobs
1447 params["method"] = "no-logprobs"
1448 self._test_reranking(params)
1449 logger.warning(
1450 f"logprobs is not supported for this model: {params.get('model_name')}. using no-logprobs mode"
1451 )
1452 else:
1453 raise RuntimeError(f"Problem with reranker config: {e}") from e
1455 def _create_persistent_pgvector(self, params=None):
1456 """Create default vector database for knowledge base, if not specified"""
1457 vector_store_name = "kb_pgvector_store"
1459 # check if exists
1460 if self.session.integration_controller.get(vector_store_name):
1461 return vector_store_name
1463 self.session.integration_controller.add(vector_store_name, "pgvector", params or {})
1464 return vector_store_name
1466 def _create_persistent_chroma(self, kb_name, engine="chromadb"):
1467 """Create default vector database for knowledge base, if not specified"""
1469 vector_store_name = f"{kb_name}_{engine}"
1471 vector_store_folder_name = f"{vector_store_name}"
1472 connection_args = {"persist_directory": vector_store_folder_name}
1474 # check if exists
1475 if self.session.integration_controller.get(vector_store_name): 1475 ↛ 1476line 1475 didn't jump to line 1476 because the condition on line 1475 was never true
1476 return vector_store_name
1478 self.session.integration_controller.add(vector_store_name, engine, connection_args)
1479 return vector_store_name
1481 def _check_embedding_model(self, project_name, params: dict = None, kb_name="") -> dict:
1482 """check embedding model for knowledge base, return embedding model info"""
1484 # if mindsdb model from old KB exists - drop it
1485 model_name = f"kb_embedding_{kb_name}"
1486 try:
1487 model = self.session.model_controller.get_model(model_name, project_name=project_name)
1488 if model is not None:
1489 self.session.model_controller.delete_model(model_name, project_name)
1490 except PredictorRecordNotFound:
1491 pass
1493 if "provider" not in params: 1493 ↛ 1494line 1493 didn't jump to line 1494 because the condition on line 1493 was never true
1494 raise ValueError("'provider' parameter is required for embedding model")
1496 # check available providers
1497 avail_providers = ("openai", "azure_openai", "bedrock", "gemini", "google", "ollama", "snowflake")
1498 if params["provider"] not in avail_providers: 1498 ↛ 1499line 1498 didn't jump to line 1499 because the condition on line 1498 was never true
1499 raise ValueError(
1500 f"Wrong embedding provider: {params['provider']}. Available providers: {', '.join(avail_providers)}"
1501 )
1503 llm_client = LLMClient(params, session=self.session)
1505 try:
1506 resp = llm_client.embeddings(["test"])
1507 return {"dimension": len(resp[0])}
1508 except Exception as e:
1509 raise RuntimeError(f"Problem with embedding model config: {e}") from e
1511 def delete(self, name: str, project_name: int, if_exists: bool = False) -> None:
1512 """
1513 Delete a knowledge base from the database
1514 """
1515 try:
1516 project = self.session.database_controller.get_project(project_name)
1517 except ValueError as e:
1518 raise ValueError(f"Project not found: {project_name}") from e
1519 project_id = project.id
1521 # check if knowledge base exists
1522 kb = self.get(name, project_id)
1523 if kb is None:
1524 # knowledge base does not exist
1525 if if_exists:
1526 return
1527 else:
1528 raise EntityNotExistsError("Knowledge base does not exist", name)
1530 # kb exists
1531 db.session.delete(kb)
1532 db.session.commit()
1534 # drop objects if they were created automatically
1535 if "default_vector_storage" in kb.params:
1536 try:
1537 dn = self.session.datahub.get(kb.params["default_vector_storage"])
1538 dn.integration_handler.drop_table(kb.vector_database_table)
1539 if dn.ds_type != "pgvector":
1540 self.session.integration_controller.delete(kb.params["default_vector_storage"])
1541 except EntityNotExistsError:
1542 pass
1543 if "created_embedding_model" in kb.params:
1544 try:
1545 self.session.model_controller.delete_model(kb.params["created_embedding_model"], project_name)
1546 except EntityNotExistsError:
1547 pass
1549 def get(self, name: str, project_id: int) -> db.KnowledgeBase:
1550 """
1551 Get a knowledge base from the database
1552 by name + project_id
1553 """
1554 kb = (
1555 db.session.query(db.KnowledgeBase)
1556 .filter_by(
1557 name=name,
1558 project_id=project_id,
1559 )
1560 .first()
1561 )
1562 return kb
1564 def get_table(self, name: str, project_id: int, params: dict = None) -> KnowledgeBaseTable:
1565 """
1566 Returns kb table object with properly configured preprocessing
1567 :param name: table name
1568 :param project_id: project id
1569 :param params: runtime parameters for KB. Keys: 'model' - parameters for embedding model
1570 :return: kb table object
1571 """
1572 kb = self.get(name, project_id)
1573 if kb is not None: 1573 ↛ exitline 1573 didn't return from function 'get_table' because the condition on line 1573 was always true
1574 table = KnowledgeBaseTable(kb, self.session)
1575 if params: 1575 ↛ 1579line 1575 didn't jump to line 1579 because the condition on line 1575 was always true
1576 table.model_params = params.get("model")
1578 # Always configure preprocessing - either from params or default
1579 if kb.params and "preprocessing" in kb.params: 1579 ↛ 1580line 1579 didn't jump to line 1580 because the condition on line 1579 was never true
1580 table.configure_preprocessing(kb.params["preprocessing"])
1581 else:
1582 table.configure_preprocessing(None) # This ensures default preprocessor is created
1584 return table
1586 def list(self, project_name: str = None) -> List[dict]:
1587 """
1588 List all knowledge bases from the database
1589 belonging to a project
1590 """
1591 project_controller = ProjectController()
1592 projects = project_controller.get_list()
1593 if project_name is not None: 1593 ↛ 1596line 1593 didn't jump to line 1596 because the condition on line 1593 was always true
1594 projects = [p for p in projects if p.name == project_name]
1596 query = db.session.query(db.KnowledgeBase).filter(
1597 db.KnowledgeBase.project_id.in_(list([p.id for p in projects]))
1598 )
1600 data = []
1601 project_names = {i.id: i.name for i in project_controller.get_list()}
1603 for record in query: 1603 ↛ 1604line 1603 didn't jump to line 1604 because the loop on line 1603 never started
1604 kb = record.as_dict(with_secrets=self.session.show_secrets)
1605 kb["project_name"] = project_names[record.project_id]
1607 data.append(kb)
1609 return data
1611 def create_index(self, table_name, project_name):
1612 project_id = self.session.database_controller.get_project(project_name).id
1613 kb_table = self.get_table(table_name, project_id)
1614 kb_table.create_index()
1616 def evaluate(self, table_name: str, project_name: str, params: dict = None) -> pd.DataFrame:
1617 """
1618 Run evaluate and/or create test data for evaluation
1619 :param table_name: name of KB
1620 :param project_name: project of KB
1621 :param params: evaluation parameters
1622 :return: evaluation results
1623 """
1624 project_id = self.session.database_controller.get_project(project_name).id
1625 kb_table = self.get_table(table_name, project_id)
1627 scores = EvaluateBase.run(self.session, kb_table, params)
1629 return scores