Coverage for mindsdb / integrations / utilities / rag / retrievers / sql_retriever.py: 14%

355 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import re 

2import math 

3import logging 

4import collections 

5from typing import List, Any, Optional, Dict, Tuple, Union, Callable 

6 

7from pydantic import BaseModel, Field 

8from langchain.chains.llm import LLMChain 

9from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun 

10from langchain_core.documents.base import Document 

11from langchain_core.embeddings import Embeddings 

12from langchain_core.exceptions import OutputParserException 

13from langchain_core.language_models.chat_models import BaseChatModel 

14from langchain_core.output_parsers import PydanticOutputParser 

15from langchain_core.prompts import PromptTemplate, ChatPromptTemplate 

16from langchain_core.retrievers import BaseRetriever 

17 

18from mindsdb.api.executor.data_types.response_type import RESPONSE_TYPE 

19from mindsdb.integrations.libs.response import HandlerResponse 

20from mindsdb.integrations.libs.vectordatabase_handler import ( 

21 DistanceFunction, 

22 VectorStoreHandler, 

23) 

24from mindsdb.integrations.utilities.rag.settings import ( 

25 DatabaseSchema, 

26 TableSchema, 

27 ColumnSchema, 

28 ValueSchema, 

29 SearchKwargs, 

30) 

31from mindsdb.utilities import log 

32 

33import numpy as np 

34 

35logger = log.getLogger(__name__) 

36 

37 

38class MetadataFilter(BaseModel): 

39 """Represents an LLM generated metadata filter to apply to a PostgreSQL query.""" 

40 

41 attribute: str = Field(description="Database column to apply filter to") 

42 comparator: str = Field(description="PostgreSQL comparator to use to filter database column") 

43 value: Any = Field(description="Value to use to filter database column") 

44 

45 

46class AblativeMetadataFilter(MetadataFilter): 

47 """Adds additional fields to support ablation.""" 

48 

49 schema_table: str = Field(description="schema name of the table for this filter") 

50 schema_column: str = Field(description="schema name of the column for this filter") 

51 schema_value: str = Field(description="schema name of the value for this filter") 

52 

53 

54class MetadataFilters(BaseModel): 

55 """List of LLM generated metadata filters to apply to a PostgreSQL query.""" 

56 

57 filters: List[MetadataFilter] = Field(description="List of PostgreSQL metadata filters to apply for user query") 

58 

59 

60class SQLRetriever(BaseRetriever): 

61 """Retriever that uses a LLM to generate pgvector queries to do similarity search with metadata filters. 

62 

63 How it works: 

64 

65 1. Use a LLM to rewrite the user input to something more suitable for retrieval. For example: 

66 "Show me documents containing how to finetune a LLM please" --> "how to finetune a LLM" 

67 

68 2. Use a LLM to generate structured metadata filters based on the user input. Provided 

69 metadata schemas & examples are used as additional context. 

70 

71 3. Generate a prepared PostgreSQL query from the structured metadata filters. 

72 

73 4. Actually execute the query against our vector database to retrieve documents & return them. 

74 """ 

75 

76 fallback_retriever: BaseRetriever 

77 vector_store_handler: VectorStoreHandler 

78 # search parameters 

79 max_filters: int 

80 filter_threshold: float 

81 min_k: int 

82 

83 # Schema description 

84 database_schema: Optional[DatabaseSchema] = None 

85 

86 # Embeddings 

87 embeddings_model: Embeddings 

88 search_kwargs: SearchKwargs 

89 

90 # prompt templates 

91 rewrite_prompt_template: str 

92 

93 # schema templates 

94 table_prompt_template: str 

95 column_prompt_template: str 

96 value_prompt_template: str 

97 

98 # formatting templates 

99 boolean_system_prompt: str 

100 generative_system_prompt: str 

101 

102 # SQL search config 

103 num_retries: int 

104 embeddings_table: str 

105 source_table: str 

106 source_id_column: str 

107 distance_function: DistanceFunction 

108 

109 # Re-rank and metadata generation model. 

110 llm: BaseChatModel 

111 

112 def _sort_schema_by_priority_key( 

113 self, 

114 schema_dict_item: Tuple[str, Union[TableSchema, ColumnSchema, ValueSchema]], 

115 ): 

116 return schema_dict_item[1].priority 

117 

118 def _sort_schema_by_relevance_key( 

119 self, 

120 schema_dict_item: Tuple[str, Union[TableSchema, ColumnSchema, ValueSchema]], 

121 ): 

122 if schema_dict_item[1].relevance is not None: 

123 return schema_dict_item[1].relevance 

124 else: 

125 return 0 

126 

127 def _sort_schema_by_key( 

128 self, 

129 schema: Union[DatabaseSchema, TableSchema, ColumnSchema], 

130 key: Callable, 

131 update: Dict[str, Any] = None, 

132 ) -> Union[DatabaseSchema, TableSchema, ColumnSchema]: 

133 """Takes a schema and converts its dict into an OrderedDict""" 

134 if isinstance(schema, DatabaseSchema): 

135 collection_key = "tables" 

136 elif isinstance(schema, TableSchema): 

137 collection_key = "columns" 

138 elif isinstance(schema, ColumnSchema): 

139 collection_key = "values" 

140 else: 

141 raise Exception("schema must be either a DatabaseSchema, TableSchema, or ColumnSchema.") 

142 

143 if update is not None: 

144 ordered = collections.OrderedDict(sorted(update.items(), key=key, reverse=True)) 

145 else: 

146 ordered = collections.OrderedDict(sorted(getattr(schema, collection_key).items(), key=key, reverse=True)) 

147 schema = schema.model_copy(update={collection_key: ordered}) 

148 

149 return schema 

150 

151 def _sort_database_schema_by_key(self, database_schema: DatabaseSchema, key: Callable) -> DatabaseSchema: 

152 """Re-build schema with OrderedDicts""" 

153 tables = {} 

154 # build new tables dict 

155 for table_key, table_schema in database_schema.tables.items(): 

156 columns = {} 

157 # build new column dict 

158 for column_key, column_schema in table_schema.columns.items(): 

159 # sort values directly and update column schema 

160 columns[column_key] = self._sort_schema_by_key(schema=column_schema, key=key) 

161 # update table schema and sort 

162 tables[table_key] = self._sort_schema_by_key(schema=table_schema, key=key, update=columns) 

163 # update table schema and sort 

164 database_schema = self._sort_schema_by_key(schema=database_schema, key=key, update=tables) 

165 

166 return database_schema 

167 

168 def _prepare_value_prompt( 

169 self, 

170 value_schema: ValueSchema, 

171 column_schema: ColumnSchema, 

172 table_schema: TableSchema, 

173 boolean_system_prompt: bool = True, 

174 format_instructions: Optional[str] = None, 

175 ) -> ChatPromptTemplate: 

176 if boolean_system_prompt is True: 

177 system_prompt = self.boolean_system_prompt 

178 else: 

179 system_prompt = self.generative_system_prompt 

180 

181 prepared_column_prompt = self._prepare_column_prompt(column_schema=column_schema, table_schema=table_schema) 

182 column_schema_str = ( 

183 prepared_column_prompt.messages[1] 

184 .format( 

185 **prepared_column_prompt.partial_variables, 

186 query="See query at the lowest level schema.", 

187 ) 

188 .content 

189 ) 

190 

191 base_prompt_template = ChatPromptTemplate.from_messages( 

192 [("system", system_prompt), ("user", self.value_prompt_template)] 

193 ) 

194 

195 value_str = "" 

196 header_str = "" 

197 if type(value_schema.value) in [str, int, float, bool]: 

198 header_str = f"This schema describes a single value in the {column_schema.column} column." 

199 

200 value_str = f""" 

201 -**Value**: {value_schema.value} 

202""" 

203 

204 elif type(value_schema.value) is dict: 

205 header_str = f"This schema describes enumerated values in the {column_schema.column} column." 

206 

207 value_str = """ 

208## **Enumerated Values** 

209 

210The values in the column are an enumeration of named values. These are listed below with format **[Column Value]**: [named value]. 

211""" 

212 for value, value_name in value_schema.value.items(): 

213 value_str += f""" 

214- **{value}:** {value_name}""" 

215 

216 elif type(value_schema.value) is list: 

217 header_str = f"This schema describes some of the values in the {column_schema.column} column." 

218 

219 value_str = """ 

220## **Sample Values** 

221 

222There are too many values in this column to list exhaustively. Below is a sampling of values found in the column: 

223""" 

224 for value in value_schema.value: 

225 value_str += f""" 

226- {value}""" 

227 

228 if getattr(value_schema, "comparator", None) is not None: 

229 comparator_str = """ 

230 

231## **Comparators** 

232 

233Below is a list of comparison operators for constructing filters for this value schema: 

234""" 

235 if type(value_schema.comparator) is str: 

236 comparator_str += f"""- {value_schema.comparator} 

237""" 

238 else: 

239 for comp in value_schema.comparator: 

240 comparator_str += f"""- {comp} 

241""" 

242 else: 

243 comparator_str = "" 

244 

245 if getattr(value_schema, "example_questions", None) is not None: 

246 example_str = """## **Example Questions** 

247""" 

248 for i, example in enumerate(value_schema.example_questions): 

249 example_str += f"""{i}. **Query:** {example.input} **Answer:** {example.output} 

250""" 

251 else: 

252 example_str = "" 

253 

254 return base_prompt_template.partial( 

255 format_instructions=format_instructions, 

256 header=header_str, 

257 column_schema=column_schema_str, 

258 value=value_str, 

259 comparator=comparator_str, 

260 type=value_schema.type, 

261 description=value_schema.description, 

262 usage=value_schema.usage, 

263 examples=example_str, 

264 ) 

265 

266 def _prepare_column_prompt( 

267 self, 

268 column_schema: ColumnSchema, 

269 table_schema: TableSchema, 

270 boolean_system_prompt: bool = True, 

271 ) -> ChatPromptTemplate: 

272 if boolean_system_prompt is True: 

273 system_prompt = self.boolean_system_prompt 

274 else: 

275 system_prompt = self.generative_system_prompt 

276 

277 prepared_table_prompt = self._prepare_table_prompt( 

278 table_schema=table_schema, boolean_system_prompt=boolean_system_prompt 

279 ) 

280 table_schema_str = ( 

281 prepared_table_prompt.messages[1] 

282 .format( 

283 **prepared_table_prompt.partial_variables, 

284 query="See query at the lowest level schema", 

285 ) 

286 .content 

287 ) 

288 

289 base_prompt_template = ChatPromptTemplate.from_messages( 

290 [("system", system_prompt), ("user", self.column_prompt_template)] 

291 ) 

292 

293 header_str = f"This schema describes a column in the {table_schema.table} table." 

294 

295 value_str = """ 

296## **Content** 

297 

298Below is a description of the contents in this column in list format: 

299""" 

300 for value_schema in column_schema.values.values(): 

301 value_str += f""" 

302- {value_schema.description} 

303""" 

304 value_str += """ 

305**Important:** The above descriptions are not the actual values stored in this column. See the Value schema for actual values. 

306""" 

307 

308 if getattr(column_schema, "examples", None) is not None: 

309 example_str = """## **Example Questions** 

310""" 

311 for example in column_schema.examples: 

312 example_str += f"""- {example} 

313""" 

314 else: 

315 example_str = "" 

316 

317 return base_prompt_template.partial( 

318 table_schema=table_schema_str, 

319 header=header_str, 

320 column=column_schema.column, 

321 type=column_schema.type, 

322 description=column_schema.description, 

323 usage=column_schema.usage, 

324 values=value_str, 

325 examples=example_str, 

326 ) 

327 

328 def _prepare_table_prompt( 

329 self, table_schema: TableSchema, boolean_system_prompt: bool = True 

330 ) -> ChatPromptTemplate: 

331 if boolean_system_prompt is True: 

332 system_prompt = self.boolean_system_prompt 

333 else: 

334 system_prompt = self.generative_system_prompt 

335 

336 base_prompt_template = ChatPromptTemplate.from_messages( 

337 [("system", system_prompt), ("user", self.table_prompt_template)] 

338 ) 

339 

340 header_str = "This schema describes a table in the database." 

341 

342 columns_str = "" 

343 for column_key, column_schema in table_schema.columns.items(): 

344 columns_str += f""" 

345- **{column_schema.column}:** {column_schema.description} 

346""" 

347 

348 if getattr(table_schema, "examples", None) is not None: 

349 example_str = """## **Example Questions** 

350""" 

351 for example in table_schema.examples: 

352 example_str += f"""- {example} 

353""" 

354 else: 

355 example_str = "" 

356 

357 return base_prompt_template.partial( 

358 header=header_str, 

359 table=table_schema.table, 

360 description=table_schema.description, 

361 usage=table_schema.usage, 

362 columns=columns_str, 

363 examples=example_str, 

364 ) 

365 

366 def _rank_schema(self, prompt: ChatPromptTemplate, query: str) -> float: 

367 rank_chain = LLMChain(llm=self.llm.bind(logprobs=True), prompt=prompt, return_final_only=False) 

368 output = rank_chain({"query": query}) # returns metadata 

369 

370 # parse through metadata tokens until encountering either yes, or no. 

371 score = None # a None score indicates the model output could not be parsed. 

372 for content in output["full_generation"][0].message.response_metadata["logprobs"]["content"]: 

373 # Convert answer to score using the model's confidence 

374 if content["token"].lower().strip() == "yes": 

375 score = (1 + math.exp(content["logprob"])) / 2 # If yes, use the model's confidence 

376 break 

377 elif content["token"].lower().strip() == "no": 

378 score = (1 - math.exp(content["logprob"])) / 2 # If no, invert the confidence 

379 break 

380 

381 if score is None: 

382 score = 0.0 

383 

384 return score 

385 

386 def _breadth_first_search(self, query: str, greedy: bool = False) -> Tuple: 

387 """Search breadth wise through Tables, then Columns, then Values.Uses a greedy strategy to maximize quota if greedy=True, otherwise a dynamic strategy.""" 

388 

389 # sort based on priority 

390 ordered_database_schema = self._sort_database_schema_by_key( 

391 database_schema=self.database_schema, key=self._sort_schema_by_priority_key 

392 ) 

393 

394 # Rank Tables ######################################################## 

395 greedy_count = 0 

396 tables = {} 

397 # rank tables by relevance 

398 for table_key, table_schema in ordered_database_schema.tables.items(): 

399 prompt: ChatPromptTemplate = self._prepare_table_prompt( 

400 table_schema=table_schema, boolean_system_prompt=True 

401 ) 

402 table_schema.relevance = self._rank_schema(prompt=prompt, query=query) 

403 

404 # only keep greedy tables 

405 tables[table_key] = table_schema 

406 

407 if greedy: 

408 if table_schema.relevance >= ordered_database_schema.filter_threshold: 

409 greedy_count += 1 

410 if greedy_count >= ordered_database_schema.max_filters: 

411 break 

412 

413 # sort tables 

414 ordered_database_schema = self._sort_schema_by_key( 

415 schema=ordered_database_schema, 

416 key=self._sort_schema_by_relevance_key, 

417 update=tables, 

418 ) 

419 

420 # Rank Columns ####################################################### 

421 # iterate through tables to rank columns 

422 tables = {} 

423 table_count = 0 # take only the top n number of tables specified by the databases max filters 

424 for table_key, table_schema in ordered_database_schema.tables.items(): 

425 # only drop into tables above the filter threshold 

426 if table_schema.relevance >= ordered_database_schema.filter_threshold: 

427 greedy_count = 0 

428 # rank columns by relevance 

429 columns = {} 

430 for column_key, column_schema in table_schema.columns.items(): 

431 prompt: ChatPromptTemplate = self._prepare_column_prompt( 

432 column_schema=column_schema, 

433 table_schema=table_schema, 

434 boolean_system_prompt=True, 

435 ) 

436 column_schema.relevance = self._rank_schema(prompt=prompt, query=query) 

437 

438 columns[column_key] = column_schema 

439 

440 if greedy: 

441 if column_schema.relevance >= table_schema.filter_threshold: 

442 greedy_count += 1 

443 if greedy_count >= table_schema.max_filters: 

444 break 

445 

446 # sort columns and keep only columns that made the cut. 

447 tables[table_key] = self._sort_schema_by_key( 

448 table_schema, key=self._sort_schema_by_relevance_key, update=columns 

449 ) 

450 

451 table_count += 1 

452 if table_count >= ordered_database_schema.max_filters: 

453 break 

454 

455 # sort tables and keep only tables that made the cut. 

456 ordered_database_schema = self._sort_schema_by_key( 

457 ordered_database_schema, 

458 key=self._sort_schema_by_relevance_key, 

459 update=tables, 

460 ) 

461 

462 # Rank Values ######################################################## 

463 # iterate through tables to rank values 

464 tables = {} 

465 for table_key, table_schema in ordered_database_schema.tables.items(): 

466 columns = {} 

467 column_count = 0 

468 # iterate through columns to rank values 

469 for column_key, column_schema in table_schema.columns.items(): 

470 if column_schema.relevance >= table_schema.filter_threshold: 

471 greedy_count = 0 

472 values = {} 

473 # rank values by relevance 

474 for value_key, value_schema in column_schema.values.items(): 

475 prompt: ChatPromptTemplate = self._prepare_value_prompt( 

476 value_schema=value_schema, 

477 column_schema=column_schema, 

478 table_schema=table_schema, 

479 boolean_system_prompt=True, 

480 ) 

481 value_schema.relevance = self._rank_schema(prompt=prompt, query=query) 

482 

483 values[value_key] = value_schema 

484 

485 if greedy: 

486 if value_schema.relevance >= column_schema.filter_threshold: 

487 greedy_count += 1 

488 if greedy_count >= column_schema.max_filters: 

489 break 

490 

491 # sort values and keep only values that make the cut 

492 columns[column_key] = self._sort_schema_by_key( 

493 column_schema, 

494 key=self._sort_schema_by_relevance_key, 

495 update=values, 

496 ) 

497 

498 column_count += 1 

499 if column_count >= table_schema.max_filters: 

500 break 

501 

502 # sort columns and keep only columns that made the cut 

503 tables[table_key] = self._sort_schema_by_key( 

504 table_schema, key=self._sort_schema_by_relevance_key, update=columns 

505 ) 

506 

507 # sort tables and keep only tables that made the cut. 

508 ordered_database_schema = self._sort_schema_by_key( 

509 ordered_database_schema, 

510 key=self._sort_schema_by_relevance_key, 

511 update=tables, 

512 ) 

513 

514 # discard low ranked values ################################################################################### 

515 tables = {} 

516 for table_key, table_schema in ordered_database_schema.tables.items(): 

517 columns = {} 

518 # iterate through columns to rank values 

519 for column_key, column_schema in table_schema.columns.items(): 

520 value_count = 0 

521 values = {} 

522 # rank values by relevance 

523 for value_key, value_schema in column_schema.values.items(): 

524 if value_schema.relevance >= column_schema.filter_threshold: 

525 values[value_key] = value_schema 

526 

527 value_count += 1 

528 if value_count >= column_schema.max_filters: 

529 break 

530 

531 # sort values and keep only values that make the cut 

532 columns[column_key] = self._sort_schema_by_key( 

533 column_schema, 

534 key=self._sort_schema_by_relevance_key, 

535 update=values, 

536 ) 

537 

538 # sort columns and keep only columns that made the cut 

539 tables[table_key] = self._sort_schema_by_key( 

540 table_schema, key=self._sort_schema_by_relevance_key, update=columns 

541 ) 

542 

543 # sort tables and keep only tables that made the cut. 

544 ordered_database_schema = self._sort_schema_by_key( 

545 ordered_database_schema, 

546 key=self._sort_schema_by_relevance_key, 

547 update=tables, 

548 ) 

549 

550 ranked_database_schema = ordered_database_schema 

551 

552 # Build Ablation ##################################################### 

553 

554 ablation_value_dict = {} 

555 # assemble a relevance dictionary 

556 for table_key, table_schema in ordered_database_schema.tables.items(): 

557 for column_key, column_schema in table_schema.columns.items(): 

558 for value_key, value_schema in column_schema.values.items(): 

559 ablation_value_dict[(table_key, column_key, value_key)] = value_schema.relevance 

560 

561 ablation_value_dict = collections.OrderedDict(sorted(ablation_value_dict.items(), key=lambda x: x[1])) 

562 

563 relevance_scores = list(ablation_value_dict.values()) 

564 if len(relevance_scores) > 0: 

565 ablation_quantiles = np.quantile(relevance_scores, np.linspace(0, 1, self.num_retries + 2)[1:-1]) 

566 else: 

567 ablation_quantiles = None 

568 

569 return ranked_database_schema, ablation_value_dict, ablation_quantiles 

570 

571 def _dynamic_ablation( 

572 self, 

573 metadata_filters: List[AblativeMetadataFilter], 

574 ablation_value_dict, 

575 ablation_quantiles, 

576 retry: int, 

577 ): 

578 """Ablate metadata filters in aggregate by quantiles until the required minimum number of documents are returned.""" 

579 

580 ablated_dict = {} 

581 for key, value in ablation_value_dict.items(): 

582 if value >= ablation_quantiles[retry]: 

583 ablated_dict[key] = value 

584 

585 # discard low ranked filters ################################################################################## 

586 ablated_filters = [] 

587 for filter in metadata_filters: 

588 for key in ablated_dict.keys(): 

589 if filter.schema_table in key and filter.schema_column in key and filter.schema_value in key: 

590 ablated_filters.append(filter) 

591 

592 return ablated_filters 

593 

594 def depth_first_search(self, greedy=True): 

595 """Search depth wise through Tables, then Columns, then Values. Uses a greedy strategy to maximize quota if greedy=True, otherwise a dynamic strategy.""" 

596 pass 

597 

598 def depth_first_ablation(self): 

599 """Ablate metadata filters in reverse depth first search until the required minimum number of documents are returned.""" 

600 pass 

601 

602 def _prepare_retrieval_query(self, query: str) -> str: 

603 rewrite_prompt = PromptTemplate(input_variables=["input"], template=self.rewrite_prompt_template) 

604 rewrite_chain = LLMChain(llm=self.llm, prompt=rewrite_prompt) 

605 return rewrite_chain.predict(input=query) 

606 

607 def _prepare_pgvector_query( 

608 self, 

609 ranked_database_schema: DatabaseSchema, 

610 metadata_filters: List[AblativeMetadataFilter], 

611 retry: int = 0, 

612 ) -> str: 

613 # Base select JOINed with document source table. 

614 base_query = f"""SELECT * FROM {self.embeddings_table} AS e INNER JOIN {self.source_table} AS s ON (e.metadata->>'original_row_id')::int = s."{self.source_id_column}" """ 

615 

616 # return an empty string if schema has not been ranked 

617 if not ranked_database_schema: 

618 return "" 

619 

620 # Add Table JOIN statements 

621 join_clauses = set() 

622 for metadata_filter in metadata_filters: 

623 join_clause = ranked_database_schema.tables[metadata_filter.schema_table].join 

624 if join_clause in join_clauses: 

625 continue 

626 else: 

627 join_clauses.add(join_clause) 

628 base_query += join_clause + " " 

629 

630 # Add WHERE conditions from metadata filters 

631 if metadata_filters: 

632 base_query += "WHERE " 

633 for i, filter in enumerate(metadata_filters): 

634 value = filter.value 

635 if isinstance(value, str): 

636 value = f"'{value}'" 

637 base_query += f'"{filter.attribute}" {filter.comparator} {value}' 

638 if i < len(metadata_filters) - 1: 

639 base_query += " AND " 

640 

641 base_query += ( 

642 f" ORDER BY e.embeddings {self.distance_function.value[0]} '{{embeddings}}' LIMIT {self.search_kwargs.k};" 

643 ) 

644 return base_query 

645 

646 def _generate_filter(self, prompt: ChatPromptTemplate, query: str) -> MetadataFilter: 

647 gen_filter_chain = LLMChain(llm=self.llm, prompt=prompt) 

648 output = gen_filter_chain({"query": query}) 

649 return output 

650 

651 def _generate_metadata_filters( 

652 self, query: str, ranked_database_schema 

653 ) -> Union[List[AblativeMetadataFilter], HandlerResponse]: 

654 parser = PydanticOutputParser(pydantic_object=MetadataFilter) 

655 

656 metadata_filter_list = [] 

657 # iterate through tables to rank values 

658 for table_key, table_schema in ranked_database_schema.tables.items(): 

659 # iterate through columns to rank values 

660 for column_key, column_schema in table_schema.columns.items(): 

661 if column_schema.relevance >= table_schema.filter_threshold: 

662 # generate filters 

663 for value_key, value_schema in column_schema.values.items(): 

664 # must use generation if field is a dictionary of tuples or a list 

665 if type(value_schema.value) in [list, dict]: 

666 try: 

667 metadata_prompt: ChatPromptTemplate = self._prepare_value_prompt( 

668 format_instructions=parser.get_format_instructions(), 

669 value_schema=value_schema, 

670 column_schema=column_schema, 

671 table_schema=table_schema, 

672 boolean_system_prompt=False, 

673 ) 

674 

675 metadata_filters_chain = LLMChain(llm=self.llm, prompt=metadata_prompt) 

676 metadata_filter_output = metadata_filters_chain.predict( 

677 query=query, 

678 ) 

679 

680 # If the LLM outputs raw JSON, use it as-is. 

681 # If the LLM outputs anything including a json markdown section, use the last one. 

682 json_markdown_output = re.findall(r"```json.*```", metadata_filter_output, re.DOTALL) 

683 if json_markdown_output: 

684 metadata_filter_output = json_markdown_output[-1] 

685 # Clean the json tags. 

686 metadata_filter_output = metadata_filter_output[7:] 

687 metadata_filter_output = metadata_filter_output[:-3] 

688 

689 metadata_filter = parser.invoke(metadata_filter_output) 

690 model_dump = metadata_filter.model_dump() 

691 model_dump.update( 

692 { 

693 "schema_table": table_key, 

694 "schema_column": column_key, 

695 "schema_value": value_key, 

696 } 

697 ) 

698 metadata_filter = AblativeMetadataFilter(**model_dump) 

699 except OutputParserException as e: 

700 logger.warning( 

701 f"LLM failed to generate structured metadata filters: {e}", 

702 exc_info=logger.isEnabledFor(logging.DEBUG), 

703 ) 

704 return HandlerResponse(RESPONSE_TYPE.ERROR, error_message=str(e)) 

705 else: 

706 metadata_filter = AblativeMetadataFilter( 

707 attribute=column_schema.column, 

708 comparator=value_schema.comparator, 

709 value=value_schema.value, 

710 schema_table=table_key, 

711 schema_column=column_key, 

712 schema_value=value_key, 

713 ) 

714 metadata_filter_list.append(metadata_filter) 

715 

716 return metadata_filter_list 

717 

718 def _prepare_and_execute_query( 

719 self, 

720 ranked_database_schema: DatabaseSchema, 

721 metadata_filters: List[AblativeMetadataFilter], 

722 embeddings_str: str, 

723 ) -> HandlerResponse: 

724 try: 

725 checked_sql_query = self._prepare_pgvector_query(ranked_database_schema, metadata_filters) 

726 checked_sql_query_with_embeddings = checked_sql_query.format(embeddings=embeddings_str) 

727 return self.vector_store_handler.native_query(checked_sql_query_with_embeddings) 

728 except Exception as e: 

729 logger.warning( 

730 f"Failed to prepare and execute SQL query from structured metadata: {e}", 

731 exc_info=logger.isEnabledFor(logging.DEBUG), 

732 ) 

733 return HandlerResponse(RESPONSE_TYPE.ERROR, error_message=str(e)) 

734 

735 def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]: 

736 # Rewrite query to be suitable for retrieval. 

737 retrieval_query = self._prepare_retrieval_query(query) 

738 

739 # Embed the rewritten retrieval query & include it in the similarity search pgvector query. 

740 embedded_query = self.embeddings_model.embed_query(retrieval_query) 

741 

742 # Search for relevant filters 

743 ranked_database_schema, ablation_value_dict, ablation_quantiles = self._breadth_first_search(query=query) 

744 

745 # Generate metadata filters 

746 metadata_filters = self._generate_metadata_filters(query=query, ranked_database_schema=ranked_database_schema) 

747 

748 if type(metadata_filters) is list: 

749 # Initial Execution of the similarity search with metadata filters. 

750 document_response = self._prepare_and_execute_query( 

751 ranked_database_schema=ranked_database_schema, 

752 metadata_filters=metadata_filters, 

753 embeddings_str=str(embedded_query), 

754 ) 

755 num_retries = 0 

756 while num_retries < self.num_retries: 

757 if ( 

758 document_response.resp_type != RESPONSE_TYPE.ERROR 

759 and len(document_response.data_frame) >= self.min_k 

760 ): 

761 # Successfully retrieved k documents to send to re-ranker. 

762 break 

763 elif document_response.resp_type == RESPONSE_TYPE.ERROR: 

764 # LLMs won't always generate structured metadata so we should have a fallback after retrying. 

765 logger.info(f"SQL Retriever query failed with error {document_response.error_message}") 

766 else: 

767 logger.info( 

768 f"SQL Retriever did not retrieve {self.min_k} documents: {len(document_response.data_frame)} documents retrieved." 

769 ) 

770 

771 ablated_metadata_filters = self._dynamic_ablation( 

772 metadata_filters=metadata_filters, 

773 ablation_value_dict=ablation_value_dict, 

774 ablation_quantiles=ablation_quantiles, 

775 retry=num_retries, 

776 ) 

777 

778 document_response = self._prepare_and_execute_query( 

779 ranked_database_schema=ranked_database_schema, 

780 metadata_filters=ablated_metadata_filters, 

781 embeddings_str=str(embedded_query), 

782 ) 

783 

784 num_retries += 1 

785 

786 retrieved_documents = [] 

787 if document_response.resp_type != RESPONSE_TYPE.ERROR: 

788 document_df = document_response.data_frame 

789 for _, document_row in document_df.iterrows(): 

790 retrieved_documents.append( 

791 Document( 

792 document_row.get("content", ""), 

793 metadata=document_row.get("metadata", {}), 

794 ) 

795 ) 

796 if retrieved_documents: 

797 return retrieved_documents 

798 

799 # If the SQL query constructed did not return any documents, fallback. 

800 logger.info("No documents returned from SQL retriever, using fallback retriever.") 

801 return self.fallback_retriever._get_relevant_documents(retrieval_query, run_manager=run_manager) 

802 else: 

803 # If no metadata fields could be generated fallback. 

804 logger.info("No metadata fields were successfully generated, using fallback retriever.") 

805 return self.fallback_retriever._get_relevant_documents(retrieval_query, run_manager=run_manager)