Coverage for mindsdb / integrations / handlers / mongodb_handler / utils / mongodb_render.py: 64%

246 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import datetime as dt 

2from typing import Dict, Union, Any, Optional, Tuple as TypingTuple 

3 

4from bson.objectid import ObjectId 

5from mindsdb_sql_parser.ast import ( 

6 Select, 

7 Update, 

8 Identifier, 

9 Star, 

10 Constant, 

11 Tuple, 

12 BinaryOperation, 

13 Latest, 

14 TypeCast, 

15 Function, 

16) 

17from mindsdb_sql_parser.ast.base import ASTNode 

18 

19from mindsdb.integrations.handlers.mongodb_handler.utils.mongodb_query import MongoQuery 

20 

21 

22# TODO: Create base NonRelationalRender as SqlAlchemyRender 

23class NonRelationalRender: 

24 pass 

25 

26 

27class MongodbRender(NonRelationalRender): 

28 """ 

29 Renderer to convert SQL queries represented as ASTNodes to MongoQuery instances. 

30 """ 

31 

32 def _parse_inner_query(self, node: ASTNode) -> Dict[str, Any]: 

33 """ 

34 Return field ref like "$field" or constant for projection expressions. 

35 """ 

36 if isinstance(node, Identifier): 36 ↛ 38line 36 didn't jump to line 38 because the condition on line 36 was always true

37 return f"${node.parts[-1]}" 

38 elif isinstance(node, Constant): 

39 return node.value 

40 else: 

41 raise NotImplementedError(f"Not supported inner query {node}") 

42 

43 def _convert_type_cast(self, node: TypeCast) -> Dict[str, Any]: 

44 """ 

45 Converts a TypeCast ASTNode to a MongoDB-compatible format. 

46 

47 Args: 

48 node (TypeCast): The TypeCast node to be converted. 

49 

50 Returns: 

51 Dict[str, Any]: The converted type cast representation. 

52 """ 

53 inner_query = self._parse_inner_query(node.arg) 

54 type_name = node.type_name.upper() 

55 

56 def convert(value: Any, to_type: str) -> Dict[str, Any]: 

57 return {"$convert": {"input": value, "to": to_type, "onError": None}} 

58 

59 if type_name in ("VARCHAR", "TEXT", "STRING"): 59 ↛ 62line 59 didn't jump to line 62 because the condition on line 59 was always true

60 return convert(inner_query, "string") 

61 

62 if type_name in ("INT", "INTEGER", "BIGINT", "LONG"): 

63 return convert(inner_query, "long") 

64 

65 if type_name in ("DOUBLE", "FLOAT", "DECIMAL", "NUMERIC"): 

66 return convert(inner_query, "double") 

67 

68 if type_name in ("DATE", "DATETIME", "TIMESTAMP"): 

69 return convert(inner_query, "date") 

70 

71 return inner_query 

72 

73 def _parse_select(self, from_table: Any) -> TypingTuple[str, Dict[str, Any], Optional[Dict[str, Any]]]: 

74 """ 

75 Parses the from_table to extract the collection name 

76 If from_table is subquery, transform it for MongoDB 

77 Args: 

78 from_table (Any): The from_table to be parsed. 

79 Returns: 

80 str: The collection name. 

81 Dict[str, Any]: The query filters. 

82 Optional[Dict[str, Any]]: The projection fields. 

83 """ 

84 # Simple collection 

85 if isinstance(from_table, Identifier): 

86 return from_table.parts[-1], {}, None 

87 

88 # Trivial subselect 

89 if isinstance(from_table, Select): 89 ↛ 128line 89 didn't jump to line 128 because the condition on line 89 was always true

90 # reject complex forms early 

91 # how deep we want to go with subqueries? 

92 if from_table.group_by is not None or from_table.having is not None: 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true

93 raise NotImplementedError(f"Not supported, subquery has `having` or `group by`: {from_table}") 

94 

95 if not isinstance(from_table.from_table, Identifier): 95 ↛ 96line 95 didn't jump to line 96 because the condition on line 95 was never true

96 raise NotImplementedError(f"Only simple subqueries are allowed in {from_table}") 

97 

98 collection = from_table.from_table.parts[-1] 

99 

100 pre_match: Dict[str, Any] = {} 

101 if from_table.where is not None: 

102 pre_match = self.handle_where(from_table.where) 

103 

104 if from_table.targets is None: 104 ↛ 105line 104 didn't jump to line 105 because the condition on line 104 was never true

105 pre_project: Optional[Dict[str, Any]] = {"_id": 0} 

106 else: 

107 saw_star = any(isinstance(t, Star) for t in from_table.targets) 

108 if saw_star: 

109 pre_project = None 

110 else: 

111 pre_project = {"_id": 0} 

112 for t in from_table.targets: 

113 if isinstance(t, Identifier): 

114 name = ".".join(t.parts) 

115 alias = name if t.alias is None else t.alias.parts[-1] 

116 pre_project[alias] = f"${name}" 

117 elif isinstance(t, Constant): 117 ↛ 118line 117 didn't jump to line 118 because the condition on line 117 was never true

118 alias = str(t.value) if t.alias is None else t.alias.parts[-1] 

119 pre_project[alias] = t.value 

120 elif isinstance(t, TypeCast): 

121 alias = t.alias.parts[-1] if t.alias is not None else t.arg.parts[-1] 

122 pre_project[alias] = self._convert_type_cast(t) 

123 else: 

124 raise NotImplementedError(f"Unsupported inner target: {t}") 

125 

126 return collection, pre_match, pre_project 

127 

128 raise NotImplementedError(f"Not supported from {from_table}") 

129 

130 def to_mongo_query(self, node: ASTNode) -> MongoQuery: 

131 """ 

132 Converts SQL query to MongoQuery instance. 

133 

134 Args: 

135 node (ASTNode): An ASTNode representing the SQL query to be converted. 

136 

137 Returns: 

138 MongoQuery: The converted MongoQuery instance. 

139 """ 

140 if isinstance(node, Select): 

141 return self.select(node) 

142 elif isinstance(node, Update): 

143 return self.update(node) 

144 raise NotImplementedError(f"Unknown statement: {node.__class__.__name__}") 

145 

146 def update(self, node: Update) -> MongoQuery: 

147 """ 

148 Converts an Update statement to MongoQuery instance. 

149 

150 Args: 

151 node (Update): An ASTNode representing the SQL Update statement. 

152 

153 Returns: 

154 MongoQuery: The converted MongoQuery instance. 

155 """ 

156 collection = node.table.parts[-1] 

157 mquery = MongoQuery(collection) 

158 

159 filters = self.handle_where(node.where) 

160 row = {k: v.value for k, v in node.update_columns.items()} 

161 mquery.add_step({"method": "update_many", "args": [filters, {"$set": row}]}) 

162 return mquery 

163 

164 def select(self, node: Select) -> MongoQuery: 

165 """ 

166 Converts a Select statement to MongoQuery instance. 

167 

168 Args: 

169 node (Select): An ASTNode representing the SQL Select statement. 

170 

171 Returns: 

172 MongoQuery: The converted MongoQuery instance. 

173 """ 

174 collection, pre_match, pre_project = self._parse_select(node.from_table) 

175 

176 # check for table aliases 

177 table_aliases = {collection} 

178 

179 if isinstance(node.from_table, Identifier) and node.from_table.alias is not None: 179 ↛ 180line 179 didn't jump to line 180 because the condition on line 179 was never true

180 table_aliases.add(node.from_table.alias.parts[-1]) 

181 

182 filters: Dict[str, Any] = {} 

183 agg_group: Dict[str, Any] = {} 

184 

185 if node.where is not None: 

186 filters = self.handle_where(node.where) 

187 

188 group: Dict[str, Any] = {} 

189 project: Dict[str, Any] = {"_id": 0} 

190 if node.distinct: 190 ↛ 192line 190 didn't jump to line 192 because the condition on line 190 was never true

191 # Group by distinct fields. 

192 group = {"_id": {}} 

193 

194 if node.targets is not None: 194 ↛ 266line 194 didn't jump to line 266 because the condition on line 194 was always true

195 for col in node.targets: 

196 if isinstance(col, Star): 

197 # Show all fields. 

198 project = {} 

199 break 

200 if isinstance(col, Identifier): 

201 parts = list(col.parts) 

202 

203 # Strip table alias/qualifier prefix if present 

204 if len(parts) > 1: 

205 if parts[0] in table_aliases: 205 ↛ 206line 205 didn't jump to line 206 because the condition on line 205 was never true

206 parts = parts[1:] 

207 elif len(table_aliases) == 1: 207 ↛ 211line 207 didn't jump to line 211 because the condition on line 207 was always true

208 parts = parts[1:] 

209 

210 # Convert parts to strings and join 

211 name = ".".join(str(p) for p in parts) if len(parts) > 0 else str(parts[0]) 

212 

213 if col.alias is None: 213 ↛ 216line 213 didn't jump to line 216 because the condition on line 213 was always true

214 alias = name 

215 else: 

216 alias = col.alias.parts[-1] 

217 

218 project[alias] = f"${name}" # Project field. 

219 

220 # Group by distinct fields. 

221 if node.distinct: 221 ↛ 222line 221 didn't jump to line 222 because the condition on line 221 was never true

222 group["_id"][name] = f"${name}" # Group field. 

223 group[name] = {"$first": f"${name}"} # Show field. 

224 

225 elif isinstance(col, Function): 

226 func_name = col.op.lower() 

227 alias = col.alias.parts[-1] if col.alias is not None else func_name 

228 

229 if len(col.args) == 0: 229 ↛ 230line 229 didn't jump to line 230 because the condition on line 229 was never true

230 raise NotImplementedError(f"Function {func_name.upper()} requires arguments") 

231 

232 arg0 = col.args[0] 

233 

234 if func_name == "count" and isinstance(arg0, Star): 234 ↛ 235line 234 didn't jump to line 235 because the condition on line 234 was never true

235 agg_group[alias] = {"$sum": 1} 

236 elif isinstance(arg0, Identifier): 236 ↛ 195line 236 didn't jump to line 195 because the condition on line 236 was always true

237 args_parts = list(arg0.parts) 

238 

239 # Strip table alias/qualifier prefix if present 

240 if len(args_parts) > 1: 240 ↛ 248line 240 didn't jump to line 248 because the condition on line 240 was always true

241 if args_parts[0] in table_aliases: 241 ↛ 242line 241 didn't jump to line 242 because the condition on line 241 was never true

242 args_parts = args_parts[1:] 

243 # Handle implicit qualifiers in single-table queries 

244 elif len(table_aliases) == 1: 244 ↛ 248line 244 didn't jump to line 248 because the condition on line 244 was always true

245 args_parts = args_parts[1:] 

246 

247 # Convert parts to strings and join 

248 field_name = ".".join(str(p) for p in args_parts) if len(args_parts) > 0 else str(args_parts[0]) 

249 

250 if func_name == "avg": 250 ↛ 252line 250 didn't jump to line 252 because the condition on line 250 was always true

251 agg_group[alias] = {"$avg": f"${field_name}"} 

252 elif func_name == "sum": 

253 agg_group[alias] = {"$sum": f"${field_name}"} 

254 elif func_name == "count": 

255 agg_group[alias] = {"$sum": {"$cond": [{"$ne": [f"${field_name}", None]}, 1, 0]}} 

256 elif func_name == "min": 

257 agg_group[alias] = {"$min": f"${field_name}"} 

258 elif func_name == "max": 

259 agg_group[alias] = {"$max": f"${field_name}"} 

260 else: 

261 raise NotImplementedError(f"Aggregation function '{func_name.upper()}' is not supported") 

262 elif isinstance(col, Constant): 262 ↛ 195line 262 didn't jump to line 195 because the condition on line 262 was always true

263 alias = str(col.value) if col.alias is None else col.alias.parts[-1] 

264 project[alias] = col.value 

265 

266 if node.group_by is not None: 

267 if "_id" not in group or not isinstance(group["_id"], dict): 267 ↛ 270line 267 didn't jump to line 270 because the condition on line 267 was always true

268 group["_id"] = {} 

269 

270 for group_col in node.group_by: 

271 if not isinstance(group_col, Identifier): 271 ↛ 272line 271 didn't jump to line 272 because the condition on line 271 was never true

272 raise NotImplementedError(f"Unsupported GROUP BY column {group_col}") 

273 group_parts = list(group_col.parts) 

274 

275 # Strip table alias/qualifier prefix if present 

276 if len(group_parts) > 1: 276 ↛ 277line 276 didn't jump to line 277 because the condition on line 276 was never true

277 if group_parts[0] in table_aliases: 

278 group_parts = group_parts[1:] 

279 # Handle implicit qualifiers in single-table queries 

280 elif len(table_aliases) == 1: 

281 group_parts = group_parts[1:] 

282 

283 # Convert parts to strings and join 

284 field_name = ".".join(str(p) for p in group_parts) 

285 alias = group_col.alias.parts[-1] if group_col.alias is not None else field_name 

286 

287 group["_id"][alias] = f"${field_name}" 

288 

289 if alias in project: 289 ↛ 270line 289 didn't jump to line 270 because the condition on line 289 was always true

290 group[alias] = {"$first": f"${field_name}"} 

291 project[alias] = f"${alias}" 

292 

293 for alias, expression in agg_group.items(): 

294 group[alias] = expression 

295 project[alias] = f"${alias}" 

296 elif agg_group: 

297 group = {"_id": None} 

298 for alias, expression in agg_group.items(): 

299 group[alias] = expression 

300 project[alias] = f"${alias}" 

301 

302 sort = {} 

303 if node.order_by is not None: 303 ↛ 304line 303 didn't jump to line 304 because the condition on line 303 was never true

304 for col in node.order_by: 

305 name = col.field.parts[-1] 

306 direction = 1 if col.direction.upper() == "ASC" else -1 

307 sort[name] = direction 

308 

309 # Compose the MongoDB query. 

310 mquery = MongoQuery(collection) 

311 

312 method: str = "aggregate" 

313 margs: list = [] 

314 

315 # MongoDB related pipeline steps for the aggregate method. 

316 if node.modifiers is not None: 316 ↛ 320line 316 didn't jump to line 320 because the condition on line 316 was always true

317 for modifier in node.modifiers: 317 ↛ 318line 317 didn't jump to line 318 because the loop on line 317 never started

318 margs.append(modifier) 

319 

320 if pre_match: 

321 margs.append({"$match": pre_match}) 

322 if pre_project is not None and pre_project != {}: 

323 margs.append({"$project": pre_project}) 

324 

325 if filters: 

326 margs.append({"$match": filters}) 

327 

328 if group: 

329 margs.append({"$group": group}) 

330 

331 if project: 

332 margs.append({"$project": project}) 

333 

334 if sort: 334 ↛ 335line 334 didn't jump to line 335 because the condition on line 334 was never true

335 margs.append({"$sort": sort}) 

336 

337 if node.offset is not None: 337 ↛ 338line 337 didn't jump to line 338 because the condition on line 337 was never true

338 margs.append({"$skip": int(node.offset.value)}) 

339 

340 if node.limit is not None: 

341 margs.append({"$limit": int(node.limit.value)}) 

342 

343 mquery.add_step({"method": method, "args": [margs]}) 

344 

345 return mquery 

346 

347 def handle_where(self, node: BinaryOperation) -> Dict: 

348 """ 

349 Converts a BinaryOperation node to a dictionary of MongoDB query filters. 

350 

351 Args: 

352 node (BinaryOperation): A BinaryOperation node representing the SQL WHERE clause. 

353 

354 Returns: 

355 dict: The converted MongoDB query filters. 

356 """ 

357 # TODO: UnaryOperation, function. 

358 if not isinstance(node, BinaryOperation): 358 ↛ 359line 358 didn't jump to line 359 because the condition on line 358 was never true

359 raise NotImplementedError(f"Not supported type {type(node)}") 

360 

361 op = node.op.lower() 

362 a, b = node.args 

363 

364 if op in ("and", "or"): 

365 left = self.handle_where(a) 

366 right = self.handle_where(b) 

367 ops = { 

368 "and": "$and", 

369 "or": "$or", 

370 } 

371 query = {ops[op]: [left, right]} 

372 return query 

373 

374 ops_map = { 

375 ">=": "$gte", 

376 ">": "$gt", 

377 "<": "$lt", 

378 "<=": "$lte", 

379 "<>": "$ne", 

380 "!=": "$ne", 

381 "=": "$eq", 

382 "==": "$eq", 

383 "is": "$eq", 

384 "is not": "$ne", 

385 } 

386 

387 if isinstance(a, Identifier): 387 ↛ 418line 387 didn't jump to line 418 because the condition on line 387 was always true

388 var_name = ".".join(a.parts) 

389 # Simple operation. 

390 if isinstance(b, Constant): 390 ↛ 404line 390 didn't jump to line 404 because the condition on line 390 was always true

391 # Identifier and Constant. 

392 val = ObjectId(b.value) if var_name == "_id" else b.value 

393 if op in ("=", "=="): 

394 pass 

395 elif op in ops_map: 

396 op2 = ops_map[op] 

397 val = {op2: val} 

398 else: 

399 raise NotImplementedError(f"Not supported operator {op}") 

400 

401 return {var_name: val} 

402 

403 # IN condition. 

404 elif isinstance(b, Tuple): 

405 # Should be IN, NOT IN. 

406 ops = {"in": "$in", "not in": "$nin"} 

407 # Must be list of Constants. 

408 values = [i.value for i in b.items] 

409 if op in ops: 

410 op2 = ops[op] 

411 cond = {op2: values} 

412 else: 

413 raise NotImplementedError(f"Not supported operator {op}") 

414 

415 return {var_name: cond} 

416 

417 # Create expression. 

418 val1 = self.where_element_convert(a) 

419 val2 = self.where_element_convert(b) 

420 

421 if op in ops_map: 

422 op2 = ops_map[op] 

423 else: 

424 raise NotImplementedError(f"Not supported operator {op}") 

425 

426 return {"$expr": {op2: [val1, val2]}} 

427 

428 def where_element_convert(self, node: Union[Identifier, Latest, Constant, TypeCast]) -> Any: 

429 """ 

430 Converts a WHERE element to the corresponding MongoDB query element. 

431 

432 Args: 

433 node (Union[Identifier, Latest, Constant, TypeCast]): The WHERE element to be converted. 

434 

435 Returns: 

436 Any: The converted MongoDB query element. 

437 

438 Raises: 

439 NotImplementedError: If the WHERE element is not supported. 

440 RuntimeError: If the date format is not supported. 

441 """ 

442 if isinstance(node, Identifier): 

443 return f"${'.'.join(node.parts)}" 

444 elif isinstance(node, Latest): 

445 return "LATEST" 

446 elif isinstance(node, Constant): 

447 return node.value 

448 elif isinstance(node, TypeCast) and node.type_name.upper() in ( 

449 "DATE", 

450 "DATETIME", 

451 ): 

452 formats = ["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S.%f"] 

453 for format in formats: 

454 try: 

455 return dt.datetime.strptime(node.arg.value, format) 

456 except ValueError: 

457 pass 

458 raise RuntimeError(f"Not supported date format. Supported: {formats}") 

459 else: 

460 raise NotImplementedError(f"Unknown where element {node}")