Coverage for mindsdb / api / executor / planner / plan_join_ts.py: 91%
181 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import copy
3from mindsdb_sql_parser.ast.mindsdb import Latest
4from mindsdb_sql_parser.ast import (
5 Select,
6 Identifier,
7 BetweenOperation,
8 Join,
9 Star,
10 BinaryOperation,
11 Constant,
12 OrderBy,
13 NullConstant,
14)
16from mindsdb.integrations.utilities.query_traversal import query_traversal
18from mindsdb.api.executor.planner.exceptions import PlanningException
19from mindsdb.api.executor.planner import utils
20from mindsdb.api.executor.planner.steps import (
21 JoinStep,
22 LimitOffsetStep,
23 MultipleSteps,
24 MapReduceStep,
25 ApplyTimeseriesPredictorStep,
26)
27from mindsdb.api.executor.planner.ts_utils import (
28 validate_ts_where_condition,
29 find_time_filter,
30 replace_time_filter,
31 find_and_remove_time_filter,
32 recursively_check_join_identifiers_for_ambiguity,
33)
36class PlanJoinTSPredictorQuery:
37 def __init__(self, planner):
38 self.planner = planner
40 def adapt_dbt_query(self, query, integration):
41 orig_query = query
43 join = query.from_table
44 join_left = join.left
46 # dbt query.
48 # move latest into subquery
49 moved_conditions = []
51 def move_latest(node, **kwargs):
52 if isinstance(node, BinaryOperation):
53 if Latest() in node.args: 53 ↛ exitline 53 didn't return from function 'move_latest' because the condition on line 53 was always true
54 for arg in node.args:
55 if isinstance(arg, Identifier):
56 # remove table alias
57 arg.parts = [arg.parts[-1]]
58 moved_conditions.append(node)
60 query_traversal(query.where, move_latest)
62 # TODO make project step from query.target
64 # TODO support complex query. Only one table is supported at the moment.
65 # if not isinstance(join_left.from_table, Identifier):
66 # raise PlanningException(f'Statement not supported: {query.to_string()}')
68 # move properties to upper query
69 query = join_left
71 if query.from_table.alias is not None: 71 ↛ 74line 71 didn't jump to line 74 because the condition on line 71 was always true
72 table_alias = [query.from_table.alias.parts[0]]
73 else:
74 table_alias = query.from_table.parts
76 # add latest to query.where
77 for cond in moved_conditions:
78 if query.where is not None: 78 ↛ 81line 78 didn't jump to line 81 because the condition on line 78 was always true
79 query.where = BinaryOperation("and", args=[query.where, cond])
80 else:
81 query.where = cond
83 def add_aliases(node, is_table, **kwargs):
84 if not is_table and isinstance(node, Identifier):
85 if len(node.parts) == 1:
86 # add table alias to field
87 node.parts = table_alias + node.parts
88 node.is_quoted = [False] + node.is_quoted
90 query_traversal(query.where, add_aliases)
92 if isinstance(query.from_table, Identifier): 92 ↛ 100line 92 didn't jump to line 100 because the condition on line 92 was always true
93 # DBT workaround: allow use tables without integration.
94 # if table.part[0] not in integration - take integration name from create table command
95 if integration is not None and query.from_table.parts[0] not in self.planner.databases:
96 # add integration name to table
97 query.from_table.parts.insert(0, integration)
98 query.from_table.is_quoted.insert(0, False)
100 join_left = join_left.from_table
102 if orig_query.limit is not None: 102 ↛ 103line 102 didn't jump to line 103 because the condition on line 102 was never true
103 if query.limit is None or query.limit.value > orig_query.limit.value:
104 query.limit = orig_query.limit
105 query.parentheses = False
106 query.alias = None
108 return query, join_left
110 def get_aliased_fields(self, targets):
111 # get aliases from select target
112 aliased_fields = {}
113 for target in targets:
114 if target.alias is not None:
115 aliased_fields[target.alias.to_string()] = target
116 return aliased_fields
118 def plan_fetch_timeseries_partitions(self, query, table, predictor_group_by_names):
119 targets = [Identifier(column) for column in predictor_group_by_names]
121 query = Select(
122 distinct=True,
123 targets=targets,
124 from_table=table,
125 where=query.where,
126 modifiers=query.modifiers,
127 )
128 select_step = self.planner.plan_integration_select(query)
129 return select_step
131 def plan(self, query, integration=None):
132 # integration is for dbt only
134 join = query.from_table
135 join_left = join.left
136 join_right = join.right
138 predictor_is_left = False
139 if self.planner.is_predictor(join_left):
140 # predictor is in the left, put it in the right
141 join_left, join_right = join_right, join_left
142 predictor_is_left = True
144 if self.planner.is_predictor(join_left): 144 ↛ 146line 144 didn't jump to line 146 because the condition on line 144 was never true
145 # in the left is also predictor
146 raise PlanningException(f"Can't join two predictors {join_left} and {join_left}")
148 orig_query = query
149 # dbt query?
150 if isinstance(join_left, Select) and isinstance(join_left.from_table, Identifier):
151 query, join_left = self.adapt_dbt_query(query, integration)
153 predictor_namespace, predictor = self.planner.get_predictor_namespace_and_name_from_identifier(join_right)
154 table = join_left
156 aliased_fields = self.get_aliased_fields(query.targets)
158 recursively_check_join_identifiers_for_ambiguity(query.where)
159 recursively_check_join_identifiers_for_ambiguity(query.group_by, aliased_fields=aliased_fields)
160 recursively_check_join_identifiers_for_ambiguity(query.having)
161 recursively_check_join_identifiers_for_ambiguity(query.order_by, aliased_fields=aliased_fields)
163 predictor_steps = self.plan_timeseries_predictor(query, table, predictor_namespace, predictor)
165 # add join
166 # Update reference
168 left = Identifier(predictor_steps["predictor"].result.ref_name)
169 right = Identifier(predictor_steps["data"].result.ref_name)
171 if not predictor_is_left:
172 # swap join
173 left, right = right, left
174 new_join = Join(left=left, right=right, join_type=join.join_type)
176 left = predictor_steps["predictor"].result
177 right = predictor_steps["data"].result
178 if not predictor_is_left:
179 # swap join
180 left, right = right, left
182 last_step = self.planner.plan.add_step(JoinStep(left=left, right=right, query=new_join))
184 # limit from timeseries
185 if predictor_steps.get("saved_limit"):
186 last_step = self.planner.plan.add_step(
187 LimitOffsetStep(dataframe=last_step.result, limit=predictor_steps["saved_limit"])
188 )
190 return self.planner.plan_project(orig_query, last_step.result)
192 def plan_timeseries_predictor(self, query, table, predictor_namespace, predictor):
193 predictor_metadata = self.planner.get_predictor(predictor)
195 predictor_time_column_name = predictor_metadata["order_by_column"]
196 predictor_group_by_names = predictor_metadata["group_by_columns"]
197 if predictor_group_by_names is None: 197 ↛ 198line 197 didn't jump to line 198 because the condition on line 197 was never true
198 predictor_group_by_names = []
199 predictor_window = predictor_metadata["window"]
201 if query.order_by: 201 ↛ 202line 201 didn't jump to line 202 because the condition on line 201 was never true
202 raise PlanningException(
203 f"Can't provide ORDER BY to time series predictor, it will be taken from predictor settings. Found: {query.order_by}"
204 )
206 saved_limit = None
207 if query.limit is not None:
208 saved_limit = query.limit.value
210 if query.group_by or query.having or query.offset: 210 ↛ 211line 210 didn't jump to line 211 because the condition on line 210 was never true
211 raise PlanningException(f"Unsupported query to timeseries predictor: {str(query)}")
213 allowed_columns = [predictor_time_column_name.lower()]
214 if len(predictor_group_by_names) > 0:
215 allowed_columns += [i.lower() for i in predictor_group_by_names]
217 no_time_filter_query = copy.deepcopy(query)
219 preparation_where = no_time_filter_query.where
221 validate_ts_where_condition(preparation_where, allowed_columns=allowed_columns)
223 time_filter = find_time_filter(preparation_where, time_column_name=predictor_time_column_name)
225 order_by = [OrderBy(Identifier(parts=[predictor_time_column_name]), direction="DESC")]
227 query_modifiers = query.modifiers
229 # add {order_by_field} is not null
230 def add_order_not_null(condition):
231 order_field_not_null = BinaryOperation(
232 op="is not", args=[Identifier(parts=[predictor_time_column_name]), NullConstant()]
233 )
234 if condition is not None:
235 condition = BinaryOperation(op="and", args=[condition, order_field_not_null])
236 else:
237 condition = order_field_not_null
238 return condition
240 preparation_where2 = copy.deepcopy(preparation_where)
241 preparation_where = add_order_not_null(preparation_where)
243 # Obtain integration selects
244 if isinstance(time_filter, BetweenOperation):
245 between_from = time_filter.args[1]
246 preparation_time_filter = BinaryOperation("<", args=[Identifier(predictor_time_column_name), between_from])
247 preparation_where2 = replace_time_filter(preparation_where2, time_filter, preparation_time_filter)
248 integration_select_1 = Select(
249 targets=[Star()],
250 from_table=table,
251 where=add_order_not_null(preparation_where2),
252 modifiers=query_modifiers,
253 order_by=order_by,
254 limit=Constant(predictor_window),
255 )
257 integration_select_2 = Select(
258 targets=[Star()],
259 from_table=table,
260 where=preparation_where,
261 modifiers=query_modifiers,
262 order_by=order_by,
263 )
265 integration_selects = [integration_select_1, integration_select_2]
266 elif isinstance(time_filter, BinaryOperation) and time_filter.op == ">" and time_filter.args[1] == Latest():
267 integration_select = Select(
268 targets=[Star()],
269 from_table=table,
270 where=preparation_where,
271 modifiers=query_modifiers,
272 order_by=order_by,
273 limit=Constant(predictor_window),
274 )
275 integration_select.where = find_and_remove_time_filter(integration_select.where, time_filter)
276 integration_selects = [integration_select]
277 elif isinstance(time_filter, BinaryOperation) and time_filter.op == "=":
278 integration_select = Select(
279 targets=[Star()],
280 from_table=table,
281 where=preparation_where,
282 modifiers=query_modifiers,
283 order_by=order_by,
284 limit=Constant(predictor_window),
285 )
287 if type(time_filter.args[1]) is Latest: 287 ↛ 288line 287 didn't jump to line 288 because the condition on line 287 was never true
288 integration_select.where = find_and_remove_time_filter(integration_select.where, time_filter)
289 else:
290 time_filter_date = time_filter.args[1]
291 preparation_time_filter = BinaryOperation(
292 "<=", args=[Identifier(predictor_time_column_name), time_filter_date]
293 )
294 integration_select.where = add_order_not_null(
295 replace_time_filter(preparation_where2, time_filter, preparation_time_filter)
296 )
297 time_filter.op = ">"
299 integration_selects = [integration_select]
300 elif isinstance(time_filter, BinaryOperation) and time_filter.op in (">", ">="):
301 time_filter_date = time_filter.args[1]
302 preparation_time_filter_op = {">": "<=", ">=": "<"}[time_filter.op]
304 preparation_time_filter = BinaryOperation(
305 preparation_time_filter_op, args=[Identifier(predictor_time_column_name), time_filter_date]
306 )
307 preparation_where2 = replace_time_filter(preparation_where2, time_filter, preparation_time_filter)
308 integration_select_1 = Select(
309 targets=[Star()],
310 from_table=table,
311 where=add_order_not_null(preparation_where2),
312 modifiers=query_modifiers,
313 order_by=order_by,
314 limit=Constant(predictor_window),
315 )
317 integration_select_2 = Select(
318 targets=[Star()],
319 from_table=table,
320 where=preparation_where,
321 modifiers=query_modifiers,
322 order_by=order_by,
323 )
325 integration_selects = [integration_select_1, integration_select_2]
326 else:
327 integration_select = Select(
328 targets=[Star()],
329 from_table=table,
330 where=preparation_where,
331 modifiers=query_modifiers,
332 order_by=order_by,
333 )
334 integration_selects = [integration_select]
336 if len(predictor_group_by_names) == 0:
337 # ts query without grouping
338 # one or multistep
339 if len(integration_selects) == 1:
340 select_partition_step = self.planner.get_integration_select_step(integration_selects[0])
341 else:
342 select_partition_step = MultipleSteps(
343 steps=[self.planner.get_integration_select_step(s) for s in integration_selects], reduce="union"
344 )
346 # fetch data step
347 data_step = self.planner.plan.add_step(select_partition_step)
348 else:
349 # inject $var to queries
350 for integration_select in integration_selects:
351 condition = integration_select.where
352 for num, column in enumerate(predictor_group_by_names):
353 cond = BinaryOperation("=", args=[Identifier(column), Constant(f"$var[{column}]")])
355 # join to main condition
356 if condition is None: 356 ↛ 357line 356 didn't jump to line 357 because the condition on line 356 was never true
357 condition = cond
358 else:
359 condition = BinaryOperation("and", args=[condition, cond])
361 integration_select.where = condition
362 # one or multistep
363 if len(integration_selects) == 1:
364 select_partition_step = self.planner.get_integration_select_step(integration_selects[0])
365 else:
366 select_partition_step = MultipleSteps(
367 steps=[self.planner.get_integration_select_step(s) for s in integration_selects], reduce="union"
368 )
370 # get groping values
371 no_time_filter_query.where = find_and_remove_time_filter(no_time_filter_query.where, time_filter)
372 select_partitions_step = self.plan_fetch_timeseries_partitions(
373 no_time_filter_query, table, predictor_group_by_names
374 )
376 # sub-query by every grouping value
377 map_reduce_step = self.planner.plan.add_step(
378 MapReduceStep(values=select_partitions_step.result, reduce="union", step=select_partition_step)
379 )
380 data_step = map_reduce_step
382 predictor_identifier = utils.get_predictor_name_identifier(predictor)
384 params = None
385 if query.using is not None:
386 params = query.using
387 predictor_step = self.planner.plan.add_step(
388 ApplyTimeseriesPredictorStep(
389 output_time_filter=time_filter,
390 namespace=predictor_namespace,
391 dataframe=data_step.result,
392 predictor=predictor_identifier,
393 params=params,
394 )
395 )
397 return {
398 "predictor": predictor_step,
399 "data": data_step,
400 "saved_limit": saved_limit,
401 }