Coverage for mindsdb / interfaces / query_context / context_controller.py: 26%
282 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import List, Optional, Iterable
2import pickle
3import datetime as dt
5from sqlalchemy.orm.attributes import flag_modified
6import pandas as pd
8from mindsdb_sql_parser import Select, Star, OrderBy
10from mindsdb_sql_parser.ast import Identifier, BinaryOperation, Last, Constant, ASTNode, Function
11from mindsdb.integrations.utilities.query_traversal import query_traversal
12from mindsdb.utilities.cache import get_cache
14from mindsdb.interfaces.storage import db
15from mindsdb.utilities.context import context as ctx
16from mindsdb.utilities.config import config
18from .last_query import LastQuery
21class RunningQuery:
22 """
23 Query in progres
24 """
26 OBJECT_TYPE = "query"
28 def __init__(self, record: db.Queries):
29 self.record = record
30 self.sql = record.sql
31 self.database = record.database or config.get("default_project")
33 def get_partitions(self, dn, step_call, query: Select) -> Iterable:
34 """
35 Gets chunks of data from data handler for executing them in next steps of the planner
36 Check if datanode supports fetch with stream
37 :param dn: datanode to execute query
38 :param step_call: instance of StepCall to get some parameters from it
39 :param query: AST query to execute
40 :return: generator with query results
41 """
42 if hasattr(dn, "has_support_stream") and dn.has_support_stream():
43 query2 = self.get_partition_query(step_call.current_step_num, query, stream=True)
45 for df in dn.query_stream(query2, fetch_size=self.batch_size):
46 max_track_value = self.get_max_track_value(df)
47 yield df
48 self.set_progress(max_track_value=max_track_value)
50 else:
51 while True:
52 query2 = self.get_partition_query(step_call.current_step_num, query, stream=False)
54 response = dn.query(query=query2, session=step_call.session)
55 df = response.data_frame
57 if df is None or len(df) == 0:
58 break
60 max_track_value = self.get_max_track_value(df)
61 yield df
62 self.set_progress(max_track_value=max_track_value)
64 def get_partition_query(self, step_num: int, query: Select, stream=False) -> Select:
65 """
66 Generate query for fetching the next partition
67 It wraps query to
68 select * from ({query})
69 where {track_column} > {previous_value}
70 order by track_column
71 limit size {batch_size}
72 And fill track_column, previous_value, batch_size
74 If steam is true:
75 - if track_column is defined:
76 - don't add limit
77 - else:
78 - return user query without modifications
79 """
81 track_column = self.record.parameters.get("track_column")
82 if track_column is None and stream:
83 # if no track column for stream fetching: it is not resumable query, execute original query
85 # check if it is first run of the query
86 if self.record.processed_rows > 0:
87 raise RuntimeError("Can't resume query without track_column")
88 return query
90 if not stream and track_column is None:
91 raise ValueError("Track column is not defined")
93 query = Select(
94 targets=[Star()],
95 from_table=query,
96 order_by=[OrderBy(Identifier(track_column))],
97 )
98 if not stream:
99 query.limit = Constant(self.batch_size)
101 track_value = self.record.context.get("track_value")
102 # is it different step?
103 cur_step_num = self.record.context.get("step_num")
104 if cur_step_num is not None and cur_step_num != step_num:
105 # reset track_value
106 track_value = None
107 self.record.context["track_value"] = None
108 self.record.context["step_num"] = step_num
109 flag_modified(self.record, "context")
110 db.session.commit()
112 if track_value is not None:
113 query.where = BinaryOperation(
114 op=">",
115 args=[Identifier(track_column), Constant(track_value)],
116 )
118 return query
120 def get_info(self):
121 record = self.record
122 return {
123 "id": record.id,
124 "sql": record.sql,
125 "database": record.database,
126 "started_at": record.started_at,
127 "finished_at": record.finished_at,
128 "parameters": record.parameters,
129 "context": record.context,
130 "processed_rows": record.processed_rows,
131 "error": record.error,
132 "updated_at": record.updated_at,
133 }
135 def add_to_task(self):
136 task_record = db.Tasks(
137 company_id=ctx.company_id,
138 user_class=ctx.user_class,
139 object_type=self.OBJECT_TYPE,
140 object_id=self.record.id,
141 )
142 db.session.add(task_record)
143 db.session.commit()
145 def remove_from_task(self):
146 task = db.Tasks.query.filter(
147 db.Tasks.object_type == self.OBJECT_TYPE,
148 db.Tasks.object_id == self.record.id,
149 db.Tasks.company_id == ctx.company_id,
150 ).first()
152 if task is not None:
153 db.session.delete(task)
154 db.session.commit()
156 def set_params(self, params: dict):
157 """
158 Store parameters of the step which is about to be split into partitions
159 """
161 if "batch_size" not in params:
162 params["batch_size"] = 1000
164 self.record.parameters = params
165 self.batch_size = self.record.parameters["batch_size"]
166 db.session.commit()
168 def get_max_track_value(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
169 """
170 return max value to use in `set_progress`.
171 this function is called before execution substeps,
172 `set_progress` function - after
173 """
174 if "track_column" in self.record.parameters:
175 track_column = self.record.parameters["track_column"]
176 return df[track_column].max()
177 else:
178 # stream mode
179 return None
181 def set_progress(self, processed_rows: int = None, max_track_value: int = None):
182 """
183 Store progres of the query, it is called after processing of batch
184 """
186 if processed_rows is not None and processed_rows > 0:
187 self.record.processed_rows = self.record.processed_rows + processed_rows
188 db.session.commit()
190 if max_track_value is not None:
191 cur_value = self.record.context.get("track_value")
192 new_value = max_track_value
193 if new_value is not None:
194 if cur_value is None or new_value > cur_value:
195 self.record.context["track_value"] = new_value
196 flag_modified(self.record, "context")
197 db.session.commit()
199 def on_error(self, error: Exception, step_num: int, steps_data: dict):
200 """
201 Saves error of the query in database
202 Also saves step data and current step num to be able to resume query
203 """
204 self.record.error = str(error)
205 self.record.context["step_num"] = step_num
206 flag_modified(self.record, "context")
208 # save steps_data
209 cache = get_cache("steps_data")
210 data = pickle.dumps(steps_data, protocol=5)
211 cache.set(str(self.record.id), data)
213 db.session.commit()
215 def mark_as_run(self):
216 """
217 Mark query as running and reset error of the query
218 """
219 if self.record.finished_at is not None:
220 raise RuntimeError("The query already finished")
222 if self.record.started_at is None:
223 self.record.started_at = dt.datetime.now()
224 db.session.commit()
225 elif self.record.error is not None:
226 self.record.error = None
227 db.session.commit()
228 else:
229 raise RuntimeError("The query might be running already")
231 def get_state(self) -> dict:
232 """
233 Returns stored state for resuming the query
234 """
235 cache = get_cache("steps_data")
236 key = self.record.id
237 data = cache.get(key)
238 cache.delete(key)
240 steps_data = pickle.loads(data)
242 return {
243 "step_num": self.record.context.get("step_num"),
244 "steps_data": steps_data,
245 }
247 def finish(self):
248 """
249 Mark query as finished
250 """
252 self.record.finished_at = dt.datetime.now()
253 db.session.commit()
256class QueryContextController:
257 IGNORE_CONTEXT = "<IGNORE>"
259 def handle_db_context_vars(self, query: ASTNode, dn, session) -> tuple:
260 """
261 Check context variables in query and replace them with values.
262 Should be used before exec query in database.
264 Input:
265 - query: input query
266 - params are used to find current values of context variables
267 - dn: datanode
268 - session: mindsdb server session
270 Returns:
271 - query with replaced context variables
272 - callback to call with result of the query. it is used to update context variables
274 """
275 context_name = self.get_current_context()
277 l_query = LastQuery(query)
278 if l_query.query is None: 278 ↛ 282line 278 didn't jump to line 282 because the condition on line 278 was always true
279 # no last keyword, exit
280 return query, None
282 if context_name == self.IGNORE_CONTEXT:
283 # return with empty constants
284 return l_query.query, None
286 query_str = l_query.to_string()
288 rec = self._get_context_record(context_name, query_str)
290 if rec is None or len(rec.values) == 0:
291 values = self._get_init_last_values(l_query, dn, session)
292 if rec is None:
293 self.__add_context_record(context_name, query_str, values)
294 if context_name.startswith("job-if-"):
295 # add context for job also
296 self.__add_context_record(context_name.replace("job-if", "job"), query_str, values)
297 else:
298 rec.values = values
299 else:
300 values = rec.values
302 db.session.commit()
304 query_out = l_query.apply_values(values)
306 def callback(df, columns_info):
307 self._result_callback(l_query, context_name, query_str, df, columns_info)
309 return query_out, callback
311 def remove_lasts(self, query):
312 def replace_lasts(node, **kwargs):
313 # find last in where
314 if isinstance(node, BinaryOperation): 314 ↛ 315line 314 didn't jump to line 315 because the condition on line 314 was never true
315 arg1, arg2 = node.args
316 if not isinstance(arg1, Identifier):
317 arg1, arg2 = arg2, arg1
319 # one of the args must be identifier
320 if not isinstance(arg1, Identifier):
321 return
323 # another must be LAST or function with LAST in args
324 if isinstance(arg2, Last) or (
325 isinstance(arg2, Function) and any(isinstance(arg, Last) for arg in arg2.args)
326 ):
327 node.args = [Constant(0), Constant(0)]
328 node.op = "="
329 return node
331 # find lasts
332 query_traversal(query, replace_lasts)
333 return query
335 def _result_callback(
336 self, l_query: LastQuery, context_name: str, query_str: str, df: pd.DataFrame, columns_info: list
337 ):
338 """
339 This function handlers result from executed query and updates context variables with new values
341 Input
342 - l_query: LastQuery object
343 - To identify context:
344 - context_name: name of the context
345 - query_str: rendered query to search in context table
346 - result of the query
347 - data: list of dicts
348 - columns_info: list
350 """
351 if len(df) == 0:
352 return
354 values = {}
355 # get max values
356 for info in l_query.get_last_columns():
357 target_idx = info["target_idx"]
358 if target_idx is not None:
359 # get by index
360 col_name = columns_info[target_idx]["name"]
361 else:
362 col_name = info["column_name"]
363 # get by name
364 if col_name not in df:
365 continue
367 column_values = df[col_name].dropna()
368 try:
369 value = max(column_values)
370 except (TypeError, ValueError):
371 try:
372 # try to convert to float
373 value = max(map(float, column_values))
374 except (TypeError, ValueError):
375 try:
376 # try to convert to str
377 value = max(map(str, column_values))
378 except (TypeError, ValueError):
379 continue
381 if value is not None:
382 values[info["table_name"]] = {info["column_name"]: value}
384 self.__update_context_record(context_name, query_str, values)
386 def drop_query_context(self, object_type: str, object_id: int = None):
387 """
388 Drop context for object
389 :param object_type: type of the object
390 :param object_id: id
391 """
393 context_name = self.gen_context_name(object_type, object_id)
394 for rec in ( 394 ↛ 397line 394 didn't jump to line 397 because the loop on line 394 never started
395 db.session.query(db.QueryContext).filter_by(context_name=context_name, company_id=ctx.company_id).all()
396 ):
397 db.session.delete(rec)
398 db.session.commit()
400 def _get_init_last_values(self, l_query: LastQuery, dn, session) -> dict:
401 """
402 Gets current last values for query.
403 Creates and executes query for it:
404 'select <col> from <table> order by <col> desc limit 1"
405 """
406 last_values = {}
407 for query, info in l_query.get_init_queries():
408 response = dn.query(query=query, session=session)
409 data = response.data_frame
410 columns_info = response.columns
412 if len(data) == 0:
413 value = None
414 else:
415 row = list(data.iloc[0])
417 idx = None
418 for i, col in enumerate(columns_info):
419 if col["name"].upper() == info["column_name"].upper():
420 idx = i
421 break
423 if idx is None or len(row) == 1:
424 value = row[0]
425 else:
426 value = row[idx]
428 if value is not None:
429 last_values[info["table_name"]] = {info["column_name"]: value}
431 return last_values
433 # Context
435 def get_current_context(self) -> str:
436 """
437 returns current context name
438 """
439 try:
440 context_stack = ctx.context_stack or []
441 except AttributeError:
442 context_stack = []
443 if len(context_stack) > 0:
444 return context_stack[-1]
445 else:
446 return ""
448 def set_context(self, object_type: str = None, object_id: int = None):
449 """
450 Updates current context name, using object name and id
451 Previous context names are stored on lower levels of stack
452 """
453 try:
454 context_stack = ctx.context_stack or []
455 except AttributeError:
456 context_stack = []
457 context_stack.append(self.gen_context_name(object_type, object_id))
458 ctx.context_stack = context_stack
460 def release_context(self, object_type: str = None, object_id: int = None):
461 """
462 Removed current context (defined by object type and id) and restored previous one
463 """
464 try:
465 context_stack = ctx.context_stack or []
466 except AttributeError:
467 context_stack = []
468 if len(context_stack) == 0: 468 ↛ 469line 468 didn't jump to line 469 because the condition on line 468 was never true
469 return
470 context_name = self.gen_context_name(object_type, object_id)
471 if context_stack[-1] == context_name: 471 ↛ 473line 471 didn't jump to line 473 because the condition on line 471 was always true
472 context_stack.pop()
473 ctx.context_stack = context_stack
475 def gen_context_name(self, object_type: str, object_id: int) -> str:
476 """
477 Generated name of the context according to object type and name
478 :return: context name
479 """
481 if object_type is None: 481 ↛ 482line 481 didn't jump to line 482 because the condition on line 481 was never true
482 return ""
483 if object_id is not None:
484 object_type += "-" + str(object_id)
485 return object_type
487 def get_context_vars(self, object_type: str, object_id: int) -> List[dict]:
488 """
489 Return variables stored in context (defined by object type and id)
491 :return: list of all context variables related to context name how they stored in context table
492 """
493 context_name = self.gen_context_name(object_type, object_id)
494 vars = []
495 for rec in db.session.query(db.QueryContext).filter_by(context_name=context_name, company_id=ctx.company_id): 495 ↛ 496line 495 didn't jump to line 496 because the loop on line 495 never started
496 if rec.values is not None:
497 vars.append(rec.values)
499 return vars
501 # DB
502 def _get_context_record(self, context_name: str, query_str: str) -> db.QueryContext:
503 """
504 Find and return record for context and query string
505 """
507 return (
508 db.session.query(db.QueryContext)
509 .filter_by(query=query_str, context_name=context_name, company_id=ctx.company_id)
510 .first()
511 )
513 def __add_context_record(self, context_name: str, query_str: str, values: dict) -> db.QueryContext:
514 """
515 Creates record (for context and query string) with values and returns it
516 """
517 rec = db.QueryContext(query=query_str, context_name=context_name, company_id=ctx.company_id, values=values)
518 db.session.add(rec)
519 return rec
521 def __update_context_record(self, context_name: str, query_str: str, values: dict):
522 """
523 Updates context record with new values
524 """
525 rec = self._get_context_record(context_name, query_str)
526 rec.values = values
527 db.session.commit()
529 def get_query(self, query_id: int) -> RunningQuery:
530 """
531 Get running query by id
532 """
534 rec = db.Queries.query.filter(db.Queries.id == query_id, db.Queries.company_id == ctx.company_id).first()
536 if rec is None:
537 raise RuntimeError(f"Query not found: {query_id}")
538 return RunningQuery(rec)
540 def create_query(self, query: ASTNode, database: str = None) -> RunningQuery:
541 """
542 Create a new running query from AST query
543 """
545 # remove old queries
546 remove_query = db.session.query(db.Queries).filter(
547 db.Queries.company_id == ctx.company_id, db.Queries.finished_at < (dt.datetime.now() - dt.timedelta(days=1))
548 )
549 for rec in remove_query.all():
550 self.get_query(rec.id).remove_from_task()
551 db.session.delete(rec)
553 rec = db.Queries(
554 sql=str(query),
555 database=database,
556 company_id=ctx.company_id,
557 )
559 db.session.add(rec)
560 db.session.commit()
561 return RunningQuery(rec)
563 def list_queries(self) -> List[dict]:
564 """
565 Get list of all running queries with metadata
566 """
568 query = db.session.query(db.Queries).filter(db.Queries.company_id == ctx.company_id)
569 return [RunningQuery(record).get_info() for record in query]
571 def cancel_query(self, query_id: int):
572 """
573 Cancels running query by id
574 """
575 rec = db.Queries.query.filter(db.Queries.id == query_id, db.Queries.company_id == ctx.company_id).first()
576 if rec is None:
577 raise RuntimeError(f"Query not found: {query_id}")
579 self.get_query(rec.id).remove_from_task()
581 # the query in progress will fail when it tries to update status
582 db.session.delete(rec)
583 db.session.commit()
586query_context_controller = QueryContextController()