Coverage for mindsdb / interfaces / agents / mindsdb_chat_model.py: 25%

100 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from __future__ import annotations 

2 

3import logging 

4from typing import ( 

5 Any, 

6 Dict, 

7 List, 

8 Mapping, 

9 Optional, 

10) 

11 

12import pandas as pd 

13from langchain_core.callbacks import ( 

14 CallbackManagerForLLMRun, 

15) 

16from langchain_core.language_models.chat_models import ( 

17 BaseChatModel, 

18) 

19from langchain_core.messages import ( 

20 AIMessage, 

21 BaseMessage, 

22 ChatMessage, 

23 FunctionMessage, 

24 HumanMessage, 

25 SystemMessage, 

26) 

27from langchain_core.outputs import ( 

28 ChatGeneration, 

29 ChatResult, 

30) 

31from pydantic import model_validator 

32 

33from mindsdb.interfaces.agents.constants import USER_COLUMN 

34from mindsdb.utilities.config import config 

35 

36logger = logging.getLogger(__name__) 

37default_project = config.get('default_project') 

38 

39 

40def _convert_message_to_dict(message: BaseMessage) -> dict: 

41 if isinstance(message, ChatMessage): 

42 message_dict = {"role": message.role, "content": message.content} 

43 elif isinstance(message, HumanMessage): 

44 message_dict = {"role": "user", "content": message.content} 

45 elif isinstance(message, AIMessage): 

46 message_dict = {"role": "assistant", "content": message.content} 

47 if "function_call" in message.additional_kwargs: 

48 message_dict["function_call"] = message.additional_kwargs["function_call"] 

49 elif isinstance(message, SystemMessage): 

50 message_dict = {"role": "system", "content": message.content} 

51 elif isinstance(message, FunctionMessage): 

52 message_dict = { 

53 "role": "function", 

54 "content": message.content, 

55 "name": message.name, 

56 } 

57 else: 

58 raise ValueError(f"Got unknown type {message}") 

59 if "name" in message.additional_kwargs: 

60 message_dict["name"] = message.additional_kwargs["name"] 

61 return message_dict 

62 

63 

64class ChatMindsdb(BaseChatModel): 

65 """A chat model that uses the Mindsdb""" 

66 

67 model_name: str 

68 project_name: Optional[str] = default_project 

69 model_info: Optional[dict] = None 

70 project_datanode: Optional[Any] = None 

71 

72 class Config: 

73 """Configuration for this pydantic object.""" 

74 arbitrary_types_allowed = True 

75 allow_reuse = True 

76 

77 @property 

78 def _default_params(self) -> Dict[str, Any]: 

79 return {} 

80 

81 def completion( 

82 self, messages: List[dict] 

83 ) -> Any: 

84 problem_definition = self.model_info['problem_definition'].get('using', {}) 

85 output_col = self.model_info['predict'] 

86 

87 # TODO create table for conversational model? 

88 if len(messages) > 1: 

89 content = '\n'.join([ 

90 f"{m['role']}: {m['content']}" 

91 for m in messages 

92 ]) 

93 else: 

94 content = messages[0]['content'] 

95 

96 record = {} 

97 params = {} 

98 # Default to conversational if not set. 

99 mode = problem_definition.get('mode', 'conversational') 

100 if mode == 'conversational' or mode == 'retrieval': 

101 # flag for langchain to prevent calling agent inside of agent 

102 if self.model_info['engine'] == 'langchain': 

103 params['mode'] = 'chat_model' 

104 

105 user_column = problem_definition.get('user_column', USER_COLUMN) 

106 record[user_column] = content 

107 

108 elif 'column' in problem_definition: 

109 # input defined as 'column' param 

110 record[problem_definition['column']] = content 

111 

112 else: 

113 # failback, maybe handler supports template injection 

114 params['prompt_template'] = content 

115 

116 predictions = self.project_datanode.predict( 

117 model_name=self.model_name, 

118 df=pd.DataFrame([record]), 

119 params=params, 

120 ) 

121 

122 col = output_col 

123 if col not in predictions.columns: 

124 # get first column 

125 col = predictions.columns[0] 

126 

127 # get first row 

128 result = predictions[col][0] 

129 

130 # TODO token calculation 

131 return { 

132 'messages': [result] 

133 } 

134 

135 @model_validator(mode='before') 

136 def validate_environment(cls, values: Dict) -> Dict: 

137 

138 model_name = values['model_name'] 

139 project_name = values['project_name'] 

140 

141 from mindsdb.api.executor.controllers import SessionController 

142 

143 session = SessionController() 

144 session.database = default_project 

145 

146 values['model_info'] = session.model_controller.get_model(model_name, project_name=project_name) 

147 

148 project_datanode = session.datahub.get(values['project_name']) 

149 

150 values["project_datanode"] = project_datanode 

151 

152 return values 

153 

154 def _generate( 

155 self, 

156 messages: List[BaseMessage], 

157 stop: Optional[List[str]] = None, 

158 run_manager: Optional[CallbackManagerForLLMRun] = None, 

159 stream: Optional[bool] = None, 

160 **kwargs: Any, 

161 ) -> ChatResult: 

162 

163 message_dicts = [_convert_message_to_dict(m) for m in messages] 

164 

165 response = self.completion( 

166 messages=message_dicts 

167 ) 

168 return self._create_chat_result(response) 

169 

170 def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: 

171 generations = [] 

172 for content in response["messages"]: 

173 message = AIMessage(content=content) 

174 gen = ChatGeneration( 

175 message=message, 

176 generation_info=dict(finish_reason=None), 

177 ) 

178 generations.append(gen) 

179 token_usage = response.get("usage", {}) 

180 set_model_value = self.model_name 

181 if self.model_name is not None: 

182 set_model_value = self.model_name 

183 llm_output = {"token_usage": token_usage, "model": set_model_value} 

184 return ChatResult(generations=generations, llm_output=llm_output) 

185 

186 @property 

187 def _identifying_params(self) -> Dict[str, Any]: 

188 """Get the identifying parameters.""" 

189 set_model_value = self.model_name 

190 if self.model_name is not None: 

191 set_model_value = self.model_name 

192 return { 

193 "model_name": set_model_value, 

194 } 

195 

196 @property 

197 def _llm_type(self) -> str: 

198 return "mindsdb"