Coverage for mindsdb / interfaces / functions / controller.py: 18%

142 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import os 

2import copy 

3 

4from duckdb.typing import BIGINT, DOUBLE, VARCHAR, BLOB, BOOLEAN 

5from mindsdb.interfaces.storage.model_fs import HandlerStorage 

6from mindsdb.utilities.config import config 

7 

8 

9def python_to_duckdb_type(py_type): 

10 if py_type == "int": 

11 return BIGINT 

12 elif py_type == "float": 

13 return DOUBLE 

14 elif py_type == "str": 

15 return VARCHAR 

16 elif py_type == "bool": 

17 return BOOLEAN 

18 elif py_type == "bytes": 

19 return BLOB 

20 else: 

21 # Unknown 

22 return VARCHAR 

23 

24 

25# duckdb doesn't like *args 

26def function_maker(n_args, other_function): 

27 return [ 

28 lambda: other_function(), 

29 lambda arg_0: other_function(arg_0), 

30 lambda arg_0, arg_1: other_function(arg_0, arg_1), 

31 lambda arg_0, arg_1, arg_2: other_function(arg_0, arg_1, arg_2), 

32 lambda arg_0, arg_1, arg_2, arg_3: other_function(arg_0, arg_1, arg_2, arg_2), 

33 ][n_args] 

34 

35 

36class BYOMFunctionsController: 

37 """ 

38 User functions based on BYOM handler 

39 """ 

40 

41 def __init__(self, session): 

42 self.session = session 

43 

44 self.byom_engines = None 

45 self.byom_methods = {} 

46 self.byom_handlers = {} 

47 

48 self.callbacks = {} 

49 

50 def get_engines(self): 

51 # get all byom engines 

52 if self.byom_engines is None: 

53 # first run 

54 self.byom_engines = [] 

55 for name, info in self.session.integration_controller.get_all().items(): 

56 if info["type"] == "ml" and info["engine"] == "byom": 

57 if info["connection_data"].get("mode") == "custom_function": 

58 self.byom_engines.append(name) 

59 return self.byom_engines 

60 

61 def get_methods(self, engine): 

62 if engine not in self.byom_methods: 

63 ml_handler = self.session.integration_controller.get_ml_handler(engine) 

64 

65 storage = HandlerStorage(ml_handler.integration_id) 

66 methods = storage.json_get("methods") 

67 self.byom_methods[engine] = methods 

68 self.byom_handlers[engine] = ml_handler 

69 

70 return self.byom_methods[engine] 

71 

72 def check_function(self, node): 

73 engine = node.namespace 

74 if engine not in self.get_engines(): 

75 return 

76 

77 methods = self.get_methods(engine) 

78 

79 fnc_name = node.op.lower() 

80 if fnc_name not in methods: 

81 # do nothing 

82 return 

83 

84 new_name = f"{node.namespace}_{fnc_name}" 

85 node.op = new_name 

86 

87 if new_name in self.callbacks: 

88 # already exists 

89 return self.callbacks[new_name] 

90 

91 def callback(*args): 

92 return self.method_call(engine, fnc_name, args) 

93 

94 input_types = [param["type"] for param in methods[fnc_name]["input_params"]] 

95 

96 meta = { 

97 "name": new_name, 

98 "callback": callback, 

99 "input_types": input_types, 

100 "output_type": methods[fnc_name]["output_type"], 

101 } 

102 

103 self.callbacks[new_name] = meta 

104 return meta 

105 

106 def method_call(self, engine, method_name, args): 

107 return self.byom_handlers[engine].function_call(method_name, args) 

108 

109 def create_function_set(self): 

110 return DuckDBFunctions(self) 

111 

112 

113class FunctionController(BYOMFunctionsController): 

114 def __init__(self, *args, **kwargs): 

115 super().__init__(*args, **kwargs) 

116 

117 def check_function(self, node): 

118 meta = super().check_function(node) 

119 if meta is not None: 

120 return meta 

121 

122 # builtin functions 

123 if node.op.lower() == "llm": 

124 return self.llm_call_function(node) 

125 

126 elif node.op.lower() == "to_markdown": 

127 return self.to_markdown_call_function(node) 

128 

129 def llm_call_function(self, node): 

130 name = node.op.lower() 

131 

132 if name in self.callbacks: 

133 return self.callbacks[name] 

134 

135 chat_model_params = self._parse_chat_model_params() 

136 

137 try: 

138 from langchain_core.messages import HumanMessage 

139 from mindsdb.interfaces.agents.langchain_agent import create_chat_model 

140 

141 llm = create_chat_model(chat_model_params) 

142 except Exception as e: 

143 raise RuntimeError(f"Unable to use LLM function, check ENV variables: {e}") from e 

144 

145 def callback(question): 

146 resp = llm([HumanMessage(question)]) 

147 return resp.content 

148 

149 meta = {"name": name, "callback": callback, "input_types": ["str"], "output_type": "str"} 

150 self.callbacks[name] = meta 

151 return meta 

152 

153 def to_markdown_call_function(self, node): 

154 # load on-demand because lib is heavy 

155 from mindsdb.interfaces.functions.to_markdown import ToMarkdown 

156 

157 name = node.op.lower() 

158 

159 if name in self.callbacks: 

160 return self.callbacks[name] 

161 

162 def prepare_chat_model_params(chat_model_params: dict) -> dict: 

163 """ 

164 Parepares the chat model parameters for the ToMarkdown function. 

165 """ 

166 params_copy = copy.deepcopy(chat_model_params) 

167 params_copy["model"] = params_copy.pop("model_name") 

168 

169 # Set the base_url for the Google provider. 

170 if params_copy["provider"] == "google" and "base_url" not in params_copy: 

171 params_copy["base_url"] = "https://generativelanguage.googleapis.com/v1beta/" 

172 

173 params_copy.pop("api_keys") 

174 params_copy.pop("provider") 

175 

176 return params_copy 

177 

178 def callback(file_path_or_url): 

179 chat_model_params = self._parse_chat_model_params("TO_MARKDOWN_FUNCTION_") 

180 chat_model_params = prepare_chat_model_params(chat_model_params) 

181 

182 to_markdown = ToMarkdown() 

183 return to_markdown.call(file_path_or_url, **chat_model_params) 

184 

185 meta = {"name": name, "callback": callback, "input_types": ["str"], "output_type": "str"} 

186 self.callbacks[name] = meta 

187 return meta 

188 

189 def _parse_chat_model_params(self, param_prefix: str = "LLM_FUNCTION_"): 

190 """ 

191 Parses the environment variables for chat model parameters. 

192 """ 

193 chat_model_params = config.get("default_llm") or {} 

194 for k, v in os.environ.items(): 

195 if k.startswith(param_prefix): 

196 param_name = k[len(param_prefix) :] 

197 if param_name == "MODEL": 

198 chat_model_params["model_name"] = v 

199 else: 

200 chat_model_params[param_name.lower()] = v 

201 

202 if "provider" not in chat_model_params: 

203 chat_model_params["provider"] = "openai" 

204 

205 if "api_key" in chat_model_params: 

206 # move to api_keys dict 

207 chat_model_params["api_keys"] = {chat_model_params["provider"]: chat_model_params["api_key"]} 

208 

209 return chat_model_params 

210 

211 

212class DuckDBFunctions: 

213 def __init__(self, controller): 

214 self.controller = controller 

215 self.functions = {} 

216 

217 def check_function(self, node): 

218 meta = self.controller.check_function(node) 

219 if meta is None: 

220 return 

221 

222 name = meta["name"] 

223 

224 if name in self.functions: 

225 return 

226 

227 input_types = [python_to_duckdb_type(param) for param in meta["input_types"]] 

228 

229 self.functions[name] = { 

230 "callback": function_maker(len(input_types), meta["callback"]), 

231 "input": input_types, 

232 "output": python_to_duckdb_type(meta["output_type"]), 

233 } 

234 

235 def register(self, connection): 

236 for name, info in self.functions.items(): 236 ↛ 237line 236 didn't jump to line 237 because the loop on line 236 never started

237 connection.create_function(name, info["callback"], info["input"], info["output"], null_handling="special")