Coverage for mindsdb / api / executor / sql_query / steps / subselect_step.py: 45%

129 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from collections import defaultdict 

2 

3import pandas as pd 

4 

5from mindsdb_sql_parser.ast import ( 

6 Identifier, 

7 Select, 

8 Star, 

9 Constant, 

10 Function, 

11 Variable, 

12 BinaryOperation, 

13) 

14 

15from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import SERVER_VARIABLES 

16from mindsdb.api.executor.planner.step_result import Result 

17from mindsdb.api.executor.planner.steps import SubSelectStep, QueryStep 

18from mindsdb.api.executor.sql_query.result_set import ResultSet, Column 

19from mindsdb.api.executor.utilities.sql import query_df 

20from mindsdb.api.executor.exceptions import KeyColumnDoesNotExist 

21from mindsdb.integrations.utilities.query_traversal import query_traversal 

22from mindsdb.interfaces.query_context.context_controller import query_context_controller 

23 

24from .base import BaseStepCall 

25from .fetch_dataframe import get_fill_param_fnc 

26 

27 

28class SubSelectStepCall(BaseStepCall): 

29 bind = SubSelectStep 

30 

31 def call(self, step): 

32 result = self.steps_data[step.dataframe.step_num] 

33 

34 table_name = step.table_name 

35 if table_name is None: 

36 table_name = "df_table" 

37 else: 

38 table_name = table_name 

39 

40 query = step.query 

41 query.from_table = Identifier("df_table") 

42 

43 if step.add_absent_cols and isinstance(query, Select): 

44 query_cols = set() 

45 

46 def f_all_cols(node, **kwargs): 

47 if isinstance(node, Identifier): 

48 query_cols.add(node.parts[-1]) 

49 elif isinstance(node, Result): 

50 prev_result = self.steps_data[node.step_num] 

51 return Constant(prev_result.get_column_values(col_idx=0)[0]) 

52 

53 query_traversal(query.where, f_all_cols) 

54 

55 result_cols = [col.name for col in result.columns] 

56 

57 for col_name in query_cols: 

58 if col_name not in result_cols: 

59 result.add_column(Column(name=col_name)) 

60 

61 # inject previous step values 

62 if isinstance(query, Select): 

63 fill_params = get_fill_param_fnc(self.steps_data) 

64 query_traversal(query, fill_params) 

65 

66 df = result.to_df() 

67 res = query_df(df, query, session=self.session) 

68 

69 # get database from first column 

70 database = result.columns[0].database 

71 

72 return ResultSet.from_df(res, database, table_name) 

73 

74 

75class QueryStepCall(BaseStepCall): 

76 bind = QueryStep 

77 

78 def call(self, step: QueryStep): 

79 query = step.query 

80 

81 if step.from_table is not None: 81 ↛ 88line 81 didn't jump to line 88 because the condition on line 81 was always true

82 if isinstance(step.from_table, pd.DataFrame): 82 ↛ 85line 82 didn't jump to line 85 because the condition on line 82 was always true

83 result_set = ResultSet.from_df(step.from_table) 

84 else: 

85 result_set = self.steps_data[step.from_table.step_num] 

86 else: 

87 # only from_table can content result 

88 prev_step_num = query.from_table.value.step_num 

89 result_set = self.steps_data[prev_step_num] 

90 

91 df, col_names = result_set.to_df_cols() 

92 col_idx = {} 

93 tbl_idx = defaultdict(list) 

94 for name, col in col_names.items(): 

95 col_idx[col.alias] = name 

96 col_idx[(col.table_alias, col.alias)] = name 

97 # add to tables 

98 tbl_idx[col.table_name].append(name) 

99 if col.table_name != col.table_alias: 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true

100 tbl_idx[col.table_alias].append(name) 

101 

102 lower_col_idx = {} 

103 for key, value in col_idx.items(): 

104 if isinstance(key, int): 

105 key = str(key) 

106 if isinstance(key, str): 

107 lower_col_idx[key.lower()] = value 

108 continue 

109 lower_col_idx[tuple(str(x).lower() for x in key)] = value 

110 

111 # get aliases of first level 

112 aliases = [] 

113 for col in query.targets: 

114 if col.alias is not None: 114 ↛ 113line 114 didn't jump to line 113 because the condition on line 114 was always true

115 aliases.append(col.alias.parts[0]) 

116 

117 # analyze condition and change name of columns 

118 def check_fields(node, is_target=None, **kwargs): 

119 if isinstance(node, Function): 119 ↛ 120line 119 didn't jump to line 120 because the condition on line 119 was never true

120 function_name = node.op.lower() 

121 

122 functions_results = { 

123 "database": self.session.database, 

124 "current_user": self.session.username, 

125 "user": self.session.username, 

126 "version": "8.0.17", 

127 "current_schema": "public", 

128 "schema": "public", 

129 "connection_id": self.context.get("connection_id"), 

130 } 

131 if function_name in functions_results: 

132 return Constant(functions_results[function_name], alias=Identifier(parts=[function_name])) 

133 

134 if isinstance(node, Variable): 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true

135 var_name = node.value 

136 column_name = f"@@{var_name}" 

137 result = SERVER_VARIABLES.get(column_name) 

138 if result is None: 

139 raise ValueError(f"Unknown variable '{var_name}'") 

140 else: 

141 return Constant(result[0], alias=Identifier(parts=[column_name])) 

142 

143 if isinstance(node, Identifier): 

144 # only column name 

145 col_name = node.parts[-1] 

146 if is_target and isinstance(col_name, Star): 146 ↛ 147line 146 didn't jump to line 147 because the condition on line 146 was never true

147 if len(node.parts) == 1: 

148 # left as is 

149 return 

150 else: 

151 # replace with all columns from table 

152 table_name = node.parts[-2] 

153 return [Identifier(parts=[col]) for col in tbl_idx.get(table_name, [])] 

154 

155 if node.parts[-1].lower() == "session_user": 155 ↛ 156line 155 didn't jump to line 156 because the condition on line 155 was never true

156 return Constant(self.session.username, alias=node) 

157 if node.parts[-1].lower() == "$$": 157 ↛ 160line 157 didn't jump to line 160 because the condition on line 157 was never true

158 # NOTE: sinve version 9.0 mysql client sends query 'select $$'. 

159 # Connection can be continued only if answer is parse error. 

160 raise ValueError( 

161 "You have an error in your SQL syntax; check the manual that corresponds to your server " 

162 "version for the right syntax to use near '$$' at line 1" 

163 ) 

164 

165 match node.parts, node.is_quoted: 

166 case [column_name], [column_quoted]: 166 ↛ 181line 166 didn't jump to line 181 because the pattern on line 166 always matched

167 if column_name in aliases: 167 ↛ 169line 167 didn't jump to line 169 because the condition on line 167 was never true

168 # key is defined as alias 

169 return 

170 

171 key = column_name if column_quoted else column_name.lower() 

172 

173 if key not in col_idx and key not in lower_col_idx: 173 ↛ 184line 173 didn't jump to line 184 because the condition on line 173 was always true

174 # it can be local alias of a query, like: 

175 # SELECT t1.a + t2.a col1, min(t1.a) c 

176 # FROM dummy_data.tbl1 as t1 

177 # JOIN pg.tbl2 as t2 on t1.c=t2.c 

178 # group by col1 

179 # order by c -- <--- "с" is alias 

180 return 

181 case [*_, table_name, column_name], [*_, column_quoted]: 

182 key = (table_name, column_name) if column_quoted else (table_name.lower(), column_name.lower()) 

183 

184 search_idx = col_idx if column_quoted else lower_col_idx 

185 

186 if key not in search_idx: 

187 raise KeyColumnDoesNotExist(f"Table not found for column: {key}") 

188 

189 new_name = search_idx[key] 

190 return Identifier(parts=[new_name], alias=node.alias, with_rollup=node.with_rollup) 

191 

192 # fill params 

193 fill_params = get_fill_param_fnc(self.steps_data) 

194 query_traversal(query, fill_params) 

195 

196 if not step.strict_where: 196 ↛ 201line 196 didn't jump to line 201 because the condition on line 196 was never true

197 # remove conditions with not-existed columns. 

198 # these conditions can be already used as input to model or knowledge base 

199 # but can be absent in their output 

200 

201 def remove_not_used_conditions(node, **kwargs): 

202 if isinstance(node, BinaryOperation): 

203 for arg in node.args: 

204 if isinstance(arg, Identifier) and len(arg.parts) > 1: 

205 key = tuple(arg.parts[-2:]) 

206 if key not in col_idx: 

207 # exclude 

208 node.args = [Constant(0), Constant(0)] 

209 node.op = "=" 

210 

211 query_traversal(query.where, remove_not_used_conditions) 

212 

213 query_traversal(query, check_fields) 

214 query.where = query_context_controller.remove_lasts(query.where) 

215 

216 query.from_table = Identifier("df_table") 

217 res = query_df(df, query, session=self.session) 

218 

219 return ResultSet.from_df_cols(df=res, columns_dict=col_names, strict=False)