Coverage for mindsdb / api / executor / planner / plan_join.py: 89%

493 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import copy 

2from dataclasses import dataclass, field 

3 

4from mindsdb_sql_parser import ast 

5from mindsdb_sql_parser.ast import ( 

6 Select, 

7 Identifier, 

8 BetweenOperation, 

9 Join, 

10 Star, 

11 BinaryOperation, 

12 Constant, 

13 NativeQuery, 

14 Parameter, 

15 Function, 

16 Last, 

17 Tuple, 

18) 

19 

20from mindsdb.integrations.utilities.query_traversal import query_traversal 

21 

22from mindsdb.api.executor.planner.exceptions import PlanningException 

23from mindsdb.api.executor.planner.steps import ( 

24 FetchDataframeStep, 

25 FetchDataframeStepPartition, 

26 JoinStep, 

27 ApplyPredictorStep, 

28 SubSelectStep, 

29 QueryStep, 

30 MapReduceStep, 

31) 

32from mindsdb.api.executor.planner.utils import filters_to_bin_op 

33from mindsdb.api.executor.planner.plan_join_ts import PlanJoinTSPredictorQuery 

34 

35 

36@dataclass 

37class TableInfo: 

38 integration: str 

39 table: Identifier 

40 aliases: list[tuple[str, ...]] = field(default_factory=list) 

41 conditions: list = None 

42 sub_select: ast.ASTNode = None 

43 predictor_info: dict = None 

44 join_condition = None 

45 index: int = None 

46 

47 

48class PlanJoin: 

49 def __init__(self, planner): 

50 self.planner = planner 

51 

52 def is_timeseries(self, query): 

53 join = query.from_table 

54 l_predictor = self.planner.get_predictor(join.left) if isinstance(join.left, Identifier) else None 

55 r_predictor = self.planner.get_predictor(join.right) if isinstance(join.right, Identifier) else None 

56 if l_predictor and l_predictor.get("timeseries"): 

57 return True 

58 if r_predictor and r_predictor.get("timeseries"): 

59 return True 

60 

61 def check_single_integration(self, query): 

62 query_info = self.planner.get_query_info(query) 

63 # can we send all query to integration? 

64 # one integration and not mindsdb objects in query 

65 if ( 

66 len(query_info["mdb_entities"]) == 0 

67 and len(query_info["integrations"]) == 1 

68 and "files" not in query_info["integrations"] 

69 and "views" not in query_info["integrations"] 

70 ): 

71 int_name = list(query_info["integrations"])[0] 

72 # if is sql database 

73 class_type = self.planner.integrations.get(int_name, {}).get("class_type") 

74 if class_type != "api": 74 ↛ 77line 74 didn't jump to line 77 because the condition on line 74 was always true

75 # send to this integration 

76 return int_name 

77 return None 

78 

79 def plan(self, query, integration=None): 

80 # FIXME: Tableau workaround, INFORMATION_SCHEMA with Where 

81 # if isinstance(join.right, Identifier) \ 

82 # and self.resolve_database_table(join.right)[0] == 'INFORMATION_SCHEMA': 

83 # pass 

84 

85 # send join to integration as is? 

86 integration_to_send = self.check_single_integration(query) 

87 if integration_to_send: 

88 self.planner.prepare_integration_select(integration_to_send, query) 

89 

90 fetch_params = self.planner.get_fetch_params(query.using) 

91 last_step = self.planner.plan.add_step( 

92 FetchDataframeStep(integration=integration_to_send, query=query, params=fetch_params) 

93 ) 

94 return last_step 

95 elif self.is_timeseries(query): 

96 return PlanJoinTSPredictorQuery(self.planner).plan(query, integration) 

97 else: 

98 return PlanJoinTablesQuery(self.planner).plan(query) 

99 

100 

101class PlanJoinTablesQuery: 

102 def __init__(self, planner): 

103 self.planner = planner 

104 

105 # index to lookup tables 

106 self.tables_idx = None 

107 self.tables = [] 

108 self.tables_fetch_step = {} 

109 

110 self.step_stack = None 

111 self.query_context = {} 

112 

113 self.partition = None 

114 

115 def plan(self, query): 

116 self.tables_idx = {} 

117 join_step = self.plan_join_tables(query) 

118 

119 if ( 

120 query.group_by is not None 

121 or query.order_by is not None 

122 or query.having is not None 

123 or query.distinct is True 

124 or query.where is not None 

125 or query.limit is not None 

126 or query.offset is not None 

127 or len(query.targets) != 1 

128 or not isinstance(query.targets[0], Star) 

129 ): 

130 query2 = copy.deepcopy(query) 

131 query2.from_table = None 

132 query2.using = None 

133 query2.cte = None 

134 sup_select = QueryStep(query2, from_table=join_step.result, strict_where=False) 

135 self.planner.plan.add_step(sup_select) 

136 return sup_select 

137 return join_step 

138 

139 def resolve_table(self, table): 

140 # gets integration for table and name to access to it 

141 table = copy.deepcopy(table) 

142 # get possible table aliases 

143 aliases = [] 

144 if table.alias is not None: 

145 # to lowercase 

146 parts = tuple(map(str.lower, table.alias.parts)) 

147 aliases.append(parts) 

148 else: 

149 for i in range(0, len(table.parts)): 

150 parts = table.parts[i:] 

151 parts = tuple(map(str.lower, parts)) 

152 aliases.append(parts) 

153 

154 # try to use default namespace 

155 integration = self.planner.default_namespace 

156 if len(table.parts) > 0: 156 ↛ 169line 156 didn't jump to line 169 because the condition on line 156 was always true

157 # if not quoted check in lower case 

158 part = table.parts[0] 

159 if part not in self.planner.databases and not table.is_quoted[0]: 

160 part = part.lower() 

161 

162 if part in self.planner.databases: 

163 integration = part 

164 table.parts.pop(0) 

165 table.is_quoted.pop(0) 

166 else: 

167 integration = self.planner.default_namespace 

168 

169 if integration is None and not hasattr(table, "sub_select"): 169 ↛ 170line 169 didn't jump to line 170 because the condition on line 169 was never true

170 raise PlanningException(f"Database not found for: {table}") 

171 

172 sub_select = getattr(table, "sub_select", None) 

173 

174 return TableInfo(integration, table, aliases, conditions=[], sub_select=sub_select) 

175 

176 def get_table_for_column(self, column: Identifier): 

177 if not isinstance(column, Identifier): 177 ↛ 178line 177 didn't jump to line 178 because the condition on line 177 was never true

178 return 

179 # to lowercase 

180 parts = tuple(map(str.lower, column.parts[:-1])) 

181 if parts in self.tables_idx: 

182 return self.tables_idx[parts] 

183 

184 def get_join_sequence(self, node, condition=None): 

185 sequence = [] 

186 if isinstance(node, Identifier): 

187 # resolve identifier 

188 

189 table_info = self.resolve_table(node) 

190 for alias in table_info.aliases: 

191 self.tables_idx[alias] = table_info 

192 

193 table_info.index = len(self.tables) 

194 self.tables.append(table_info) 

195 

196 table_info.predictor_info = self.planner.get_predictor(node) 

197 

198 if condition is not None: 

199 table_info.join_condition = condition 

200 sequence.append(table_info) 

201 

202 elif isinstance(node, Join): 202 ↛ 221line 202 didn't jump to line 221 because the condition on line 202 was always true

203 # create sequence: 1)table1, 2)table2, 3)join 1 2, 4)table 3, 5)join 3 4 

204 

205 # put all tables before 

206 sequence2 = self.get_join_sequence(node.left) 

207 for item in sequence2: 

208 sequence.append(item) 

209 

210 sequence2 = self.get_join_sequence(node.right, condition=node.condition) 

211 if len(sequence2) != 1: 211 ↛ 212line 211 didn't jump to line 212 because the condition on line 211 was never true

212 raise PlanningException("Unexpected join nesting behavior") 

213 

214 # put next table 

215 sequence.append(sequence2[0]) 

216 

217 # put join 

218 sequence.append(node) 

219 

220 else: 

221 raise NotImplementedError() 

222 return sequence 

223 

224 def check_node_condition(self, node): 

225 col_idx = 0 

226 if len(node.args) == 2: 

227 if not isinstance(node.args[col_idx], Identifier): 

228 # try to use second arg, could be: 'x'=col 

229 col_idx = 1 

230 

231 # check the case col <condition> value, col between value and value 

232 for i, arg in enumerate(node.args): 

233 if i == col_idx: 

234 if not isinstance(arg, Identifier): 

235 return 

236 else: 

237 if not self.can_be_table_filter(arg): 

238 return 

239 

240 # checked, find table and store condition 

241 node2 = copy.deepcopy(node) 

242 

243 arg1 = node2.args[col_idx] 

244 

245 if len(arg1.parts) < 2: 

246 return 

247 

248 table_info = self.get_table_for_column(arg1) 

249 if table_info is None: 249 ↛ 250line 249 didn't jump to line 250 because the condition on line 249 was never true

250 raise PlanningException(f"Table not found for identifier: {arg1.to_string()}") 

251 

252 # keep only column name 

253 arg1.parts = [arg1.parts[-1]] 

254 

255 node2._orig_node = node 

256 table_info.conditions.append(node2) 

257 

258 def can_be_table_filter(self, node): 

259 """ 

260 Check if node can be used as a filter. 

261 It can contain only: Constant, Parameter, Tuple (for IN clauses), Function (with Last) 

262 """ 

263 if isinstance(node, (Constant, Parameter)): 

264 return True 

265 if isinstance(node, Tuple): 

266 return all(isinstance(item, Constant) for item in node.items) 

267 if isinstance(node, Function): 267 ↛ 269line 267 didn't jump to line 269 because the condition on line 267 was never true

268 # `Last` must be in args 

269 if not any(isinstance(arg, Last) for arg in node.args): 

270 return False 

271 return all([self.can_be_table_filter(arg) for arg in node.args]) 

272 

273 def check_query_conditions(self, query): 

274 # get conditions for tables 

275 binary_ops = [] 

276 

277 def _check_node_condition(node, **kwargs): 

278 if isinstance(node, BetweenOperation): 

279 self.check_node_condition(node) 

280 

281 if isinstance(node, BinaryOperation): 

282 binary_ops.append(node.op) 

283 

284 self.check_node_condition(node) 

285 

286 query_traversal(query.where, _check_node_condition) 

287 

288 self.query_context["binary_ops"] = binary_ops 

289 

290 def check_use_limit(self, query_in, join_sequence): 

291 # if only models (predictors), not for regular table joins 

292 use_limit = False 

293 optimize_inner_join = False 

294 if query_in.having is None and query_in.group_by is None and query_in.limit is not None: 

295 use_limit = True 

296 

297 # Check what we're joining 

298 has_predictor = False 

299 

300 for item in join_sequence: 

301 if isinstance(item, TableInfo): 

302 if item.predictor_info is not None: 

303 has_predictor = True 

304 elif isinstance(item, Join) and not has_predictor: 

305 # LEFT JOIN preserves left table row count - LIMIT pushdown is safe 

306 join_type = str(item.join_type).upper() if item.join_type else "" 

307 if join_type in ("LEFT JOIN", "LEFT OUTER JOIN"): 

308 continue 

309 

310 if query_in.offset is None: 

311 optimize_inner_join = True 

312 continue 

313 use_limit = False 

314 

315 self.query_context["use_limit"] = use_limit 

316 self.query_context["optimize_inner_join"] = optimize_inner_join 

317 

318 def plan_join_tables(self, query_in): 

319 # plan all nested selects in 'where' 

320 find_selects = self.planner.get_nested_selects_plan_fnc(self.planner.default_namespace, force=True) 

321 query_in.targets = query_traversal(query_in.targets, find_selects) 

322 query_traversal(query_in.where, find_selects) 

323 

324 query = copy.deepcopy(query_in) 

325 

326 # replace sub selects, with identifiers with links to original selects 

327 def replace_subselects(node, **args): 

328 if isinstance(node, Select) or isinstance(node, NativeQuery) or isinstance(node, ast.Data): 

329 name = f"t_{id(node)}" 

330 node2 = Identifier(name, alias=node.alias) 

331 

332 # save in attribute 

333 if isinstance(node, NativeQuery) or isinstance(node, ast.Data): 

334 # wrap to select 

335 node = Select(targets=[Star()], from_table=node) 

336 node2.sub_select = node 

337 return node2 

338 

339 query_traversal(query.from_table, replace_subselects) 

340 

341 # get all join tables, form join sequence 

342 join_sequence = self.get_join_sequence(query.from_table) 

343 self.join_sequence = join_sequence 

344 

345 # find tables for identifiers used in query 

346 def _check_identifiers(node, is_table, **kwargs): 

347 if not is_table and isinstance(node, Identifier): 

348 if len(node.parts) > 1: 

349 table_info = self.get_table_for_column(node) 

350 if table_info is None: 

351 raise PlanningException(f"Table not found for identifier: {node.to_string()}") 

352 

353 # # replace identifies name 

354 col_parts = list(table_info.aliases[-1]) 

355 col_parts.append(node.parts[-1]) 

356 node.parts = col_parts 

357 

358 query_traversal(query, _check_identifiers) 

359 

360 self.check_query_conditions(query) 

361 

362 # workaround for 'model join table': swap tables: 

363 if len(join_sequence) == 3 and join_sequence[0].predictor_info is not None: 363 ↛ 364line 363 didn't jump to line 364 because the condition on line 363 was never true

364 join_sequence = [join_sequence[1], join_sequence[0], join_sequence[2]] 

365 

366 self.check_use_limit(query_in, join_sequence) 

367 

368 # create plan 

369 # TODO add optimization: one integration without predictor 

370 

371 self.step_stack = [] 

372 for item in join_sequence: 

373 if isinstance(item, TableInfo): 

374 if item.sub_select is not None: 

375 self.process_subselect(item, query_in) 

376 elif item.predictor_info is not None: 

377 self.process_predictor(item, query_in) 

378 else: 

379 # is table 

380 self.process_table(item, query_in) 

381 

382 elif isinstance(item, Join): 382 ↛ 372line 382 didn't jump to line 372 because the condition on line 382 was always true

383 step_right = self.step_stack.pop() 

384 step_left = self.step_stack.pop() 

385 

386 new_join = copy.deepcopy(item) 

387 

388 # TODO 

389 new_join.left = Identifier("tab1") 

390 new_join.right = Identifier("tab2") 

391 new_join.implicit = False 

392 

393 step = self.add_plan_step(JoinStep(left=step_left.result, right=step_right.result, query=new_join)) 

394 

395 self.step_stack.append(step) 

396 

397 query_in.where = query.where 

398 

399 if self.query_context["optimize_inner_join"]: 

400 self.planner.plan.steps = self.optimize_inner_join(self.planner.plan.steps) 

401 

402 self.close_partition() 

403 return self.planner.plan.steps[-1] 

404 

405 def optimize_inner_join(self, steps_in): 

406 steps_out = [] 

407 

408 partition_step = None 

409 partition_used = False 

410 

411 for step in steps_in: 

412 if partition_step is None: 

413 if isinstance(step, FetchDataframeStep) and not partition_used and step.query.limit is not None: 413 ↛ 435line 413 didn't jump to line 435 because the condition on line 413 was always true

414 limit = step.query.limit.value 

415 step.query.limit = None 

416 partition_used = True 

417 

418 partition_step = FetchDataframeStepPartition( 

419 step_num=step.step_num, 

420 integration=step.integration, 

421 query=step.query, 

422 raw_query=step.raw_query, 

423 params=step.params, 

424 condition={"limit": limit}, 

425 ) 

426 steps_out.append(partition_step) 

427 continue 

428 

429 elif isinstance(step, (JoinStep, FetchDataframeStep, SubSelectStep)): 429 ↛ 433line 429 didn't jump to line 433 because the condition on line 429 was always true

430 partition_step.steps.append(step) 

431 continue 

432 else: 

433 partition_step = None 

434 

435 steps_out.append(step) 

436 

437 return steps_out 

438 

439 def process_subselect(self, item, query_in): 

440 # is sub select 

441 item.sub_select.alias = None 

442 item.sub_select.parentheses = False 

443 step = self.planner.plan_select(item.sub_select) 

444 

445 where = filters_to_bin_op(item.conditions) 

446 

447 # Column pruning for subselects: 

448 # - If subselect has pure SELECT *, we can prune to only needed columns 

449 # - If subselect has explicit columns (SELECT a, b, c), pass through all (don't prune) 

450 # This preserves column aliases and prevents breaking explicit projections 

451 targets = [Star()] 

452 needed_columns = self.get_fetch_columns_for_table(item, query_in) 

453 if needed_columns: 

454 targets = needed_columns 

455 

456 # apply table alias 

457 query2 = Select(targets=targets, where=where) 

458 if item.table.alias is None: 458 ↛ 459line 458 didn't jump to line 459 because the condition on line 458 was never true

459 raise PlanningException(f"Subselect in join have to be aliased: {item.sub_select.to_string()}") 

460 table_name = item.table.alias.parts[-1] 

461 

462 add_absent_cols = False 

463 if hasattr(item.sub_select, "from_table") and isinstance(item.sub_select.from_table, ast.Data): 

464 add_absent_cols = True 

465 

466 step2 = SubSelectStep(query2, step.result, table_name=table_name, add_absent_cols=add_absent_cols) 

467 step2 = self.add_plan_step(step2) 

468 self.step_stack.append(step2) 

469 

470 def _collect_from_order_by(self, query_in, alias_map, add_column_callback): 

471 """Helper to collect columns from ORDER BY clause, resolving aliases and ordinals.""" 

472 for order_col in query_in.order_by: 

473 field = order_col.field 

474 

475 # Handle ORDER BY ordinal (e.g., ORDER BY 1) 

476 if isinstance(field, Constant) and isinstance(field.value, int): 

477 ordinal = field.value 

478 if 1 <= ordinal <= len(query_in.targets): 

479 target_expr = query_in.targets[ordinal - 1] 

480 query_traversal(target_expr, add_column_callback) 

481 continue 

482 

483 # Handle ORDER BY alias (e.g., ORDER BY alias_name) 

484 if isinstance(field, Identifier) and len(field.parts) == 1: 

485 alias_name = field.parts[0].lower() 

486 if alias_name in alias_map: 

487 query_traversal(alias_map[alias_name], add_column_callback) 

488 continue 

489 

490 # Regular column reference 

491 query_traversal(field, add_column_callback) 

492 

493 def _join_has_predictor(self, join_sequence) -> bool: 

494 """Check if the join sequence contains any predictor.""" 

495 for item in join_sequence: 

496 if isinstance(item, TableInfo) and item.predictor_info is not None: 

497 return True 

498 return False 

499 

500 def _can_prune_columns(self, table_info) -> bool: 

501 """ 

502 Determine if column pruning can be applied to this table. 

503 

504 Returns: 

505 True if column pruning can be applied 

506 False if we should skip pruning (use SELECT *) 

507 """ 

508 

509 # If this table is part of a join with a predictor: cannot prune 

510 # Predictors may need all columns from joined tables as input features 

511 if hasattr(self, "join_sequence") and self._join_has_predictor(self.join_sequence): 

512 return False 

513 

514 # For subselects: can only prune if they have pure SELECT * (no other columns) 

515 sub = table_info.sub_select 

516 if sub is not None and isinstance(sub, Select): 

517 targets = getattr(sub, "targets", None) or [] 

518 # Can prune only if subselect has PURE SELECT * (one target that is Star) 

519 # Cannot prune if: 

520 # - Mixed: SELECT *, col1 (has Star but also other columns) 

521 if len(targets) == 1 and isinstance(targets[0], Star): 

522 return True # Pure SELECT * - can prune 

523 return False 

524 

525 # For project tables (KB tables, views, etc.): cannot prune 

526 # Project tables need SELECT * for proper column mapping 

527 if table_info.integration and table_info.integration in self.planner.projects: 

528 return False 

529 

530 # Regular integration tables: can prune 

531 return True 

532 

533 def get_fetch_columns_for_table(self, table_info, query_in): 

534 """ 

535 Collect all columns needed from a specific table for column pruning optimization. 

536 

537 Note: Caller should check _can_prune_columns() before calling this method. 

538 

539 Returns a list of column Identifiers or None if we should fetch all columns. 

540 """ 

541 if not self._can_prune_columns(table_info): 

542 return None 

543 

544 columns = {} 

545 has_qualified_star_for_table = False 

546 

547 alias_map = {} 

548 if query_in.targets: 548 ↛ 553line 548 didn't jump to line 553 because the condition on line 548 was always true

549 for target in query_in.targets: 

550 if isinstance(target, Identifier) and target.alias: 550 ↛ 551line 550 didn't jump to line 551 because the condition on line 550 was never true

551 alias_map[target.alias.parts[-1].lower()] = target 

552 

553 def add_column(node, **kwargs): 

554 if isinstance(node, Identifier): 

555 col_table = self.get_table_for_column(node) 

556 if not col_table or col_table.index != table_info.index: 

557 return 

558 

559 # Check for qualified star: t1.* or alias.* 

560 col_name = node.parts[-1] 

561 is_quoted = node.is_quoted[-1] 

562 

563 if isinstance(col_name, Star): 

564 nonlocal has_qualified_star_for_table 

565 has_qualified_star_for_table = True 

566 return 

567 

568 # Store - if already exists, keep it quoted if either reference was quoted 

569 columns[col_name] = columns.get(col_name) or is_quoted 

570 

571 # Check for bare Star() in targets 

572 if query_in.targets: 572 ↛ 577line 572 didn't jump to line 577 because the condition on line 572 was always true

573 for target in query_in.targets: 

574 if isinstance(target, Star): 

575 return None 

576 

577 query_traversal(query_in, add_column) 

578 

579 # If qualified star found for this table, fetch all columns 

580 if has_qualified_star_for_table: 

581 return None 

582 

583 # If we found no columns, fetch all 

584 if not columns: 584 ↛ 585line 584 didn't jump to line 585 because the condition on line 584 was never true

585 return None 

586 

587 # Convert column names to Identifier objects, we need to preserve quoting 

588 result = [] 

589 for col, is_quoted in sorted(columns.items()): 

590 result.append(Identifier(parts=[col], is_quoted=[is_quoted])) 

591 return result 

592 

593 def process_table(self, item, query_in): 

594 table = copy.deepcopy(item.table) 

595 table.parts.insert(0, item.integration) 

596 table.is_quoted.insert(0, False) 

597 

598 needed_columns = self.get_fetch_columns_for_table(item, query_in) 

599 targets = needed_columns if needed_columns else [Star()] 

600 

601 query2 = Select(from_table=table, targets=targets) 

602 conditions = item.conditions 

603 if "or" in self.query_context["binary_ops"]: 603 ↛ 604line 603 didn't jump to line 604 because the condition on line 603 was never true

604 conditions = [] 

605 

606 if self.query_context.get("had_limit"): 

607 conditions += self.get_filters_from_join_conditions(item) 

608 

609 if self.query_context["use_limit"]: 

610 order_by = None 

611 if query_in.order_by is not None: 

612 order_by = [] 

613 # all order column are from this table 

614 for col in query_in.order_by: 

615 table_info = self.get_table_for_column(col.field) 

616 if table_info is None or table_info.table != item.table: 616 ↛ 617line 616 didn't jump to line 617 because the condition on line 616 was never true

617 order_by = False 

618 break 

619 col = copy.deepcopy(col) 

620 col.field.parts = [col.field.parts[-1]] 

621 col.field.is_quoted = [col.field.is_quoted[-1]] 

622 order_by.append(col) 

623 

624 if order_by is not False: 624 ↛ 633line 624 didn't jump to line 633 because the condition on line 624 was always true

625 # copy limit from upper query 

626 query2.limit = query_in.limit 

627 # move offset from upper query 

628 query2.offset = query_in.offset 

629 query_in.offset = None 

630 # copy order 

631 query2.order_by = order_by 

632 

633 self.query_context["use_limit"] = False 

634 self.query_context["had_limit"] = True 

635 for cond in conditions: 

636 if query2.where is not None: 

637 query2.where = BinaryOperation("and", args=[query2.where, cond]) 

638 else: 

639 query2.where = cond 

640 

641 step = self.planner.get_integration_select_step(query2, params=query_in.using) 

642 self.tables_fetch_step[item.index] = step 

643 

644 self.add_plan_step(step) 

645 self.step_stack.append(step) 

646 

647 def join_condition_to_columns_map(self, model_table): 

648 columns_map = {} 

649 

650 def _check_conditions(node, **kwargs): 

651 if not isinstance(node, BinaryOperation): 

652 return 

653 

654 arg1, arg2 = node.args 

655 if not (isinstance(arg1, Identifier) and isinstance(arg2, Identifier)): 

656 return 

657 

658 table1 = self.get_table_for_column(arg1) 

659 table2 = self.get_table_for_column(arg2) 

660 

661 if table1 is model_table: 

662 # model is on the left 

663 columns_map[arg1.parts[-1]] = arg2 

664 elif table2 is model_table: 664 ↛ 669line 664 didn't jump to line 669 because the condition on line 664 was always true

665 # model is on the right 

666 columns_map[arg2.parts[-1]] = arg1 

667 else: 

668 # not found, skip 

669 return 

670 

671 # exclude condition 

672 node.args = [Constant(0), Constant(0)] 

673 

674 query_traversal(model_table.join_condition, _check_conditions) 

675 return columns_map 

676 

677 def get_filters_from_join_conditions(self, fetch_table): 

678 """ 

679 Extract filters from join conditions for filter pushdown optimization. 

680 

681 Note: This function is currently disabled (not called) to avoid: 

682 - Creating massive IN clauses that exceed database query size limits 

683 - Making arbitrary assumptions about data distribution 

684 

685 For cross-database joins with large tables, users should: 

686 - Add explicit WHERE clauses to filter data at the source 

687 - Use indexed/partitioned tables in their databases 

688 - Consider materializing intermediate results 

689 """ 

690 binary_ops = set() 

691 conditions = [] 

692 data_conditions = [] 

693 

694 def _check_conditions(node, **kwargs): 

695 if not isinstance(node, BinaryOperation): 

696 return 

697 

698 if node.op != "=": 

699 binary_ops.add(node.op.lower()) 

700 return 

701 

702 arg1, arg2 = node.args 

703 table1 = self.get_table_for_column(arg1) if isinstance(arg1, Identifier) else None 

704 table2 = self.get_table_for_column(arg2) if isinstance(arg2, Identifier) else None 

705 

706 if table1 is not fetch_table: 706 ↛ 713line 706 didn't jump to line 713 because the condition on line 706 was always true

707 if table2 is not fetch_table: 707 ↛ 708line 707 didn't jump to line 708 because the condition on line 707 was never true

708 return 

709 # set our table first 

710 table1, table2 = table2, table1 

711 arg1, arg2 = arg2, arg1 

712 

713 if isinstance(arg2, Constant): 713 ↛ 714line 713 didn't jump to line 714 because the condition on line 713 was never true

714 conditions.append(node) 

715 elif table2 is not None: 715 ↛ exitline 715 didn't return from function '_check_conditions' because the condition on line 715 was always true

716 data_conditions.append([arg1, arg2]) 

717 

718 query_traversal(fetch_table.join_condition, _check_conditions) 

719 

720 binary_ops.discard("and") 

721 if len(binary_ops) > 0: 

722 # other operations exists, skip 

723 return [] 

724 

725 for arg1, arg2 in data_conditions: 

726 # is fetched? 

727 table2 = self.get_table_for_column(arg2) 

728 fetch_step = self.tables_fetch_step.get(table2.index) 

729 

730 if fetch_step is None: 730 ↛ 731line 730 didn't jump to line 731 because the condition on line 730 was never true

731 continue 

732 

733 # extract distinct values 

734 # remove aliases 

735 arg1 = Identifier(parts=[arg1.parts[-1]]) 

736 arg2 = Identifier(parts=[arg2.parts[-1]]) 

737 

738 query2 = Select(targets=[arg2], distinct=True) 

739 subselect_step = SubSelectStep(query2, fetch_step.result) 

740 subselect_step = self.add_plan_step(subselect_step) 

741 

742 conditions.append(BinaryOperation(op="in", args=[arg1, Parameter(subselect_step.result)])) 

743 

744 return conditions 

745 

746 def process_predictor(self, item, query_in): 

747 if len(self.step_stack) == 0: 747 ↛ 748line 747 didn't jump to line 748 because the condition on line 747 was never true

748 raise NotImplementedError("Predictor can't be first element of join syntax") 

749 if item.predictor_info.get("timeseries"): 749 ↛ 750line 749 didn't jump to line 750 because the condition on line 749 was never true

750 raise NotImplementedError("TS predictor is not supported here yet") 

751 data_step = self.step_stack[-1] 

752 row_dict = None 

753 

754 predict_target = item.predictor_info.get("to_predict") 

755 if isinstance(predict_target, list) and len(predict_target) > 0: 

756 predict_target = predict_target[0] 

757 if predict_target is not None: 

758 predict_target = predict_target.lower() 

759 

760 columns_map = None 

761 if item.join_condition: 

762 columns_map = self.join_condition_to_columns_map(item) 

763 

764 if item.conditions: 

765 row_dict = {} 

766 for i, el in enumerate(item.conditions): 

767 if isinstance(el.args[0], Identifier) and el.op == "=": 767 ↛ 766line 767 didn't jump to line 766 because the condition on line 767 was always true

768 col_name = el.args[0].parts[-1] 

769 if col_name.lower() == predict_target: 

770 # don't add predict target to parameters 

771 continue 

772 

773 if isinstance(el.args[1], (Constant, Parameter)): 773 ↛ 777line 773 didn't jump to line 777 because the condition on line 773 was always true

774 row_dict[el.args[0].parts[-1]] = el.args[1].value 

775 

776 # exclude condition 

777 el._orig_node.args = [Constant(0), Constant(0)] 

778 

779 # params for model 

780 model_params = None 

781 partition_size = None 

782 if query_in.using is not None: 

783 model_params = {} 

784 for param, value in query_in.using.items(): 

785 if "." in param: 

786 alias = param.split(".")[0] 

787 if (alias,) in item.aliases: 

788 new_param = ".".join(param.split(".")[1:]) 

789 model_params[new_param.lower()] = value 

790 else: 

791 model_params[param.lower()] = value 

792 

793 partition_size = model_params.pop("partition_size", None) 

794 

795 predictor_step = ApplyPredictorStep( 

796 namespace=item.integration, 

797 dataframe=data_step.result, 

798 predictor=item.table, 

799 params=model_params, 

800 row_dict=row_dict, 

801 columns_map=columns_map, 

802 ) 

803 

804 self.step_stack.append(self.add_plan_step(predictor_step, partition_size=partition_size)) 

805 

806 def add_plan_step(self, step, partition_size=None): 

807 """ 

808 Adds step to plan 

809 

810 If partition_size is defined: create partition 

811 If partition is active 

812 If step can be partitioned: 

813 Add step to partition not in plan 

814 Otherwise: 

815 Add partition to plan 

816 Add step to plan 

817 """ 

818 if self.partition: 

819 if isinstance(step, (JoinStep, ApplyPredictorStep)): 819 ↛ 838line 819 didn't jump to line 838 because the condition on line 819 was always true

820 # add to partition 

821 

822 self.add_step_to_partition(step) 

823 return step 

824 

825 elif partition_size is not None: 

826 # create partition 

827 

828 self.partition = MapReduceStep(values=step.dataframe, reduce="union", step=[], partition=partition_size) 

829 self.planner.plan.add_step(self.partition) 

830 

831 self.add_step_to_partition(step) 

832 return step 

833 

834 else: 

835 # next step can't be partitioned. 

836 self.close_partition() 

837 

838 return self.planner.plan.add_step(step) 

839 

840 def add_step_to_partition(self, step): 

841 step.step_num = f"{self.partition.step_num}_{len(self.partition.step)}" 

842 self.partition.step.append(step) 

843 

844 def close_partition(self): 

845 # return 

846 # if partitions is exist - clear it and replace last stack item with it 

847 

848 if self.partition: 

849 if len(self.step_stack) > 0: 849 ↛ 852line 849 didn't jump to line 852 because the condition on line 849 was always true

850 self.step_stack[-1] = self.partition 

851 

852 self.partition = None