Coverage for mindsdb / integrations / handlers / mongodb_handler / utils / mongodb_ast.py: 0%

137 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import re 

2import ast as py_ast 

3import typing as t 

4 

5from mindsdb_sql_parser.ast import OrderBy, Identifier, Star, Select, Constant, BinaryOperation, Tuple, Latest 

6 

7 

8class MongoToAst: 

9 """ 

10 Converts query mongo to AST format 

11 """ 

12 

13 def from_mongoqeury(self, query): 

14 # IS NOT USED YET AND NOT FINISHED 

15 

16 collection = query.collection 

17 

18 filter, projection = None, None 

19 sort, limit, skip = None, None, None 

20 for step in query.pipeline: 

21 if step["method"] == "find": 

22 filter = step["args"][0] 

23 if len(step) > 1: 

24 projection = step["args"][1] 

25 elif step["method"] == "sort": 

26 sort = step["args"][0] 

27 elif step["method"] == "limit": 

28 limit = step["args"][0] 

29 elif step["method"] == "skip": 

30 skip = step["args"][0] 

31 

32 return self.find(collection, filter=filter, sort=sort, projection=projection, limit=limit, skip=skip) 

33 

34 def find( 

35 self, collection: t.Union[list, str], filter=None, sort=None, projection=None, limit=None, skip=None, **kwargs 

36 ): 

37 # https://www.mongodb.com/docs/v4.2/reference/method/db.collection.find/ 

38 

39 order_by = None 

40 if sort is not None: 

41 # sort is dict 

42 order_by = [] 

43 for col, direction in sort.items(): 

44 order_by.append(OrderBy(field=Identifier(parts=[col]), direction="DESC" if direction == -1 else "ASC")) 

45 

46 if projection is not None: 

47 targets = [] 

48 for col, alias in projection.items(): 

49 # it is only identifiers 

50 if isinstance(alias, str): 

51 alias = Identifier(parts=[alias]) 

52 else: 

53 alias = None 

54 targets.append(Identifier(path_str=col, alias=alias)) 

55 else: 

56 targets = [Star()] 

57 

58 where = None 

59 if filter is not None: 

60 where = self.convert_filter(filter) 

61 

62 # convert to AST node 

63 # collection can be string or list 

64 if isinstance(collection, list): 

65 collection = Identifier(parts=collection) 

66 else: 

67 collection = Identifier(path_str=collection) 

68 

69 node = Select( 

70 targets=targets, 

71 from_table=collection, 

72 where=where, 

73 order_by=order_by, 

74 ) 

75 if limit is not None: 

76 node.limit = Constant(value=limit) 

77 

78 if skip is not None and skip != 0: 

79 node.offset = Constant(value=skip) 

80 

81 return node 

82 

83 def convert_filter(self, filter): 

84 cond_ops = { 

85 "$and": "and", 

86 "$or": "or", 

87 } 

88 

89 ast_filter = None 

90 for k, v in filter.items(): 

91 if k in ("$or", "$and"): 

92 # suppose it is one key in dict 

93 

94 op = cond_ops[k] 

95 

96 nodes = [] 

97 for cond in v: 

98 nodes.append(self.convert_filter(cond)) 

99 

100 if len(nodes) == 1: 

101 return nodes[0] 

102 

103 # compose as tree 

104 arg1 = nodes[0] 

105 for node in nodes[1:]: 

106 arg1 = BinaryOperation(op=op, args=[arg1, node]) 

107 

108 return arg1 

109 if k in ("$where", "$expr"): 

110 # try to parse simple expression like 'this.saledate > this.latest' 

111 return MongoWhereParser(v).to_ast() 

112 

113 # is filter 

114 arg1 = Identifier(parts=[k]) 

115 

116 op, value = self.handle_filter(v) 

117 arg2 = Constant(value=value) 

118 ast_com = BinaryOperation(op=op, args=[arg1, arg2]) 

119 if ast_filter is None: 

120 ast_filter = ast_com 

121 else: 

122 ast_filter = BinaryOperation(op="and", args=[ast_filter, ast_com]) 

123 return ast_filter 

124 

125 def handle_filter(self, value): 

126 ops = {"$ge": ">=", "$gt": ">", "$lt": "<", "$le": "<=", "$ne": "!=", "$eq": "="} 

127 in_ops = {"$in": "in", "$nin": "not in"} 

128 

129 if isinstance(value, dict): 

130 key, value = list(value.items())[0] 

131 if key in ops: 

132 op = ops[key] 

133 return op, value 

134 

135 if key in in_ops: 

136 op = in_ops[key] 

137 if not isinstance(value, list): 

138 raise NotImplementedError(f"Unknown type {key}, {value}") 

139 value = Tuple(value) 

140 

141 return op, value 

142 

143 raise NotImplementedError(f"Unknown type {key}") 

144 

145 elif isinstance(value, list): 

146 raise NotImplementedError(f"Unknown filter {value}") 

147 else: 

148 # is simple type 

149 op = "=" 

150 value = value 

151 return op, value 

152 

153 

154class MongoWhereParser: 

155 def __init__(self, query): 

156 self.query = query 

157 

158 def to_ast(self): 

159 # parse as python string 

160 # replace '=' with '==' 

161 query = re.sub(r"([^=><])=([^=])", r"\1==\2", self.query) 

162 

163 tree = py_ast.parse(query, mode="eval") 

164 return self.process(tree.body) 

165 

166 def process(self, node): 

167 if isinstance(node, py_ast.BoolOp): 

168 # is AND or OR 

169 op = node.op.__class__.__name__ 

170 # values can be more than 2 

171 arg1 = self.process(node.values[0]) 

172 for val1 in node.values[1:]: 

173 arg2 = self.process(val1) 

174 arg1 = BinaryOperation(op=op, args=[arg1, arg2]) 

175 

176 return arg1 

177 

178 if isinstance(node, py_ast.Compare): 

179 # it is 

180 if len(node.ops) != 1: 

181 raise NotImplementedError(f"Multiple ops {node.ops}") 

182 op = self.compare_op(node.ops[0]) 

183 arg1 = self.process(node.left) 

184 arg2 = self.process(node.comparators[0]) 

185 return BinaryOperation(op=op, args=[arg1, arg2]) 

186 

187 if isinstance(node, py_ast.Name): 

188 # is special operator: latest, ... 

189 if node.id == "latest": 

190 return Latest() 

191 

192 if isinstance(node, py_ast.Constant): 

193 # it is constant 

194 return Constant(value=node.value) 

195 

196 # ---- python 3.7 objects ----- 

197 if isinstance(node, py_ast.Str): 

198 return Constant(value=node.s) 

199 

200 if isinstance(node, py_ast.Num): 

201 return Constant(value=node.n) 

202 

203 # ----------------------------- 

204 

205 if isinstance(node, py_ast.Attribute): 

206 # is 'this.field' - is attribute 

207 if node.value.id != "this": 

208 raise NotImplementedError(f"Unknown variable {node.value.id}") 

209 return Identifier(parts=[node.attr]) 

210 

211 raise NotImplementedError(f"Unknown node {node}") 

212 

213 def compare_op(self, op): 

214 opname = op.__class__.__name__ 

215 

216 # TODO: in, not 

217 

218 ops = { 

219 "Eq": "=", 

220 "NotEq": "!=", 

221 "Gt": ">", 

222 "Lt": "<", 

223 "GtE": ">=", 

224 "LtE": "<=", 

225 } 

226 if opname not in ops: 

227 raise NotImplementedError(f"Unknown $where op: {opname}") 

228 return ops[opname] 

229 

230 @staticmethod 

231 def test(cls): 

232 assert cls('this.a ==1 and "te" >= latest').to_string() == "a = 1 AND 'te' >= LATEST"