Coverage for mindsdb / interfaces / query_context / context_controller.py: 26%

282 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import List, Optional, Iterable 

2import pickle 

3import datetime as dt 

4 

5from sqlalchemy.orm.attributes import flag_modified 

6import pandas as pd 

7 

8from mindsdb_sql_parser import Select, Star, OrderBy 

9 

10from mindsdb_sql_parser.ast import Identifier, BinaryOperation, Last, Constant, ASTNode, Function 

11from mindsdb.integrations.utilities.query_traversal import query_traversal 

12from mindsdb.utilities.cache import get_cache 

13 

14from mindsdb.interfaces.storage import db 

15from mindsdb.utilities.context import context as ctx 

16from mindsdb.utilities.config import config 

17 

18from .last_query import LastQuery 

19 

20 

21class RunningQuery: 

22 """ 

23 Query in progres 

24 """ 

25 

26 OBJECT_TYPE = "query" 

27 

28 def __init__(self, record: db.Queries): 

29 self.record = record 

30 self.sql = record.sql 

31 self.database = record.database or config.get("default_project") 

32 

33 def get_partitions(self, dn, step_call, query: Select) -> Iterable: 

34 """ 

35 Gets chunks of data from data handler for executing them in next steps of the planner 

36 Check if datanode supports fetch with stream 

37 :param dn: datanode to execute query 

38 :param step_call: instance of StepCall to get some parameters from it 

39 :param query: AST query to execute 

40 :return: generator with query results 

41 """ 

42 if hasattr(dn, "has_support_stream") and dn.has_support_stream(): 

43 query2 = self.get_partition_query(step_call.current_step_num, query, stream=True) 

44 

45 for df in dn.query_stream(query2, fetch_size=self.batch_size): 

46 max_track_value = self.get_max_track_value(df) 

47 yield df 

48 self.set_progress(max_track_value=max_track_value) 

49 

50 else: 

51 while True: 

52 query2 = self.get_partition_query(step_call.current_step_num, query, stream=False) 

53 

54 response = dn.query(query=query2, session=step_call.session) 

55 df = response.data_frame 

56 

57 if df is None or len(df) == 0: 

58 break 

59 

60 max_track_value = self.get_max_track_value(df) 

61 yield df 

62 self.set_progress(max_track_value=max_track_value) 

63 

64 def get_partition_query(self, step_num: int, query: Select, stream=False) -> Select: 

65 """ 

66 Generate query for fetching the next partition 

67 It wraps query to 

68 select * from ({query}) 

69 where {track_column} > {previous_value} 

70 order by track_column 

71 limit size {batch_size} 

72 And fill track_column, previous_value, batch_size 

73 

74 If steam is true: 

75 - if track_column is defined: 

76 - don't add limit 

77 - else: 

78 - return user query without modifications 

79 """ 

80 

81 track_column = self.record.parameters.get("track_column") 

82 if track_column is None and stream: 

83 # if no track column for stream fetching: it is not resumable query, execute original query 

84 

85 # check if it is first run of the query 

86 if self.record.processed_rows > 0: 

87 raise RuntimeError("Can't resume query without track_column") 

88 return query 

89 

90 if not stream and track_column is None: 

91 raise ValueError("Track column is not defined") 

92 

93 query = Select( 

94 targets=[Star()], 

95 from_table=query, 

96 order_by=[OrderBy(Identifier(track_column))], 

97 ) 

98 if not stream: 

99 query.limit = Constant(self.batch_size) 

100 

101 track_value = self.record.context.get("track_value") 

102 # is it different step? 

103 cur_step_num = self.record.context.get("step_num") 

104 if cur_step_num is not None and cur_step_num != step_num: 

105 # reset track_value 

106 track_value = None 

107 self.record.context["track_value"] = None 

108 self.record.context["step_num"] = step_num 

109 flag_modified(self.record, "context") 

110 db.session.commit() 

111 

112 if track_value is not None: 

113 query.where = BinaryOperation( 

114 op=">", 

115 args=[Identifier(track_column), Constant(track_value)], 

116 ) 

117 

118 return query 

119 

120 def get_info(self): 

121 record = self.record 

122 return { 

123 "id": record.id, 

124 "sql": record.sql, 

125 "database": record.database, 

126 "started_at": record.started_at, 

127 "finished_at": record.finished_at, 

128 "parameters": record.parameters, 

129 "context": record.context, 

130 "processed_rows": record.processed_rows, 

131 "error": record.error, 

132 "updated_at": record.updated_at, 

133 } 

134 

135 def add_to_task(self): 

136 task_record = db.Tasks( 

137 company_id=ctx.company_id, 

138 user_class=ctx.user_class, 

139 object_type=self.OBJECT_TYPE, 

140 object_id=self.record.id, 

141 ) 

142 db.session.add(task_record) 

143 db.session.commit() 

144 

145 def remove_from_task(self): 

146 task = db.Tasks.query.filter( 

147 db.Tasks.object_type == self.OBJECT_TYPE, 

148 db.Tasks.object_id == self.record.id, 

149 db.Tasks.company_id == ctx.company_id, 

150 ).first() 

151 

152 if task is not None: 

153 db.session.delete(task) 

154 db.session.commit() 

155 

156 def set_params(self, params: dict): 

157 """ 

158 Store parameters of the step which is about to be split into partitions 

159 """ 

160 

161 if "batch_size" not in params: 

162 params["batch_size"] = 1000 

163 

164 self.record.parameters = params 

165 self.batch_size = self.record.parameters["batch_size"] 

166 db.session.commit() 

167 

168 def get_max_track_value(self, df: pd.DataFrame) -> Optional[pd.DataFrame]: 

169 """ 

170 return max value to use in `set_progress`. 

171 this function is called before execution substeps, 

172 `set_progress` function - after 

173 """ 

174 if "track_column" in self.record.parameters: 

175 track_column = self.record.parameters["track_column"] 

176 return df[track_column].max() 

177 else: 

178 # stream mode 

179 return None 

180 

181 def set_progress(self, processed_rows: int = None, max_track_value: int = None): 

182 """ 

183 Store progres of the query, it is called after processing of batch 

184 """ 

185 

186 if processed_rows is not None and processed_rows > 0: 

187 self.record.processed_rows = self.record.processed_rows + processed_rows 

188 db.session.commit() 

189 

190 if max_track_value is not None: 

191 cur_value = self.record.context.get("track_value") 

192 new_value = max_track_value 

193 if new_value is not None: 

194 if cur_value is None or new_value > cur_value: 

195 self.record.context["track_value"] = new_value 

196 flag_modified(self.record, "context") 

197 db.session.commit() 

198 

199 def on_error(self, error: Exception, step_num: int, steps_data: dict): 

200 """ 

201 Saves error of the query in database 

202 Also saves step data and current step num to be able to resume query 

203 """ 

204 self.record.error = str(error) 

205 self.record.context["step_num"] = step_num 

206 flag_modified(self.record, "context") 

207 

208 # save steps_data 

209 cache = get_cache("steps_data") 

210 data = pickle.dumps(steps_data, protocol=5) 

211 cache.set(str(self.record.id), data) 

212 

213 db.session.commit() 

214 

215 def mark_as_run(self): 

216 """ 

217 Mark query as running and reset error of the query 

218 """ 

219 if self.record.finished_at is not None: 

220 raise RuntimeError("The query already finished") 

221 

222 if self.record.started_at is None: 

223 self.record.started_at = dt.datetime.now() 

224 db.session.commit() 

225 elif self.record.error is not None: 

226 self.record.error = None 

227 db.session.commit() 

228 else: 

229 raise RuntimeError("The query might be running already") 

230 

231 def get_state(self) -> dict: 

232 """ 

233 Returns stored state for resuming the query 

234 """ 

235 cache = get_cache("steps_data") 

236 key = self.record.id 

237 data = cache.get(key) 

238 cache.delete(key) 

239 

240 steps_data = pickle.loads(data) 

241 

242 return { 

243 "step_num": self.record.context.get("step_num"), 

244 "steps_data": steps_data, 

245 } 

246 

247 def finish(self): 

248 """ 

249 Mark query as finished 

250 """ 

251 

252 self.record.finished_at = dt.datetime.now() 

253 db.session.commit() 

254 

255 

256class QueryContextController: 

257 IGNORE_CONTEXT = "<IGNORE>" 

258 

259 def handle_db_context_vars(self, query: ASTNode, dn, session) -> tuple: 

260 """ 

261 Check context variables in query and replace them with values. 

262 Should be used before exec query in database. 

263 

264 Input: 

265 - query: input query 

266 - params are used to find current values of context variables 

267 - dn: datanode 

268 - session: mindsdb server session 

269 

270 Returns: 

271 - query with replaced context variables 

272 - callback to call with result of the query. it is used to update context variables 

273 

274 """ 

275 context_name = self.get_current_context() 

276 

277 l_query = LastQuery(query) 

278 if l_query.query is None: 278 ↛ 282line 278 didn't jump to line 282 because the condition on line 278 was always true

279 # no last keyword, exit 

280 return query, None 

281 

282 if context_name == self.IGNORE_CONTEXT: 

283 # return with empty constants 

284 return l_query.query, None 

285 

286 query_str = l_query.to_string() 

287 

288 rec = self._get_context_record(context_name, query_str) 

289 

290 if rec is None or len(rec.values) == 0: 

291 values = self._get_init_last_values(l_query, dn, session) 

292 if rec is None: 

293 self.__add_context_record(context_name, query_str, values) 

294 if context_name.startswith("job-if-"): 

295 # add context for job also 

296 self.__add_context_record(context_name.replace("job-if", "job"), query_str, values) 

297 else: 

298 rec.values = values 

299 else: 

300 values = rec.values 

301 

302 db.session.commit() 

303 

304 query_out = l_query.apply_values(values) 

305 

306 def callback(df, columns_info): 

307 self._result_callback(l_query, context_name, query_str, df, columns_info) 

308 

309 return query_out, callback 

310 

311 def remove_lasts(self, query): 

312 def replace_lasts(node, **kwargs): 

313 # find last in where 

314 if isinstance(node, BinaryOperation): 314 ↛ 315line 314 didn't jump to line 315 because the condition on line 314 was never true

315 arg1, arg2 = node.args 

316 if not isinstance(arg1, Identifier): 

317 arg1, arg2 = arg2, arg1 

318 

319 # one of the args must be identifier 

320 if not isinstance(arg1, Identifier): 

321 return 

322 

323 # another must be LAST or function with LAST in args 

324 if isinstance(arg2, Last) or ( 

325 isinstance(arg2, Function) and any(isinstance(arg, Last) for arg in arg2.args) 

326 ): 

327 node.args = [Constant(0), Constant(0)] 

328 node.op = "=" 

329 return node 

330 

331 # find lasts 

332 query_traversal(query, replace_lasts) 

333 return query 

334 

335 def _result_callback( 

336 self, l_query: LastQuery, context_name: str, query_str: str, df: pd.DataFrame, columns_info: list 

337 ): 

338 """ 

339 This function handlers result from executed query and updates context variables with new values 

340 

341 Input 

342 - l_query: LastQuery object 

343 - To identify context: 

344 - context_name: name of the context 

345 - query_str: rendered query to search in context table 

346 - result of the query 

347 - data: list of dicts 

348 - columns_info: list 

349 

350 """ 

351 if len(df) == 0: 

352 return 

353 

354 values = {} 

355 # get max values 

356 for info in l_query.get_last_columns(): 

357 target_idx = info["target_idx"] 

358 if target_idx is not None: 

359 # get by index 

360 col_name = columns_info[target_idx]["name"] 

361 else: 

362 col_name = info["column_name"] 

363 # get by name 

364 if col_name not in df: 

365 continue 

366 

367 column_values = df[col_name].dropna() 

368 try: 

369 value = max(column_values) 

370 except (TypeError, ValueError): 

371 try: 

372 # try to convert to float 

373 value = max(map(float, column_values)) 

374 except (TypeError, ValueError): 

375 try: 

376 # try to convert to str 

377 value = max(map(str, column_values)) 

378 except (TypeError, ValueError): 

379 continue 

380 

381 if value is not None: 

382 values[info["table_name"]] = {info["column_name"]: value} 

383 

384 self.__update_context_record(context_name, query_str, values) 

385 

386 def drop_query_context(self, object_type: str, object_id: int = None): 

387 """ 

388 Drop context for object 

389 :param object_type: type of the object 

390 :param object_id: id 

391 """ 

392 

393 context_name = self.gen_context_name(object_type, object_id) 

394 for rec in ( 394 ↛ 397line 394 didn't jump to line 397 because the loop on line 394 never started

395 db.session.query(db.QueryContext).filter_by(context_name=context_name, company_id=ctx.company_id).all() 

396 ): 

397 db.session.delete(rec) 

398 db.session.commit() 

399 

400 def _get_init_last_values(self, l_query: LastQuery, dn, session) -> dict: 

401 """ 

402 Gets current last values for query. 

403 Creates and executes query for it: 

404 'select <col> from <table> order by <col> desc limit 1" 

405 """ 

406 last_values = {} 

407 for query, info in l_query.get_init_queries(): 

408 response = dn.query(query=query, session=session) 

409 data = response.data_frame 

410 columns_info = response.columns 

411 

412 if len(data) == 0: 

413 value = None 

414 else: 

415 row = list(data.iloc[0]) 

416 

417 idx = None 

418 for i, col in enumerate(columns_info): 

419 if col["name"].upper() == info["column_name"].upper(): 

420 idx = i 

421 break 

422 

423 if idx is None or len(row) == 1: 

424 value = row[0] 

425 else: 

426 value = row[idx] 

427 

428 if value is not None: 

429 last_values[info["table_name"]] = {info["column_name"]: value} 

430 

431 return last_values 

432 

433 # Context 

434 

435 def get_current_context(self) -> str: 

436 """ 

437 returns current context name 

438 """ 

439 try: 

440 context_stack = ctx.context_stack or [] 

441 except AttributeError: 

442 context_stack = [] 

443 if len(context_stack) > 0: 

444 return context_stack[-1] 

445 else: 

446 return "" 

447 

448 def set_context(self, object_type: str = None, object_id: int = None): 

449 """ 

450 Updates current context name, using object name and id 

451 Previous context names are stored on lower levels of stack 

452 """ 

453 try: 

454 context_stack = ctx.context_stack or [] 

455 except AttributeError: 

456 context_stack = [] 

457 context_stack.append(self.gen_context_name(object_type, object_id)) 

458 ctx.context_stack = context_stack 

459 

460 def release_context(self, object_type: str = None, object_id: int = None): 

461 """ 

462 Removed current context (defined by object type and id) and restored previous one 

463 """ 

464 try: 

465 context_stack = ctx.context_stack or [] 

466 except AttributeError: 

467 context_stack = [] 

468 if len(context_stack) == 0: 468 ↛ 469line 468 didn't jump to line 469 because the condition on line 468 was never true

469 return 

470 context_name = self.gen_context_name(object_type, object_id) 

471 if context_stack[-1] == context_name: 471 ↛ 473line 471 didn't jump to line 473 because the condition on line 471 was always true

472 context_stack.pop() 

473 ctx.context_stack = context_stack 

474 

475 def gen_context_name(self, object_type: str, object_id: int) -> str: 

476 """ 

477 Generated name of the context according to object type and name 

478 :return: context name 

479 """ 

480 

481 if object_type is None: 481 ↛ 482line 481 didn't jump to line 482 because the condition on line 481 was never true

482 return "" 

483 if object_id is not None: 

484 object_type += "-" + str(object_id) 

485 return object_type 

486 

487 def get_context_vars(self, object_type: str, object_id: int) -> List[dict]: 

488 """ 

489 Return variables stored in context (defined by object type and id) 

490 

491 :return: list of all context variables related to context name how they stored in context table 

492 """ 

493 context_name = self.gen_context_name(object_type, object_id) 

494 vars = [] 

495 for rec in db.session.query(db.QueryContext).filter_by(context_name=context_name, company_id=ctx.company_id): 495 ↛ 496line 495 didn't jump to line 496 because the loop on line 495 never started

496 if rec.values is not None: 

497 vars.append(rec.values) 

498 

499 return vars 

500 

501 # DB 

502 def _get_context_record(self, context_name: str, query_str: str) -> db.QueryContext: 

503 """ 

504 Find and return record for context and query string 

505 """ 

506 

507 return ( 

508 db.session.query(db.QueryContext) 

509 .filter_by(query=query_str, context_name=context_name, company_id=ctx.company_id) 

510 .first() 

511 ) 

512 

513 def __add_context_record(self, context_name: str, query_str: str, values: dict) -> db.QueryContext: 

514 """ 

515 Creates record (for context and query string) with values and returns it 

516 """ 

517 rec = db.QueryContext(query=query_str, context_name=context_name, company_id=ctx.company_id, values=values) 

518 db.session.add(rec) 

519 return rec 

520 

521 def __update_context_record(self, context_name: str, query_str: str, values: dict): 

522 """ 

523 Updates context record with new values 

524 """ 

525 rec = self._get_context_record(context_name, query_str) 

526 rec.values = values 

527 db.session.commit() 

528 

529 def get_query(self, query_id: int) -> RunningQuery: 

530 """ 

531 Get running query by id 

532 """ 

533 

534 rec = db.Queries.query.filter(db.Queries.id == query_id, db.Queries.company_id == ctx.company_id).first() 

535 

536 if rec is None: 

537 raise RuntimeError(f"Query not found: {query_id}") 

538 return RunningQuery(rec) 

539 

540 def create_query(self, query: ASTNode, database: str = None) -> RunningQuery: 

541 """ 

542 Create a new running query from AST query 

543 """ 

544 

545 # remove old queries 

546 remove_query = db.session.query(db.Queries).filter( 

547 db.Queries.company_id == ctx.company_id, db.Queries.finished_at < (dt.datetime.now() - dt.timedelta(days=1)) 

548 ) 

549 for rec in remove_query.all(): 

550 self.get_query(rec.id).remove_from_task() 

551 db.session.delete(rec) 

552 

553 rec = db.Queries( 

554 sql=str(query), 

555 database=database, 

556 company_id=ctx.company_id, 

557 ) 

558 

559 db.session.add(rec) 

560 db.session.commit() 

561 return RunningQuery(rec) 

562 

563 def list_queries(self) -> List[dict]: 

564 """ 

565 Get list of all running queries with metadata 

566 """ 

567 

568 query = db.session.query(db.Queries).filter(db.Queries.company_id == ctx.company_id) 

569 return [RunningQuery(record).get_info() for record in query] 

570 

571 def cancel_query(self, query_id: int): 

572 """ 

573 Cancels running query by id 

574 """ 

575 rec = db.Queries.query.filter(db.Queries.id == query_id, db.Queries.company_id == ctx.company_id).first() 

576 if rec is None: 

577 raise RuntimeError(f"Query not found: {query_id}") 

578 

579 self.get_query(rec.id).remove_from_task() 

580 

581 # the query in progress will fail when it tries to update status 

582 db.session.delete(rec) 

583 db.session.commit() 

584 

585 

586query_context_controller = QueryContextController()