Coverage for mindsdb / integrations / handlers / mongodb_handler / utils / mongodb_render.py: 64%
246 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import datetime as dt
2from typing import Dict, Union, Any, Optional, Tuple as TypingTuple
4from bson.objectid import ObjectId
5from mindsdb_sql_parser.ast import (
6 Select,
7 Update,
8 Identifier,
9 Star,
10 Constant,
11 Tuple,
12 BinaryOperation,
13 Latest,
14 TypeCast,
15 Function,
16)
17from mindsdb_sql_parser.ast.base import ASTNode
19from mindsdb.integrations.handlers.mongodb_handler.utils.mongodb_query import MongoQuery
22# TODO: Create base NonRelationalRender as SqlAlchemyRender
23class NonRelationalRender:
24 pass
27class MongodbRender(NonRelationalRender):
28 """
29 Renderer to convert SQL queries represented as ASTNodes to MongoQuery instances.
30 """
32 def _parse_inner_query(self, node: ASTNode) -> Dict[str, Any]:
33 """
34 Return field ref like "$field" or constant for projection expressions.
35 """
36 if isinstance(node, Identifier): 36 ↛ 38line 36 didn't jump to line 38 because the condition on line 36 was always true
37 return f"${node.parts[-1]}"
38 elif isinstance(node, Constant):
39 return node.value
40 else:
41 raise NotImplementedError(f"Not supported inner query {node}")
43 def _convert_type_cast(self, node: TypeCast) -> Dict[str, Any]:
44 """
45 Converts a TypeCast ASTNode to a MongoDB-compatible format.
47 Args:
48 node (TypeCast): The TypeCast node to be converted.
50 Returns:
51 Dict[str, Any]: The converted type cast representation.
52 """
53 inner_query = self._parse_inner_query(node.arg)
54 type_name = node.type_name.upper()
56 def convert(value: Any, to_type: str) -> Dict[str, Any]:
57 return {"$convert": {"input": value, "to": to_type, "onError": None}}
59 if type_name in ("VARCHAR", "TEXT", "STRING"): 59 ↛ 62line 59 didn't jump to line 62 because the condition on line 59 was always true
60 return convert(inner_query, "string")
62 if type_name in ("INT", "INTEGER", "BIGINT", "LONG"):
63 return convert(inner_query, "long")
65 if type_name in ("DOUBLE", "FLOAT", "DECIMAL", "NUMERIC"):
66 return convert(inner_query, "double")
68 if type_name in ("DATE", "DATETIME", "TIMESTAMP"):
69 return convert(inner_query, "date")
71 return inner_query
73 def _parse_select(self, from_table: Any) -> TypingTuple[str, Dict[str, Any], Optional[Dict[str, Any]]]:
74 """
75 Parses the from_table to extract the collection name
76 If from_table is subquery, transform it for MongoDB
77 Args:
78 from_table (Any): The from_table to be parsed.
79 Returns:
80 str: The collection name.
81 Dict[str, Any]: The query filters.
82 Optional[Dict[str, Any]]: The projection fields.
83 """
84 # Simple collection
85 if isinstance(from_table, Identifier):
86 return from_table.parts[-1], {}, None
88 # Trivial subselect
89 if isinstance(from_table, Select): 89 ↛ 128line 89 didn't jump to line 128 because the condition on line 89 was always true
90 # reject complex forms early
91 # how deep we want to go with subqueries?
92 if from_table.group_by is not None or from_table.having is not None: 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true
93 raise NotImplementedError(f"Not supported, subquery has `having` or `group by`: {from_table}")
95 if not isinstance(from_table.from_table, Identifier): 95 ↛ 96line 95 didn't jump to line 96 because the condition on line 95 was never true
96 raise NotImplementedError(f"Only simple subqueries are allowed in {from_table}")
98 collection = from_table.from_table.parts[-1]
100 pre_match: Dict[str, Any] = {}
101 if from_table.where is not None:
102 pre_match = self.handle_where(from_table.where)
104 if from_table.targets is None: 104 ↛ 105line 104 didn't jump to line 105 because the condition on line 104 was never true
105 pre_project: Optional[Dict[str, Any]] = {"_id": 0}
106 else:
107 saw_star = any(isinstance(t, Star) for t in from_table.targets)
108 if saw_star:
109 pre_project = None
110 else:
111 pre_project = {"_id": 0}
112 for t in from_table.targets:
113 if isinstance(t, Identifier):
114 name = ".".join(t.parts)
115 alias = name if t.alias is None else t.alias.parts[-1]
116 pre_project[alias] = f"${name}"
117 elif isinstance(t, Constant): 117 ↛ 118line 117 didn't jump to line 118 because the condition on line 117 was never true
118 alias = str(t.value) if t.alias is None else t.alias.parts[-1]
119 pre_project[alias] = t.value
120 elif isinstance(t, TypeCast):
121 alias = t.alias.parts[-1] if t.alias is not None else t.arg.parts[-1]
122 pre_project[alias] = self._convert_type_cast(t)
123 else:
124 raise NotImplementedError(f"Unsupported inner target: {t}")
126 return collection, pre_match, pre_project
128 raise NotImplementedError(f"Not supported from {from_table}")
130 def to_mongo_query(self, node: ASTNode) -> MongoQuery:
131 """
132 Converts SQL query to MongoQuery instance.
134 Args:
135 node (ASTNode): An ASTNode representing the SQL query to be converted.
137 Returns:
138 MongoQuery: The converted MongoQuery instance.
139 """
140 if isinstance(node, Select):
141 return self.select(node)
142 elif isinstance(node, Update):
143 return self.update(node)
144 raise NotImplementedError(f"Unknown statement: {node.__class__.__name__}")
146 def update(self, node: Update) -> MongoQuery:
147 """
148 Converts an Update statement to MongoQuery instance.
150 Args:
151 node (Update): An ASTNode representing the SQL Update statement.
153 Returns:
154 MongoQuery: The converted MongoQuery instance.
155 """
156 collection = node.table.parts[-1]
157 mquery = MongoQuery(collection)
159 filters = self.handle_where(node.where)
160 row = {k: v.value for k, v in node.update_columns.items()}
161 mquery.add_step({"method": "update_many", "args": [filters, {"$set": row}]})
162 return mquery
164 def select(self, node: Select) -> MongoQuery:
165 """
166 Converts a Select statement to MongoQuery instance.
168 Args:
169 node (Select): An ASTNode representing the SQL Select statement.
171 Returns:
172 MongoQuery: The converted MongoQuery instance.
173 """
174 collection, pre_match, pre_project = self._parse_select(node.from_table)
176 # check for table aliases
177 table_aliases = {collection}
179 if isinstance(node.from_table, Identifier) and node.from_table.alias is not None: 179 ↛ 180line 179 didn't jump to line 180 because the condition on line 179 was never true
180 table_aliases.add(node.from_table.alias.parts[-1])
182 filters: Dict[str, Any] = {}
183 agg_group: Dict[str, Any] = {}
185 if node.where is not None:
186 filters = self.handle_where(node.where)
188 group: Dict[str, Any] = {}
189 project: Dict[str, Any] = {"_id": 0}
190 if node.distinct: 190 ↛ 192line 190 didn't jump to line 192 because the condition on line 190 was never true
191 # Group by distinct fields.
192 group = {"_id": {}}
194 if node.targets is not None: 194 ↛ 266line 194 didn't jump to line 266 because the condition on line 194 was always true
195 for col in node.targets:
196 if isinstance(col, Star):
197 # Show all fields.
198 project = {}
199 break
200 if isinstance(col, Identifier):
201 parts = list(col.parts)
203 # Strip table alias/qualifier prefix if present
204 if len(parts) > 1:
205 if parts[0] in table_aliases: 205 ↛ 206line 205 didn't jump to line 206 because the condition on line 205 was never true
206 parts = parts[1:]
207 elif len(table_aliases) == 1: 207 ↛ 211line 207 didn't jump to line 211 because the condition on line 207 was always true
208 parts = parts[1:]
210 # Convert parts to strings and join
211 name = ".".join(str(p) for p in parts) if len(parts) > 0 else str(parts[0])
213 if col.alias is None: 213 ↛ 216line 213 didn't jump to line 216 because the condition on line 213 was always true
214 alias = name
215 else:
216 alias = col.alias.parts[-1]
218 project[alias] = f"${name}" # Project field.
220 # Group by distinct fields.
221 if node.distinct: 221 ↛ 222line 221 didn't jump to line 222 because the condition on line 221 was never true
222 group["_id"][name] = f"${name}" # Group field.
223 group[name] = {"$first": f"${name}"} # Show field.
225 elif isinstance(col, Function):
226 func_name = col.op.lower()
227 alias = col.alias.parts[-1] if col.alias is not None else func_name
229 if len(col.args) == 0: 229 ↛ 230line 229 didn't jump to line 230 because the condition on line 229 was never true
230 raise NotImplementedError(f"Function {func_name.upper()} requires arguments")
232 arg0 = col.args[0]
234 if func_name == "count" and isinstance(arg0, Star): 234 ↛ 235line 234 didn't jump to line 235 because the condition on line 234 was never true
235 agg_group[alias] = {"$sum": 1}
236 elif isinstance(arg0, Identifier): 236 ↛ 195line 236 didn't jump to line 195 because the condition on line 236 was always true
237 args_parts = list(arg0.parts)
239 # Strip table alias/qualifier prefix if present
240 if len(args_parts) > 1: 240 ↛ 248line 240 didn't jump to line 248 because the condition on line 240 was always true
241 if args_parts[0] in table_aliases: 241 ↛ 242line 241 didn't jump to line 242 because the condition on line 241 was never true
242 args_parts = args_parts[1:]
243 # Handle implicit qualifiers in single-table queries
244 elif len(table_aliases) == 1: 244 ↛ 248line 244 didn't jump to line 248 because the condition on line 244 was always true
245 args_parts = args_parts[1:]
247 # Convert parts to strings and join
248 field_name = ".".join(str(p) for p in args_parts) if len(args_parts) > 0 else str(args_parts[0])
250 if func_name == "avg": 250 ↛ 252line 250 didn't jump to line 252 because the condition on line 250 was always true
251 agg_group[alias] = {"$avg": f"${field_name}"}
252 elif func_name == "sum":
253 agg_group[alias] = {"$sum": f"${field_name}"}
254 elif func_name == "count":
255 agg_group[alias] = {"$sum": {"$cond": [{"$ne": [f"${field_name}", None]}, 1, 0]}}
256 elif func_name == "min":
257 agg_group[alias] = {"$min": f"${field_name}"}
258 elif func_name == "max":
259 agg_group[alias] = {"$max": f"${field_name}"}
260 else:
261 raise NotImplementedError(f"Aggregation function '{func_name.upper()}' is not supported")
262 elif isinstance(col, Constant): 262 ↛ 195line 262 didn't jump to line 195 because the condition on line 262 was always true
263 alias = str(col.value) if col.alias is None else col.alias.parts[-1]
264 project[alias] = col.value
266 if node.group_by is not None:
267 if "_id" not in group or not isinstance(group["_id"], dict): 267 ↛ 270line 267 didn't jump to line 270 because the condition on line 267 was always true
268 group["_id"] = {}
270 for group_col in node.group_by:
271 if not isinstance(group_col, Identifier): 271 ↛ 272line 271 didn't jump to line 272 because the condition on line 271 was never true
272 raise NotImplementedError(f"Unsupported GROUP BY column {group_col}")
273 group_parts = list(group_col.parts)
275 # Strip table alias/qualifier prefix if present
276 if len(group_parts) > 1: 276 ↛ 277line 276 didn't jump to line 277 because the condition on line 276 was never true
277 if group_parts[0] in table_aliases:
278 group_parts = group_parts[1:]
279 # Handle implicit qualifiers in single-table queries
280 elif len(table_aliases) == 1:
281 group_parts = group_parts[1:]
283 # Convert parts to strings and join
284 field_name = ".".join(str(p) for p in group_parts)
285 alias = group_col.alias.parts[-1] if group_col.alias is not None else field_name
287 group["_id"][alias] = f"${field_name}"
289 if alias in project: 289 ↛ 270line 289 didn't jump to line 270 because the condition on line 289 was always true
290 group[alias] = {"$first": f"${field_name}"}
291 project[alias] = f"${alias}"
293 for alias, expression in agg_group.items():
294 group[alias] = expression
295 project[alias] = f"${alias}"
296 elif agg_group:
297 group = {"_id": None}
298 for alias, expression in agg_group.items():
299 group[alias] = expression
300 project[alias] = f"${alias}"
302 sort = {}
303 if node.order_by is not None: 303 ↛ 304line 303 didn't jump to line 304 because the condition on line 303 was never true
304 for col in node.order_by:
305 name = col.field.parts[-1]
306 direction = 1 if col.direction.upper() == "ASC" else -1
307 sort[name] = direction
309 # Compose the MongoDB query.
310 mquery = MongoQuery(collection)
312 method: str = "aggregate"
313 margs: list = []
315 # MongoDB related pipeline steps for the aggregate method.
316 if node.modifiers is not None: 316 ↛ 320line 316 didn't jump to line 320 because the condition on line 316 was always true
317 for modifier in node.modifiers: 317 ↛ 318line 317 didn't jump to line 318 because the loop on line 317 never started
318 margs.append(modifier)
320 if pre_match:
321 margs.append({"$match": pre_match})
322 if pre_project is not None and pre_project != {}:
323 margs.append({"$project": pre_project})
325 if filters:
326 margs.append({"$match": filters})
328 if group:
329 margs.append({"$group": group})
331 if project:
332 margs.append({"$project": project})
334 if sort: 334 ↛ 335line 334 didn't jump to line 335 because the condition on line 334 was never true
335 margs.append({"$sort": sort})
337 if node.offset is not None: 337 ↛ 338line 337 didn't jump to line 338 because the condition on line 337 was never true
338 margs.append({"$skip": int(node.offset.value)})
340 if node.limit is not None:
341 margs.append({"$limit": int(node.limit.value)})
343 mquery.add_step({"method": method, "args": [margs]})
345 return mquery
347 def handle_where(self, node: BinaryOperation) -> Dict:
348 """
349 Converts a BinaryOperation node to a dictionary of MongoDB query filters.
351 Args:
352 node (BinaryOperation): A BinaryOperation node representing the SQL WHERE clause.
354 Returns:
355 dict: The converted MongoDB query filters.
356 """
357 # TODO: UnaryOperation, function.
358 if not isinstance(node, BinaryOperation): 358 ↛ 359line 358 didn't jump to line 359 because the condition on line 358 was never true
359 raise NotImplementedError(f"Not supported type {type(node)}")
361 op = node.op.lower()
362 a, b = node.args
364 if op in ("and", "or"):
365 left = self.handle_where(a)
366 right = self.handle_where(b)
367 ops = {
368 "and": "$and",
369 "or": "$or",
370 }
371 query = {ops[op]: [left, right]}
372 return query
374 ops_map = {
375 ">=": "$gte",
376 ">": "$gt",
377 "<": "$lt",
378 "<=": "$lte",
379 "<>": "$ne",
380 "!=": "$ne",
381 "=": "$eq",
382 "==": "$eq",
383 "is": "$eq",
384 "is not": "$ne",
385 }
387 if isinstance(a, Identifier): 387 ↛ 418line 387 didn't jump to line 418 because the condition on line 387 was always true
388 var_name = ".".join(a.parts)
389 # Simple operation.
390 if isinstance(b, Constant): 390 ↛ 404line 390 didn't jump to line 404 because the condition on line 390 was always true
391 # Identifier and Constant.
392 val = ObjectId(b.value) if var_name == "_id" else b.value
393 if op in ("=", "=="):
394 pass
395 elif op in ops_map:
396 op2 = ops_map[op]
397 val = {op2: val}
398 else:
399 raise NotImplementedError(f"Not supported operator {op}")
401 return {var_name: val}
403 # IN condition.
404 elif isinstance(b, Tuple):
405 # Should be IN, NOT IN.
406 ops = {"in": "$in", "not in": "$nin"}
407 # Must be list of Constants.
408 values = [i.value for i in b.items]
409 if op in ops:
410 op2 = ops[op]
411 cond = {op2: values}
412 else:
413 raise NotImplementedError(f"Not supported operator {op}")
415 return {var_name: cond}
417 # Create expression.
418 val1 = self.where_element_convert(a)
419 val2 = self.where_element_convert(b)
421 if op in ops_map:
422 op2 = ops_map[op]
423 else:
424 raise NotImplementedError(f"Not supported operator {op}")
426 return {"$expr": {op2: [val1, val2]}}
428 def where_element_convert(self, node: Union[Identifier, Latest, Constant, TypeCast]) -> Any:
429 """
430 Converts a WHERE element to the corresponding MongoDB query element.
432 Args:
433 node (Union[Identifier, Latest, Constant, TypeCast]): The WHERE element to be converted.
435 Returns:
436 Any: The converted MongoDB query element.
438 Raises:
439 NotImplementedError: If the WHERE element is not supported.
440 RuntimeError: If the date format is not supported.
441 """
442 if isinstance(node, Identifier):
443 return f"${'.'.join(node.parts)}"
444 elif isinstance(node, Latest):
445 return "LATEST"
446 elif isinstance(node, Constant):
447 return node.value
448 elif isinstance(node, TypeCast) and node.type_name.upper() in (
449 "DATE",
450 "DATETIME",
451 ):
452 formats = ["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S.%f"]
453 for format in formats:
454 try:
455 return dt.datetime.strptime(node.arg.value, format)
456 except ValueError:
457 pass
458 raise RuntimeError(f"Not supported date format. Supported: {formats}")
459 else:
460 raise NotImplementedError(f"Unknown where element {node}")