Coverage for mindsdb / integrations / utilities / sql_utils.py: 35%

154 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from enum import Enum 

2from typing import Any 

3import pandas as pd 

4 

5from mindsdb.api.executor.utilities.sql import query_df 

6from mindsdb_sql_parser import ast 

7from mindsdb_sql_parser.ast.base import ASTNode 

8 

9from mindsdb.integrations.utilities.query_traversal import query_traversal 

10from mindsdb.utilities.config import config 

11 

12 

13class FilterOperator(Enum): 

14 """ 

15 Enum for filter operators. 

16 """ 

17 

18 EQUAL = "=" 

19 NOT_EQUAL = "!=" 

20 LESS_THAN = "<" 

21 LESS_THAN_OR_EQUAL = "<=" 

22 GREATER_THAN = ">" 

23 GREATER_THAN_OR_EQUAL = ">=" 

24 IN = "IN" 

25 NOT_IN = "NOT IN" 

26 BETWEEN = "BETWEEN" 

27 NOT_BETWEEN = "NOT BETWEEN" 

28 LIKE = "LIKE" 

29 NOT_LIKE = "NOT LIKE" 

30 IS_NULL = "IS NULL" 

31 IS_NOT_NULL = "IS NOT NULL" 

32 IS = "IS" 

33 IS_NOT = "IS NOT" 

34 

35 

36class FilterCondition: 

37 """ 

38 Base class for filter conditions. 

39 """ 

40 

41 def __init__(self, column: str, op: FilterOperator, value: Any): 

42 self.column = column 

43 self.op = op 

44 self.value = value 

45 self.applied = False 

46 

47 def __eq__(self, __value: object) -> bool: 

48 if isinstance(__value, FilterCondition): 

49 return self.column == __value.column and self.op == __value.op and self.value == __value.value 

50 else: 

51 return False 

52 

53 def __repr__(self) -> str: 

54 return f""" 

55 FilterCondition( 

56 column={self.column}, 

57 op={self.op}, 

58 value={self.value} 

59 ) 

60 """ 

61 

62 

63class KeywordSearchArgs: 

64 def __init__(self, column: str, query: str): 

65 """ 

66 Args: 

67 column: The column to search in. 

68 query: The search query string. 

69 """ 

70 self.column = column 

71 self.query = query 

72 

73 

74class SortColumn: 

75 def __init__(self, column: str, ascending: bool = True): 

76 self.column = column 

77 self.ascending = ascending 

78 self.applied = False 

79 

80 

81def make_sql_session(): 

82 from mindsdb.api.executor.controllers.session_controller import SessionController 

83 

84 sql_session = SessionController() 

85 sql_session.database = config.get("default_project") 

86 return sql_session 

87 

88 

89def conditions_to_filter(binary_op: ASTNode): 

90 conditions = extract_comparison_conditions(binary_op) 

91 

92 filters = {} 

93 for op, arg1, arg2 in conditions: 

94 if op != "=": 

95 raise NotImplementedError 

96 filters[arg1] = arg2 

97 return filters 

98 

99 

100def extract_comparison_conditions(binary_op: ASTNode, ignore_functions=False, strict=True): 

101 """Extracts all simple comparison conditions that must be true from an AST node. 

102 Does NOT support 'or' conditions. 

103 """ 

104 conditions = [] 

105 

106 def _extract_comparison_conditions(node: ASTNode, **kwargs): 

107 if isinstance(node, ast.BinaryOperation): 

108 op = node.op.lower() 

109 if op == "and": 

110 # Want to separate individual conditions, not include 'and' as its own condition. 

111 return 

112 

113 arg1, arg2 = node.args 

114 if ignore_functions and isinstance(arg1, ast.Function): 114 ↛ 116line 114 didn't jump to line 116 because the condition on line 114 was never true

115 # handle lower/upper 

116 if arg1.op.lower() in ("lower", "upper"): 

117 if isinstance(arg1.args[0], ast.Identifier): 

118 arg1 = arg1.args[0] 

119 

120 if not isinstance(arg1, ast.Identifier): 120 ↛ 122line 120 didn't jump to line 122 because the condition on line 120 was never true

121 # Only support [identifier] =/</>/>=/<=/etc [constant] comparisons. 

122 if strict: 

123 raise NotImplementedError(f"Not implemented arg1: {arg1}") 

124 else: 

125 conditions.append(node) 

126 return 

127 

128 if isinstance(arg2, ast.Constant): 128 ↛ 130line 128 didn't jump to line 130 because the condition on line 128 was always true

129 value = arg2.value 

130 elif isinstance(arg2, ast.Tuple): 

131 value = [i.value for i in arg2.items] 

132 else: 

133 raise NotImplementedError(f"Not implemented arg2: {arg2}") 

134 

135 conditions.append([op, arg1.parts[-1], value]) 

136 if isinstance(node, ast.BetweenOperation): 136 ↛ 137line 136 didn't jump to line 137 because the condition on line 136 was never true

137 var, up, down = node.args 

138 if not ( 

139 isinstance(var, ast.Identifier) and isinstance(up, ast.Constant) and isinstance(down, ast.Constant) 

140 ): 

141 raise NotImplementedError(f"Not implemented: {node}") 

142 

143 op = node.op.lower() 

144 conditions.append([op, var.parts[-1], (up.value, down.value)]) 

145 

146 query_traversal(binary_op, _extract_comparison_conditions) 

147 return conditions 

148 

149 

150def project_dataframe(df, targets, table_columns): 

151 """ 

152 case-insensitive projection 

153 'select A' and 'select a' return different column case but with the same content 

154 """ 

155 

156 columns = [] 

157 df_cols_idx = {col.lower(): col for col in df.columns} 

158 df_col_rename = {} 

159 

160 for target in targets: 

161 if isinstance(target, ast.Star): 

162 for col in table_columns: 

163 col_df = df_cols_idx.get(col.lower()) 

164 if col_df is not None: 

165 df_col_rename[col_df] = col 

166 columns.append(col) 

167 

168 break 

169 elif isinstance(target, ast.Identifier): 

170 col = target.parts[-1] 

171 col_df = df_cols_idx.get(col.lower()) 

172 if col_df is not None: 

173 if hasattr(target, "alias") and isinstance(target.alias, ast.Identifier): 

174 df_col_rename[col_df] = target.alias.parts[0] 

175 else: 

176 df_col_rename[col_df] = col 

177 columns.append(col) 

178 else: 

179 raise NotImplementedError 

180 

181 if len(df) == 0: 

182 df = pd.DataFrame([], columns=columns) 

183 else: 

184 # add absent columns 

185 for col in set(columns) & set(df.columns) ^ set(columns): 

186 df[col] = None 

187 

188 # filter by columns 

189 df = df[columns] 

190 

191 # adapt column names to projection 

192 if len(df_col_rename) > 0: 

193 df.rename(columns=df_col_rename, inplace=True) 

194 return df 

195 

196 

197def filter_dataframe(df: pd.DataFrame, conditions: list, raw_conditions=None, order_by=None): 

198 # convert list of conditions to ast. 

199 # assumes that list was got from extract_comparison_conditions 

200 where_query = None 

201 for op, arg1, arg2 in conditions: 201 ↛ 202line 201 didn't jump to line 202 because the loop on line 201 never started

202 op = op.lower() 

203 

204 if op == "between": 

205 item = ast.BetweenOperation(args=[ast.Identifier(arg1), ast.Constant(arg2[0]), ast.Constant(arg2[1])]) 

206 else: 

207 if isinstance(arg2, (tuple, list)): 

208 arg2 = ast.Tuple(arg2) 

209 

210 item = ast.BinaryOperation(op=op, args=[ast.Identifier(arg1), ast.Constant(arg2)]) 

211 if where_query is None: 

212 where_query = item 

213 else: 

214 where_query = ast.BinaryOperation(op="and", args=[where_query, item]) 

215 

216 if raw_conditions: 216 ↛ 217line 216 didn't jump to line 217 because the condition on line 216 was never true

217 for condition in raw_conditions: 

218 if where_query is None: 

219 where_query = condition 

220 else: 

221 where_query = ast.BinaryOperation(op="and", args=[where_query, condition]) 

222 

223 query = ast.Select(targets=[ast.Star()], from_table=ast.Identifier("df"), where=where_query) 

224 

225 if order_by: 225 ↛ 226line 225 didn't jump to line 226 because the condition on line 225 was never true

226 query.order_by = order_by 

227 

228 return query_df(df, query) 

229 

230 

231def sort_dataframe(df, order_by: list): 

232 cols = [] 

233 ascending = [] 

234 for order in order_by: 

235 if not isinstance(order, ast.OrderBy): 

236 continue 

237 

238 col = order.field.parts[-1] 

239 if col not in df.columns: 

240 continue 

241 

242 cols.append(col) 

243 ascending.append(False if order.direction.lower() == "desc" else True) 

244 if len(cols) > 0: 

245 df = df.sort_values(by=cols, ascending=ascending) 

246 return df