Coverage for mindsdb / api / executor / sql_query / steps / apply_predictor_step.py: 6%

240 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import datetime as dt 

2import re 

3 

4import pandas as pd 

5 

6from mindsdb_sql_parser.ast import ( 

7 BinaryOperation, 

8 Identifier, 

9 Constant, 

10 BetweenOperation, 

11 Parameter, 

12) 

13from mindsdb_sql_parser.ast.mindsdb import Latest 

14 

15from mindsdb.api.executor.planner.step_result import Result 

16from mindsdb.api.executor.planner.steps import ( 

17 ApplyTimeseriesPredictorStep, 

18 ApplyPredictorRowStep, 

19 ApplyPredictorStep, 

20) 

21 

22from mindsdb.api.executor.sql_query.result_set import ResultSet, Column 

23from mindsdb.utilities.cache import get_cache, dataframe_checksum 

24 

25from .base import BaseStepCall 

26 

27 

28def get_preditor_alias(step, mindsdb_database): 

29 predictor_name = ".".join(step.predictor.parts) 

30 predictor_alias = ".".join(step.predictor.alias.parts) if step.predictor.alias is not None else predictor_name 

31 return (mindsdb_database, predictor_name, predictor_alias) 

32 

33 

34class ApplyPredictorBaseCall(BaseStepCall): 

35 def apply_predictor(self, project_name, predictor_name, df, version, params): 

36 # is it an agent? 

37 agent = self.session.agents_controller.get_agent(predictor_name, project_name) 

38 if agent is not None: 

39 messages = df.to_dict("records") 

40 predictions = self.session.agents_controller.get_completion( 

41 agent, messages=messages, project_name=project_name, params=params 

42 ) 

43 

44 else: 

45 project_datanode = self.session.datahub.get(project_name) 

46 predictions = project_datanode.predict(model_name=predictor_name, df=df, version=version, params=params) 

47 return predictions 

48 

49 

50class ApplyPredictorRowStepCall(ApplyPredictorBaseCall): 

51 bind = ApplyPredictorRowStep 

52 

53 def call(self, step): 

54 project_name = step.namespace 

55 predictor_name = step.predictor.parts[0] 

56 where_data0 = step.row_dict 

57 project_datanode = self.session.datahub.get(project_name) 

58 

59 # fill params 

60 where_data = {} 

61 for key, value in where_data0.items(): 

62 if isinstance(value, Parameter): 

63 rs = self.steps_data[value.value.step_num] 

64 if rs.length() == 1: 

65 # one value, don't do list 

66 value = rs.get_column_values(col_idx=0)[0] 

67 else: 

68 value = rs.get_column_values(col_idx=0) 

69 where_data[key] = value 

70 

71 version = None 

72 if len(step.predictor.parts) > 1 and step.predictor.parts[-1].isdigit(): 

73 version = int(step.predictor.parts[-1]) 

74 

75 df = pd.DataFrame([where_data]) 

76 predictions = self.apply_predictor(project_name, predictor_name, df, version, step.params) 

77 

78 # update predictions with input data 

79 for k, v in where_data.items(): 

80 predictions[k] = v 

81 

82 table_name = get_preditor_alias(step, self.context.get("database")) 

83 

84 if len(predictions) == 0: 

85 columns_names = project_datanode.get_table_columns_names(predictor_name) 

86 predictions = pd.DataFrame([], columns=columns_names) 

87 

88 return ResultSet.from_df( 

89 df=predictions, 

90 database=table_name[0], 

91 table_name=table_name[1], 

92 table_alias=table_name[2], 

93 is_prediction=True, 

94 ) 

95 

96 

97class ApplyPredictorStepCall(ApplyPredictorBaseCall): 

98 bind = ApplyPredictorStep 

99 

100 def call(self, step): 

101 # set row_id 

102 data = self.steps_data[step.dataframe.step_num] 

103 

104 params = step.params or {} 

105 

106 # adding __mindsdb_row_id, use first table if exists 

107 if len(data.find_columns("__mindsdb_row_id")) == 0: 

108 table = data.get_tables()[0] if len(data.get_tables()) > 0 else None 

109 

110 row_id_col = Column( 

111 name="__mindsdb_row_id", 

112 database=table["database"] if table is not None else None, 

113 table_name=table["table_name"] if table is not None else None, 

114 table_alias=table["table_alias"] if table is not None else None, 

115 ) 

116 

117 row_id = self.context.get("row_id") 

118 values = range(row_id, row_id + data.length()) 

119 data.add_column(row_id_col, values) 

120 self.context["row_id"] += data.length() 

121 

122 project_name = step.namespace 

123 predictor_name = step.predictor.parts[0] 

124 

125 # add constants from where 

126 if step.row_dict is not None: 

127 for k, v in step.row_dict.items(): 

128 if isinstance(v, Result): 

129 prev_result = self.steps_data[v.step_num] 

130 # TODO we await only one value: model.param = (subselect) 

131 v = prev_result.get_column_values(col_idx=0)[0] 

132 data.set_column_values(k, v) 

133 

134 predictor_metadata = {} 

135 for pm in self.context["predictor_metadata"]: 

136 if pm["name"] == predictor_name and pm["integration_name"].lower() == project_name: 

137 predictor_metadata = pm 

138 break 

139 is_timeseries = predictor_metadata["timeseries"] 

140 _mdb_forecast_offset = None 

141 if is_timeseries: 

142 if "> LATEST" in self.context["query_str"]: 

143 # stream mode -- if > LATEST, forecast starts on inferred next timestamp 

144 _mdb_forecast_offset = 1 

145 elif "= LATEST" in self.context["query_str"]: 

146 # override: when = LATEST, forecast starts on last provided timestamp instead of inferred next time 

147 _mdb_forecast_offset = 0 

148 else: 

149 # normal mode -- emit a forecast ($HORIZON data points on each) for each provided timestamp 

150 params["force_ts_infer"] = True 

151 _mdb_forecast_offset = None 

152 

153 data.add_column(Column(name="__mdb_forecast_offset"), _mdb_forecast_offset) 

154 

155 table_name = get_preditor_alias(step, self.context["database"]) 

156 

157 project_datanode = self.session.datahub.get(project_name) 

158 if len(data) == 0: 

159 columns_names = project_datanode.get_table_columns_names(predictor_name) + ["__mindsdb_row_id"] 

160 result = ResultSet(is_prediction=True) 

161 for column_name in columns_names: 

162 result.add_column( 

163 Column( 

164 name=column_name, database=table_name[0], table_name=table_name[1], table_alias=table_name[2] 

165 ) 

166 ) 

167 else: 

168 predictor_id = predictor_metadata["id"] 

169 table_df = data.to_df() 

170 

171 if self.session.predictor_cache is not False: 

172 key = f"{predictor_name}_{predictor_id}_{dataframe_checksum(table_df)}" 

173 

174 predictor_cache = get_cache("predict") 

175 predictions = predictor_cache.get(key) 

176 else: 

177 predictions = None 

178 

179 if predictions is None: 

180 # handle columns mapping to model 

181 if step.columns_map is not None: 

182 # step.columns_map is {str: Identifier} 

183 

184 cols_to_rename = {} 

185 for model_col, table_col in step.columns_map.items(): 

186 if len(table_col.parts) != 2: 

187 continue 

188 tbl_name, col_name = table_col.parts 

189 data_cols = data.find_columns(col_name, table_alias=tbl_name) 

190 if len(data_cols) == 0: 

191 continue 

192 # add first found column to rename list 

193 cols_to_rename[data.get_col_index(data_cols[0])] = model_col 

194 # update input data 

195 if cols_to_rename: 

196 columns = list(table_df.columns) 

197 for col_idx, name in cols_to_rename.items(): 

198 columns[col_idx] = name 

199 table_df.columns = columns 

200 

201 version = None 

202 if len(step.predictor.parts) > 1 and step.predictor.parts[-1].isdigit(): 

203 version = int(step.predictor.parts[-1]) 

204 predictions = self.apply_predictor(project_name, predictor_name, table_df, version, params) 

205 

206 if self.session.predictor_cache is not False: 

207 if predictions is not None and isinstance(predictions, pd.DataFrame): 

208 predictor_cache.set(key, predictions) 

209 

210 # apply filter 

211 if is_timeseries: 

212 pred_data = predictions.to_dict(orient="records") 

213 where_data = list(data.get_records()) 

214 pred_data = self.apply_ts_filter(pred_data, where_data, step, predictor_metadata) 

215 predictions = pd.DataFrame(pred_data) 

216 

217 result = ResultSet.from_df( 

218 predictions, 

219 database=table_name[0], 

220 table_name=table_name[1], 

221 table_alias=table_name[2], 

222 is_prediction=True, 

223 ) 

224 

225 return result 

226 

227 def apply_ts_filter(self, predictor_data, table_data, step, predictor_metadata): 

228 if step.output_time_filter is None: 

229 # no filter, exit 

230 return predictor_data 

231 

232 # apply filter 

233 group_cols = predictor_metadata["group_by_columns"] 

234 order_col = predictor_metadata["order_by_column"] 

235 

236 filter_args = step.output_time_filter.args 

237 filter_op = step.output_time_filter.op 

238 

239 # filter field must be order column 

240 if not (isinstance(filter_args[0], Identifier) and filter_args[0].parts[-1] == order_col): 

241 # exit otherwise 

242 return predictor_data 

243 

244 def get_date_format(samples): 

245 # Try common formats first with explicit patterns 

246 for date_format, pattern in ( 

247 ("%Y-%m-%d", r"[\d]{4}-[\d]{2}-[\d]{2}"), 

248 ("%Y-%m-%d %H:%M:%S", r"[\d]{4}-[\d]{2}-[\d]{2} [\d]{2}:[\d]{2}:[\d]{2}"), 

249 # ('%Y-%m-%d %H:%M:%S%z', r'[\d]{4}-[\d]{2}-[\d]{2} [\d]{2}:[\d]{2}:[\d]{2}\+[\d]{2}:[\d]{2}'), 

250 # ('%Y', '[\d]{4}') 

251 ): 

252 if re.match(pattern, samples[0]): 

253 # suggested format 

254 for sample in samples: 

255 try: 

256 dt.datetime.strptime(sample, date_format) 

257 except ValueError: 

258 date_format = None 

259 break 

260 if date_format is not None: 

261 return date_format 

262 

263 # Use dateparser as fallback and infer format 

264 try: 

265 # Parse the first sample to get its format 

266 # The import is heavy, so we do it here on-demand 

267 import dateparser 

268 

269 parsed_date = dateparser.parse(samples[0]) 

270 if parsed_date is None: 

271 raise ValueError("Could not parse date") 

272 

273 # Verify the format works for all samples 

274 for sample in samples[1:]: 

275 if dateparser.parse(sample) is None: 

276 raise ValueError("Inconsistent date formats in samples") 

277 # Convert to strftime format based on the input 

278 if re.search(r"\d{2}:\d{2}:\d{2}", samples[0]): 

279 return "%Y-%m-%d %H:%M:%S" 

280 return "%Y-%m-%d" 

281 except (ValueError, AttributeError): 

282 # If dateparser fails, return a basic format as last resort 

283 return "%Y-%m-%d" 

284 

285 model_types = predictor_metadata["model_types"] 

286 if model_types.get(order_col) in ("float", "integer"): 

287 # convert strings to digits 

288 fnc = {"integer": int, "float": float}[model_types[order_col]] 

289 

290 # convert predictor_data 

291 if len(predictor_data) > 0: 

292 if isinstance(predictor_data[0][order_col], str): 

293 for row in predictor_data: 

294 row[order_col] = fnc(row[order_col]) 

295 elif isinstance(predictor_data[0][order_col], dt.date): 

296 # convert to datetime 

297 for row in predictor_data: 

298 row[order_col] = fnc(row[order_col]) 

299 

300 # convert predictor_data 

301 if isinstance(table_data[0][order_col], str): 

302 for row in table_data: 

303 row[order_col] = fnc(row[order_col]) 

304 elif isinstance(table_data[0][order_col], dt.date): 

305 # convert to datetime 

306 for row in table_data: 

307 row[order_col] = fnc(row[order_col]) 

308 

309 # convert args to date 

310 samples = [arg.value for arg in filter_args if isinstance(arg, Constant) and isinstance(arg.value, str)] 

311 if len(samples) > 0: 

312 for arg in filter_args: 

313 if isinstance(arg, Constant) and isinstance(arg.value, str): 

314 arg.value = fnc(arg.value) 

315 

316 if model_types.get(order_col) in ("date", "datetime") or isinstance(predictor_data[0][order_col], pd.Timestamp): # noqa 

317 # convert strings to date 

318 # it is making side effect on original data by changing it but let it be 

319 

320 def _cast_samples(data, order_col): 

321 if isinstance(data[0][order_col], str): 

322 samples = [row[order_col] for row in data] 

323 date_format = get_date_format(samples) 

324 

325 for row in data: 

326 row[order_col] = dt.datetime.strptime(row[order_col], date_format) 

327 elif isinstance(data[0][order_col], dt.datetime): 

328 pass # check because dt.datetime is instance of dt.date but here we don't need to add HH:MM:SS 

329 elif isinstance(data[0][order_col], dt.date): 

330 # convert to datetime 

331 for row in data: 

332 row[order_col] = dt.datetime.combine(row[order_col], dt.datetime.min.time()) 

333 

334 # convert predictor_data 

335 if len(predictor_data) > 0: 

336 _cast_samples(predictor_data, order_col) 

337 

338 # convert table data 

339 _cast_samples(table_data, order_col) 

340 

341 # convert args to date 

342 samples = [arg.value for arg in filter_args if isinstance(arg, Constant) and isinstance(arg.value, str)] 

343 if len(samples) > 0: 

344 date_format = get_date_format(samples) 

345 

346 for arg in filter_args: 

347 if isinstance(arg, Constant) and isinstance(arg.value, str): 

348 arg.value = dt.datetime.strptime(arg.value, date_format) 

349 # TODO can be dt.date in args? 

350 

351 # first pass: get max values for Latest in table data 

352 latest_vals = {} 

353 if Latest() in filter_args: 

354 for row in table_data: 

355 if group_cols is None: 

356 key = 0 # the same for any value 

357 else: 

358 key = tuple([str(row[i]) for i in group_cols]) 

359 val = row[order_col] 

360 if key not in latest_vals or latest_vals[key] < val: 

361 latest_vals[key] = val 

362 

363 # second pass: do filter rows 

364 data2 = [] 

365 for row in predictor_data: 

366 val = row[order_col] 

367 

368 if isinstance(step.output_time_filter, BetweenOperation): 

369 if val >= filter_args[1].value and val <= filter_args[2].value: 

370 data2.append(row) 

371 elif isinstance(step.output_time_filter, BinaryOperation): 

372 op_map = { 

373 "<": "__lt__", 

374 "<=": "__le__", 

375 ">": "__gt__", 

376 ">=": "__ge__", 

377 "=": "__eq__", 

378 } 

379 arg = filter_args[1] 

380 if isinstance(arg, Latest): 

381 if group_cols is None: 

382 key = 0 # the same for any value 

383 else: 

384 key = tuple([str(row[i]) for i in group_cols]) 

385 if key not in latest_vals: 

386 # pass this row 

387 continue 

388 arg = latest_vals[key] 

389 elif isinstance(arg, Constant): 

390 arg = arg.value 

391 

392 if filter_op not in op_map: 

393 # unknown operation, exit immediately 

394 return predictor_data 

395 

396 # check condition 

397 filter_op2 = op_map[filter_op] 

398 if getattr(val, filter_op2)(arg): 

399 data2.append(row) 

400 else: 

401 # unknown operation, add anyway 

402 data2.append(row) 

403 

404 return data2 

405 

406 

407class ApplyTimeseriesPredictorStepCall(ApplyPredictorStepCall): 

408 bind = ApplyTimeseriesPredictorStep