Coverage for mindsdb / integrations / handlers / llama_index_handler / llama_index_handler.py: 0%
116 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import Optional, Dict
3import pandas as pd
4from llama_index.llms.openai import OpenAI
5from llama_index.core import Document
6from llama_index.readers.web import SimpleWebPageReader
7from llama_index.core import PromptTemplate
8from llama_index.core import StorageContext, load_index_from_storage
9from llama_index.embeddings.openai import OpenAIEmbedding
10from llama_index.core import VectorStoreIndex
11from llama_index.core import Settings
13from mindsdb.integrations.libs.base import BaseMLEngine
14from mindsdb.utilities.config import Config
15from mindsdb.utilities.security import validate_urls
16from mindsdb.integrations.handlers.llama_index_handler.settings import llama_index_config, LlamaIndexModel
17from mindsdb.integrations.libs.api_handler_exceptions import MissingConnectionParams
18from mindsdb.integrations.utilities.handler_utils import get_api_key
21class LlamaIndexHandler(BaseMLEngine):
22 """Integration with the LlamaIndex data framework for LLM applications."""
24 name = "llama_index"
26 def __init__(self, *args, **kwargs):
27 super().__init__(*args, **kwargs)
28 self.generative = True
29 self.default_index_class = llama_index_config.DEFAULT_INDEX_CLASS
30 self.supported_index_class = llama_index_config.SUPPORTED_INDEXES
31 self.default_reader = llama_index_config.DEFAULT_READER
32 self.supported_reader = llama_index_config.SUPPORTED_READERS
33 self.config = Config()
35 @staticmethod
36 def create_validation(target, args=None, **kwargs):
37 if "using" not in args:
38 raise MissingConnectionParams("LlamaIndex engine requires USING clause!")
39 else:
40 args = args["using"]
41 LlamaIndexModel(**args)
43 def create(
44 self,
45 target: str,
46 df: Optional[pd.DataFrame] = None,
47 args: Optional[Dict] = None,
48 ) -> None:
49 # workaround to create llama model without input data
50 if df is None or df.empty:
51 df = pd.DataFrame([{"text": ""}])
53 args_reader = args.get("using", {}).get("reader", self.default_reader)
55 if args_reader == "DFReader":
56 dstrs = df.apply(
57 lambda x: ", ".join([f"{col}: {str(entry)}" for col, entry in zip(df.columns, x)]),
58 axis=1,
59 )
60 reader = list(map(lambda x: Document(text=x), dstrs.tolist()))
61 elif args_reader == "SimpleWebPageReader":
62 url = args["using"]["source_url_link"]
63 allowed_urls = self.config.get("web_crawling_allowed_sites", [])
64 if allowed_urls and not validate_urls(url, allowed_urls):
65 raise ValueError(
66 f"The provided URL is not allowed for web crawling. Please use any of {', '.join(allowed_urls)}."
67 )
68 reader = SimpleWebPageReader(html_to_text=True).load_data([url])
69 else:
70 raise Exception(f"Invalid operation mode. Please use one of {self.supported_reader}.")
71 self.model_storage.json_set("args", args)
72 index = self._setup_index(reader)
73 path = self.model_storage.folder_get("context")
74 index.storage_context.persist(persist_dir=path)
75 self.model_storage.folder_sync("context")
77 def update(self, args) -> None:
78 args_cur = self.model_storage.json_get("args")
79 args_cur["using"].update(args["using"])
81 # check new set of arguments
82 self.create_validation(None, args_cur)
84 self.model_storage.json_set("args", args_cur)
86 def predict(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> pd.DataFrame:
87 pred_args = args["predict_params"] if args else {}
89 args = self.model_storage.json_get("args")
90 engine_kwargs = {}
92 if args["using"].get("mode") == "conversational":
93 user_column = args["using"]["user_column"]
94 assistant_column = args["using"]["assistant_column"]
96 messages = []
97 for row in df[:-1].to_dict("records"):
98 messages.append(f"user: {row[user_column]}")
99 messages.append(f"assistant: {row[assistant_column]}")
101 conversation = "\n".join(messages)
103 questions = [df.iloc[-1][user_column]]
105 if "prompt" in pred_args and pred_args["prompt"] is not None:
106 user_prompt = pred_args["prompt"]
107 else:
108 user_prompt = args["using"].get("prompt", "")
110 prompt_template = (
111 f"{user_prompt}\n"
112 f"---------------------\n"
113 f"We have provided context information below. \n"
114 f"{{context_str}}\n"
115 f"---------------------\n"
116 f"This is previous conversation history:\n"
117 f"{conversation}\n"
118 f"---------------------\n"
119 f"Given this information, please answer the question: {{query_str}}"
120 )
122 engine_kwargs["text_qa_template"] = PromptTemplate(prompt_template)
124 else:
125 input_column = args["using"].get("input_column", None)
127 prompt_template = args["using"].get("prompt_template", args.get("prompt_template", None))
128 if prompt_template is not None:
129 self.create_validation(args=args)
130 engine_kwargs["text_qa_template"] = PromptTemplate(prompt_template)
132 if input_column is None:
133 raise Exception(
134 "`input_column` must be provided at model creation time or through USING clause when predicting. Please try again."
135 ) # noqa
137 if input_column not in df.columns:
138 raise Exception(f'Column "{input_column}" not found in input data! Please try again.')
140 questions = df[input_column]
142 index_path = self.model_storage.folder_get("context")
143 storage_context = StorageContext.from_defaults(persist_dir=index_path)
144 self._get_service_context()
146 index = load_index_from_storage(storage_context)
147 query_engine = index.as_query_engine(**engine_kwargs)
149 results = []
151 for question in questions:
152 query_results = query_engine.query(question) # TODO: provide extra_info in explain_target col
153 results.append(query_results.response)
155 result_df = pd.DataFrame({"question": questions, args["target"]: results}) # result_df['answer'].tolist()
156 return result_df
158 def _get_service_context(self) -> None:
159 args = self.model_storage.json_get("args")
160 engine_storage = self.engine_storage
161 openai_api_key = get_api_key("openai", args["using"], engine_storage, strict=True)
162 llm_kwargs = {"api_key": openai_api_key}
164 if "temperature" in args["using"]:
165 llm_kwargs["temperature"] = args["using"]["temperature"]
166 if "model_name" in args["using"]:
167 llm_kwargs["model_name"] = args["using"]["model_name"]
168 if "max_tokens" in args["using"]:
169 llm_kwargs["max_tokens"] = args["using"]["max_tokens"]
170 # only way this works is by sending the key through openai
172 if Settings.llm is None:
173 llm = OpenAI(api_key=openai_api_key)
174 Settings.llm = llm
175 if Settings.embed_model is None:
176 embed_model = OpenAIEmbedding()
177 Settings.embed_model = embed_model
178 # TODO: all usual params should be added to Settings
180 def _setup_index(self, documents):
181 self._get_service_context()
182 index = VectorStoreIndex.from_documents(documents)
183 return index