Coverage for mindsdb / api / executor / sql_query / steps / subselect_step.py: 45%
129 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from collections import defaultdict
3import pandas as pd
5from mindsdb_sql_parser.ast import (
6 Identifier,
7 Select,
8 Star,
9 Constant,
10 Function,
11 Variable,
12 BinaryOperation,
13)
15from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import SERVER_VARIABLES
16from mindsdb.api.executor.planner.step_result import Result
17from mindsdb.api.executor.planner.steps import SubSelectStep, QueryStep
18from mindsdb.api.executor.sql_query.result_set import ResultSet, Column
19from mindsdb.api.executor.utilities.sql import query_df
20from mindsdb.api.executor.exceptions import KeyColumnDoesNotExist
21from mindsdb.integrations.utilities.query_traversal import query_traversal
22from mindsdb.interfaces.query_context.context_controller import query_context_controller
24from .base import BaseStepCall
25from .fetch_dataframe import get_fill_param_fnc
28class SubSelectStepCall(BaseStepCall):
29 bind = SubSelectStep
31 def call(self, step):
32 result = self.steps_data[step.dataframe.step_num]
34 table_name = step.table_name
35 if table_name is None:
36 table_name = "df_table"
37 else:
38 table_name = table_name
40 query = step.query
41 query.from_table = Identifier("df_table")
43 if step.add_absent_cols and isinstance(query, Select):
44 query_cols = set()
46 def f_all_cols(node, **kwargs):
47 if isinstance(node, Identifier):
48 query_cols.add(node.parts[-1])
49 elif isinstance(node, Result):
50 prev_result = self.steps_data[node.step_num]
51 return Constant(prev_result.get_column_values(col_idx=0)[0])
53 query_traversal(query.where, f_all_cols)
55 result_cols = [col.name for col in result.columns]
57 for col_name in query_cols:
58 if col_name not in result_cols:
59 result.add_column(Column(name=col_name))
61 # inject previous step values
62 if isinstance(query, Select):
63 fill_params = get_fill_param_fnc(self.steps_data)
64 query_traversal(query, fill_params)
66 df = result.to_df()
67 res = query_df(df, query, session=self.session)
69 # get database from first column
70 database = result.columns[0].database
72 return ResultSet.from_df(res, database, table_name)
75class QueryStepCall(BaseStepCall):
76 bind = QueryStep
78 def call(self, step: QueryStep):
79 query = step.query
81 if step.from_table is not None: 81 ↛ 88line 81 didn't jump to line 88 because the condition on line 81 was always true
82 if isinstance(step.from_table, pd.DataFrame): 82 ↛ 85line 82 didn't jump to line 85 because the condition on line 82 was always true
83 result_set = ResultSet.from_df(step.from_table)
84 else:
85 result_set = self.steps_data[step.from_table.step_num]
86 else:
87 # only from_table can content result
88 prev_step_num = query.from_table.value.step_num
89 result_set = self.steps_data[prev_step_num]
91 df, col_names = result_set.to_df_cols()
92 col_idx = {}
93 tbl_idx = defaultdict(list)
94 for name, col in col_names.items():
95 col_idx[col.alias] = name
96 col_idx[(col.table_alias, col.alias)] = name
97 # add to tables
98 tbl_idx[col.table_name].append(name)
99 if col.table_name != col.table_alias: 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true
100 tbl_idx[col.table_alias].append(name)
102 lower_col_idx = {}
103 for key, value in col_idx.items():
104 if isinstance(key, int):
105 key = str(key)
106 if isinstance(key, str):
107 lower_col_idx[key.lower()] = value
108 continue
109 lower_col_idx[tuple(str(x).lower() for x in key)] = value
111 # get aliases of first level
112 aliases = []
113 for col in query.targets:
114 if col.alias is not None: 114 ↛ 113line 114 didn't jump to line 113 because the condition on line 114 was always true
115 aliases.append(col.alias.parts[0])
117 # analyze condition and change name of columns
118 def check_fields(node, is_target=None, **kwargs):
119 if isinstance(node, Function): 119 ↛ 120line 119 didn't jump to line 120 because the condition on line 119 was never true
120 function_name = node.op.lower()
122 functions_results = {
123 "database": self.session.database,
124 "current_user": self.session.username,
125 "user": self.session.username,
126 "version": "8.0.17",
127 "current_schema": "public",
128 "schema": "public",
129 "connection_id": self.context.get("connection_id"),
130 }
131 if function_name in functions_results:
132 return Constant(functions_results[function_name], alias=Identifier(parts=[function_name]))
134 if isinstance(node, Variable): 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true
135 var_name = node.value
136 column_name = f"@@{var_name}"
137 result = SERVER_VARIABLES.get(column_name)
138 if result is None:
139 raise ValueError(f"Unknown variable '{var_name}'")
140 else:
141 return Constant(result[0], alias=Identifier(parts=[column_name]))
143 if isinstance(node, Identifier):
144 # only column name
145 col_name = node.parts[-1]
146 if is_target and isinstance(col_name, Star): 146 ↛ 147line 146 didn't jump to line 147 because the condition on line 146 was never true
147 if len(node.parts) == 1:
148 # left as is
149 return
150 else:
151 # replace with all columns from table
152 table_name = node.parts[-2]
153 return [Identifier(parts=[col]) for col in tbl_idx.get(table_name, [])]
155 if node.parts[-1].lower() == "session_user": 155 ↛ 156line 155 didn't jump to line 156 because the condition on line 155 was never true
156 return Constant(self.session.username, alias=node)
157 if node.parts[-1].lower() == "$$": 157 ↛ 160line 157 didn't jump to line 160 because the condition on line 157 was never true
158 # NOTE: sinve version 9.0 mysql client sends query 'select $$'.
159 # Connection can be continued only if answer is parse error.
160 raise ValueError(
161 "You have an error in your SQL syntax; check the manual that corresponds to your server "
162 "version for the right syntax to use near '$$' at line 1"
163 )
165 match node.parts, node.is_quoted:
166 case [column_name], [column_quoted]: 166 ↛ 181line 166 didn't jump to line 181 because the pattern on line 166 always matched
167 if column_name in aliases: 167 ↛ 169line 167 didn't jump to line 169 because the condition on line 167 was never true
168 # key is defined as alias
169 return
171 key = column_name if column_quoted else column_name.lower()
173 if key not in col_idx and key not in lower_col_idx: 173 ↛ 184line 173 didn't jump to line 184 because the condition on line 173 was always true
174 # it can be local alias of a query, like:
175 # SELECT t1.a + t2.a col1, min(t1.a) c
176 # FROM dummy_data.tbl1 as t1
177 # JOIN pg.tbl2 as t2 on t1.c=t2.c
178 # group by col1
179 # order by c -- <--- "с" is alias
180 return
181 case [*_, table_name, column_name], [*_, column_quoted]:
182 key = (table_name, column_name) if column_quoted else (table_name.lower(), column_name.lower())
184 search_idx = col_idx if column_quoted else lower_col_idx
186 if key not in search_idx:
187 raise KeyColumnDoesNotExist(f"Table not found for column: {key}")
189 new_name = search_idx[key]
190 return Identifier(parts=[new_name], alias=node.alias, with_rollup=node.with_rollup)
192 # fill params
193 fill_params = get_fill_param_fnc(self.steps_data)
194 query_traversal(query, fill_params)
196 if not step.strict_where: 196 ↛ 201line 196 didn't jump to line 201 because the condition on line 196 was never true
197 # remove conditions with not-existed columns.
198 # these conditions can be already used as input to model or knowledge base
199 # but can be absent in their output
201 def remove_not_used_conditions(node, **kwargs):
202 if isinstance(node, BinaryOperation):
203 for arg in node.args:
204 if isinstance(arg, Identifier) and len(arg.parts) > 1:
205 key = tuple(arg.parts[-2:])
206 if key not in col_idx:
207 # exclude
208 node.args = [Constant(0), Constant(0)]
209 node.op = "="
211 query_traversal(query.where, remove_not_used_conditions)
213 query_traversal(query, check_fields)
214 query.where = query_context_controller.remove_lasts(query.where)
216 query.from_table = Identifier("df_table")
217 res = query_df(df, query, session=self.session)
219 return ResultSet.from_df_cols(df=res, columns_dict=col_names, strict=False)