Coverage for mindsdb / integrations / handlers / llama_index_handler / llama_index_handler.py: 0%

116 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import Optional, Dict 

2 

3import pandas as pd 

4from llama_index.llms.openai import OpenAI 

5from llama_index.core import Document 

6from llama_index.readers.web import SimpleWebPageReader 

7from llama_index.core import PromptTemplate 

8from llama_index.core import StorageContext, load_index_from_storage 

9from llama_index.embeddings.openai import OpenAIEmbedding 

10from llama_index.core import VectorStoreIndex 

11from llama_index.core import Settings 

12 

13from mindsdb.integrations.libs.base import BaseMLEngine 

14from mindsdb.utilities.config import Config 

15from mindsdb.utilities.security import validate_urls 

16from mindsdb.integrations.handlers.llama_index_handler.settings import llama_index_config, LlamaIndexModel 

17from mindsdb.integrations.libs.api_handler_exceptions import MissingConnectionParams 

18from mindsdb.integrations.utilities.handler_utils import get_api_key 

19 

20 

21class LlamaIndexHandler(BaseMLEngine): 

22 """Integration with the LlamaIndex data framework for LLM applications.""" 

23 

24 name = "llama_index" 

25 

26 def __init__(self, *args, **kwargs): 

27 super().__init__(*args, **kwargs) 

28 self.generative = True 

29 self.default_index_class = llama_index_config.DEFAULT_INDEX_CLASS 

30 self.supported_index_class = llama_index_config.SUPPORTED_INDEXES 

31 self.default_reader = llama_index_config.DEFAULT_READER 

32 self.supported_reader = llama_index_config.SUPPORTED_READERS 

33 self.config = Config() 

34 

35 @staticmethod 

36 def create_validation(target, args=None, **kwargs): 

37 if "using" not in args: 

38 raise MissingConnectionParams("LlamaIndex engine requires USING clause!") 

39 else: 

40 args = args["using"] 

41 LlamaIndexModel(**args) 

42 

43 def create( 

44 self, 

45 target: str, 

46 df: Optional[pd.DataFrame] = None, 

47 args: Optional[Dict] = None, 

48 ) -> None: 

49 # workaround to create llama model without input data 

50 if df is None or df.empty: 

51 df = pd.DataFrame([{"text": ""}]) 

52 

53 args_reader = args.get("using", {}).get("reader", self.default_reader) 

54 

55 if args_reader == "DFReader": 

56 dstrs = df.apply( 

57 lambda x: ", ".join([f"{col}: {str(entry)}" for col, entry in zip(df.columns, x)]), 

58 axis=1, 

59 ) 

60 reader = list(map(lambda x: Document(text=x), dstrs.tolist())) 

61 elif args_reader == "SimpleWebPageReader": 

62 url = args["using"]["source_url_link"] 

63 allowed_urls = self.config.get("web_crawling_allowed_sites", []) 

64 if allowed_urls and not validate_urls(url, allowed_urls): 

65 raise ValueError( 

66 f"The provided URL is not allowed for web crawling. Please use any of {', '.join(allowed_urls)}." 

67 ) 

68 reader = SimpleWebPageReader(html_to_text=True).load_data([url]) 

69 else: 

70 raise Exception(f"Invalid operation mode. Please use one of {self.supported_reader}.") 

71 self.model_storage.json_set("args", args) 

72 index = self._setup_index(reader) 

73 path = self.model_storage.folder_get("context") 

74 index.storage_context.persist(persist_dir=path) 

75 self.model_storage.folder_sync("context") 

76 

77 def update(self, args) -> None: 

78 args_cur = self.model_storage.json_get("args") 

79 args_cur["using"].update(args["using"]) 

80 

81 # check new set of arguments 

82 self.create_validation(None, args_cur) 

83 

84 self.model_storage.json_set("args", args_cur) 

85 

86 def predict(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> pd.DataFrame: 

87 pred_args = args["predict_params"] if args else {} 

88 

89 args = self.model_storage.json_get("args") 

90 engine_kwargs = {} 

91 

92 if args["using"].get("mode") == "conversational": 

93 user_column = args["using"]["user_column"] 

94 assistant_column = args["using"]["assistant_column"] 

95 

96 messages = [] 

97 for row in df[:-1].to_dict("records"): 

98 messages.append(f"user: {row[user_column]}") 

99 messages.append(f"assistant: {row[assistant_column]}") 

100 

101 conversation = "\n".join(messages) 

102 

103 questions = [df.iloc[-1][user_column]] 

104 

105 if "prompt" in pred_args and pred_args["prompt"] is not None: 

106 user_prompt = pred_args["prompt"] 

107 else: 

108 user_prompt = args["using"].get("prompt", "") 

109 

110 prompt_template = ( 

111 f"{user_prompt}\n" 

112 f"---------------------\n" 

113 f"We have provided context information below. \n" 

114 f"{{context_str}}\n" 

115 f"---------------------\n" 

116 f"This is previous conversation history:\n" 

117 f"{conversation}\n" 

118 f"---------------------\n" 

119 f"Given this information, please answer the question: {{query_str}}" 

120 ) 

121 

122 engine_kwargs["text_qa_template"] = PromptTemplate(prompt_template) 

123 

124 else: 

125 input_column = args["using"].get("input_column", None) 

126 

127 prompt_template = args["using"].get("prompt_template", args.get("prompt_template", None)) 

128 if prompt_template is not None: 

129 self.create_validation(args=args) 

130 engine_kwargs["text_qa_template"] = PromptTemplate(prompt_template) 

131 

132 if input_column is None: 

133 raise Exception( 

134 "`input_column` must be provided at model creation time or through USING clause when predicting. Please try again." 

135 ) # noqa 

136 

137 if input_column not in df.columns: 

138 raise Exception(f'Column "{input_column}" not found in input data! Please try again.') 

139 

140 questions = df[input_column] 

141 

142 index_path = self.model_storage.folder_get("context") 

143 storage_context = StorageContext.from_defaults(persist_dir=index_path) 

144 self._get_service_context() 

145 

146 index = load_index_from_storage(storage_context) 

147 query_engine = index.as_query_engine(**engine_kwargs) 

148 

149 results = [] 

150 

151 for question in questions: 

152 query_results = query_engine.query(question) # TODO: provide extra_info in explain_target col 

153 results.append(query_results.response) 

154 

155 result_df = pd.DataFrame({"question": questions, args["target"]: results}) # result_df['answer'].tolist() 

156 return result_df 

157 

158 def _get_service_context(self) -> None: 

159 args = self.model_storage.json_get("args") 

160 engine_storage = self.engine_storage 

161 openai_api_key = get_api_key("openai", args["using"], engine_storage, strict=True) 

162 llm_kwargs = {"api_key": openai_api_key} 

163 

164 if "temperature" in args["using"]: 

165 llm_kwargs["temperature"] = args["using"]["temperature"] 

166 if "model_name" in args["using"]: 

167 llm_kwargs["model_name"] = args["using"]["model_name"] 

168 if "max_tokens" in args["using"]: 

169 llm_kwargs["max_tokens"] = args["using"]["max_tokens"] 

170 # only way this works is by sending the key through openai 

171 

172 if Settings.llm is None: 

173 llm = OpenAI(api_key=openai_api_key) 

174 Settings.llm = llm 

175 if Settings.embed_model is None: 

176 embed_model = OpenAIEmbedding() 

177 Settings.embed_model = embed_model 

178 # TODO: all usual params should be added to Settings 

179 

180 def _setup_index(self, documents): 

181 self._get_service_context() 

182 index = VectorStoreIndex.from_documents(documents) 

183 return index