Coverage for mindsdb / integrations / utilities / rag / retrievers / sql_retriever.py: 14%
355 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import re
2import math
3import logging
4import collections
5from typing import List, Any, Optional, Dict, Tuple, Union, Callable
7from pydantic import BaseModel, Field
8from langchain.chains.llm import LLMChain
9from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
10from langchain_core.documents.base import Document
11from langchain_core.embeddings import Embeddings
12from langchain_core.exceptions import OutputParserException
13from langchain_core.language_models.chat_models import BaseChatModel
14from langchain_core.output_parsers import PydanticOutputParser
15from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
16from langchain_core.retrievers import BaseRetriever
18from mindsdb.api.executor.data_types.response_type import RESPONSE_TYPE
19from mindsdb.integrations.libs.response import HandlerResponse
20from mindsdb.integrations.libs.vectordatabase_handler import (
21 DistanceFunction,
22 VectorStoreHandler,
23)
24from mindsdb.integrations.utilities.rag.settings import (
25 DatabaseSchema,
26 TableSchema,
27 ColumnSchema,
28 ValueSchema,
29 SearchKwargs,
30)
31from mindsdb.utilities import log
33import numpy as np
35logger = log.getLogger(__name__)
38class MetadataFilter(BaseModel):
39 """Represents an LLM generated metadata filter to apply to a PostgreSQL query."""
41 attribute: str = Field(description="Database column to apply filter to")
42 comparator: str = Field(description="PostgreSQL comparator to use to filter database column")
43 value: Any = Field(description="Value to use to filter database column")
46class AblativeMetadataFilter(MetadataFilter):
47 """Adds additional fields to support ablation."""
49 schema_table: str = Field(description="schema name of the table for this filter")
50 schema_column: str = Field(description="schema name of the column for this filter")
51 schema_value: str = Field(description="schema name of the value for this filter")
54class MetadataFilters(BaseModel):
55 """List of LLM generated metadata filters to apply to a PostgreSQL query."""
57 filters: List[MetadataFilter] = Field(description="List of PostgreSQL metadata filters to apply for user query")
60class SQLRetriever(BaseRetriever):
61 """Retriever that uses a LLM to generate pgvector queries to do similarity search with metadata filters.
63 How it works:
65 1. Use a LLM to rewrite the user input to something more suitable for retrieval. For example:
66 "Show me documents containing how to finetune a LLM please" --> "how to finetune a LLM"
68 2. Use a LLM to generate structured metadata filters based on the user input. Provided
69 metadata schemas & examples are used as additional context.
71 3. Generate a prepared PostgreSQL query from the structured metadata filters.
73 4. Actually execute the query against our vector database to retrieve documents & return them.
74 """
76 fallback_retriever: BaseRetriever
77 vector_store_handler: VectorStoreHandler
78 # search parameters
79 max_filters: int
80 filter_threshold: float
81 min_k: int
83 # Schema description
84 database_schema: Optional[DatabaseSchema] = None
86 # Embeddings
87 embeddings_model: Embeddings
88 search_kwargs: SearchKwargs
90 # prompt templates
91 rewrite_prompt_template: str
93 # schema templates
94 table_prompt_template: str
95 column_prompt_template: str
96 value_prompt_template: str
98 # formatting templates
99 boolean_system_prompt: str
100 generative_system_prompt: str
102 # SQL search config
103 num_retries: int
104 embeddings_table: str
105 source_table: str
106 source_id_column: str
107 distance_function: DistanceFunction
109 # Re-rank and metadata generation model.
110 llm: BaseChatModel
112 def _sort_schema_by_priority_key(
113 self,
114 schema_dict_item: Tuple[str, Union[TableSchema, ColumnSchema, ValueSchema]],
115 ):
116 return schema_dict_item[1].priority
118 def _sort_schema_by_relevance_key(
119 self,
120 schema_dict_item: Tuple[str, Union[TableSchema, ColumnSchema, ValueSchema]],
121 ):
122 if schema_dict_item[1].relevance is not None:
123 return schema_dict_item[1].relevance
124 else:
125 return 0
127 def _sort_schema_by_key(
128 self,
129 schema: Union[DatabaseSchema, TableSchema, ColumnSchema],
130 key: Callable,
131 update: Dict[str, Any] = None,
132 ) -> Union[DatabaseSchema, TableSchema, ColumnSchema]:
133 """Takes a schema and converts its dict into an OrderedDict"""
134 if isinstance(schema, DatabaseSchema):
135 collection_key = "tables"
136 elif isinstance(schema, TableSchema):
137 collection_key = "columns"
138 elif isinstance(schema, ColumnSchema):
139 collection_key = "values"
140 else:
141 raise Exception("schema must be either a DatabaseSchema, TableSchema, or ColumnSchema.")
143 if update is not None:
144 ordered = collections.OrderedDict(sorted(update.items(), key=key, reverse=True))
145 else:
146 ordered = collections.OrderedDict(sorted(getattr(schema, collection_key).items(), key=key, reverse=True))
147 schema = schema.model_copy(update={collection_key: ordered})
149 return schema
151 def _sort_database_schema_by_key(self, database_schema: DatabaseSchema, key: Callable) -> DatabaseSchema:
152 """Re-build schema with OrderedDicts"""
153 tables = {}
154 # build new tables dict
155 for table_key, table_schema in database_schema.tables.items():
156 columns = {}
157 # build new column dict
158 for column_key, column_schema in table_schema.columns.items():
159 # sort values directly and update column schema
160 columns[column_key] = self._sort_schema_by_key(schema=column_schema, key=key)
161 # update table schema and sort
162 tables[table_key] = self._sort_schema_by_key(schema=table_schema, key=key, update=columns)
163 # update table schema and sort
164 database_schema = self._sort_schema_by_key(schema=database_schema, key=key, update=tables)
166 return database_schema
168 def _prepare_value_prompt(
169 self,
170 value_schema: ValueSchema,
171 column_schema: ColumnSchema,
172 table_schema: TableSchema,
173 boolean_system_prompt: bool = True,
174 format_instructions: Optional[str] = None,
175 ) -> ChatPromptTemplate:
176 if boolean_system_prompt is True:
177 system_prompt = self.boolean_system_prompt
178 else:
179 system_prompt = self.generative_system_prompt
181 prepared_column_prompt = self._prepare_column_prompt(column_schema=column_schema, table_schema=table_schema)
182 column_schema_str = (
183 prepared_column_prompt.messages[1]
184 .format(
185 **prepared_column_prompt.partial_variables,
186 query="See query at the lowest level schema.",
187 )
188 .content
189 )
191 base_prompt_template = ChatPromptTemplate.from_messages(
192 [("system", system_prompt), ("user", self.value_prompt_template)]
193 )
195 value_str = ""
196 header_str = ""
197 if type(value_schema.value) in [str, int, float, bool]:
198 header_str = f"This schema describes a single value in the {column_schema.column} column."
200 value_str = f"""
201 -**Value**: {value_schema.value}
202"""
204 elif type(value_schema.value) is dict:
205 header_str = f"This schema describes enumerated values in the {column_schema.column} column."
207 value_str = """
208## **Enumerated Values**
210The values in the column are an enumeration of named values. These are listed below with format **[Column Value]**: [named value].
211"""
212 for value, value_name in value_schema.value.items():
213 value_str += f"""
214- **{value}:** {value_name}"""
216 elif type(value_schema.value) is list:
217 header_str = f"This schema describes some of the values in the {column_schema.column} column."
219 value_str = """
220## **Sample Values**
222There are too many values in this column to list exhaustively. Below is a sampling of values found in the column:
223"""
224 for value in value_schema.value:
225 value_str += f"""
226- {value}"""
228 if getattr(value_schema, "comparator", None) is not None:
229 comparator_str = """
231## **Comparators**
233Below is a list of comparison operators for constructing filters for this value schema:
234"""
235 if type(value_schema.comparator) is str:
236 comparator_str += f"""- {value_schema.comparator}
237"""
238 else:
239 for comp in value_schema.comparator:
240 comparator_str += f"""- {comp}
241"""
242 else:
243 comparator_str = ""
245 if getattr(value_schema, "example_questions", None) is not None:
246 example_str = """## **Example Questions**
247"""
248 for i, example in enumerate(value_schema.example_questions):
249 example_str += f"""{i}. **Query:** {example.input} **Answer:** {example.output}
250"""
251 else:
252 example_str = ""
254 return base_prompt_template.partial(
255 format_instructions=format_instructions,
256 header=header_str,
257 column_schema=column_schema_str,
258 value=value_str,
259 comparator=comparator_str,
260 type=value_schema.type,
261 description=value_schema.description,
262 usage=value_schema.usage,
263 examples=example_str,
264 )
266 def _prepare_column_prompt(
267 self,
268 column_schema: ColumnSchema,
269 table_schema: TableSchema,
270 boolean_system_prompt: bool = True,
271 ) -> ChatPromptTemplate:
272 if boolean_system_prompt is True:
273 system_prompt = self.boolean_system_prompt
274 else:
275 system_prompt = self.generative_system_prompt
277 prepared_table_prompt = self._prepare_table_prompt(
278 table_schema=table_schema, boolean_system_prompt=boolean_system_prompt
279 )
280 table_schema_str = (
281 prepared_table_prompt.messages[1]
282 .format(
283 **prepared_table_prompt.partial_variables,
284 query="See query at the lowest level schema",
285 )
286 .content
287 )
289 base_prompt_template = ChatPromptTemplate.from_messages(
290 [("system", system_prompt), ("user", self.column_prompt_template)]
291 )
293 header_str = f"This schema describes a column in the {table_schema.table} table."
295 value_str = """
296## **Content**
298Below is a description of the contents in this column in list format:
299"""
300 for value_schema in column_schema.values.values():
301 value_str += f"""
302- {value_schema.description}
303"""
304 value_str += """
305**Important:** The above descriptions are not the actual values stored in this column. See the Value schema for actual values.
306"""
308 if getattr(column_schema, "examples", None) is not None:
309 example_str = """## **Example Questions**
310"""
311 for example in column_schema.examples:
312 example_str += f"""- {example}
313"""
314 else:
315 example_str = ""
317 return base_prompt_template.partial(
318 table_schema=table_schema_str,
319 header=header_str,
320 column=column_schema.column,
321 type=column_schema.type,
322 description=column_schema.description,
323 usage=column_schema.usage,
324 values=value_str,
325 examples=example_str,
326 )
328 def _prepare_table_prompt(
329 self, table_schema: TableSchema, boolean_system_prompt: bool = True
330 ) -> ChatPromptTemplate:
331 if boolean_system_prompt is True:
332 system_prompt = self.boolean_system_prompt
333 else:
334 system_prompt = self.generative_system_prompt
336 base_prompt_template = ChatPromptTemplate.from_messages(
337 [("system", system_prompt), ("user", self.table_prompt_template)]
338 )
340 header_str = "This schema describes a table in the database."
342 columns_str = ""
343 for column_key, column_schema in table_schema.columns.items():
344 columns_str += f"""
345- **{column_schema.column}:** {column_schema.description}
346"""
348 if getattr(table_schema, "examples", None) is not None:
349 example_str = """## **Example Questions**
350"""
351 for example in table_schema.examples:
352 example_str += f"""- {example}
353"""
354 else:
355 example_str = ""
357 return base_prompt_template.partial(
358 header=header_str,
359 table=table_schema.table,
360 description=table_schema.description,
361 usage=table_schema.usage,
362 columns=columns_str,
363 examples=example_str,
364 )
366 def _rank_schema(self, prompt: ChatPromptTemplate, query: str) -> float:
367 rank_chain = LLMChain(llm=self.llm.bind(logprobs=True), prompt=prompt, return_final_only=False)
368 output = rank_chain({"query": query}) # returns metadata
370 # parse through metadata tokens until encountering either yes, or no.
371 score = None # a None score indicates the model output could not be parsed.
372 for content in output["full_generation"][0].message.response_metadata["logprobs"]["content"]:
373 # Convert answer to score using the model's confidence
374 if content["token"].lower().strip() == "yes":
375 score = (1 + math.exp(content["logprob"])) / 2 # If yes, use the model's confidence
376 break
377 elif content["token"].lower().strip() == "no":
378 score = (1 - math.exp(content["logprob"])) / 2 # If no, invert the confidence
379 break
381 if score is None:
382 score = 0.0
384 return score
386 def _breadth_first_search(self, query: str, greedy: bool = False) -> Tuple:
387 """Search breadth wise through Tables, then Columns, then Values.Uses a greedy strategy to maximize quota if greedy=True, otherwise a dynamic strategy."""
389 # sort based on priority
390 ordered_database_schema = self._sort_database_schema_by_key(
391 database_schema=self.database_schema, key=self._sort_schema_by_priority_key
392 )
394 # Rank Tables ########################################################
395 greedy_count = 0
396 tables = {}
397 # rank tables by relevance
398 for table_key, table_schema in ordered_database_schema.tables.items():
399 prompt: ChatPromptTemplate = self._prepare_table_prompt(
400 table_schema=table_schema, boolean_system_prompt=True
401 )
402 table_schema.relevance = self._rank_schema(prompt=prompt, query=query)
404 # only keep greedy tables
405 tables[table_key] = table_schema
407 if greedy:
408 if table_schema.relevance >= ordered_database_schema.filter_threshold:
409 greedy_count += 1
410 if greedy_count >= ordered_database_schema.max_filters:
411 break
413 # sort tables
414 ordered_database_schema = self._sort_schema_by_key(
415 schema=ordered_database_schema,
416 key=self._sort_schema_by_relevance_key,
417 update=tables,
418 )
420 # Rank Columns #######################################################
421 # iterate through tables to rank columns
422 tables = {}
423 table_count = 0 # take only the top n number of tables specified by the databases max filters
424 for table_key, table_schema in ordered_database_schema.tables.items():
425 # only drop into tables above the filter threshold
426 if table_schema.relevance >= ordered_database_schema.filter_threshold:
427 greedy_count = 0
428 # rank columns by relevance
429 columns = {}
430 for column_key, column_schema in table_schema.columns.items():
431 prompt: ChatPromptTemplate = self._prepare_column_prompt(
432 column_schema=column_schema,
433 table_schema=table_schema,
434 boolean_system_prompt=True,
435 )
436 column_schema.relevance = self._rank_schema(prompt=prompt, query=query)
438 columns[column_key] = column_schema
440 if greedy:
441 if column_schema.relevance >= table_schema.filter_threshold:
442 greedy_count += 1
443 if greedy_count >= table_schema.max_filters:
444 break
446 # sort columns and keep only columns that made the cut.
447 tables[table_key] = self._sort_schema_by_key(
448 table_schema, key=self._sort_schema_by_relevance_key, update=columns
449 )
451 table_count += 1
452 if table_count >= ordered_database_schema.max_filters:
453 break
455 # sort tables and keep only tables that made the cut.
456 ordered_database_schema = self._sort_schema_by_key(
457 ordered_database_schema,
458 key=self._sort_schema_by_relevance_key,
459 update=tables,
460 )
462 # Rank Values ########################################################
463 # iterate through tables to rank values
464 tables = {}
465 for table_key, table_schema in ordered_database_schema.tables.items():
466 columns = {}
467 column_count = 0
468 # iterate through columns to rank values
469 for column_key, column_schema in table_schema.columns.items():
470 if column_schema.relevance >= table_schema.filter_threshold:
471 greedy_count = 0
472 values = {}
473 # rank values by relevance
474 for value_key, value_schema in column_schema.values.items():
475 prompt: ChatPromptTemplate = self._prepare_value_prompt(
476 value_schema=value_schema,
477 column_schema=column_schema,
478 table_schema=table_schema,
479 boolean_system_prompt=True,
480 )
481 value_schema.relevance = self._rank_schema(prompt=prompt, query=query)
483 values[value_key] = value_schema
485 if greedy:
486 if value_schema.relevance >= column_schema.filter_threshold:
487 greedy_count += 1
488 if greedy_count >= column_schema.max_filters:
489 break
491 # sort values and keep only values that make the cut
492 columns[column_key] = self._sort_schema_by_key(
493 column_schema,
494 key=self._sort_schema_by_relevance_key,
495 update=values,
496 )
498 column_count += 1
499 if column_count >= table_schema.max_filters:
500 break
502 # sort columns and keep only columns that made the cut
503 tables[table_key] = self._sort_schema_by_key(
504 table_schema, key=self._sort_schema_by_relevance_key, update=columns
505 )
507 # sort tables and keep only tables that made the cut.
508 ordered_database_schema = self._sort_schema_by_key(
509 ordered_database_schema,
510 key=self._sort_schema_by_relevance_key,
511 update=tables,
512 )
514 # discard low ranked values ###################################################################################
515 tables = {}
516 for table_key, table_schema in ordered_database_schema.tables.items():
517 columns = {}
518 # iterate through columns to rank values
519 for column_key, column_schema in table_schema.columns.items():
520 value_count = 0
521 values = {}
522 # rank values by relevance
523 for value_key, value_schema in column_schema.values.items():
524 if value_schema.relevance >= column_schema.filter_threshold:
525 values[value_key] = value_schema
527 value_count += 1
528 if value_count >= column_schema.max_filters:
529 break
531 # sort values and keep only values that make the cut
532 columns[column_key] = self._sort_schema_by_key(
533 column_schema,
534 key=self._sort_schema_by_relevance_key,
535 update=values,
536 )
538 # sort columns and keep only columns that made the cut
539 tables[table_key] = self._sort_schema_by_key(
540 table_schema, key=self._sort_schema_by_relevance_key, update=columns
541 )
543 # sort tables and keep only tables that made the cut.
544 ordered_database_schema = self._sort_schema_by_key(
545 ordered_database_schema,
546 key=self._sort_schema_by_relevance_key,
547 update=tables,
548 )
550 ranked_database_schema = ordered_database_schema
552 # Build Ablation #####################################################
554 ablation_value_dict = {}
555 # assemble a relevance dictionary
556 for table_key, table_schema in ordered_database_schema.tables.items():
557 for column_key, column_schema in table_schema.columns.items():
558 for value_key, value_schema in column_schema.values.items():
559 ablation_value_dict[(table_key, column_key, value_key)] = value_schema.relevance
561 ablation_value_dict = collections.OrderedDict(sorted(ablation_value_dict.items(), key=lambda x: x[1]))
563 relevance_scores = list(ablation_value_dict.values())
564 if len(relevance_scores) > 0:
565 ablation_quantiles = np.quantile(relevance_scores, np.linspace(0, 1, self.num_retries + 2)[1:-1])
566 else:
567 ablation_quantiles = None
569 return ranked_database_schema, ablation_value_dict, ablation_quantiles
571 def _dynamic_ablation(
572 self,
573 metadata_filters: List[AblativeMetadataFilter],
574 ablation_value_dict,
575 ablation_quantiles,
576 retry: int,
577 ):
578 """Ablate metadata filters in aggregate by quantiles until the required minimum number of documents are returned."""
580 ablated_dict = {}
581 for key, value in ablation_value_dict.items():
582 if value >= ablation_quantiles[retry]:
583 ablated_dict[key] = value
585 # discard low ranked filters ##################################################################################
586 ablated_filters = []
587 for filter in metadata_filters:
588 for key in ablated_dict.keys():
589 if filter.schema_table in key and filter.schema_column in key and filter.schema_value in key:
590 ablated_filters.append(filter)
592 return ablated_filters
594 def depth_first_search(self, greedy=True):
595 """Search depth wise through Tables, then Columns, then Values. Uses a greedy strategy to maximize quota if greedy=True, otherwise a dynamic strategy."""
596 pass
598 def depth_first_ablation(self):
599 """Ablate metadata filters in reverse depth first search until the required minimum number of documents are returned."""
600 pass
602 def _prepare_retrieval_query(self, query: str) -> str:
603 rewrite_prompt = PromptTemplate(input_variables=["input"], template=self.rewrite_prompt_template)
604 rewrite_chain = LLMChain(llm=self.llm, prompt=rewrite_prompt)
605 return rewrite_chain.predict(input=query)
607 def _prepare_pgvector_query(
608 self,
609 ranked_database_schema: DatabaseSchema,
610 metadata_filters: List[AblativeMetadataFilter],
611 retry: int = 0,
612 ) -> str:
613 # Base select JOINed with document source table.
614 base_query = f"""SELECT * FROM {self.embeddings_table} AS e INNER JOIN {self.source_table} AS s ON (e.metadata->>'original_row_id')::int = s."{self.source_id_column}" """
616 # return an empty string if schema has not been ranked
617 if not ranked_database_schema:
618 return ""
620 # Add Table JOIN statements
621 join_clauses = set()
622 for metadata_filter in metadata_filters:
623 join_clause = ranked_database_schema.tables[metadata_filter.schema_table].join
624 if join_clause in join_clauses:
625 continue
626 else:
627 join_clauses.add(join_clause)
628 base_query += join_clause + " "
630 # Add WHERE conditions from metadata filters
631 if metadata_filters:
632 base_query += "WHERE "
633 for i, filter in enumerate(metadata_filters):
634 value = filter.value
635 if isinstance(value, str):
636 value = f"'{value}'"
637 base_query += f'"{filter.attribute}" {filter.comparator} {value}'
638 if i < len(metadata_filters) - 1:
639 base_query += " AND "
641 base_query += (
642 f" ORDER BY e.embeddings {self.distance_function.value[0]} '{{embeddings}}' LIMIT {self.search_kwargs.k};"
643 )
644 return base_query
646 def _generate_filter(self, prompt: ChatPromptTemplate, query: str) -> MetadataFilter:
647 gen_filter_chain = LLMChain(llm=self.llm, prompt=prompt)
648 output = gen_filter_chain({"query": query})
649 return output
651 def _generate_metadata_filters(
652 self, query: str, ranked_database_schema
653 ) -> Union[List[AblativeMetadataFilter], HandlerResponse]:
654 parser = PydanticOutputParser(pydantic_object=MetadataFilter)
656 metadata_filter_list = []
657 # iterate through tables to rank values
658 for table_key, table_schema in ranked_database_schema.tables.items():
659 # iterate through columns to rank values
660 for column_key, column_schema in table_schema.columns.items():
661 if column_schema.relevance >= table_schema.filter_threshold:
662 # generate filters
663 for value_key, value_schema in column_schema.values.items():
664 # must use generation if field is a dictionary of tuples or a list
665 if type(value_schema.value) in [list, dict]:
666 try:
667 metadata_prompt: ChatPromptTemplate = self._prepare_value_prompt(
668 format_instructions=parser.get_format_instructions(),
669 value_schema=value_schema,
670 column_schema=column_schema,
671 table_schema=table_schema,
672 boolean_system_prompt=False,
673 )
675 metadata_filters_chain = LLMChain(llm=self.llm, prompt=metadata_prompt)
676 metadata_filter_output = metadata_filters_chain.predict(
677 query=query,
678 )
680 # If the LLM outputs raw JSON, use it as-is.
681 # If the LLM outputs anything including a json markdown section, use the last one.
682 json_markdown_output = re.findall(r"```json.*```", metadata_filter_output, re.DOTALL)
683 if json_markdown_output:
684 metadata_filter_output = json_markdown_output[-1]
685 # Clean the json tags.
686 metadata_filter_output = metadata_filter_output[7:]
687 metadata_filter_output = metadata_filter_output[:-3]
689 metadata_filter = parser.invoke(metadata_filter_output)
690 model_dump = metadata_filter.model_dump()
691 model_dump.update(
692 {
693 "schema_table": table_key,
694 "schema_column": column_key,
695 "schema_value": value_key,
696 }
697 )
698 metadata_filter = AblativeMetadataFilter(**model_dump)
699 except OutputParserException as e:
700 logger.warning(
701 f"LLM failed to generate structured metadata filters: {e}",
702 exc_info=logger.isEnabledFor(logging.DEBUG),
703 )
704 return HandlerResponse(RESPONSE_TYPE.ERROR, error_message=str(e))
705 else:
706 metadata_filter = AblativeMetadataFilter(
707 attribute=column_schema.column,
708 comparator=value_schema.comparator,
709 value=value_schema.value,
710 schema_table=table_key,
711 schema_column=column_key,
712 schema_value=value_key,
713 )
714 metadata_filter_list.append(metadata_filter)
716 return metadata_filter_list
718 def _prepare_and_execute_query(
719 self,
720 ranked_database_schema: DatabaseSchema,
721 metadata_filters: List[AblativeMetadataFilter],
722 embeddings_str: str,
723 ) -> HandlerResponse:
724 try:
725 checked_sql_query = self._prepare_pgvector_query(ranked_database_schema, metadata_filters)
726 checked_sql_query_with_embeddings = checked_sql_query.format(embeddings=embeddings_str)
727 return self.vector_store_handler.native_query(checked_sql_query_with_embeddings)
728 except Exception as e:
729 logger.warning(
730 f"Failed to prepare and execute SQL query from structured metadata: {e}",
731 exc_info=logger.isEnabledFor(logging.DEBUG),
732 )
733 return HandlerResponse(RESPONSE_TYPE.ERROR, error_message=str(e))
735 def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
736 # Rewrite query to be suitable for retrieval.
737 retrieval_query = self._prepare_retrieval_query(query)
739 # Embed the rewritten retrieval query & include it in the similarity search pgvector query.
740 embedded_query = self.embeddings_model.embed_query(retrieval_query)
742 # Search for relevant filters
743 ranked_database_schema, ablation_value_dict, ablation_quantiles = self._breadth_first_search(query=query)
745 # Generate metadata filters
746 metadata_filters = self._generate_metadata_filters(query=query, ranked_database_schema=ranked_database_schema)
748 if type(metadata_filters) is list:
749 # Initial Execution of the similarity search with metadata filters.
750 document_response = self._prepare_and_execute_query(
751 ranked_database_schema=ranked_database_schema,
752 metadata_filters=metadata_filters,
753 embeddings_str=str(embedded_query),
754 )
755 num_retries = 0
756 while num_retries < self.num_retries:
757 if (
758 document_response.resp_type != RESPONSE_TYPE.ERROR
759 and len(document_response.data_frame) >= self.min_k
760 ):
761 # Successfully retrieved k documents to send to re-ranker.
762 break
763 elif document_response.resp_type == RESPONSE_TYPE.ERROR:
764 # LLMs won't always generate structured metadata so we should have a fallback after retrying.
765 logger.info(f"SQL Retriever query failed with error {document_response.error_message}")
766 else:
767 logger.info(
768 f"SQL Retriever did not retrieve {self.min_k} documents: {len(document_response.data_frame)} documents retrieved."
769 )
771 ablated_metadata_filters = self._dynamic_ablation(
772 metadata_filters=metadata_filters,
773 ablation_value_dict=ablation_value_dict,
774 ablation_quantiles=ablation_quantiles,
775 retry=num_retries,
776 )
778 document_response = self._prepare_and_execute_query(
779 ranked_database_schema=ranked_database_schema,
780 metadata_filters=ablated_metadata_filters,
781 embeddings_str=str(embedded_query),
782 )
784 num_retries += 1
786 retrieved_documents = []
787 if document_response.resp_type != RESPONSE_TYPE.ERROR:
788 document_df = document_response.data_frame
789 for _, document_row in document_df.iterrows():
790 retrieved_documents.append(
791 Document(
792 document_row.get("content", ""),
793 metadata=document_row.get("metadata", {}),
794 )
795 )
796 if retrieved_documents:
797 return retrieved_documents
799 # If the SQL query constructed did not return any documents, fallback.
800 logger.info("No documents returned from SQL retriever, using fallback retriever.")
801 return self.fallback_retriever._get_relevant_documents(retrieval_query, run_manager=run_manager)
802 else:
803 # If no metadata fields could be generated fallback.
804 logger.info("No metadata fields were successfully generated, using fallback retriever.")
805 return self.fallback_retriever._get_relevant_documents(retrieval_query, run_manager=run_manager)