Coverage for mindsdb / integrations / handlers / mongodb_handler / utils / mongodb_ast.py: 0%
137 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import re
2import ast as py_ast
3import typing as t
5from mindsdb_sql_parser.ast import OrderBy, Identifier, Star, Select, Constant, BinaryOperation, Tuple, Latest
8class MongoToAst:
9 """
10 Converts query mongo to AST format
11 """
13 def from_mongoqeury(self, query):
14 # IS NOT USED YET AND NOT FINISHED
16 collection = query.collection
18 filter, projection = None, None
19 sort, limit, skip = None, None, None
20 for step in query.pipeline:
21 if step["method"] == "find":
22 filter = step["args"][0]
23 if len(step) > 1:
24 projection = step["args"][1]
25 elif step["method"] == "sort":
26 sort = step["args"][0]
27 elif step["method"] == "limit":
28 limit = step["args"][0]
29 elif step["method"] == "skip":
30 skip = step["args"][0]
32 return self.find(collection, filter=filter, sort=sort, projection=projection, limit=limit, skip=skip)
34 def find(
35 self, collection: t.Union[list, str], filter=None, sort=None, projection=None, limit=None, skip=None, **kwargs
36 ):
37 # https://www.mongodb.com/docs/v4.2/reference/method/db.collection.find/
39 order_by = None
40 if sort is not None:
41 # sort is dict
42 order_by = []
43 for col, direction in sort.items():
44 order_by.append(OrderBy(field=Identifier(parts=[col]), direction="DESC" if direction == -1 else "ASC"))
46 if projection is not None:
47 targets = []
48 for col, alias in projection.items():
49 # it is only identifiers
50 if isinstance(alias, str):
51 alias = Identifier(parts=[alias])
52 else:
53 alias = None
54 targets.append(Identifier(path_str=col, alias=alias))
55 else:
56 targets = [Star()]
58 where = None
59 if filter is not None:
60 where = self.convert_filter(filter)
62 # convert to AST node
63 # collection can be string or list
64 if isinstance(collection, list):
65 collection = Identifier(parts=collection)
66 else:
67 collection = Identifier(path_str=collection)
69 node = Select(
70 targets=targets,
71 from_table=collection,
72 where=where,
73 order_by=order_by,
74 )
75 if limit is not None:
76 node.limit = Constant(value=limit)
78 if skip is not None and skip != 0:
79 node.offset = Constant(value=skip)
81 return node
83 def convert_filter(self, filter):
84 cond_ops = {
85 "$and": "and",
86 "$or": "or",
87 }
89 ast_filter = None
90 for k, v in filter.items():
91 if k in ("$or", "$and"):
92 # suppose it is one key in dict
94 op = cond_ops[k]
96 nodes = []
97 for cond in v:
98 nodes.append(self.convert_filter(cond))
100 if len(nodes) == 1:
101 return nodes[0]
103 # compose as tree
104 arg1 = nodes[0]
105 for node in nodes[1:]:
106 arg1 = BinaryOperation(op=op, args=[arg1, node])
108 return arg1
109 if k in ("$where", "$expr"):
110 # try to parse simple expression like 'this.saledate > this.latest'
111 return MongoWhereParser(v).to_ast()
113 # is filter
114 arg1 = Identifier(parts=[k])
116 op, value = self.handle_filter(v)
117 arg2 = Constant(value=value)
118 ast_com = BinaryOperation(op=op, args=[arg1, arg2])
119 if ast_filter is None:
120 ast_filter = ast_com
121 else:
122 ast_filter = BinaryOperation(op="and", args=[ast_filter, ast_com])
123 return ast_filter
125 def handle_filter(self, value):
126 ops = {"$ge": ">=", "$gt": ">", "$lt": "<", "$le": "<=", "$ne": "!=", "$eq": "="}
127 in_ops = {"$in": "in", "$nin": "not in"}
129 if isinstance(value, dict):
130 key, value = list(value.items())[0]
131 if key in ops:
132 op = ops[key]
133 return op, value
135 if key in in_ops:
136 op = in_ops[key]
137 if not isinstance(value, list):
138 raise NotImplementedError(f"Unknown type {key}, {value}")
139 value = Tuple(value)
141 return op, value
143 raise NotImplementedError(f"Unknown type {key}")
145 elif isinstance(value, list):
146 raise NotImplementedError(f"Unknown filter {value}")
147 else:
148 # is simple type
149 op = "="
150 value = value
151 return op, value
154class MongoWhereParser:
155 def __init__(self, query):
156 self.query = query
158 def to_ast(self):
159 # parse as python string
160 # replace '=' with '=='
161 query = re.sub(r"([^=><])=([^=])", r"\1==\2", self.query)
163 tree = py_ast.parse(query, mode="eval")
164 return self.process(tree.body)
166 def process(self, node):
167 if isinstance(node, py_ast.BoolOp):
168 # is AND or OR
169 op = node.op.__class__.__name__
170 # values can be more than 2
171 arg1 = self.process(node.values[0])
172 for val1 in node.values[1:]:
173 arg2 = self.process(val1)
174 arg1 = BinaryOperation(op=op, args=[arg1, arg2])
176 return arg1
178 if isinstance(node, py_ast.Compare):
179 # it is
180 if len(node.ops) != 1:
181 raise NotImplementedError(f"Multiple ops {node.ops}")
182 op = self.compare_op(node.ops[0])
183 arg1 = self.process(node.left)
184 arg2 = self.process(node.comparators[0])
185 return BinaryOperation(op=op, args=[arg1, arg2])
187 if isinstance(node, py_ast.Name):
188 # is special operator: latest, ...
189 if node.id == "latest":
190 return Latest()
192 if isinstance(node, py_ast.Constant):
193 # it is constant
194 return Constant(value=node.value)
196 # ---- python 3.7 objects -----
197 if isinstance(node, py_ast.Str):
198 return Constant(value=node.s)
200 if isinstance(node, py_ast.Num):
201 return Constant(value=node.n)
203 # -----------------------------
205 if isinstance(node, py_ast.Attribute):
206 # is 'this.field' - is attribute
207 if node.value.id != "this":
208 raise NotImplementedError(f"Unknown variable {node.value.id}")
209 return Identifier(parts=[node.attr])
211 raise NotImplementedError(f"Unknown node {node}")
213 def compare_op(self, op):
214 opname = op.__class__.__name__
216 # TODO: in, not
218 ops = {
219 "Eq": "=",
220 "NotEq": "!=",
221 "Gt": ">",
222 "Lt": "<",
223 "GtE": ">=",
224 "LtE": "<=",
225 }
226 if opname not in ops:
227 raise NotImplementedError(f"Unknown $where op: {opname}")
228 return ops[opname]
230 @staticmethod
231 def test(cls):
232 assert cls('this.a ==1 and "te" >= latest').to_string() == "a = 1 AND 'te' >= LATEST"