Coverage for mindsdb / api / executor / sql_query / steps / project_step.py: 16%

49 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from collections import defaultdict 

2 

3from mindsdb_sql_parser.ast import ( 

4 Identifier, 

5 Select, 

6 Star, 

7) 

8from mindsdb.api.executor.planner.steps import ProjectStep 

9from mindsdb.integrations.utilities.query_traversal import query_traversal 

10 

11from mindsdb.api.executor.sql_query.result_set import ResultSet 

12from mindsdb.api.executor.utilities.sql import query_df 

13from mindsdb.api.executor.exceptions import ( 

14 KeyColumnDoesNotExist, 

15 NotSupportedYet 

16) 

17 

18from .base import BaseStepCall 

19 

20 

21class ProjectStepCall(BaseStepCall): 

22 

23 bind = ProjectStep 

24 

25 def call(self, step): 

26 result_set = self.steps_data[step.dataframe.step_num] 

27 

28 df, col_names = result_set.to_df_cols() 

29 col_idx = {} 

30 tbl_idx = defaultdict(list) 

31 for name, col in col_names.items(): 

32 col_idx[col.alias] = name 

33 col_idx[(col.table_alias, col.alias)] = name 

34 # add to tables 

35 tbl_idx[col.table_name].append(name) 

36 if col.table_name != col.table_alias: 

37 tbl_idx[col.table_alias].append(name) 

38 

39 # analyze condition and change name of columns 

40 def check_fields(node, is_table=None, **kwargs): 

41 if is_table: 

42 raise NotSupportedYet('Subqueries is not supported in target') 

43 if isinstance(node, Identifier): 

44 # only column name 

45 col_name = node.parts[-1] 

46 if isinstance(col_name, Star): 

47 if len(node.parts) == 1: 

48 # left as is 

49 return 

50 else: 

51 # replace with all columns from table 

52 table_name = node.parts[-2] 

53 return [ 

54 Identifier(parts=[col]) 

55 for col in tbl_idx.get(table_name, []) 

56 ] 

57 

58 if len(node.parts) == 1: 

59 key = col_name 

60 else: 

61 table_name = node.parts[-2] 

62 key = (table_name, col_name) 

63 

64 if key not in col_idx: 

65 raise KeyColumnDoesNotExist(f'Table not found for column: {key}') 

66 

67 new_name = col_idx[key] 

68 return Identifier(parts=[new_name], alias=node.alias) 

69 

70 query = Select( 

71 targets=step.columns, 

72 from_table=Identifier('df_table') 

73 ) 

74 

75 targets0 = query_traversal(query.targets, check_fields) 

76 targets = [] 

77 for target in targets0: 

78 if isinstance(target, list): 

79 targets.extend(target) 

80 else: 

81 targets.append(target) 

82 query.targets = targets 

83 

84 res = query_df(df, query, session=self.session) 

85 

86 return ResultSet.from_df_cols(df=res, columns_dict=col_names, strict=False)