Coverage for mindsdb / integrations / utilities / rag / settings.py: 82%
283 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from enum import Enum
2from functools import lru_cache
3from typing import List, Union, Any, Optional, Dict, OrderedDict, TYPE_CHECKING
5from pydantic import BaseModel, Field, field_validator, ConfigDict
7if TYPE_CHECKING: # pragma: no cover - import only for type hints
8 from langchain_core.documents import Document
9 from langchain_core.embeddings import Embeddings
10 from langchain_core.language_models import BaseChatModel
11 from langchain_core.vectorstores import VectorStore
12 from langchain_core.stores import BaseStore
13 from langchain_text_splitters import TextSplitter
14else: # Avoid importing heavy optional dependencies at runtime
15 Document = Embeddings = BaseChatModel = VectorStore = BaseStore = TextSplitter = Any
18def _require_kb_dependency(feature: str, exc: ModuleNotFoundError):
19 missing = exc.name or "langchain dependency"
20 raise ImportError(
21 f"{feature} requires the optional knowledge base dependencies (missing {missing}). "
22 "Install them via `pip install mindsdb[kb]`."
23 ) from exc
26@lru_cache(maxsize=1)
27def _load_vector_store_classes():
28 try:
29 from langchain_community.vectorstores.chroma import Chroma
30 from langchain_community.vectorstores.pgvector import PGVector
31 except ModuleNotFoundError as exc: # pragma: no cover - runtime guard
32 if getattr(exc, "name", "").startswith("langchain") or "langchain" in str(exc):
33 _require_kb_dependency("Vector store configuration", exc)
34 raise
35 return {"chromadb": Chroma, "pgvector": PGVector}
38DEFAULT_COLLECTION_NAME = "default_collection"
40# Multi retriever specific
41DEFAULT_ID_KEY = "doc_id"
42DEFAULT_MAX_CONCURRENCY = 5
43DEFAULT_K = 20
45DEFAULT_CARDINALITY_THRESHOLD = 40
46DEFAULT_MAX_SUMMARIZATION_TOKENS = 4000
47DEFAULT_CHUNK_SIZE = 1000
48DEFAULT_CHUNK_OVERLAP = 200
49DEFAULT_POOL_RECYCLE = 3600
50DEFAULT_LLM_MODEL = "gpt-4o"
51DEFAULT_LLM_MODEL_PROVIDER = "openai"
52DEFAULT_CONTENT_COLUMN_NAME = "body"
53DEFAULT_DATASET_DESCRIPTION = "email inbox"
54DEFAULT_TEST_TABLE_NAME = "test_email"
55DEFAULT_RERANKER_FLAG = False
56DEFAULT_RERANKING_MODEL = "gpt-4o"
57DEFAULT_LLM_ENDPOINT = "https://api.openai.com/v1"
58DEFAULT_RERANKER_N = 1
59DEFAULT_RERANKER_LOGPROBS = True
60DEFAULT_RERANKER_TOP_LOGPROBS = 4
61DEFAULT_RERANKER_MAX_TOKENS = 100
62DEFAULT_VALID_CLASS_TOKENS = ["1", "2", "3", "4"]
63DEFAULT_AUTO_META_PROMPT_TEMPLATE = """
64Below is a json representation of a table with information about {description}.
65Return a JSON list with an entry for each column. Each entry should have
66{{"name": "column name", "description": "column description", "type": "column data type"}}
67\n\n{dataframe}\n\nJSON:\n
68"""
69DEFAULT_RAG_PROMPT_TEMPLATE = """You are an assistant for
70question-answering tasks. Use the following pieces of retrieved context
71to answer the question. If you don't know the answer, just say that you
72don't know. Use two sentences maximum and keep the answer concise.
73Question: {question}
74Context: {context}
75Answer:"""
77DEFAULT_QA_GENERATION_PROMPT_TEMPLATE = """You are an assistant for
78generating sample questions and answers from the given document and metadata. Given
79a document and its metadata as context, generate a question and answer from that document and its metadata.
81The document will be a string. The metadata will be a JSON string. You need
82to parse the JSON to understand it.
84Generate a question that requires BOTH the document and metadata to answer, if possible.
85Otherwise, generate a question that requires ONLY the document to answer.
87Return a JSON dictionary with the question and answer like this:
88{{ "question": <the full generated question>, "answer": <the full generated answer> }}
90Make sure the JSON string is valid before returning it. You must return the question and answer
91in the specified JSON format no matter what.
93Document: {document}
94Metadata: {metadata}
95Answer:"""
97DEFAULT_MAP_PROMPT_TEMPLATE = """The following is a set of documents
98{docs}
99Based on this list of docs, please summarize based on the user input.
101User input: {input}
103Helpful Answer:"""
105DEFAULT_REDUCE_PROMPT_TEMPLATE = """The following is set of summaries:
106{docs}
107Take these and distill it into a final, consolidated summary related to the user input.
109User input: {input}
111Helpful Answer:"""
113DEFAULT_SEMANTIC_PROMPT_TEMPLATE = """Provide a better search query for web search engine to answer the given question.
115<< EXAMPLES >>
1161. Input: "Show me documents containing how to finetune a LLM please"
117Output: "how to finetune a LLM"
119Output only a single better search query and nothing else like in the example.
121Here is the user input: {input}
122"""
124DEFAULT_METADATA_FILTERS_PROMPT_TEMPLATE = """Construct a list of PostgreSQL metadata filters to filter documents in the database based on the user input.
126<< INSTRUCTIONS >>
127{format_instructions}
129RETURN ONLY THE FINAL JSON. DO NOT EXPLAIN, JUST RETURN THE FINAL JSON.
131<< TABLES YOU HAVE ACCESS TO >>
133{schema}
135<< EXAMPLES >>
137{examples}
139Here is the user input:
140{input}
141"""
143DEFAULT_BOOLEAN_PROMPT_TEMPLATE = """**Task:** Determine Schema Relevance for Database Search Queries
145As an expert in constructing database search queries, you are provided with database schemas detailing tables, columns, and values. Your task is to assess whether these elements can be used to effectively search the database in relation to a given user query.
147**Instructions:**
149- **Evaluate the Schema**:
150 - Analyze the tables, columns, and values described.
151 - Consider their potential usefulness in retrieving information pertinent to the user query.
153- **Decision Criteria**:
154 - Determine if any part of the schema could assist in forming a relevant search query for the information requested.
156- **Response**:
157 - Reply with a single word: 'yes' if the schema components are useful, otherwise 'no'.
159**Note:** Provide your answer based solely on the relevance of the described schema to the user query."""
161DEFAULT_GENERATIVE_SYSTEM_PROMPT = """You are an expert database analyst that can assist in building SQL queries by providing structured output. Follow these format instructions precisely to generate a metadata filter given the provided schema description.
163## Format instructions:
164{format_instructions}
165 """
167DEFAULT_VALUE_PROMPT_TEMPLATE = """
168{column_schema}
170# **Value Schema**
171{header}
173- The type of the value: {type}
175## **Description**
176{description}
178{value}{comparator}
180## **Usage**
181{usage}
183{examples}
185## **Query**
186{query}
188"""
190DEFAULT_COLUMN_PROMPT_TEMPLATE = """
191{table_schema}
193# **Column Schema**
194{header}
196- The column name in the database table: {column}
197- The type of the values in this column: {type}
199## **Description**
200{description}
202## **Usage**
203{usage}
205{examples}
207## **Query**
208{query}
209"""
211DEFAULT_TABLE_PROMPT_TEMPLATE = """# **Table Schema**
212{header}
214- The name of this table in the database: {table}
216## **Description**
217{description}
219## **Usage**
220{usage}
222## **Column Descriptions**
223Below are descriptions of each column in this table:
225{columns}
227{examples}
229## **Query**
230{query}
231"""
233DEFAULT_SQL_PROMPT_TEMPLATE = """
234Construct a valid {dialect} SQL query to select documents relevant to the user input.
235Source documents are found in the {source_table} table. You may need to join with other tables to get additional document metadata.
237The JSON col "metadata" in the {embeddings_table} has a string field called "original_row_id". This "original_row_id" string field in the
238"metadata" col is the document ID associated with a row in the {embeddings_table} table.
239You MUST always join with the {embeddings_table} table containing vector embeddings for the documents. For example, for a table named sd with an id column "Id":
240JOIN {embeddings_table} v ON (v."metadata"->>'original_row_id')::int = sd."Id"
242You MUST always order the embeddings by the {distance_function} comparator with '{{embeddings}}'.
243You MUST always limit by {k} returned documents.
244For example:
245ORDER BY v.embeddings {distance_function} '{{embeddings}}' LIMIT {k};
248<< TABLES YOU HAVE ACCESS TO >>
2491. {embeddings_table} - Contains document chunks, vector embeddings, and metadata for documents.
250You MUST always include the metadata column in your SELECT statement.
251You MUST always join with the {embeddings_table} table containing vector embeddings for the documents.
252You MUST always order by the provided embeddings vector using the {distance_function} comparator.
253You MUST always limit by {k} returned documents.
255Columns:
256```json
257{{
258 "id": {{
259 "type": "string",
260 "description": "Unique ID for this document chunk"
261 }},
262 "content": {{
263 "type": "string",
264 "description": "A document chunk (subset of the original document)"
265 }},
266 "embeddings": {{
267 "type": "vector",
268 "description": "Vector embeddings for the document chunk. ALWAYS order by the provided embeddings vector using the {distance_function} comparator."
269 }},
270 "metadata": {{
271 "type": "jsonb",
272 "description": "Metadata for the document chunk. Always select metadata and always join with the {source_table} table on the string metadata field 'original_row_id'"
273 }}
274}}
276{schema}
278<< EXAMPLES >>
280{examples}
282Output the {dialect} SQL query that is ready to be executed only WITHOUT ANY DELIMITERS. Make sure to properly quote identifiers.
284Here is the user input:
285{input}
286"""
288DEFAULT_QUESTION_REFORMULATION_TEMPLATE = """Given the original question and the retrieved context,
289analyze what additional information is needed for a complete, accurate answer.
291Original Question: {question}
293Retrieved Context:
294{context}
296Analysis Instructions:
2971. Evaluate Context Coverage:
298 - Identify key entities and concepts from the question
299 - Check for temporal information (dates, periods, sequences)
300 - Verify causal relationships are explained
301 - Confirm presence of requested quantitative data
302 - Assess if geographic or spatial context is sufficient
3042. Quality Assessment:
305 If the retrieved context is:
306 - Irrelevant or tangential
307 - Too general or vague
308 - Potentially contradictory
309 - Missing key perspectives
310 - Lacking proper evidence
311 Generate questions to address these specific gaps.
3133. Follow-up Question Requirements:
314 - Questions must directly contribute to answering the original query
315 - Break complex relationships into simpler, sequential steps
316 - Maintain specificity rather than broad inquiries
317 - Avoid questions answerable from existing context
318 - Ensure questions build on each other logically
319 - Limit questions to 150 characters each
320 - Each question must be self-contained
321 - Questions must end with a question mark
3234. Response Format:
324 - Return a JSON array of strings
325 - Use square brackets and double quotes
326 - Questions must be unique (no duplicates)
327 - If context is sufficient, return empty array []
328 - Maximum 3 follow-up questions
329 - Minimum length per question: 30 characters
330 - No null values or empty strings
332Example:
333Original: "How did the development of antibiotics affect military casualties in WWII?"
335Invalid responses:
336{'questions': ['What are antibiotics?']} // Wrong format
337['What is WWII?'] // Too basic
338['How did it impact things?'] // Too vague
339['', 'Question 2'] // Contains empty string
340['Same question?', 'Same question?'] // Duplicate
342Valid response:
343["What were military casualty rates from infections before widespread antibiotic use in 1942?",
344 "How did penicillin availability change throughout different stages of WWII?",
345 "What were the primary battlefield infections treated with antibiotics during WWII?"]
347or [] if context fully answers the original question.
349Your task: Based on the analysis of the original question and context,
350output ONLY a JSON array of follow-up questions needed to provide a complete answer.
351If no additional information is needed, output an empty array [].
353Follow-up Questions:"""
355DEFAULT_QUERY_RETRY_PROMPT_TEMPLATE = """
356{query}
358The {dialect} query above failed with the error message: {error}.
360<< TABLES YOU HAVE ACCESS TO >>
3611. {embeddings_table} - Contains document chunks, vector embeddings, and metadata for documents.
363Columns:
364```json
365{{
366 "id": {{
367 "type": "string",
368 "description": "Unique ID for this document chunk"
369 }},
370 "content": {{
371 "type": "string",
372 "description": "A document chunk (subset of the original document)"
373 }},
374 "embeddings": {{
375 "type": "vector",
376 "description": "Vector embeddings for the document chunk."
377 }},
378 "metadata": {{
379 "type": "jsonb",
380 "description": "Metadata for the document chunk."
381 }}
382}}
384{schema}
386Rewrite the query so it works.
388Output the final SQL query only.
390SQL Query:
391"""
393DEFAULT_NUM_QUERY_RETRIES = 2
396class LLMConfig(BaseModel):
397 model_name: str = Field(default=DEFAULT_LLM_MODEL, description="LLM model to use for generation")
398 provider: str = Field(
399 default=DEFAULT_LLM_MODEL_PROVIDER,
400 description="LLM model provider to use for generation",
401 )
402 params: Dict[str, Any] = Field(default_factory=dict)
403 model_config = ConfigDict(protected_namespaces=())
406class MultiVectorRetrieverMode(Enum):
407 """
408 Enum for MultiVectorRetriever types.
409 """
411 SPLIT = "split"
412 SUMMARIZE = "summarize"
413 BOTH = "both"
416class VectorStoreType(Enum):
417 CHROMA = "chromadb"
418 PGVECTOR = "pgvector"
421def get_vector_store_map():
422 """Return available vector store classes, importing on demand."""
423 classes = _load_vector_store_classes()
424 return {
425 VectorStoreType.CHROMA: classes["chromadb"],
426 VectorStoreType.PGVECTOR: classes["pgvector"],
427 }
430def get_vector_store_class(store_type: "VectorStoreType"):
431 return get_vector_store_map()[store_type]
434class VectorStoreConfig(BaseModel):
435 vector_store_type: VectorStoreType = VectorStoreType.CHROMA
436 persist_directory: str = None
437 collection_name: str = DEFAULT_COLLECTION_NAME
438 connection_string: str = None
439 kb_table: Any = None
440 is_sparse: bool = False
441 vector_size: Optional[int] = None
443 class Config:
444 arbitrary_types_allowed = True
445 extra = "forbid"
448def _default_vector_store_factory():
449 config = VectorStoreConfig()
450 return get_vector_store_class(config.vector_store_type)
453class RetrieverType(str, Enum):
454 """Retriever type for RAG pipeline"""
456 VECTOR_STORE = "vector_store"
457 AUTO = "auto"
458 MULTI = "multi"
459 SQL = "sql"
460 MULTI_HOP = "multi_hop"
463class SearchType(Enum):
464 """
465 Enum for vector store search types.
466 """
468 SIMILARITY = "similarity"
469 MMR = "mmr"
470 SIMILARITY_SCORE_THRESHOLD = "similarity_score_threshold"
473class SearchKwargs(BaseModel):
474 k: int = Field(default=DEFAULT_K, description="Amount of documents to return", ge=1)
475 filter: Optional[Dict[str, Any]] = Field(default=None, description="Filter by document metadata")
476 # For similarity_score_threshold search type
477 score_threshold: Optional[float] = Field(
478 default=None,
479 description="Minimum relevance threshold for similarity_score_threshold search",
480 ge=0.0,
481 le=1.0,
482 )
483 # For MMR search type
484 fetch_k: Optional[int] = Field(default=None, description="Amount of documents to pass to MMR algorithm", ge=1)
485 lambda_mult: Optional[float] = Field(
486 default=None,
487 description="Diversity of results returned by MMR (1=min diversity, 0=max)",
488 ge=0.0,
489 le=1.0,
490 )
492 def model_dump(self, *args, **kwargs):
493 # Override model_dump to exclude None values by default
494 kwargs["exclude_none"] = True
495 return super().model_dump(*args, **kwargs)
498class LLMExample(BaseModel):
499 input: str = Field(description="User input for the example")
500 output: str = Field(description="What the LLM should generate for this example's input")
503class ValueSchema(BaseModel):
504 value: Union[
505 Union[str, int, float],
506 Dict[Union[str, int, float], str],
507 List[Union[str, int, float]],
508 ] = Field(
509 description="One of the following. The value as it exists in the table column. A dict of {table_value: descriptive value, ...}, where table_value is the value in the table. A list of sample values taken from the column."
510 )
511 comparator: Optional[Union[str, List[str]]] = Field(
512 description="The posgtres sql operators used to compare two values. For example: `>`, `<`, `=`, or `%`.",
513 default="=",
514 )
515 type: str = Field(
516 description="A valid postgres type for this value. One of: int, string, float, or bool. When numbers appear they should be of type int or float."
517 )
518 description: str = Field(description="Description of what the value represents.")
519 usage: str = Field(description="How and when to use this value for search.")
520 example_questions: Optional[List[LLMExample]] = Field(
521 default=None, description="Example questions where this value is set."
522 )
523 filter_threshold: Optional[float] = Field(
524 default=0.0,
525 description="Minimum relevance threshold to include metadata filters from this column.",
526 exclude=True,
527 )
528 priority: Optional[int] = Field(
529 default=0,
530 description="Priority level for this column, lower numbers will be processed first.",
531 )
532 relevance: Optional[float] = Field(
533 default=None,
534 description="Relevance computed during search. Should not be set by the end user.",
535 exclude=True,
536 )
539class MetadataConfig(BaseModel):
540 """Class to configure metadata for retrieval. Only supports very basic document name lookup at the moment."""
542 table: str = Field(description="Source table for metadata.")
543 max_document_context: int = Field(
544 # To work well with models with context window of 32768.
545 default=16384,
546 description="Truncate a document before using as context with an LLM if it exceeds this amount of tokens",
547 )
548 embeddings_table: str = Field(default="embeddings", description="Source table for embeddings")
549 id_column: str = Field(default="Id", description="Name of ID column in metadata table")
550 name_column: str = Field(default="Title", description="Name of column containing name or title of document")
551 name_column_index: Optional[str] = Field(default=None, description="Name of GIN index to use when looking up name.")
552 content_column: str = Field(
553 default="content", description="Name of column in embeddings table containing chunk content"
554 )
555 embeddings_metadata_column: str = Field(
556 default="metadata", description="Name of column in embeddings table containing chunk metadata"
557 )
558 doc_id_key: str = Field(
559 default="original_row_id", description="Metadata field that links an embedded chunk back to source document ID"
560 )
563class ColumnSchema(BaseModel):
564 column: str = Field(description="Name of the column in the database")
565 type: str = Field(description="Type of the column (e.g. int, string, datetime)")
566 description: str = Field(description="Description of what the column represents")
567 usage: str = Field(description="How and when to use this Table for search.")
568 values: Optional[
569 Union[
570 OrderedDict[Union[str, int, float], ValueSchema],
571 Dict[Union[str, int, float], ValueSchema],
572 ]
573 ] = Field(
574 default=None,
575 description="One of the following. A dict or ordered dict of {schema_value: ValueSchema, ...}, where schema value is the name given for this value description in the schema.",
576 )
577 example_questions: Optional[List[LLMExample]] = Field(
578 default=None, description="Example questions where this table is useful."
579 )
580 max_filters: Optional[int] = Field(default=1, description="Maximum number of filters to generate for this column.")
581 filter_threshold: Optional[float] = Field(
582 default=0.0,
583 description="Minimum relevance threshold to include metadata filters from this column.",
584 )
585 priority: Optional[int] = Field(
586 default=1,
587 description="Priority level for this column, lower numbers will be processed first.",
588 )
589 relevance: Optional[float] = Field(
590 default=None,
591 description="Relevance computed during search. Should not be set by the end user.",
592 )
595class TableSchema(BaseModel):
596 table: str = Field(description="Name of table in the database")
597 description: str = Field(description="Description of what the table represents")
598 usage: str = Field(description="How and when to use this Table for search.")
599 columns: Optional[Union[OrderedDict[str, ColumnSchema], Dict[str, ColumnSchema]]] = Field(
600 description="Dict or Ordered Dict of {column_name: ColumnSchemas} describing the metadata columns available for the table"
601 )
602 example_questions: Optional[List[LLMExample]] = Field(
603 default=None, description="Example questions where this table is useful."
604 )
605 join: str = Field(
606 description="SQL join string to join this table with source documents table",
607 default="",
608 )
609 max_filters: Optional[int] = Field(default=1, description="Maximum number of filters to generate for this table.")
610 filter_threshold: Optional[float] = Field(
611 default=0.0,
612 description="Minimum relevance required to use this table to generate filters.",
613 )
614 priority: Optional[int] = Field(
615 default=1,
616 description="Priority level for this table, lower numbers will be processed first.",
617 )
618 relevance: Optional[float] = Field(
619 default=None,
620 description="Relevance computed during search. Should not be set by the end user.",
621 )
624class DatabaseSchema(BaseModel):
625 database: str = Field(description="Name of database in the Database")
626 description: str = Field(description="Description of what the Database represents")
627 usage: str = Field(description="How and when to use this Database for search.")
628 tables: Union[OrderedDict[str, TableSchema], Dict[str, TableSchema]] = Field(
629 description="Dict of {column_name: ColumnSchemas} describing the metadata columns available for the table"
630 )
631 example_questions: Optional[List[LLMExample]] = Field(
632 default=None, description="Example questions where this Database is useful."
633 )
634 max_filters: Optional[int] = Field(
635 default=1,
636 description="Maximum number of filters to generate for this Database.",
637 )
638 filter_threshold: Optional[float] = Field(
639 default=0.0,
640 description="Minimum relevance required to use this Database to generate filters.",
641 )
642 priority: Optional[int] = Field(
643 default=0,
644 description="Priority level for this Database, lower numbers will be processed first.",
645 )
646 relevance: Optional[float] = Field(
647 default=None,
648 description="Relevance computed during search. Should not be set by the end user.",
649 )
652class SQLRetrieverConfig(BaseModel):
653 llm_config: LLMConfig = Field(
654 default_factory=LLMConfig,
655 description="LLM configuration to use for generating the final SQL query for retrieval",
656 )
657 metadata_filters_prompt_template: str = Field(
658 default=DEFAULT_METADATA_FILTERS_PROMPT_TEMPLATE,
659 description="Prompt template to generate PostgreSQL metadata filters. Has 'format_instructions', 'schema', 'examples', and 'input' input variables",
660 )
661 num_retries: int = Field(
662 default=DEFAULT_NUM_QUERY_RETRIES,
663 description="How many times for an LLM to try rewriting a failed SQL query before using the fallback retriever.",
664 )
665 rewrite_prompt_template: str = Field(
666 default=DEFAULT_SEMANTIC_PROMPT_TEMPLATE,
667 description="Prompt template to rewrite user input to be better suited for retrieval. Has 'input' input variable.",
668 )
669 table_prompt_template: str = Field(
670 default=DEFAULT_TABLE_PROMPT_TEMPLATE,
671 description="Prompt template to rewrite user input to be better suited for retrieval. Has 'input' input variable.",
672 )
673 column_prompt_template: str = Field(
674 default=DEFAULT_COLUMN_PROMPT_TEMPLATE,
675 description="Prompt template to rewrite user input to be better suited for retrieval. Has 'input' input variable.",
676 )
677 value_prompt_template: str = Field(
678 default=DEFAULT_VALUE_PROMPT_TEMPLATE,
679 description="Prompt template to rewrite user input to be better suited for retrieval. Has 'input' input variable.",
680 )
681 boolean_system_prompt: str = Field(
682 default=DEFAULT_BOOLEAN_PROMPT_TEMPLATE,
683 description="Prompt template to rewrite user input to be better suited for retrieval. Has 'input' input variable.",
684 )
685 generative_system_prompt: str = Field(
686 default=DEFAULT_GENERATIVE_SYSTEM_PROMPT,
687 description="Prompt template to rewrite user input to be better suited for retrieval. Has 'input' input variable.",
688 )
689 source_table: str = Field(
690 description="Name of the source table containing the original documents that were embedded"
691 )
692 source_id_column: str = Field(description="Name of the column containing the UUID.", default="Id")
693 max_filters: Optional[int] = Field(description="Maximum number of filters to generate for sql queries.", default=10)
694 filter_threshold: Optional[float] = Field(
695 description="Minimum relevance required to use this Database to generate filters.",
696 default=0.0,
697 )
698 min_k: Optional[int] = Field(
699 description="Minimum number of documents accepted from a generated sql query.",
700 default=10,
701 )
702 database_schema: Optional[DatabaseSchema] = Field(
703 default=None,
704 description="DatabaseSchema describing the database.",
705 )
706 examples: Optional[List[LLMExample]] = Field(
707 default=None,
708 description="Optional examples of final generated pgvector queries based on user input.",
709 )
712class SummarizationConfig(BaseModel):
713 llm_config: LLMConfig = Field(
714 default_factory=LLMConfig,
715 description="LLM configuration to use for summarization",
716 )
717 map_prompt_template: str = Field(
718 default=DEFAULT_MAP_PROMPT_TEMPLATE,
719 description="Prompt for an LLM to summarize a single document",
720 )
721 reduce_prompt_template: str = Field(
722 default=DEFAULT_REDUCE_PROMPT_TEMPLATE,
723 description="Prompt for an LLM to summarize a set of summaries of documents into one",
724 )
725 max_summarization_tokens: int = Field(
726 default=DEFAULT_MAX_SUMMARIZATION_TOKENS,
727 description="Max number of tokens for summarized documents",
728 )
731class RerankerMode(str, Enum):
732 POINTWISE = "pointwise"
733 LISTWISE = "listwise"
735 @classmethod
736 def _missing_(cls, value):
737 if isinstance(value, str):
738 value = value.lower()
739 for member in cls:
740 if member.value == value:
741 return member
742 return None
745class RerankerConfig(BaseModel):
746 model: str = DEFAULT_RERANKING_MODEL
747 base_url: str = DEFAULT_LLM_ENDPOINT
748 filtering_threshold: float = 0.5
749 num_docs_to_keep: Optional[int] = None
750 mode: RerankerMode = Field(
751 default=RerankerMode.POINTWISE,
752 description="Reranking mode to use. 'pointwise' for individual scoring, '"
753 "listwise' for joint scoring of all documents.",
754 )
755 max_concurrent_requests: int = 20
756 max_retries: int = 3
757 retry_delay: float = 1.0
758 early_stop: bool = True # Whether to enable early stopping
759 early_stop_threshold: float = 0.8 # Confidence threshold for early stopping
760 n: int = DEFAULT_RERANKER_N # Number of completions to generate
761 logprobs: bool = DEFAULT_RERANKER_LOGPROBS # Whether to include log probabilities
762 top_logprobs: int = DEFAULT_RERANKER_TOP_LOGPROBS # Number of top log probabilities to include
763 max_tokens: int = DEFAULT_RERANKER_MAX_TOKENS # Maximum tokens to generate
764 valid_class_tokens: List[str] = DEFAULT_VALID_CLASS_TOKENS # Valid class tokens to look for in the response
767class MultiHopRetrieverConfig(BaseModel):
768 """Configuration for multi-hop retrieval"""
770 base_retriever_type: RetrieverType = Field(
771 default=RetrieverType.VECTOR_STORE,
772 description="Type of base retriever to use for multi-hop retrieval",
773 )
774 max_hops: int = Field(default=3, description="Maximum number of follow-up questions to generate", ge=1)
775 reformulation_template: str = Field(
776 default=DEFAULT_QUESTION_REFORMULATION_TEMPLATE,
777 description="Template for reformulating questions",
778 )
779 llm_config: LLMConfig = Field(
780 default_factory=LLMConfig,
781 description="LLM configuration to use for generating follow-up questions",
782 )
785class RAGPipelineModel(BaseModel):
786 documents: Optional[List[Document]] = Field(default=None, description="List of documents")
788 vector_store_config: VectorStoreConfig = Field(
789 default_factory=VectorStoreConfig, description="Vector store configuration"
790 )
792 llm: Optional[BaseChatModel] = Field(default=None, description="Language model")
793 llm_model_name: str = Field(default=DEFAULT_LLM_MODEL, description="Language model name")
794 llm_provider: Optional[str] = Field(default=None, description="Language model provider")
795 vector_store: VectorStore = Field(
796 default_factory=_default_vector_store_factory,
797 description="Vector store",
798 )
799 db_connection_string: Optional[str] = Field(default=None, description="Database connection string")
800 metadata_config: Optional[MetadataConfig] = Field(
801 default=None, description="Configuration for metadata to be used for retrieval"
802 )
803 table_name: str = Field(default=DEFAULT_TEST_TABLE_NAME, description="Table name")
804 embedding_model: Optional[Embeddings] = Field(default=None, description="Embedding model")
805 rag_prompt_template: str = Field(default=DEFAULT_RAG_PROMPT_TEMPLATE, description="RAG prompt template")
806 retriever_prompt_template: Optional[Union[str, dict]] = Field(default=None, description="Retriever prompt template")
807 retriever_type: RetrieverType = Field(default=RetrieverType.VECTOR_STORE, description="Retriever type")
808 search_type: SearchType = Field(default=SearchType.SIMILARITY, description="Type of search to perform")
809 search_kwargs: SearchKwargs = Field(
810 default_factory=SearchKwargs,
811 description="Search configuration for the retriever",
812 )
813 summarization_config: Optional[SummarizationConfig] = Field(
814 default=None,
815 description="Configuration for summarizing retrieved documents as context",
816 )
817 # SQL retriever specific.
818 sql_retriever_config: Optional[SQLRetrieverConfig] = Field(
819 default=None,
820 description="Configuration for retrieving documents by generating SQL to filter by metadata & order by distance function",
821 )
823 # Multi retriever specific
824 multi_retriever_mode: MultiVectorRetrieverMode = Field(
825 default=MultiVectorRetrieverMode.BOTH, description="Multi retriever mode"
826 )
827 max_concurrency: int = Field(default=DEFAULT_MAX_CONCURRENCY, description="Maximum concurrency")
828 id_key: int = Field(default=DEFAULT_ID_KEY, description="ID key")
829 parent_store: Optional[BaseStore] = Field(default=None, description="Parent store")
830 text_splitter: Optional[TextSplitter] = Field(default=None, description="Text splitter")
831 chunk_size: int = Field(default=DEFAULT_CHUNK_SIZE, description="Chunk size")
832 chunk_overlap: int = Field(default=DEFAULT_CHUNK_OVERLAP, description="Chunk overlap")
834 # Auto retriever specific
835 auto_retriever_filter_columns: Optional[List[str]] = Field(default=None, description="Filter columns")
836 cardinality_threshold: int = Field(default=DEFAULT_CARDINALITY_THRESHOLD, description="Cardinality threshold")
837 content_column_name: str = Field(
838 default=DEFAULT_CONTENT_COLUMN_NAME,
839 description="Content column name (the column we will get embeddings)",
840 )
841 dataset_description: str = Field(default=DEFAULT_DATASET_DESCRIPTION, description="Description of the dataset")
842 reranker: bool = Field(default=DEFAULT_RERANKER_FLAG, description="Whether to use reranker")
843 reranker_config: RerankerConfig = Field(default_factory=RerankerConfig, description="Reranker configuration")
845 multi_hop_config: Optional[MultiHopRetrieverConfig] = Field(
846 default=None,
847 description="Configuration for multi-hop retrieval. Required when retriever_type is MULTI_HOP.",
848 )
850 @field_validator("multi_hop_config")
851 @classmethod
852 def validate_multi_hop_config(cls, v: Optional[MultiHopRetrieverConfig], info):
853 """Validate that multi_hop_config is set when using multi-hop retrieval."""
854 values = info.data
855 if values.get("retriever_type") == RetrieverType.MULTI_HOP and v is None:
856 raise ValueError("multi_hop_config must be set when using multi-hop retrieval")
857 return v
859 class Config:
860 arbitrary_types_allowed = True
861 extra = "forbid"
863 json_schema_extra = {
864 "example": {
865 "retriever_type": RetrieverType.VECTOR_STORE.value,
866 "multi_retriever_mode": MultiVectorRetrieverMode.BOTH.value,
867 # add more examples here
868 }
869 }
871 @classmethod
872 def get_field_names(cls):
873 return list(cls.model_fields.keys())
875 @field_validator("search_kwargs")
876 @classmethod
877 def validate_search_kwargs(cls, v: SearchKwargs, info) -> SearchKwargs:
878 search_type = info.data.get("search_type", SearchType.SIMILARITY)
880 # Validate MMR-specific parameters
881 if search_type == SearchType.MMR: 881 ↛ 882line 881 didn't jump to line 882 because the condition on line 881 was never true
882 if v.fetch_k is not None and v.fetch_k <= v.k:
883 raise ValueError("fetch_k must be greater than k")
884 if v.lambda_mult is not None and (v.lambda_mult < 0 or v.lambda_mult > 1):
885 raise ValueError("lambda_mult must be between 0 and 1")
886 if v.fetch_k is None and v.lambda_mult is not None:
887 raise ValueError("fetch_k is required when using lambda_mult with MMR search type")
888 if v.lambda_mult is None and v.fetch_k is not None:
889 raise ValueError("lambda_mult is required when using fetch_k with MMR search type")
890 elif search_type != SearchType.MMR: 890 ↛ 897line 890 didn't jump to line 897 because the condition on line 890 was always true
891 if v.fetch_k is not None: 891 ↛ 892line 891 didn't jump to line 892 because the condition on line 891 was never true
892 raise ValueError("fetch_k is only valid for MMR search type")
893 if v.lambda_mult is not None: 893 ↛ 894line 893 didn't jump to line 894 because the condition on line 893 was never true
894 raise ValueError("lambda_mult is only valid for MMR search type")
896 # Validate similarity_score_threshold parameters
897 if search_type == SearchType.SIMILARITY_SCORE_THRESHOLD: 897 ↛ 898line 897 didn't jump to line 898 because the condition on line 897 was never true
898 if v.score_threshold is not None and (v.score_threshold < 0 or v.score_threshold > 1):
899 raise ValueError("score_threshold must be between 0 and 1")
900 if v.score_threshold is None:
901 raise ValueError("score_threshold is required for similarity_score_threshold search type")
902 elif search_type != SearchType.SIMILARITY_SCORE_THRESHOLD and v.score_threshold is not None: 902 ↛ 903line 902 didn't jump to line 903 because the condition on line 902 was never true
903 raise ValueError("score_threshold is only valid for similarity_score_threshold search type")
905 return v