Coverage for mindsdb / integrations / handlers / rag_handler / rag.py: 0%

72 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import json 

2from collections import defaultdict 

3from typing import List 

4 

5from openai.types import Completion 

6 

7from mindsdb.integrations.handlers.rag_handler.settings import ( 

8 LLMLoader, 

9 PersistedVectorStoreLoader, 

10 PersistedVectorStoreLoaderConfig, 

11 RAGBaseParameters, 

12 RAGHandlerParameters, 

13 load_embeddings_model, 

14) 

15from mindsdb.utilities import log 

16 

17logger = log.getLogger(__name__) 

18 

19 

20class RAGQuestionAnswerer: 

21 """A class for using a RAG model for question answering""" 

22 

23 def __init__(self, args: RAGBaseParameters): 

24 

25 self.output_data = defaultdict(list) 

26 

27 self.args = args 

28 

29 self.embeddings_model = args.embeddings_model 

30 if self.embeddings_model is None: 

31 self.embeddings_model = load_embeddings_model(args.embeddings_model_name) 

32 

33 self.persist_directory = args.vector_store_storage_path 

34 

35 self.collection_name = args.collection_name 

36 

37 vector_store_config = PersistedVectorStoreLoaderConfig( 

38 vector_store_name=args.vector_store_name, 

39 embeddings_model=self.embeddings_model, 

40 persist_directory=self.persist_directory, 

41 collection_name=self.collection_name, 

42 ) 

43 

44 self.vector_store_loader = PersistedVectorStoreLoader(vector_store_config) 

45 

46 self.persisted_vector_store = self.vector_store_loader.load_vector_store() 

47 

48 self.prompt_template = args.prompt_template 

49 

50 if isinstance(args, RAGHandlerParameters): 

51 

52 llm_config = {"llm_config": args.llm_params.model_dump()} 

53 

54 llm_loader = LLMLoader(**llm_config) 

55 

56 self.llm = llm_loader.load_llm() 

57 

58 def __call__(self, question: str) -> defaultdict: 

59 return self.query(question) 

60 

61 def _prepare_prompt(self, vector_store_response, question) -> str: 

62 

63 context = [doc.page_content for doc in vector_store_response] 

64 

65 combined_context = "\n\n".join(context) 

66 

67 if self.args.summarize_context: 

68 return self.summarize_context(combined_context, question) 

69 

70 return self.prompt_template.format(question=question, context=combined_context) 

71 

72 def summarize_context(self, combined_context: str, question: str) -> str: 

73 

74 summarization_prompt_template = self.args.summarization_prompt_template 

75 

76 summarization_prompt = summarization_prompt_template.format( 

77 context=combined_context, question=question 

78 ) 

79 

80 summarized_context = self.llm(prompt=summarization_prompt) 

81 

82 return self.prompt_template.format( 

83 question=question, context=self.extract_generated_text(summarized_context) 

84 ) 

85 

86 def query_vector_store(self, question: str) -> List: 

87 

88 return self.persisted_vector_store.similarity_search( 

89 query=question, 

90 k=self.args.top_k, 

91 ) 

92 

93 @staticmethod 

94 def extract_generated_text(response: str) -> str: 

95 """Extract generated text from LLM response""" 

96 

97 if isinstance(response, str): 

98 data = json.loads(response) 

99 else: 

100 data = response 

101 

102 try: 

103 if "choices" in data: 

104 return data["choices"][0]["text"] 

105 elif isinstance(data, Completion): 

106 return data.choices[0].text 

107 else: 

108 logger.info( 

109 f"Error extracting generated text: failed to parse response {response}" 

110 ) 

111 return response 

112 

113 except Exception as e: 

114 raise Exception( 

115 f"{e} Error extracting generated text: failed to parse response {response}" 

116 ) 

117 

118 def query(self, question: str) -> defaultdict: 

119 """Post process LLM response""" 

120 llm_response, vector_store_response = self._query(question) 

121 

122 result = defaultdict(list) 

123 extracted_text = self.extract_generated_text(llm_response) 

124 result["answer"].append(extracted_text) 

125 

126 sources = defaultdict(list) 

127 

128 for idx, document in enumerate(vector_store_response): 

129 sources["sources_content"].append(document.page_content) 

130 sources["sources_document"].append(document.metadata.get("source", None)) 

131 sources["column"].append(document.metadata.get("column", None)) 

132 sources["sources_row"].append(document.metadata.get("row", None)) 

133 

134 result["source_documents"].append(dict(sources)) 

135 

136 return result 

137 

138 def _query(self, question: str): 

139 logger.debug(f"Querying: {question}") 

140 

141 vector_store_response = self.query_vector_store(question) 

142 

143 formatted_prompt = self._prepare_prompt(vector_store_response, question) 

144 

145 llm_response = self.llm(prompt=formatted_prompt) 

146 

147 return llm_response, vector_store_response