Coverage for mindsdb / api / executor / sql_query / steps / apply_predictor_step.py: 6%
240 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import datetime as dt
2import re
4import pandas as pd
6from mindsdb_sql_parser.ast import (
7 BinaryOperation,
8 Identifier,
9 Constant,
10 BetweenOperation,
11 Parameter,
12)
13from mindsdb_sql_parser.ast.mindsdb import Latest
15from mindsdb.api.executor.planner.step_result import Result
16from mindsdb.api.executor.planner.steps import (
17 ApplyTimeseriesPredictorStep,
18 ApplyPredictorRowStep,
19 ApplyPredictorStep,
20)
22from mindsdb.api.executor.sql_query.result_set import ResultSet, Column
23from mindsdb.utilities.cache import get_cache, dataframe_checksum
25from .base import BaseStepCall
28def get_preditor_alias(step, mindsdb_database):
29 predictor_name = ".".join(step.predictor.parts)
30 predictor_alias = ".".join(step.predictor.alias.parts) if step.predictor.alias is not None else predictor_name
31 return (mindsdb_database, predictor_name, predictor_alias)
34class ApplyPredictorBaseCall(BaseStepCall):
35 def apply_predictor(self, project_name, predictor_name, df, version, params):
36 # is it an agent?
37 agent = self.session.agents_controller.get_agent(predictor_name, project_name)
38 if agent is not None:
39 messages = df.to_dict("records")
40 predictions = self.session.agents_controller.get_completion(
41 agent, messages=messages, project_name=project_name, params=params
42 )
44 else:
45 project_datanode = self.session.datahub.get(project_name)
46 predictions = project_datanode.predict(model_name=predictor_name, df=df, version=version, params=params)
47 return predictions
50class ApplyPredictorRowStepCall(ApplyPredictorBaseCall):
51 bind = ApplyPredictorRowStep
53 def call(self, step):
54 project_name = step.namespace
55 predictor_name = step.predictor.parts[0]
56 where_data0 = step.row_dict
57 project_datanode = self.session.datahub.get(project_name)
59 # fill params
60 where_data = {}
61 for key, value in where_data0.items():
62 if isinstance(value, Parameter):
63 rs = self.steps_data[value.value.step_num]
64 if rs.length() == 1:
65 # one value, don't do list
66 value = rs.get_column_values(col_idx=0)[0]
67 else:
68 value = rs.get_column_values(col_idx=0)
69 where_data[key] = value
71 version = None
72 if len(step.predictor.parts) > 1 and step.predictor.parts[-1].isdigit():
73 version = int(step.predictor.parts[-1])
75 df = pd.DataFrame([where_data])
76 predictions = self.apply_predictor(project_name, predictor_name, df, version, step.params)
78 # update predictions with input data
79 for k, v in where_data.items():
80 predictions[k] = v
82 table_name = get_preditor_alias(step, self.context.get("database"))
84 if len(predictions) == 0:
85 columns_names = project_datanode.get_table_columns_names(predictor_name)
86 predictions = pd.DataFrame([], columns=columns_names)
88 return ResultSet.from_df(
89 df=predictions,
90 database=table_name[0],
91 table_name=table_name[1],
92 table_alias=table_name[2],
93 is_prediction=True,
94 )
97class ApplyPredictorStepCall(ApplyPredictorBaseCall):
98 bind = ApplyPredictorStep
100 def call(self, step):
101 # set row_id
102 data = self.steps_data[step.dataframe.step_num]
104 params = step.params or {}
106 # adding __mindsdb_row_id, use first table if exists
107 if len(data.find_columns("__mindsdb_row_id")) == 0:
108 table = data.get_tables()[0] if len(data.get_tables()) > 0 else None
110 row_id_col = Column(
111 name="__mindsdb_row_id",
112 database=table["database"] if table is not None else None,
113 table_name=table["table_name"] if table is not None else None,
114 table_alias=table["table_alias"] if table is not None else None,
115 )
117 row_id = self.context.get("row_id")
118 values = range(row_id, row_id + data.length())
119 data.add_column(row_id_col, values)
120 self.context["row_id"] += data.length()
122 project_name = step.namespace
123 predictor_name = step.predictor.parts[0]
125 # add constants from where
126 if step.row_dict is not None:
127 for k, v in step.row_dict.items():
128 if isinstance(v, Result):
129 prev_result = self.steps_data[v.step_num]
130 # TODO we await only one value: model.param = (subselect)
131 v = prev_result.get_column_values(col_idx=0)[0]
132 data.set_column_values(k, v)
134 predictor_metadata = {}
135 for pm in self.context["predictor_metadata"]:
136 if pm["name"] == predictor_name and pm["integration_name"].lower() == project_name:
137 predictor_metadata = pm
138 break
139 is_timeseries = predictor_metadata["timeseries"]
140 _mdb_forecast_offset = None
141 if is_timeseries:
142 if "> LATEST" in self.context["query_str"]:
143 # stream mode -- if > LATEST, forecast starts on inferred next timestamp
144 _mdb_forecast_offset = 1
145 elif "= LATEST" in self.context["query_str"]:
146 # override: when = LATEST, forecast starts on last provided timestamp instead of inferred next time
147 _mdb_forecast_offset = 0
148 else:
149 # normal mode -- emit a forecast ($HORIZON data points on each) for each provided timestamp
150 params["force_ts_infer"] = True
151 _mdb_forecast_offset = None
153 data.add_column(Column(name="__mdb_forecast_offset"), _mdb_forecast_offset)
155 table_name = get_preditor_alias(step, self.context["database"])
157 project_datanode = self.session.datahub.get(project_name)
158 if len(data) == 0:
159 columns_names = project_datanode.get_table_columns_names(predictor_name) + ["__mindsdb_row_id"]
160 result = ResultSet(is_prediction=True)
161 for column_name in columns_names:
162 result.add_column(
163 Column(
164 name=column_name, database=table_name[0], table_name=table_name[1], table_alias=table_name[2]
165 )
166 )
167 else:
168 predictor_id = predictor_metadata["id"]
169 table_df = data.to_df()
171 if self.session.predictor_cache is not False:
172 key = f"{predictor_name}_{predictor_id}_{dataframe_checksum(table_df)}"
174 predictor_cache = get_cache("predict")
175 predictions = predictor_cache.get(key)
176 else:
177 predictions = None
179 if predictions is None:
180 # handle columns mapping to model
181 if step.columns_map is not None:
182 # step.columns_map is {str: Identifier}
184 cols_to_rename = {}
185 for model_col, table_col in step.columns_map.items():
186 if len(table_col.parts) != 2:
187 continue
188 tbl_name, col_name = table_col.parts
189 data_cols = data.find_columns(col_name, table_alias=tbl_name)
190 if len(data_cols) == 0:
191 continue
192 # add first found column to rename list
193 cols_to_rename[data.get_col_index(data_cols[0])] = model_col
194 # update input data
195 if cols_to_rename:
196 columns = list(table_df.columns)
197 for col_idx, name in cols_to_rename.items():
198 columns[col_idx] = name
199 table_df.columns = columns
201 version = None
202 if len(step.predictor.parts) > 1 and step.predictor.parts[-1].isdigit():
203 version = int(step.predictor.parts[-1])
204 predictions = self.apply_predictor(project_name, predictor_name, table_df, version, params)
206 if self.session.predictor_cache is not False:
207 if predictions is not None and isinstance(predictions, pd.DataFrame):
208 predictor_cache.set(key, predictions)
210 # apply filter
211 if is_timeseries:
212 pred_data = predictions.to_dict(orient="records")
213 where_data = list(data.get_records())
214 pred_data = self.apply_ts_filter(pred_data, where_data, step, predictor_metadata)
215 predictions = pd.DataFrame(pred_data)
217 result = ResultSet.from_df(
218 predictions,
219 database=table_name[0],
220 table_name=table_name[1],
221 table_alias=table_name[2],
222 is_prediction=True,
223 )
225 return result
227 def apply_ts_filter(self, predictor_data, table_data, step, predictor_metadata):
228 if step.output_time_filter is None:
229 # no filter, exit
230 return predictor_data
232 # apply filter
233 group_cols = predictor_metadata["group_by_columns"]
234 order_col = predictor_metadata["order_by_column"]
236 filter_args = step.output_time_filter.args
237 filter_op = step.output_time_filter.op
239 # filter field must be order column
240 if not (isinstance(filter_args[0], Identifier) and filter_args[0].parts[-1] == order_col):
241 # exit otherwise
242 return predictor_data
244 def get_date_format(samples):
245 # Try common formats first with explicit patterns
246 for date_format, pattern in (
247 ("%Y-%m-%d", r"[\d]{4}-[\d]{2}-[\d]{2}"),
248 ("%Y-%m-%d %H:%M:%S", r"[\d]{4}-[\d]{2}-[\d]{2} [\d]{2}:[\d]{2}:[\d]{2}"),
249 # ('%Y-%m-%d %H:%M:%S%z', r'[\d]{4}-[\d]{2}-[\d]{2} [\d]{2}:[\d]{2}:[\d]{2}\+[\d]{2}:[\d]{2}'),
250 # ('%Y', '[\d]{4}')
251 ):
252 if re.match(pattern, samples[0]):
253 # suggested format
254 for sample in samples:
255 try:
256 dt.datetime.strptime(sample, date_format)
257 except ValueError:
258 date_format = None
259 break
260 if date_format is not None:
261 return date_format
263 # Use dateparser as fallback and infer format
264 try:
265 # Parse the first sample to get its format
266 # The import is heavy, so we do it here on-demand
267 import dateparser
269 parsed_date = dateparser.parse(samples[0])
270 if parsed_date is None:
271 raise ValueError("Could not parse date")
273 # Verify the format works for all samples
274 for sample in samples[1:]:
275 if dateparser.parse(sample) is None:
276 raise ValueError("Inconsistent date formats in samples")
277 # Convert to strftime format based on the input
278 if re.search(r"\d{2}:\d{2}:\d{2}", samples[0]):
279 return "%Y-%m-%d %H:%M:%S"
280 return "%Y-%m-%d"
281 except (ValueError, AttributeError):
282 # If dateparser fails, return a basic format as last resort
283 return "%Y-%m-%d"
285 model_types = predictor_metadata["model_types"]
286 if model_types.get(order_col) in ("float", "integer"):
287 # convert strings to digits
288 fnc = {"integer": int, "float": float}[model_types[order_col]]
290 # convert predictor_data
291 if len(predictor_data) > 0:
292 if isinstance(predictor_data[0][order_col], str):
293 for row in predictor_data:
294 row[order_col] = fnc(row[order_col])
295 elif isinstance(predictor_data[0][order_col], dt.date):
296 # convert to datetime
297 for row in predictor_data:
298 row[order_col] = fnc(row[order_col])
300 # convert predictor_data
301 if isinstance(table_data[0][order_col], str):
302 for row in table_data:
303 row[order_col] = fnc(row[order_col])
304 elif isinstance(table_data[0][order_col], dt.date):
305 # convert to datetime
306 for row in table_data:
307 row[order_col] = fnc(row[order_col])
309 # convert args to date
310 samples = [arg.value for arg in filter_args if isinstance(arg, Constant) and isinstance(arg.value, str)]
311 if len(samples) > 0:
312 for arg in filter_args:
313 if isinstance(arg, Constant) and isinstance(arg.value, str):
314 arg.value = fnc(arg.value)
316 if model_types.get(order_col) in ("date", "datetime") or isinstance(predictor_data[0][order_col], pd.Timestamp): # noqa
317 # convert strings to date
318 # it is making side effect on original data by changing it but let it be
320 def _cast_samples(data, order_col):
321 if isinstance(data[0][order_col], str):
322 samples = [row[order_col] for row in data]
323 date_format = get_date_format(samples)
325 for row in data:
326 row[order_col] = dt.datetime.strptime(row[order_col], date_format)
327 elif isinstance(data[0][order_col], dt.datetime):
328 pass # check because dt.datetime is instance of dt.date but here we don't need to add HH:MM:SS
329 elif isinstance(data[0][order_col], dt.date):
330 # convert to datetime
331 for row in data:
332 row[order_col] = dt.datetime.combine(row[order_col], dt.datetime.min.time())
334 # convert predictor_data
335 if len(predictor_data) > 0:
336 _cast_samples(predictor_data, order_col)
338 # convert table data
339 _cast_samples(table_data, order_col)
341 # convert args to date
342 samples = [arg.value for arg in filter_args if isinstance(arg, Constant) and isinstance(arg.value, str)]
343 if len(samples) > 0:
344 date_format = get_date_format(samples)
346 for arg in filter_args:
347 if isinstance(arg, Constant) and isinstance(arg.value, str):
348 arg.value = dt.datetime.strptime(arg.value, date_format)
349 # TODO can be dt.date in args?
351 # first pass: get max values for Latest in table data
352 latest_vals = {}
353 if Latest() in filter_args:
354 for row in table_data:
355 if group_cols is None:
356 key = 0 # the same for any value
357 else:
358 key = tuple([str(row[i]) for i in group_cols])
359 val = row[order_col]
360 if key not in latest_vals or latest_vals[key] < val:
361 latest_vals[key] = val
363 # second pass: do filter rows
364 data2 = []
365 for row in predictor_data:
366 val = row[order_col]
368 if isinstance(step.output_time_filter, BetweenOperation):
369 if val >= filter_args[1].value and val <= filter_args[2].value:
370 data2.append(row)
371 elif isinstance(step.output_time_filter, BinaryOperation):
372 op_map = {
373 "<": "__lt__",
374 "<=": "__le__",
375 ">": "__gt__",
376 ">=": "__ge__",
377 "=": "__eq__",
378 }
379 arg = filter_args[1]
380 if isinstance(arg, Latest):
381 if group_cols is None:
382 key = 0 # the same for any value
383 else:
384 key = tuple([str(row[i]) for i in group_cols])
385 if key not in latest_vals:
386 # pass this row
387 continue
388 arg = latest_vals[key]
389 elif isinstance(arg, Constant):
390 arg = arg.value
392 if filter_op not in op_map:
393 # unknown operation, exit immediately
394 return predictor_data
396 # check condition
397 filter_op2 = op_map[filter_op]
398 if getattr(val, filter_op2)(arg):
399 data2.append(row)
400 else:
401 # unknown operation, add anyway
402 data2.append(row)
404 return data2
407class ApplyTimeseriesPredictorStepCall(ApplyPredictorStepCall):
408 bind = ApplyTimeseriesPredictorStep