Coverage for mindsdb / integrations / utilities / sql_utils.py: 35%
154 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from enum import Enum
2from typing import Any
3import pandas as pd
5from mindsdb.api.executor.utilities.sql import query_df
6from mindsdb_sql_parser import ast
7from mindsdb_sql_parser.ast.base import ASTNode
9from mindsdb.integrations.utilities.query_traversal import query_traversal
10from mindsdb.utilities.config import config
13class FilterOperator(Enum):
14 """
15 Enum for filter operators.
16 """
18 EQUAL = "="
19 NOT_EQUAL = "!="
20 LESS_THAN = "<"
21 LESS_THAN_OR_EQUAL = "<="
22 GREATER_THAN = ">"
23 GREATER_THAN_OR_EQUAL = ">="
24 IN = "IN"
25 NOT_IN = "NOT IN"
26 BETWEEN = "BETWEEN"
27 NOT_BETWEEN = "NOT BETWEEN"
28 LIKE = "LIKE"
29 NOT_LIKE = "NOT LIKE"
30 IS_NULL = "IS NULL"
31 IS_NOT_NULL = "IS NOT NULL"
32 IS = "IS"
33 IS_NOT = "IS NOT"
36class FilterCondition:
37 """
38 Base class for filter conditions.
39 """
41 def __init__(self, column: str, op: FilterOperator, value: Any):
42 self.column = column
43 self.op = op
44 self.value = value
45 self.applied = False
47 def __eq__(self, __value: object) -> bool:
48 if isinstance(__value, FilterCondition):
49 return self.column == __value.column and self.op == __value.op and self.value == __value.value
50 else:
51 return False
53 def __repr__(self) -> str:
54 return f"""
55 FilterCondition(
56 column={self.column},
57 op={self.op},
58 value={self.value}
59 )
60 """
63class KeywordSearchArgs:
64 def __init__(self, column: str, query: str):
65 """
66 Args:
67 column: The column to search in.
68 query: The search query string.
69 """
70 self.column = column
71 self.query = query
74class SortColumn:
75 def __init__(self, column: str, ascending: bool = True):
76 self.column = column
77 self.ascending = ascending
78 self.applied = False
81def make_sql_session():
82 from mindsdb.api.executor.controllers.session_controller import SessionController
84 sql_session = SessionController()
85 sql_session.database = config.get("default_project")
86 return sql_session
89def conditions_to_filter(binary_op: ASTNode):
90 conditions = extract_comparison_conditions(binary_op)
92 filters = {}
93 for op, arg1, arg2 in conditions:
94 if op != "=":
95 raise NotImplementedError
96 filters[arg1] = arg2
97 return filters
100def extract_comparison_conditions(binary_op: ASTNode, ignore_functions=False, strict=True):
101 """Extracts all simple comparison conditions that must be true from an AST node.
102 Does NOT support 'or' conditions.
103 """
104 conditions = []
106 def _extract_comparison_conditions(node: ASTNode, **kwargs):
107 if isinstance(node, ast.BinaryOperation):
108 op = node.op.lower()
109 if op == "and":
110 # Want to separate individual conditions, not include 'and' as its own condition.
111 return
113 arg1, arg2 = node.args
114 if ignore_functions and isinstance(arg1, ast.Function): 114 ↛ 116line 114 didn't jump to line 116 because the condition on line 114 was never true
115 # handle lower/upper
116 if arg1.op.lower() in ("lower", "upper"):
117 if isinstance(arg1.args[0], ast.Identifier):
118 arg1 = arg1.args[0]
120 if not isinstance(arg1, ast.Identifier): 120 ↛ 122line 120 didn't jump to line 122 because the condition on line 120 was never true
121 # Only support [identifier] =/</>/>=/<=/etc [constant] comparisons.
122 if strict:
123 raise NotImplementedError(f"Not implemented arg1: {arg1}")
124 else:
125 conditions.append(node)
126 return
128 if isinstance(arg2, ast.Constant): 128 ↛ 130line 128 didn't jump to line 130 because the condition on line 128 was always true
129 value = arg2.value
130 elif isinstance(arg2, ast.Tuple):
131 value = [i.value for i in arg2.items]
132 else:
133 raise NotImplementedError(f"Not implemented arg2: {arg2}")
135 conditions.append([op, arg1.parts[-1], value])
136 if isinstance(node, ast.BetweenOperation): 136 ↛ 137line 136 didn't jump to line 137 because the condition on line 136 was never true
137 var, up, down = node.args
138 if not (
139 isinstance(var, ast.Identifier) and isinstance(up, ast.Constant) and isinstance(down, ast.Constant)
140 ):
141 raise NotImplementedError(f"Not implemented: {node}")
143 op = node.op.lower()
144 conditions.append([op, var.parts[-1], (up.value, down.value)])
146 query_traversal(binary_op, _extract_comparison_conditions)
147 return conditions
150def project_dataframe(df, targets, table_columns):
151 """
152 case-insensitive projection
153 'select A' and 'select a' return different column case but with the same content
154 """
156 columns = []
157 df_cols_idx = {col.lower(): col for col in df.columns}
158 df_col_rename = {}
160 for target in targets:
161 if isinstance(target, ast.Star):
162 for col in table_columns:
163 col_df = df_cols_idx.get(col.lower())
164 if col_df is not None:
165 df_col_rename[col_df] = col
166 columns.append(col)
168 break
169 elif isinstance(target, ast.Identifier):
170 col = target.parts[-1]
171 col_df = df_cols_idx.get(col.lower())
172 if col_df is not None:
173 if hasattr(target, "alias") and isinstance(target.alias, ast.Identifier):
174 df_col_rename[col_df] = target.alias.parts[0]
175 else:
176 df_col_rename[col_df] = col
177 columns.append(col)
178 else:
179 raise NotImplementedError
181 if len(df) == 0:
182 df = pd.DataFrame([], columns=columns)
183 else:
184 # add absent columns
185 for col in set(columns) & set(df.columns) ^ set(columns):
186 df[col] = None
188 # filter by columns
189 df = df[columns]
191 # adapt column names to projection
192 if len(df_col_rename) > 0:
193 df.rename(columns=df_col_rename, inplace=True)
194 return df
197def filter_dataframe(df: pd.DataFrame, conditions: list, raw_conditions=None, order_by=None):
198 # convert list of conditions to ast.
199 # assumes that list was got from extract_comparison_conditions
200 where_query = None
201 for op, arg1, arg2 in conditions: 201 ↛ 202line 201 didn't jump to line 202 because the loop on line 201 never started
202 op = op.lower()
204 if op == "between":
205 item = ast.BetweenOperation(args=[ast.Identifier(arg1), ast.Constant(arg2[0]), ast.Constant(arg2[1])])
206 else:
207 if isinstance(arg2, (tuple, list)):
208 arg2 = ast.Tuple(arg2)
210 item = ast.BinaryOperation(op=op, args=[ast.Identifier(arg1), ast.Constant(arg2)])
211 if where_query is None:
212 where_query = item
213 else:
214 where_query = ast.BinaryOperation(op="and", args=[where_query, item])
216 if raw_conditions: 216 ↛ 217line 216 didn't jump to line 217 because the condition on line 216 was never true
217 for condition in raw_conditions:
218 if where_query is None:
219 where_query = condition
220 else:
221 where_query = ast.BinaryOperation(op="and", args=[where_query, condition])
223 query = ast.Select(targets=[ast.Star()], from_table=ast.Identifier("df"), where=where_query)
225 if order_by: 225 ↛ 226line 225 didn't jump to line 226 because the condition on line 225 was never true
226 query.order_by = order_by
228 return query_df(df, query)
231def sort_dataframe(df, order_by: list):
232 cols = []
233 ascending = []
234 for order in order_by:
235 if not isinstance(order, ast.OrderBy):
236 continue
238 col = order.field.parts[-1]
239 if col not in df.columns:
240 continue
242 cols.append(col)
243 ascending.append(False if order.direction.lower() == "desc" else True)
244 if len(cols) > 0:
245 df = df.sort_values(by=cols, ascending=ascending)
246 return df