Coverage for mindsdb / integrations / utilities / query_traversal.py: 85%

182 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from mindsdb_sql_parser import ast 

2 

3 

4def query_traversal(node, callback, is_table=False, is_target=False, parent_query=None, stack=None): 

5 """ 

6 :param node: element 

7 :param callback: function applied to every element 

8 :param is_table: it is table in query 

9 :param is_target: it is the target in select 

10 :param parent_query: current query (select/update/create/...) where we are now 

11 :return: 

12 new element if it is needed to be replaced 

13 or None to keep element and traverse over it 

14 

15 Usage: 

16 Create callback function to check or replace nodes 

17 Example: 

18 ```python 

19 def remove_predictors(node, is_table, **kwargs): 

20 if is_table and isinstance(node, Identifier): 

21 if is_predictor(node): 

22 return Constant(None) 

23 

24 utils.query_traversal(ast_query, remove_predictors) 

25 ``` 

26 

27 """ 

28 

29 if stack is None: 

30 stack = [] 

31 

32 res = callback(node, is_table=is_table, is_target=is_target, parent_query=parent_query, callstack=stack) 

33 stack2 = [node] + stack 

34 

35 if res is not None: 

36 # node is going to be replaced 

37 return res 

38 

39 if isinstance(node, ast.Select): 

40 if node.from_table is not None: 

41 node_out = query_traversal(node.from_table, callback, is_table=True, parent_query=node, stack=stack2) 

42 if node_out is not None: 42 ↛ 43line 42 didn't jump to line 43 because the condition on line 42 was never true

43 node.from_table = node_out 

44 

45 array = [] 

46 for node2 in node.targets: 

47 node_out = query_traversal(node2, callback, parent_query=node, is_target=True, stack=stack2) or node2 

48 if isinstance(node_out, list): 48 ↛ 49line 48 didn't jump to line 49 because the condition on line 48 was never true

49 array.extend(node_out) 

50 else: 

51 array.append(node_out) 

52 node.targets = array 

53 

54 if node.cte is not None: 

55 array = [] 

56 for cte in node.cte: 

57 node_out = query_traversal(cte.query, callback, parent_query=node, stack=stack2) or cte 

58 array.append(node_out) 

59 node.cte = array 

60 

61 if node.where is not None: 

62 node_out = query_traversal(node.where, callback, parent_query=node, stack=stack2) 

63 if node_out is not None: 63 ↛ 64line 63 didn't jump to line 64 because the condition on line 63 was never true

64 node.where = node_out 

65 

66 if node.group_by is not None: 

67 array = [] 

68 for node2 in node.group_by: 

69 node_out = query_traversal(node2, callback, parent_query=node, stack=stack2) or node2 

70 array.append(node_out) 

71 node.group_by = array 

72 

73 if node.having is not None: 

74 node_out = query_traversal(node.having, callback, parent_query=node, stack=stack2) 

75 if node_out is not None: 75 ↛ 76line 75 didn't jump to line 76 because the condition on line 75 was never true

76 node.having = node_out 

77 

78 if node.order_by is not None: 

79 array = [] 

80 for node2 in node.order_by: 

81 node_out = query_traversal(node2, callback, parent_query=node, stack=stack2) or node2 

82 array.append(node_out) 

83 node.order_by = array 

84 

85 elif isinstance(node, (ast.Union, ast.Intersect, ast.Except)): 

86 node_out = query_traversal(node.left, callback, parent_query=node, stack=stack2) 

87 if node_out is not None: 87 ↛ 88line 87 didn't jump to line 88 because the condition on line 87 was never true

88 node.left = node_out 

89 node_out = query_traversal(node.right, callback, parent_query=node, stack=stack2) 

90 if node_out is not None: 90 ↛ 91line 90 didn't jump to line 91 because the condition on line 90 was never true

91 node.right = node_out 

92 

93 elif isinstance(node, ast.Join): 

94 node_out = query_traversal(node.right, callback, is_table=True, parent_query=parent_query, stack=stack2) 

95 if node_out is not None: 

96 node.right = node_out 

97 node_out = query_traversal(node.left, callback, is_table=True, parent_query=parent_query, stack=stack2) 

98 if node_out is not None: 

99 node.left = node_out 

100 if node.condition is not None: 

101 node_out = query_traversal(node.condition, callback, parent_query=parent_query, stack=stack2) 

102 if node_out is not None: 102 ↛ 103line 102 didn't jump to line 103 because the condition on line 102 was never true

103 node.condition = node_out 

104 

105 elif isinstance(node, (ast.Function, ast.BinaryOperation, ast.UnaryOperation, ast.BetweenOperation, 

106 ast.Exists, ast.NotExists)): 

107 array = [] 

108 for arg in node.args: 

109 node_out = query_traversal(arg, callback, parent_query=parent_query, stack=stack2) or arg 

110 array.append(node_out) 

111 node.args = array 

112 

113 if isinstance(node, ast.Function): 

114 if node.from_arg is not None: 

115 node_out = query_traversal(node.from_arg, callback, parent_query=parent_query, stack=stack2) 

116 if node_out is not None: 116 ↛ 117line 116 didn't jump to line 117 because the condition on line 116 was never true

117 node.from_arg = node_out 

118 

119 elif isinstance(node, ast.WindowFunction): 

120 query_traversal(node.function, callback, parent_query=parent_query, stack=stack2) 

121 if node.partition is not None: 

122 array = [] 

123 for node2 in node.partition: 

124 node_out = query_traversal(node2, callback, parent_query=parent_query, stack=stack2) or node2 

125 array.append(node_out) 

126 node.partition = array 

127 if node.order_by is not None: 

128 array = [] 

129 for node2 in node.order_by: 

130 node_out = query_traversal(node2, callback, parent_query=parent_query, stack=stack2) or node2 

131 array.append(node_out) 

132 node.order_by = array 

133 

134 elif isinstance(node, ast.TypeCast): 

135 node_out = query_traversal(node.arg, callback, parent_query=parent_query, stack=stack2) 

136 if node_out is not None: 136 ↛ 137line 136 didn't jump to line 137 because the condition on line 136 was never true

137 node.arg = node_out 

138 

139 elif isinstance(node, ast.Tuple): 

140 array = [] 

141 for node2 in node.items: 

142 node_out = query_traversal(node2, callback, parent_query=parent_query, stack=stack2) or node2 

143 array.append(node_out) 

144 node.items = array 

145 

146 elif isinstance(node, ast.Insert): 

147 if node.table is not None: 147 ↛ 152line 147 didn't jump to line 152 because the condition on line 147 was always true

148 node_out = query_traversal(node.table, callback, is_table=True, parent_query=node, stack=stack2) 

149 if node_out is not None: 149 ↛ 150line 149 didn't jump to line 150 because the condition on line 149 was never true

150 node.table = node_out 

151 

152 if node.values is not None: 

153 rows = [] 

154 for row in node.values: 

155 items = [] 

156 for item in row: 

157 item2 = query_traversal(item, callback, parent_query=node, stack=stack2) or item 

158 items.append(item2) 

159 rows.append(items) 

160 node.values = rows 

161 

162 if node.from_select is not None: 

163 node_out = query_traversal(node.from_select, callback, parent_query=node, stack=stack2) 

164 if node_out is not None: 164 ↛ 165line 164 didn't jump to line 165 because the condition on line 164 was never true

165 node.from_select = node_out 

166 

167 elif isinstance(node, ast.Update): 

168 if node.table is not None: 168 ↛ 173line 168 didn't jump to line 173 because the condition on line 168 was always true

169 node_out = query_traversal(node.table, callback, is_table=True, parent_query=node, stack=stack2) 

170 if node_out is not None: 170 ↛ 171line 170 didn't jump to line 171 because the condition on line 170 was never true

171 node.table = node_out 

172 

173 if node.where is not None: 173 ↛ 178line 173 didn't jump to line 178 because the condition on line 173 was always true

174 node_out = query_traversal(node.where, callback, parent_query=node, stack=stack2) 

175 if node_out is not None: 175 ↛ 176line 175 didn't jump to line 176 because the condition on line 175 was never true

176 node.where = node_out 

177 

178 if node.update_columns is not None: 178 ↛ 187line 178 didn't jump to line 187 because the condition on line 178 was always true

179 changes = {} 

180 for k, v in node.update_columns.items(): 

181 v2 = query_traversal(v, callback, parent_query=node, stack=stack2) 

182 if v2 is not None: 182 ↛ 183line 182 didn't jump to line 183 because the condition on line 182 was never true

183 changes[k] = v2 

184 if changes: 184 ↛ 185line 184 didn't jump to line 185 because the condition on line 184 was never true

185 node.update_columns.update(changes) 

186 

187 if node.from_select is not None: 187 ↛ 188line 187 didn't jump to line 188 because the condition on line 187 was never true

188 node_out = query_traversal(node.from_select, callback, parent_query=node, stack=stack2) 

189 if node_out is not None: 

190 node.from_select = node_out 

191 

192 elif isinstance(node, ast.CreateTable): 

193 array = [] 

194 if node.columns is not None: 

195 for node2 in node.columns: 

196 node_out = query_traversal(node2, callback, parent_query=node, stack=stack2) or node2 

197 array.append(node_out) 

198 node.columns = array 

199 

200 if node.name is not None: 200 ↛ 205line 200 didn't jump to line 205 because the condition on line 200 was always true

201 node_out = query_traversal(node.name, callback, is_table=True, parent_query=node, stack=stack2) 

202 if node_out is not None: 202 ↛ 203line 202 didn't jump to line 203 because the condition on line 202 was never true

203 node.name = node_out 

204 

205 if node.from_select is not None: 

206 node_out = query_traversal(node.from_select, callback, parent_query=node, stack=stack2) 

207 if node_out is not None: 207 ↛ 208line 207 didn't jump to line 208 because the condition on line 207 was never true

208 node.from_select = node_out 

209 

210 elif isinstance(node, ast.Delete): 

211 if node.where is not None: 

212 node_out = query_traversal(node.where, callback, parent_query=node, stack=stack2) 

213 if node_out is not None: 213 ↛ 214line 213 didn't jump to line 214 because the condition on line 213 was never true

214 node.where = node_out 

215 

216 elif isinstance(node, ast.OrderBy): 

217 if node.field is not None: 217 ↛ 244line 217 didn't jump to line 244 because the condition on line 217 was always true

218 node_out = query_traversal(node.field, callback, parent_query=parent_query, stack=stack2) 

219 if node_out is not None: 219 ↛ 220line 219 didn't jump to line 220 because the condition on line 219 was never true

220 node.field = node_out 

221 

222 elif isinstance(node, ast.Case): 

223 rules = [] 

224 for condition, result in node.rules: 

225 condition2 = query_traversal(condition, callback, parent_query=parent_query, stack=stack2) 

226 result2 = query_traversal(result, callback, parent_query=parent_query, stack=stack2) 

227 

228 condition = condition if condition2 is None else condition2 

229 result = result if result2 is None else result2 

230 rules.append([condition, result]) 

231 node.rules = rules 

232 default = query_traversal(node.default, callback, parent_query=parent_query, stack=stack2) 

233 if default is not None: 

234 node.default = default 

235 

236 elif isinstance(node, list): 

237 array = [] 

238 for node2 in node: 

239 node_out = query_traversal(node2, callback, parent_query=parent_query, stack=stack2) or node2 

240 array.append(node_out) 

241 return array 

242 

243 # keep original node 

244 return None