Coverage for mindsdb / interfaces / agents / mindsdb_chat_model.py: 25%
100 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from __future__ import annotations
3import logging
4from typing import (
5 Any,
6 Dict,
7 List,
8 Mapping,
9 Optional,
10)
12import pandas as pd
13from langchain_core.callbacks import (
14 CallbackManagerForLLMRun,
15)
16from langchain_core.language_models.chat_models import (
17 BaseChatModel,
18)
19from langchain_core.messages import (
20 AIMessage,
21 BaseMessage,
22 ChatMessage,
23 FunctionMessage,
24 HumanMessage,
25 SystemMessage,
26)
27from langchain_core.outputs import (
28 ChatGeneration,
29 ChatResult,
30)
31from pydantic import model_validator
33from mindsdb.interfaces.agents.constants import USER_COLUMN
34from mindsdb.utilities.config import config
36logger = logging.getLogger(__name__)
37default_project = config.get('default_project')
40def _convert_message_to_dict(message: BaseMessage) -> dict:
41 if isinstance(message, ChatMessage):
42 message_dict = {"role": message.role, "content": message.content}
43 elif isinstance(message, HumanMessage):
44 message_dict = {"role": "user", "content": message.content}
45 elif isinstance(message, AIMessage):
46 message_dict = {"role": "assistant", "content": message.content}
47 if "function_call" in message.additional_kwargs:
48 message_dict["function_call"] = message.additional_kwargs["function_call"]
49 elif isinstance(message, SystemMessage):
50 message_dict = {"role": "system", "content": message.content}
51 elif isinstance(message, FunctionMessage):
52 message_dict = {
53 "role": "function",
54 "content": message.content,
55 "name": message.name,
56 }
57 else:
58 raise ValueError(f"Got unknown type {message}")
59 if "name" in message.additional_kwargs:
60 message_dict["name"] = message.additional_kwargs["name"]
61 return message_dict
64class ChatMindsdb(BaseChatModel):
65 """A chat model that uses the Mindsdb"""
67 model_name: str
68 project_name: Optional[str] = default_project
69 model_info: Optional[dict] = None
70 project_datanode: Optional[Any] = None
72 class Config:
73 """Configuration for this pydantic object."""
74 arbitrary_types_allowed = True
75 allow_reuse = True
77 @property
78 def _default_params(self) -> Dict[str, Any]:
79 return {}
81 def completion(
82 self, messages: List[dict]
83 ) -> Any:
84 problem_definition = self.model_info['problem_definition'].get('using', {})
85 output_col = self.model_info['predict']
87 # TODO create table for conversational model?
88 if len(messages) > 1:
89 content = '\n'.join([
90 f"{m['role']}: {m['content']}"
91 for m in messages
92 ])
93 else:
94 content = messages[0]['content']
96 record = {}
97 params = {}
98 # Default to conversational if not set.
99 mode = problem_definition.get('mode', 'conversational')
100 if mode == 'conversational' or mode == 'retrieval':
101 # flag for langchain to prevent calling agent inside of agent
102 if self.model_info['engine'] == 'langchain':
103 params['mode'] = 'chat_model'
105 user_column = problem_definition.get('user_column', USER_COLUMN)
106 record[user_column] = content
108 elif 'column' in problem_definition:
109 # input defined as 'column' param
110 record[problem_definition['column']] = content
112 else:
113 # failback, maybe handler supports template injection
114 params['prompt_template'] = content
116 predictions = self.project_datanode.predict(
117 model_name=self.model_name,
118 df=pd.DataFrame([record]),
119 params=params,
120 )
122 col = output_col
123 if col not in predictions.columns:
124 # get first column
125 col = predictions.columns[0]
127 # get first row
128 result = predictions[col][0]
130 # TODO token calculation
131 return {
132 'messages': [result]
133 }
135 @model_validator(mode='before')
136 def validate_environment(cls, values: Dict) -> Dict:
138 model_name = values['model_name']
139 project_name = values['project_name']
141 from mindsdb.api.executor.controllers import SessionController
143 session = SessionController()
144 session.database = default_project
146 values['model_info'] = session.model_controller.get_model(model_name, project_name=project_name)
148 project_datanode = session.datahub.get(values['project_name'])
150 values["project_datanode"] = project_datanode
152 return values
154 def _generate(
155 self,
156 messages: List[BaseMessage],
157 stop: Optional[List[str]] = None,
158 run_manager: Optional[CallbackManagerForLLMRun] = None,
159 stream: Optional[bool] = None,
160 **kwargs: Any,
161 ) -> ChatResult:
163 message_dicts = [_convert_message_to_dict(m) for m in messages]
165 response = self.completion(
166 messages=message_dicts
167 )
168 return self._create_chat_result(response)
170 def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
171 generations = []
172 for content in response["messages"]:
173 message = AIMessage(content=content)
174 gen = ChatGeneration(
175 message=message,
176 generation_info=dict(finish_reason=None),
177 )
178 generations.append(gen)
179 token_usage = response.get("usage", {})
180 set_model_value = self.model_name
181 if self.model_name is not None:
182 set_model_value = self.model_name
183 llm_output = {"token_usage": token_usage, "model": set_model_value}
184 return ChatResult(generations=generations, llm_output=llm_output)
186 @property
187 def _identifying_params(self) -> Dict[str, Any]:
188 """Get the identifying parameters."""
189 set_model_value = self.model_name
190 if self.model_name is not None:
191 set_model_value = self.model_name
192 return {
193 "model_name": set_model_value,
194 }
196 @property
197 def _llm_type(self) -> str:
198 return "mindsdb"