Coverage for mindsdb / integrations / utilities / rag / retrievers / auto_retriever.py: 25%

44 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import List 

2import json 

3 

4 

5from langchain.retrievers.self_query.base import SelfQueryRetriever 

6 

7import pandas as pd 

8 

9from mindsdb.integrations.utilities.rag.retrievers.base import BaseRetriever 

10 

11from mindsdb.integrations.utilities.rag.utils import documents_to_df 

12from mindsdb.integrations.utilities.rag.vector_store import VectorStoreOperator 

13 

14from mindsdb.integrations.utilities.rag.settings import RAGPipelineModel 

15 

16 

17class AutoRetriever(BaseRetriever): 

18 """ 

19 AutoRetrieval is a class that uses langchain to extract metadata from a List of Document and query it using self retrievers. 

20 

21 """ 

22 

23 def __init__( 

24 self, 

25 config: RAGPipelineModel 

26 ): 

27 """ 

28 

29 :param config: RAGPipelineModel 

30 

31 

32 """ 

33 

34 self.documents = config.documents 

35 self.content_column_name = config.content_column_name 

36 self.vectorstore = config.vector_store 

37 self.filter_columns = config.auto_retriever_filter_columns 

38 self.document_description = config.dataset_description 

39 self.llm = config.llm 

40 self.embedding_model = config.embedding_model 

41 self.prompt_template = config.retriever_prompt_template 

42 self.cardinality_threshold = config.cardinality_threshold 

43 

44 def _get_low_cardinality_columns(self, data: pd.DataFrame): 

45 """ 

46 Given a dataframe, return a list of columns with low cardinality if datatype is not bool. 

47 :return: 

48 """ 

49 low_cardinality_columns = [] 

50 columns = data.columns if self.filter_columns is None else self.filter_columns 

51 for column in columns: 

52 if data[column].dtype != "bool": 

53 if data[column].nunique() < self.cardinality_threshold: 

54 low_cardinality_columns.append(column) 

55 return low_cardinality_columns 

56 

57 def get_metadata_field_info(self): 

58 """ 

59 Given a list of Document, use llm to extract metadata from it. 

60 :return: 

61 """ 

62 

63 def _alter_description(data: pd.DataFrame, 

64 low_cardinality_columns: list, 

65 result: List[dict]): 

66 """ 

67 For low cardinality columns, alter the description to include the sorted valid values. 

68 :param data: pd.DataFrame 

69 :param low_cardinality_columns: list 

70 :param result: List[dict] 

71 """ 

72 for column_name in low_cardinality_columns: 

73 valid_values = sorted(data[column_name].unique()) 

74 for entry in result: 

75 if entry["name"] == column_name: 

76 entry["description"] += f". Valid values: {valid_values}" 

77 

78 data = documents_to_df( 

79 self.content_column_name, 

80 self.documents 

81 ) 

82 

83 prompt = self.prompt_template.format(dataframe=data.head().to_json(), 

84 description=self.document_description) 

85 result: List[dict] = json.loads(self.llm.invoke(input=prompt).content) 

86 

87 _alter_description( 

88 data, 

89 self._get_low_cardinality_columns(data), 

90 result 

91 ) 

92 

93 return result 

94 

95 def get_vectorstore(self): 

96 """ 

97 

98 :return: 

99 """ 

100 return VectorStoreOperator(vector_store=self.vectorstore, 

101 documents=self.documents, 

102 embedding_model=self.embedding_model).vector_store 

103 

104 def as_runnable(self) -> BaseRetriever: 

105 """ 

106 return the self-query retriever 

107 :return: 

108 """ 

109 vectorstore = self.get_vectorstore() 

110 

111 return SelfQueryRetriever.from_llm( 

112 llm=self.llm, 

113 vectorstore=vectorstore, 

114 document_contents=self.document_description, 

115 metadata_field_info=self.get_metadata_field_info(), 

116 verbose=True 

117 )