Coverage for mindsdb / integrations / utilities / rag / retrievers / auto_retriever.py: 25%
44 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import List
2import json
5from langchain.retrievers.self_query.base import SelfQueryRetriever
7import pandas as pd
9from mindsdb.integrations.utilities.rag.retrievers.base import BaseRetriever
11from mindsdb.integrations.utilities.rag.utils import documents_to_df
12from mindsdb.integrations.utilities.rag.vector_store import VectorStoreOperator
14from mindsdb.integrations.utilities.rag.settings import RAGPipelineModel
17class AutoRetriever(BaseRetriever):
18 """
19 AutoRetrieval is a class that uses langchain to extract metadata from a List of Document and query it using self retrievers.
21 """
23 def __init__(
24 self,
25 config: RAGPipelineModel
26 ):
27 """
29 :param config: RAGPipelineModel
32 """
34 self.documents = config.documents
35 self.content_column_name = config.content_column_name
36 self.vectorstore = config.vector_store
37 self.filter_columns = config.auto_retriever_filter_columns
38 self.document_description = config.dataset_description
39 self.llm = config.llm
40 self.embedding_model = config.embedding_model
41 self.prompt_template = config.retriever_prompt_template
42 self.cardinality_threshold = config.cardinality_threshold
44 def _get_low_cardinality_columns(self, data: pd.DataFrame):
45 """
46 Given a dataframe, return a list of columns with low cardinality if datatype is not bool.
47 :return:
48 """
49 low_cardinality_columns = []
50 columns = data.columns if self.filter_columns is None else self.filter_columns
51 for column in columns:
52 if data[column].dtype != "bool":
53 if data[column].nunique() < self.cardinality_threshold:
54 low_cardinality_columns.append(column)
55 return low_cardinality_columns
57 def get_metadata_field_info(self):
58 """
59 Given a list of Document, use llm to extract metadata from it.
60 :return:
61 """
63 def _alter_description(data: pd.DataFrame,
64 low_cardinality_columns: list,
65 result: List[dict]):
66 """
67 For low cardinality columns, alter the description to include the sorted valid values.
68 :param data: pd.DataFrame
69 :param low_cardinality_columns: list
70 :param result: List[dict]
71 """
72 for column_name in low_cardinality_columns:
73 valid_values = sorted(data[column_name].unique())
74 for entry in result:
75 if entry["name"] == column_name:
76 entry["description"] += f". Valid values: {valid_values}"
78 data = documents_to_df(
79 self.content_column_name,
80 self.documents
81 )
83 prompt = self.prompt_template.format(dataframe=data.head().to_json(),
84 description=self.document_description)
85 result: List[dict] = json.loads(self.llm.invoke(input=prompt).content)
87 _alter_description(
88 data,
89 self._get_low_cardinality_columns(data),
90 result
91 )
93 return result
95 def get_vectorstore(self):
96 """
98 :return:
99 """
100 return VectorStoreOperator(vector_store=self.vectorstore,
101 documents=self.documents,
102 embedding_model=self.embedding_model).vector_store
104 def as_runnable(self) -> BaseRetriever:
105 """
106 return the self-query retriever
107 :return:
108 """
109 vectorstore = self.get_vectorstore()
111 return SelfQueryRetriever.from_llm(
112 llm=self.llm,
113 vectorstore=vectorstore,
114 document_contents=self.document_description,
115 metadata_field_info=self.get_metadata_field_info(),
116 verbose=True
117 )