Coverage for mindsdb / api / executor / sql_query / steps / project_step.py: 16%
49 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from collections import defaultdict
3from mindsdb_sql_parser.ast import (
4 Identifier,
5 Select,
6 Star,
7)
8from mindsdb.api.executor.planner.steps import ProjectStep
9from mindsdb.integrations.utilities.query_traversal import query_traversal
11from mindsdb.api.executor.sql_query.result_set import ResultSet
12from mindsdb.api.executor.utilities.sql import query_df
13from mindsdb.api.executor.exceptions import (
14 KeyColumnDoesNotExist,
15 NotSupportedYet
16)
18from .base import BaseStepCall
21class ProjectStepCall(BaseStepCall):
23 bind = ProjectStep
25 def call(self, step):
26 result_set = self.steps_data[step.dataframe.step_num]
28 df, col_names = result_set.to_df_cols()
29 col_idx = {}
30 tbl_idx = defaultdict(list)
31 for name, col in col_names.items():
32 col_idx[col.alias] = name
33 col_idx[(col.table_alias, col.alias)] = name
34 # add to tables
35 tbl_idx[col.table_name].append(name)
36 if col.table_name != col.table_alias:
37 tbl_idx[col.table_alias].append(name)
39 # analyze condition and change name of columns
40 def check_fields(node, is_table=None, **kwargs):
41 if is_table:
42 raise NotSupportedYet('Subqueries is not supported in target')
43 if isinstance(node, Identifier):
44 # only column name
45 col_name = node.parts[-1]
46 if isinstance(col_name, Star):
47 if len(node.parts) == 1:
48 # left as is
49 return
50 else:
51 # replace with all columns from table
52 table_name = node.parts[-2]
53 return [
54 Identifier(parts=[col])
55 for col in tbl_idx.get(table_name, [])
56 ]
58 if len(node.parts) == 1:
59 key = col_name
60 else:
61 table_name = node.parts[-2]
62 key = (table_name, col_name)
64 if key not in col_idx:
65 raise KeyColumnDoesNotExist(f'Table not found for column: {key}')
67 new_name = col_idx[key]
68 return Identifier(parts=[new_name], alias=node.alias)
70 query = Select(
71 targets=step.columns,
72 from_table=Identifier('df_table')
73 )
75 targets0 = query_traversal(query.targets, check_fields)
76 targets = []
77 for target in targets0:
78 if isinstance(target, list):
79 targets.extend(target)
80 else:
81 targets.append(target)
82 query.targets = targets
84 res = query_df(df, query, session=self.session)
86 return ResultSet.from_df_cols(df=res, columns_dict=col_names, strict=False)