Coverage for mindsdb / api / executor / planner / plan_join.py: 89%
493 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import copy
2from dataclasses import dataclass, field
4from mindsdb_sql_parser import ast
5from mindsdb_sql_parser.ast import (
6 Select,
7 Identifier,
8 BetweenOperation,
9 Join,
10 Star,
11 BinaryOperation,
12 Constant,
13 NativeQuery,
14 Parameter,
15 Function,
16 Last,
17 Tuple,
18)
20from mindsdb.integrations.utilities.query_traversal import query_traversal
22from mindsdb.api.executor.planner.exceptions import PlanningException
23from mindsdb.api.executor.planner.steps import (
24 FetchDataframeStep,
25 FetchDataframeStepPartition,
26 JoinStep,
27 ApplyPredictorStep,
28 SubSelectStep,
29 QueryStep,
30 MapReduceStep,
31)
32from mindsdb.api.executor.planner.utils import filters_to_bin_op
33from mindsdb.api.executor.planner.plan_join_ts import PlanJoinTSPredictorQuery
36@dataclass
37class TableInfo:
38 integration: str
39 table: Identifier
40 aliases: list[tuple[str, ...]] = field(default_factory=list)
41 conditions: list = None
42 sub_select: ast.ASTNode = None
43 predictor_info: dict = None
44 join_condition = None
45 index: int = None
48class PlanJoin:
49 def __init__(self, planner):
50 self.planner = planner
52 def is_timeseries(self, query):
53 join = query.from_table
54 l_predictor = self.planner.get_predictor(join.left) if isinstance(join.left, Identifier) else None
55 r_predictor = self.planner.get_predictor(join.right) if isinstance(join.right, Identifier) else None
56 if l_predictor and l_predictor.get("timeseries"):
57 return True
58 if r_predictor and r_predictor.get("timeseries"):
59 return True
61 def check_single_integration(self, query):
62 query_info = self.planner.get_query_info(query)
63 # can we send all query to integration?
64 # one integration and not mindsdb objects in query
65 if (
66 len(query_info["mdb_entities"]) == 0
67 and len(query_info["integrations"]) == 1
68 and "files" not in query_info["integrations"]
69 and "views" not in query_info["integrations"]
70 ):
71 int_name = list(query_info["integrations"])[0]
72 # if is sql database
73 class_type = self.planner.integrations.get(int_name, {}).get("class_type")
74 if class_type != "api": 74 ↛ 77line 74 didn't jump to line 77 because the condition on line 74 was always true
75 # send to this integration
76 return int_name
77 return None
79 def plan(self, query, integration=None):
80 # FIXME: Tableau workaround, INFORMATION_SCHEMA with Where
81 # if isinstance(join.right, Identifier) \
82 # and self.resolve_database_table(join.right)[0] == 'INFORMATION_SCHEMA':
83 # pass
85 # send join to integration as is?
86 integration_to_send = self.check_single_integration(query)
87 if integration_to_send:
88 self.planner.prepare_integration_select(integration_to_send, query)
90 fetch_params = self.planner.get_fetch_params(query.using)
91 last_step = self.planner.plan.add_step(
92 FetchDataframeStep(integration=integration_to_send, query=query, params=fetch_params)
93 )
94 return last_step
95 elif self.is_timeseries(query):
96 return PlanJoinTSPredictorQuery(self.planner).plan(query, integration)
97 else:
98 return PlanJoinTablesQuery(self.planner).plan(query)
101class PlanJoinTablesQuery:
102 def __init__(self, planner):
103 self.planner = planner
105 # index to lookup tables
106 self.tables_idx = None
107 self.tables = []
108 self.tables_fetch_step = {}
110 self.step_stack = None
111 self.query_context = {}
113 self.partition = None
115 def plan(self, query):
116 self.tables_idx = {}
117 join_step = self.plan_join_tables(query)
119 if (
120 query.group_by is not None
121 or query.order_by is not None
122 or query.having is not None
123 or query.distinct is True
124 or query.where is not None
125 or query.limit is not None
126 or query.offset is not None
127 or len(query.targets) != 1
128 or not isinstance(query.targets[0], Star)
129 ):
130 query2 = copy.deepcopy(query)
131 query2.from_table = None
132 query2.using = None
133 query2.cte = None
134 sup_select = QueryStep(query2, from_table=join_step.result, strict_where=False)
135 self.planner.plan.add_step(sup_select)
136 return sup_select
137 return join_step
139 def resolve_table(self, table):
140 # gets integration for table and name to access to it
141 table = copy.deepcopy(table)
142 # get possible table aliases
143 aliases = []
144 if table.alias is not None:
145 # to lowercase
146 parts = tuple(map(str.lower, table.alias.parts))
147 aliases.append(parts)
148 else:
149 for i in range(0, len(table.parts)):
150 parts = table.parts[i:]
151 parts = tuple(map(str.lower, parts))
152 aliases.append(parts)
154 # try to use default namespace
155 integration = self.planner.default_namespace
156 if len(table.parts) > 0: 156 ↛ 169line 156 didn't jump to line 169 because the condition on line 156 was always true
157 # if not quoted check in lower case
158 part = table.parts[0]
159 if part not in self.planner.databases and not table.is_quoted[0]:
160 part = part.lower()
162 if part in self.planner.databases:
163 integration = part
164 table.parts.pop(0)
165 table.is_quoted.pop(0)
166 else:
167 integration = self.planner.default_namespace
169 if integration is None and not hasattr(table, "sub_select"): 169 ↛ 170line 169 didn't jump to line 170 because the condition on line 169 was never true
170 raise PlanningException(f"Database not found for: {table}")
172 sub_select = getattr(table, "sub_select", None)
174 return TableInfo(integration, table, aliases, conditions=[], sub_select=sub_select)
176 def get_table_for_column(self, column: Identifier):
177 if not isinstance(column, Identifier): 177 ↛ 178line 177 didn't jump to line 178 because the condition on line 177 was never true
178 return
179 # to lowercase
180 parts = tuple(map(str.lower, column.parts[:-1]))
181 if parts in self.tables_idx:
182 return self.tables_idx[parts]
184 def get_join_sequence(self, node, condition=None):
185 sequence = []
186 if isinstance(node, Identifier):
187 # resolve identifier
189 table_info = self.resolve_table(node)
190 for alias in table_info.aliases:
191 self.tables_idx[alias] = table_info
193 table_info.index = len(self.tables)
194 self.tables.append(table_info)
196 table_info.predictor_info = self.planner.get_predictor(node)
198 if condition is not None:
199 table_info.join_condition = condition
200 sequence.append(table_info)
202 elif isinstance(node, Join): 202 ↛ 221line 202 didn't jump to line 221 because the condition on line 202 was always true
203 # create sequence: 1)table1, 2)table2, 3)join 1 2, 4)table 3, 5)join 3 4
205 # put all tables before
206 sequence2 = self.get_join_sequence(node.left)
207 for item in sequence2:
208 sequence.append(item)
210 sequence2 = self.get_join_sequence(node.right, condition=node.condition)
211 if len(sequence2) != 1: 211 ↛ 212line 211 didn't jump to line 212 because the condition on line 211 was never true
212 raise PlanningException("Unexpected join nesting behavior")
214 # put next table
215 sequence.append(sequence2[0])
217 # put join
218 sequence.append(node)
220 else:
221 raise NotImplementedError()
222 return sequence
224 def check_node_condition(self, node):
225 col_idx = 0
226 if len(node.args) == 2:
227 if not isinstance(node.args[col_idx], Identifier):
228 # try to use second arg, could be: 'x'=col
229 col_idx = 1
231 # check the case col <condition> value, col between value and value
232 for i, arg in enumerate(node.args):
233 if i == col_idx:
234 if not isinstance(arg, Identifier):
235 return
236 else:
237 if not self.can_be_table_filter(arg):
238 return
240 # checked, find table and store condition
241 node2 = copy.deepcopy(node)
243 arg1 = node2.args[col_idx]
245 if len(arg1.parts) < 2:
246 return
248 table_info = self.get_table_for_column(arg1)
249 if table_info is None: 249 ↛ 250line 249 didn't jump to line 250 because the condition on line 249 was never true
250 raise PlanningException(f"Table not found for identifier: {arg1.to_string()}")
252 # keep only column name
253 arg1.parts = [arg1.parts[-1]]
255 node2._orig_node = node
256 table_info.conditions.append(node2)
258 def can_be_table_filter(self, node):
259 """
260 Check if node can be used as a filter.
261 It can contain only: Constant, Parameter, Tuple (for IN clauses), Function (with Last)
262 """
263 if isinstance(node, (Constant, Parameter)):
264 return True
265 if isinstance(node, Tuple):
266 return all(isinstance(item, Constant) for item in node.items)
267 if isinstance(node, Function): 267 ↛ 269line 267 didn't jump to line 269 because the condition on line 267 was never true
268 # `Last` must be in args
269 if not any(isinstance(arg, Last) for arg in node.args):
270 return False
271 return all([self.can_be_table_filter(arg) for arg in node.args])
273 def check_query_conditions(self, query):
274 # get conditions for tables
275 binary_ops = []
277 def _check_node_condition(node, **kwargs):
278 if isinstance(node, BetweenOperation):
279 self.check_node_condition(node)
281 if isinstance(node, BinaryOperation):
282 binary_ops.append(node.op)
284 self.check_node_condition(node)
286 query_traversal(query.where, _check_node_condition)
288 self.query_context["binary_ops"] = binary_ops
290 def check_use_limit(self, query_in, join_sequence):
291 # if only models (predictors), not for regular table joins
292 use_limit = False
293 optimize_inner_join = False
294 if query_in.having is None and query_in.group_by is None and query_in.limit is not None:
295 use_limit = True
297 # Check what we're joining
298 has_predictor = False
300 for item in join_sequence:
301 if isinstance(item, TableInfo):
302 if item.predictor_info is not None:
303 has_predictor = True
304 elif isinstance(item, Join) and not has_predictor:
305 # LEFT JOIN preserves left table row count - LIMIT pushdown is safe
306 join_type = str(item.join_type).upper() if item.join_type else ""
307 if join_type in ("LEFT JOIN", "LEFT OUTER JOIN"):
308 continue
310 if query_in.offset is None:
311 optimize_inner_join = True
312 continue
313 use_limit = False
315 self.query_context["use_limit"] = use_limit
316 self.query_context["optimize_inner_join"] = optimize_inner_join
318 def plan_join_tables(self, query_in):
319 # plan all nested selects in 'where'
320 find_selects = self.planner.get_nested_selects_plan_fnc(self.planner.default_namespace, force=True)
321 query_in.targets = query_traversal(query_in.targets, find_selects)
322 query_traversal(query_in.where, find_selects)
324 query = copy.deepcopy(query_in)
326 # replace sub selects, with identifiers with links to original selects
327 def replace_subselects(node, **args):
328 if isinstance(node, Select) or isinstance(node, NativeQuery) or isinstance(node, ast.Data):
329 name = f"t_{id(node)}"
330 node2 = Identifier(name, alias=node.alias)
332 # save in attribute
333 if isinstance(node, NativeQuery) or isinstance(node, ast.Data):
334 # wrap to select
335 node = Select(targets=[Star()], from_table=node)
336 node2.sub_select = node
337 return node2
339 query_traversal(query.from_table, replace_subselects)
341 # get all join tables, form join sequence
342 join_sequence = self.get_join_sequence(query.from_table)
343 self.join_sequence = join_sequence
345 # find tables for identifiers used in query
346 def _check_identifiers(node, is_table, **kwargs):
347 if not is_table and isinstance(node, Identifier):
348 if len(node.parts) > 1:
349 table_info = self.get_table_for_column(node)
350 if table_info is None:
351 raise PlanningException(f"Table not found for identifier: {node.to_string()}")
353 # # replace identifies name
354 col_parts = list(table_info.aliases[-1])
355 col_parts.append(node.parts[-1])
356 node.parts = col_parts
358 query_traversal(query, _check_identifiers)
360 self.check_query_conditions(query)
362 # workaround for 'model join table': swap tables:
363 if len(join_sequence) == 3 and join_sequence[0].predictor_info is not None: 363 ↛ 364line 363 didn't jump to line 364 because the condition on line 363 was never true
364 join_sequence = [join_sequence[1], join_sequence[0], join_sequence[2]]
366 self.check_use_limit(query_in, join_sequence)
368 # create plan
369 # TODO add optimization: one integration without predictor
371 self.step_stack = []
372 for item in join_sequence:
373 if isinstance(item, TableInfo):
374 if item.sub_select is not None:
375 self.process_subselect(item, query_in)
376 elif item.predictor_info is not None:
377 self.process_predictor(item, query_in)
378 else:
379 # is table
380 self.process_table(item, query_in)
382 elif isinstance(item, Join): 382 ↛ 372line 382 didn't jump to line 372 because the condition on line 382 was always true
383 step_right = self.step_stack.pop()
384 step_left = self.step_stack.pop()
386 new_join = copy.deepcopy(item)
388 # TODO
389 new_join.left = Identifier("tab1")
390 new_join.right = Identifier("tab2")
391 new_join.implicit = False
393 step = self.add_plan_step(JoinStep(left=step_left.result, right=step_right.result, query=new_join))
395 self.step_stack.append(step)
397 query_in.where = query.where
399 if self.query_context["optimize_inner_join"]:
400 self.planner.plan.steps = self.optimize_inner_join(self.planner.plan.steps)
402 self.close_partition()
403 return self.planner.plan.steps[-1]
405 def optimize_inner_join(self, steps_in):
406 steps_out = []
408 partition_step = None
409 partition_used = False
411 for step in steps_in:
412 if partition_step is None:
413 if isinstance(step, FetchDataframeStep) and not partition_used and step.query.limit is not None: 413 ↛ 435line 413 didn't jump to line 435 because the condition on line 413 was always true
414 limit = step.query.limit.value
415 step.query.limit = None
416 partition_used = True
418 partition_step = FetchDataframeStepPartition(
419 step_num=step.step_num,
420 integration=step.integration,
421 query=step.query,
422 raw_query=step.raw_query,
423 params=step.params,
424 condition={"limit": limit},
425 )
426 steps_out.append(partition_step)
427 continue
429 elif isinstance(step, (JoinStep, FetchDataframeStep, SubSelectStep)): 429 ↛ 433line 429 didn't jump to line 433 because the condition on line 429 was always true
430 partition_step.steps.append(step)
431 continue
432 else:
433 partition_step = None
435 steps_out.append(step)
437 return steps_out
439 def process_subselect(self, item, query_in):
440 # is sub select
441 item.sub_select.alias = None
442 item.sub_select.parentheses = False
443 step = self.planner.plan_select(item.sub_select)
445 where = filters_to_bin_op(item.conditions)
447 # Column pruning for subselects:
448 # - If subselect has pure SELECT *, we can prune to only needed columns
449 # - If subselect has explicit columns (SELECT a, b, c), pass through all (don't prune)
450 # This preserves column aliases and prevents breaking explicit projections
451 targets = [Star()]
452 needed_columns = self.get_fetch_columns_for_table(item, query_in)
453 if needed_columns:
454 targets = needed_columns
456 # apply table alias
457 query2 = Select(targets=targets, where=where)
458 if item.table.alias is None: 458 ↛ 459line 458 didn't jump to line 459 because the condition on line 458 was never true
459 raise PlanningException(f"Subselect in join have to be aliased: {item.sub_select.to_string()}")
460 table_name = item.table.alias.parts[-1]
462 add_absent_cols = False
463 if hasattr(item.sub_select, "from_table") and isinstance(item.sub_select.from_table, ast.Data):
464 add_absent_cols = True
466 step2 = SubSelectStep(query2, step.result, table_name=table_name, add_absent_cols=add_absent_cols)
467 step2 = self.add_plan_step(step2)
468 self.step_stack.append(step2)
470 def _collect_from_order_by(self, query_in, alias_map, add_column_callback):
471 """Helper to collect columns from ORDER BY clause, resolving aliases and ordinals."""
472 for order_col in query_in.order_by:
473 field = order_col.field
475 # Handle ORDER BY ordinal (e.g., ORDER BY 1)
476 if isinstance(field, Constant) and isinstance(field.value, int):
477 ordinal = field.value
478 if 1 <= ordinal <= len(query_in.targets):
479 target_expr = query_in.targets[ordinal - 1]
480 query_traversal(target_expr, add_column_callback)
481 continue
483 # Handle ORDER BY alias (e.g., ORDER BY alias_name)
484 if isinstance(field, Identifier) and len(field.parts) == 1:
485 alias_name = field.parts[0].lower()
486 if alias_name in alias_map:
487 query_traversal(alias_map[alias_name], add_column_callback)
488 continue
490 # Regular column reference
491 query_traversal(field, add_column_callback)
493 def _join_has_predictor(self, join_sequence) -> bool:
494 """Check if the join sequence contains any predictor."""
495 for item in join_sequence:
496 if isinstance(item, TableInfo) and item.predictor_info is not None:
497 return True
498 return False
500 def _can_prune_columns(self, table_info) -> bool:
501 """
502 Determine if column pruning can be applied to this table.
504 Returns:
505 True if column pruning can be applied
506 False if we should skip pruning (use SELECT *)
507 """
509 # If this table is part of a join with a predictor: cannot prune
510 # Predictors may need all columns from joined tables as input features
511 if hasattr(self, "join_sequence") and self._join_has_predictor(self.join_sequence):
512 return False
514 # For subselects: can only prune if they have pure SELECT * (no other columns)
515 sub = table_info.sub_select
516 if sub is not None and isinstance(sub, Select):
517 targets = getattr(sub, "targets", None) or []
518 # Can prune only if subselect has PURE SELECT * (one target that is Star)
519 # Cannot prune if:
520 # - Mixed: SELECT *, col1 (has Star but also other columns)
521 if len(targets) == 1 and isinstance(targets[0], Star):
522 return True # Pure SELECT * - can prune
523 return False
525 # For project tables (KB tables, views, etc.): cannot prune
526 # Project tables need SELECT * for proper column mapping
527 if table_info.integration and table_info.integration in self.planner.projects:
528 return False
530 # Regular integration tables: can prune
531 return True
533 def get_fetch_columns_for_table(self, table_info, query_in):
534 """
535 Collect all columns needed from a specific table for column pruning optimization.
537 Note: Caller should check _can_prune_columns() before calling this method.
539 Returns a list of column Identifiers or None if we should fetch all columns.
540 """
541 if not self._can_prune_columns(table_info):
542 return None
544 columns = {}
545 has_qualified_star_for_table = False
547 alias_map = {}
548 if query_in.targets: 548 ↛ 553line 548 didn't jump to line 553 because the condition on line 548 was always true
549 for target in query_in.targets:
550 if isinstance(target, Identifier) and target.alias: 550 ↛ 551line 550 didn't jump to line 551 because the condition on line 550 was never true
551 alias_map[target.alias.parts[-1].lower()] = target
553 def add_column(node, **kwargs):
554 if isinstance(node, Identifier):
555 col_table = self.get_table_for_column(node)
556 if not col_table or col_table.index != table_info.index:
557 return
559 # Check for qualified star: t1.* or alias.*
560 col_name = node.parts[-1]
561 is_quoted = node.is_quoted[-1]
563 if isinstance(col_name, Star):
564 nonlocal has_qualified_star_for_table
565 has_qualified_star_for_table = True
566 return
568 # Store - if already exists, keep it quoted if either reference was quoted
569 columns[col_name] = columns.get(col_name) or is_quoted
571 # Check for bare Star() in targets
572 if query_in.targets: 572 ↛ 577line 572 didn't jump to line 577 because the condition on line 572 was always true
573 for target in query_in.targets:
574 if isinstance(target, Star):
575 return None
577 query_traversal(query_in, add_column)
579 # If qualified star found for this table, fetch all columns
580 if has_qualified_star_for_table:
581 return None
583 # If we found no columns, fetch all
584 if not columns: 584 ↛ 585line 584 didn't jump to line 585 because the condition on line 584 was never true
585 return None
587 # Convert column names to Identifier objects, we need to preserve quoting
588 result = []
589 for col, is_quoted in sorted(columns.items()):
590 result.append(Identifier(parts=[col], is_quoted=[is_quoted]))
591 return result
593 def process_table(self, item, query_in):
594 table = copy.deepcopy(item.table)
595 table.parts.insert(0, item.integration)
596 table.is_quoted.insert(0, False)
598 needed_columns = self.get_fetch_columns_for_table(item, query_in)
599 targets = needed_columns if needed_columns else [Star()]
601 query2 = Select(from_table=table, targets=targets)
602 conditions = item.conditions
603 if "or" in self.query_context["binary_ops"]: 603 ↛ 604line 603 didn't jump to line 604 because the condition on line 603 was never true
604 conditions = []
606 if self.query_context.get("had_limit"):
607 conditions += self.get_filters_from_join_conditions(item)
609 if self.query_context["use_limit"]:
610 order_by = None
611 if query_in.order_by is not None:
612 order_by = []
613 # all order column are from this table
614 for col in query_in.order_by:
615 table_info = self.get_table_for_column(col.field)
616 if table_info is None or table_info.table != item.table: 616 ↛ 617line 616 didn't jump to line 617 because the condition on line 616 was never true
617 order_by = False
618 break
619 col = copy.deepcopy(col)
620 col.field.parts = [col.field.parts[-1]]
621 col.field.is_quoted = [col.field.is_quoted[-1]]
622 order_by.append(col)
624 if order_by is not False: 624 ↛ 633line 624 didn't jump to line 633 because the condition on line 624 was always true
625 # copy limit from upper query
626 query2.limit = query_in.limit
627 # move offset from upper query
628 query2.offset = query_in.offset
629 query_in.offset = None
630 # copy order
631 query2.order_by = order_by
633 self.query_context["use_limit"] = False
634 self.query_context["had_limit"] = True
635 for cond in conditions:
636 if query2.where is not None:
637 query2.where = BinaryOperation("and", args=[query2.where, cond])
638 else:
639 query2.where = cond
641 step = self.planner.get_integration_select_step(query2, params=query_in.using)
642 self.tables_fetch_step[item.index] = step
644 self.add_plan_step(step)
645 self.step_stack.append(step)
647 def join_condition_to_columns_map(self, model_table):
648 columns_map = {}
650 def _check_conditions(node, **kwargs):
651 if not isinstance(node, BinaryOperation):
652 return
654 arg1, arg2 = node.args
655 if not (isinstance(arg1, Identifier) and isinstance(arg2, Identifier)):
656 return
658 table1 = self.get_table_for_column(arg1)
659 table2 = self.get_table_for_column(arg2)
661 if table1 is model_table:
662 # model is on the left
663 columns_map[arg1.parts[-1]] = arg2
664 elif table2 is model_table: 664 ↛ 669line 664 didn't jump to line 669 because the condition on line 664 was always true
665 # model is on the right
666 columns_map[arg2.parts[-1]] = arg1
667 else:
668 # not found, skip
669 return
671 # exclude condition
672 node.args = [Constant(0), Constant(0)]
674 query_traversal(model_table.join_condition, _check_conditions)
675 return columns_map
677 def get_filters_from_join_conditions(self, fetch_table):
678 """
679 Extract filters from join conditions for filter pushdown optimization.
681 Note: This function is currently disabled (not called) to avoid:
682 - Creating massive IN clauses that exceed database query size limits
683 - Making arbitrary assumptions about data distribution
685 For cross-database joins with large tables, users should:
686 - Add explicit WHERE clauses to filter data at the source
687 - Use indexed/partitioned tables in their databases
688 - Consider materializing intermediate results
689 """
690 binary_ops = set()
691 conditions = []
692 data_conditions = []
694 def _check_conditions(node, **kwargs):
695 if not isinstance(node, BinaryOperation):
696 return
698 if node.op != "=":
699 binary_ops.add(node.op.lower())
700 return
702 arg1, arg2 = node.args
703 table1 = self.get_table_for_column(arg1) if isinstance(arg1, Identifier) else None
704 table2 = self.get_table_for_column(arg2) if isinstance(arg2, Identifier) else None
706 if table1 is not fetch_table: 706 ↛ 713line 706 didn't jump to line 713 because the condition on line 706 was always true
707 if table2 is not fetch_table: 707 ↛ 708line 707 didn't jump to line 708 because the condition on line 707 was never true
708 return
709 # set our table first
710 table1, table2 = table2, table1
711 arg1, arg2 = arg2, arg1
713 if isinstance(arg2, Constant): 713 ↛ 714line 713 didn't jump to line 714 because the condition on line 713 was never true
714 conditions.append(node)
715 elif table2 is not None: 715 ↛ exitline 715 didn't return from function '_check_conditions' because the condition on line 715 was always true
716 data_conditions.append([arg1, arg2])
718 query_traversal(fetch_table.join_condition, _check_conditions)
720 binary_ops.discard("and")
721 if len(binary_ops) > 0:
722 # other operations exists, skip
723 return []
725 for arg1, arg2 in data_conditions:
726 # is fetched?
727 table2 = self.get_table_for_column(arg2)
728 fetch_step = self.tables_fetch_step.get(table2.index)
730 if fetch_step is None: 730 ↛ 731line 730 didn't jump to line 731 because the condition on line 730 was never true
731 continue
733 # extract distinct values
734 # remove aliases
735 arg1 = Identifier(parts=[arg1.parts[-1]])
736 arg2 = Identifier(parts=[arg2.parts[-1]])
738 query2 = Select(targets=[arg2], distinct=True)
739 subselect_step = SubSelectStep(query2, fetch_step.result)
740 subselect_step = self.add_plan_step(subselect_step)
742 conditions.append(BinaryOperation(op="in", args=[arg1, Parameter(subselect_step.result)]))
744 return conditions
746 def process_predictor(self, item, query_in):
747 if len(self.step_stack) == 0: 747 ↛ 748line 747 didn't jump to line 748 because the condition on line 747 was never true
748 raise NotImplementedError("Predictor can't be first element of join syntax")
749 if item.predictor_info.get("timeseries"): 749 ↛ 750line 749 didn't jump to line 750 because the condition on line 749 was never true
750 raise NotImplementedError("TS predictor is not supported here yet")
751 data_step = self.step_stack[-1]
752 row_dict = None
754 predict_target = item.predictor_info.get("to_predict")
755 if isinstance(predict_target, list) and len(predict_target) > 0:
756 predict_target = predict_target[0]
757 if predict_target is not None:
758 predict_target = predict_target.lower()
760 columns_map = None
761 if item.join_condition:
762 columns_map = self.join_condition_to_columns_map(item)
764 if item.conditions:
765 row_dict = {}
766 for i, el in enumerate(item.conditions):
767 if isinstance(el.args[0], Identifier) and el.op == "=": 767 ↛ 766line 767 didn't jump to line 766 because the condition on line 767 was always true
768 col_name = el.args[0].parts[-1]
769 if col_name.lower() == predict_target:
770 # don't add predict target to parameters
771 continue
773 if isinstance(el.args[1], (Constant, Parameter)): 773 ↛ 777line 773 didn't jump to line 777 because the condition on line 773 was always true
774 row_dict[el.args[0].parts[-1]] = el.args[1].value
776 # exclude condition
777 el._orig_node.args = [Constant(0), Constant(0)]
779 # params for model
780 model_params = None
781 partition_size = None
782 if query_in.using is not None:
783 model_params = {}
784 for param, value in query_in.using.items():
785 if "." in param:
786 alias = param.split(".")[0]
787 if (alias,) in item.aliases:
788 new_param = ".".join(param.split(".")[1:])
789 model_params[new_param.lower()] = value
790 else:
791 model_params[param.lower()] = value
793 partition_size = model_params.pop("partition_size", None)
795 predictor_step = ApplyPredictorStep(
796 namespace=item.integration,
797 dataframe=data_step.result,
798 predictor=item.table,
799 params=model_params,
800 row_dict=row_dict,
801 columns_map=columns_map,
802 )
804 self.step_stack.append(self.add_plan_step(predictor_step, partition_size=partition_size))
806 def add_plan_step(self, step, partition_size=None):
807 """
808 Adds step to plan
810 If partition_size is defined: create partition
811 If partition is active
812 If step can be partitioned:
813 Add step to partition not in plan
814 Otherwise:
815 Add partition to plan
816 Add step to plan
817 """
818 if self.partition:
819 if isinstance(step, (JoinStep, ApplyPredictorStep)): 819 ↛ 838line 819 didn't jump to line 838 because the condition on line 819 was always true
820 # add to partition
822 self.add_step_to_partition(step)
823 return step
825 elif partition_size is not None:
826 # create partition
828 self.partition = MapReduceStep(values=step.dataframe, reduce="union", step=[], partition=partition_size)
829 self.planner.plan.add_step(self.partition)
831 self.add_step_to_partition(step)
832 return step
834 else:
835 # next step can't be partitioned.
836 self.close_partition()
838 return self.planner.plan.add_step(step)
840 def add_step_to_partition(self, step):
841 step.step_num = f"{self.partition.step_num}_{len(self.partition.step)}"
842 self.partition.step.append(step)
844 def close_partition(self):
845 # return
846 # if partitions is exist - clear it and replace last stack item with it
848 if self.partition:
849 if len(self.step_stack) > 0: 849 ↛ 852line 849 didn't jump to line 852 because the condition on line 849 was always true
850 self.step_stack[-1] = self.partition
852 self.partition = None