Coverage for mindsdb / integrations / handlers / mongodb_handler / mongodb_handler.py: 66%

185 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import re 

2import time 

3import threading 

4 

5from bson import ObjectId 

6from mindsdb_sql_parser.ast.base import ASTNode 

7import pandas as pd 

8import pymongo 

9from pymongo import MongoClient 

10from pymongo.errors import ServerSelectionTimeoutError, OperationFailure, ConfigurationError, InvalidURI 

11from typing import Text, List, Dict, Any, Union 

12 

13from mindsdb.integrations.handlers.mongodb_handler.utils.mongodb_query import MongoQuery 

14from mindsdb.integrations.handlers.mongodb_handler.utils.mongodb_parser import MongodbParser 

15from mindsdb.integrations.libs.base import DatabaseHandler 

16from mindsdb.integrations.libs.response import ( 

17 HandlerStatusResponse as StatusResponse, 

18 HandlerResponse as Response, 

19 RESPONSE_TYPE, 

20) 

21from mindsdb.utilities import log 

22from .utils.mongodb_render import MongodbRender 

23 

24 

25logger = log.getLogger(__name__) 

26 

27 

28class MongoDBHandler(DatabaseHandler): 

29 """ 

30 This handler handles the connection and execution of SQL statements on MongoDB. 

31 """ 

32 

33 _SUBSCRIBE_SLEEP_INTERVAL = 0.5 

34 name = "mongodb" 

35 

36 def __init__(self, name: Text, **kwargs: Any) -> None: 

37 """ 

38 Initializes the handler. 

39 

40 Args: 

41 name (Text): The name of the handler instance. 

42 kwargs: Arbitrary keyword arguments including the connection data. 

43 """ 

44 super().__init__(name) 

45 connection_data = kwargs["connection_data"] 

46 self.host = connection_data.get("host") 

47 self.port = int(connection_data.get("port") or 27017) 

48 self.user = connection_data.get("username") 

49 self.password = connection_data.get("password") 

50 self.database = connection_data.get("database") 

51 self.flatten_level = connection_data.get("flatten_level", 0) 

52 

53 self.connection = None 

54 self.is_connected = False 

55 

56 def __del__(self) -> None: 

57 """ 

58 Closes the connection when the handler instance is deleted. 

59 """ 

60 if self.is_connected: 

61 self.disconnect() 

62 

63 def connect(self) -> MongoClient: 

64 """ 

65 Establishes a connection to the MongoDB host. 

66 

67 Raises: 

68 ValueError: If the expected connection parameters are not provided. 

69 

70 Returns: 

71 pymongo.MongoClient: A connection object to the MongoDB host. 

72 """ 

73 kwargs = {} 

74 if isinstance(self.user, str) and len(self.user) > 0: 74 ↛ 75line 74 didn't jump to line 75 because the condition on line 74 was never true

75 kwargs["username"] = self.user 

76 

77 if isinstance(self.password, str) and len(self.password) > 0: 77 ↛ 78line 77 didn't jump to line 78 because the condition on line 77 was never true

78 kwargs["password"] = self.password 

79 

80 if re.match(r"/?.*tls=true", self.host.lower()): 80 ↛ 81line 80 didn't jump to line 81 because the condition on line 80 was never true

81 kwargs["tls"] = True 

82 

83 if re.match(r"/?.*tls=false", self.host.lower()): 83 ↛ 84line 83 didn't jump to line 84 because the condition on line 83 was never true

84 kwargs["tls"] = False 

85 

86 try: 

87 connection = MongoClient(self.host, port=self.port, **kwargs) 

88 except InvalidURI as invalid_uri_error: 

89 logger.error(f"Invalid URI provided for MongoDB connection: {invalid_uri_error}!") 

90 raise 

91 except ConfigurationError as config_error: 

92 logger.error(f"Configuration error connecting to MongoDB: {config_error}!") 

93 raise 

94 except Exception as unknown_error: 

95 logger.error(f"Unknown error connecting to MongoDB: {unknown_error}!") 

96 raise 

97 

98 # Get the database name from the connection if it's not provided. 

99 if self.database is None: 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true

100 self.database = connection.get_database().name 

101 

102 self.is_connected = True 

103 self.connection = connection 

104 return self.connection 

105 

106 def subscribe( 

107 self, stop_event: threading.Event, callback: callable, table_name: Text, columns: List = None, **kwargs: Any 

108 ) -> None: 

109 """ 

110 Subscribes to changes in a MongoDB collection and calls the provided callback function when changes occur. 

111 

112 Args: 

113 stop_event (threading.Event): An event object to stop the subscription. 

114 callback (callable): The callback function to call when changes occur. 

115 table_name (Text): The name of the collection to subscribe to. 

116 columns (List): A list of columns to monitor for changes. 

117 kwargs: Arbitrary keyword arguments. 

118 """ 

119 con = self.connect() 

120 cur = con[self.database][table_name].watch() 

121 

122 while True: 

123 if stop_event.is_set(): 

124 cur.close() 

125 return 

126 

127 res = cur.try_next() 

128 if res is None: 

129 time.sleep(self._SUBSCRIBE_SLEEP_INTERVAL) 

130 continue 

131 

132 _id = res["documentKey"]["_id"] 

133 if res["operationType"] == "insert": 

134 if columns is not None: 

135 updated_columns = set(res["fullDocument"].keys()) 

136 if not set(columns) & set(updated_columns): 

137 # Do nothing. 

138 continue 

139 

140 callback(row=res["fullDocument"], key={"_id": _id}) 

141 

142 if res["operationType"] == "update": 

143 if columns is not None: 

144 updated_columns = set(res["updateDescription"]["updatedFields"].keys()) 

145 if not set(columns) & set(updated_columns): 

146 # Do nothing. 

147 continue 

148 

149 # Get the full document. 

150 full_doc = con[self.database][table_name].find_one(res["documentKey"]) 

151 callback(row=full_doc, key={"_id": _id}) 

152 

153 def disconnect(self) -> None: 

154 """ 

155 Closes the connection to the MongoDB host if it's currently open. 

156 """ 

157 if self.is_connected is False: 157 ↛ 158line 157 didn't jump to line 158 because the condition on line 157 was never true

158 return 

159 

160 self.connection.close() 

161 self.is_connected = False 

162 

163 def check_connection(self) -> StatusResponse: 

164 """ 

165 Checks the status of the connection to the MongoDB host. 

166 

167 Returns: 

168 StatusResponse: An object containing the success status and an error message if an error occurs. 

169 """ 

170 response = StatusResponse(False) 

171 need_to_close = self.is_connected is False 

172 

173 try: 

174 con = self.connect() 

175 con.server_info() 

176 

177 # Check if the database exists. 

178 if self.database not in con.list_database_names(): 

179 raise ValueError(f"Database {self.database} not found!") 

180 

181 response.success = True 

182 except ( 

183 InvalidURI, 

184 ServerSelectionTimeoutError, 

185 OperationFailure, 

186 ConfigurationError, 

187 ValueError, 

188 ) as known_error: 

189 logger.error(f"Error connecting to MongoDB {self.database}, {known_error}!") 

190 response.error_message = str(known_error) 

191 except Exception as unknown_error: 

192 logger.error(f"Unknown error connecting to MongoDB {self.database}, {unknown_error}!") 

193 response.error_message = str(unknown_error) 

194 

195 if response.success and need_to_close: 

196 self.disconnect() 

197 

198 elif not response.success and self.is_connected: 198 ↛ 201line 198 didn't jump to line 201 because the condition on line 198 was always true

199 self.is_connected = False 

200 

201 return response 

202 

203 def native_query(self, query: Union[Text, Dict, MongoQuery]) -> Response: 

204 """ 

205 Executes a SQL query on the MongoDB host and returns the result. 

206 

207 Args: 

208 query (str): The SQL query to be executed. 

209 

210 Returns: 

211 Response: A response object containing the result of the query or an error message. 

212 """ 

213 if isinstance(query, str): 213 ↛ 214line 213 didn't jump to line 214 because the condition on line 213 was never true

214 query = MongodbParser().from_string(query) 

215 

216 if isinstance(query, dict): 216 ↛ 218line 216 didn't jump to line 218 because the condition on line 216 was never true

217 # Fallback for the previous API. 

218 mquery = MongoQuery(query["collection"]) 

219 

220 for c in query["call"]: 

221 mquery.add_step({"method": c["method"], "args": c["args"]}) 

222 

223 query = mquery 

224 

225 collection = query.collection 

226 database = self.database 

227 

228 con = self.connect() 

229 

230 # Check if the collection exists. 

231 if collection not in con[database].list_collection_names(): 

232 return Response( 

233 RESPONSE_TYPE.ERROR, error_message=f"Collection {collection} not found in database {database}!" 

234 ) 

235 

236 try: 

237 cursor = con[database][collection] 

238 

239 for step in query.pipeline: 

240 fnc = getattr(cursor, step["method"]) 

241 cursor = fnc(*step["args"]) 

242 

243 result = [] 

244 if not isinstance(cursor, pymongo.results.UpdateResult): 

245 for row in cursor: 

246 result.append(self.flatten(row, level=self.flatten_level)) 

247 

248 else: 

249 return Response(RESPONSE_TYPE.OK) 

250 

251 if len(result) > 0: 251 ↛ 254line 251 didn't jump to line 254 because the condition on line 251 was always true

252 df = pd.DataFrame(result) 

253 else: 

254 columns = list(self.get_columns(collection).data_frame.Field) 

255 df = pd.DataFrame([], columns=columns) 

256 

257 response = Response(RESPONSE_TYPE.TABLE, df) 

258 except Exception as e: 

259 logger.error(f"Error running query: {query} on {self.database}.{collection}!") 

260 response = Response(RESPONSE_TYPE.ERROR, error_message=str(e)) 

261 

262 return response 

263 

264 def flatten(self, row: Dict, level: int = 0) -> Dict: 

265 """ 

266 Flattens a nested dictionary to a single level. 

267 

268 Args: 

269 row (Dict): The dictionary to flatten. 

270 level (int): The number of levels to flatten. If 0, the entire dictionary is flattened. 

271 

272 Returns: 

273 Dict: The flattened dictionary. 

274 """ 

275 add = {} 

276 del_keys = [] 

277 edit_keys = {} 

278 

279 for k, v in row.items(): 

280 # Convert ObjectId to string. 

281 if isinstance(v, ObjectId): 

282 edit_keys[k] = str(v) 

283 if level > 0: 283 ↛ 284line 283 didn't jump to line 284 because the condition on line 283 was never true

284 if isinstance(v, dict): 

285 for k2, v2 in self.flatten(v, level=level - 1).items(): 

286 add[f"{k}.{k2}"] = v2 

287 del_keys.append(k) 

288 

289 if add: 289 ↛ 290line 289 didn't jump to line 290 because the condition on line 289 was never true

290 row.update(add) 

291 for key in del_keys: 291 ↛ 292line 291 didn't jump to line 292 because the loop on line 291 never started

292 del row[key] 

293 if edit_keys: 

294 row.update(edit_keys) 

295 

296 return row 

297 

298 def query(self, query: ASTNode) -> Response: 

299 """ 

300 Executes a SQL query represented by an ASTNode on the MongoDB host and retrieves the data. 

301 

302 Args: 

303 query (ASTNode): An ASTNode representing the SQL query to be executed. 

304 

305 Returns: 

306 Response: The response from the `native_query` method, containing the result of the SQL query execution. 

307 """ 

308 renderer = MongodbRender() 

309 mquery = renderer.to_mongo_query(query) 

310 return self.native_query(mquery) 

311 

312 def get_tables(self) -> Response: 

313 """ 

314 Retrieves a list of all non-system tables (collections) in the MongoDB host. 

315 

316 Returns: 

317 Response: A response object containing a list of tables (collections) in the MongoDB host. 

318 """ 

319 con = self.connect() 

320 collections = con[self.database].list_collection_names() 

321 collections_ar = [[i] for i in collections] 

322 df = pd.DataFrame(collections_ar, columns=["table_name"]) 

323 

324 response = Response(RESPONSE_TYPE.TABLE, df) 

325 

326 return response 

327 

328 def get_columns(self, table_name: Text) -> Response: 

329 """ 

330 Retrieves column (field) details for a specified table (collection) in the MongoDB host. 

331 The first record in the collection is used to determine the column details. 

332 

333 Args: 

334 table_name (Text): The name of the table (collection) for which to retrieve column (field) information. 

335 

336 Raises: 

337 ValueError: If the 'table_name' is not a valid string. 

338 

339 Returns: 

340 Response: A response object containing the column details. 

341 """ 

342 if not table_name or not isinstance(table_name, str): 342 ↛ 343line 342 didn't jump to line 343 because the condition on line 342 was never true

343 raise ValueError("Invalid table name provided.") 

344 

345 con = self.connect() 

346 record = con[self.database][table_name].find_one() 

347 

348 data = [] 

349 if record is not None: 349 ↛ 355line 349 didn't jump to line 355 because the condition on line 349 was always true

350 record = self.flatten(record) 

351 

352 for k, v in record.items(): 

353 data.append([k, type(v).__name__]) 

354 

355 df = pd.DataFrame(data, columns=["Field", "Type"]) 

356 

357 response = Response(RESPONSE_TYPE.TABLE, df) 

358 return response