Coverage for mindsdb / api / executor / sql_query / steps / fetch_dataframe_partition.py: 12%
138 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import copy
2import pandas as pd
3from typing import List
5from mindsdb_sql_parser import ASTNode, Constant
6from mindsdb.api.executor.planner.steps import FetchDataframeStepPartition
7from mindsdb.integrations.utilities.query_traversal import query_traversal
9from mindsdb.interfaces.query_context.context_controller import RunningQuery
10from mindsdb.api.executor.sql_query.result_set import ResultSet
11from mindsdb.utilities import log
12from mindsdb.utilities.config import config
13from mindsdb.utilities.partitioning import get_max_thread_count, split_data_frame
14from mindsdb.api.executor.sql_query.steps.fetch_dataframe import get_table_alias, get_fill_param_fnc
15from mindsdb.utilities.context_executor import ContextThreadPoolExecutor
17from mindsdb.interfaces.query_context.context_controller import query_context_controller
20from .base import BaseStepCall
23logger = log.getLogger(__name__)
26class FetchDataframePartitionCall(BaseStepCall):
27 """
28 Alternative to FetchDataframeCall but fetch data by batches wrapping user's query to:
30 select * from ({user query})
31 where {track_column} > {previous value}
32 order by track_column
33 limit size {batch_size} `
35 """
37 bind = FetchDataframeStepPartition
39 def call(self, step: FetchDataframeStepPartition) -> ResultSet:
40 """
41 Parameters:
42 - batch_size - count of rows to fetch from database per iteration, optional default 1000
43 - threads - run partitioning in threads, bool or int, optinal, if set:
44 - int value: use this as count of threads
45 - true: table threads, autodetect count of thread
46 - false: disable threads even if ml task queue is enabled
47 - track_column - column used for creating partitions
48 - query will be sorted by this column and select will be limited by batch_size
49 - error (default raise)
50 - when `error='skip'`, errors in partition will be skipped and execution will be continued
51 """
53 self.dn = self.session.datahub.get(step.integration)
54 query = step.query
56 # fill params
57 fill_params = get_fill_param_fnc(self.steps_data)
58 query_traversal(query, fill_params)
60 self.table_alias = get_table_alias(step.query.from_table, self.context.get("database"))
61 self.current_step_num = step.step_num
63 if step.condition is not None:
64 if "limit" in step.condition:
65 return self.repeat_till_reach_limit(step, step.condition["limit"])
67 # get query record
68 run_query = self.sql_query.run_query
69 if run_query is None:
70 raise RuntimeError("Error with partitioning of the query")
71 run_query.set_params(step.params)
73 self.substeps = step.steps
75 # ml task queue enabled?
76 use_threads, thread_count = False, None
77 if config["ml_task_queue"]["type"] == "redis":
78 use_threads = True
80 # use threads?
81 if "threads" in step.params:
82 threads = step.params["threads"]
83 if isinstance(threads, int):
84 thread_count = threads
85 use_threads = True
86 if threads is True:
87 use_threads = True
88 if threads is False:
89 # disable even with ml task queue
90 use_threads = False
92 on_error = step.params.get("error", "raise")
93 if use_threads:
94 return self.fetch_threads(run_query, query, thread_count=thread_count, on_error=on_error)
95 else:
96 return self.fetch_iterate(run_query, query, on_error=on_error)
98 def repeat_till_reach_limit(self, step, limit):
99 first_table_limit = limit * 2
100 dn = self.session.datahub.get(step.integration)
102 query = step.query
104 # fill params
105 query, context_callback = query_context_controller.handle_db_context_vars(query, dn, self.session)
107 try_num = 1
108 while True:
109 self.substeps = copy.deepcopy(step.steps)
110 query2 = copy.deepcopy(query)
112 if first_table_limit is not None:
113 query2.limit = Constant(first_table_limit)
114 else:
115 query2.limit = None
117 response = dn.query(query=query2, session=self.session)
118 df = response.data_frame
120 result = self.exec_sub_steps(df)
122 if len(result) >= limit or first_table_limit is None or len(df) < first_table_limit:
123 # we have enough results
124 # OR first table doesn't return requested count of rows
125 # OR it is a flag to stop
126 result = result[:limit]
127 break
129 if try_num > 3:
130 # the last try without the limit
131 first_table_limit = None
132 continue
134 # no enough results
135 if len(result) > 0:
136 # forecast the required limit (depending on how much row we don't have)
137 first_table_limit = int(first_table_limit * limit / len(result) * try_num + 10**try_num)
138 else:
139 first_table_limit = first_table_limit * 10
141 try_num += 1
143 if context_callback:
144 context_callback(df, response.columns)
146 return result
148 def fetch_iterate(self, run_query: RunningQuery, query: ASTNode, on_error: str = None) -> ResultSet:
149 """
150 Process batches one by one in circle
151 """
153 results = []
155 for df in run_query.get_partitions(self.dn, self, query):
156 try:
157 sub_data = self.exec_sub_steps(df)
158 run_query.set_progress(processed_rows=len(df))
159 results.append(sub_data)
160 except Exception as e:
161 if on_error == "skip":
162 logger.error(e)
163 else:
164 raise e
166 return self.concat_results(results)
168 def concat_results(self, results: List[ResultSet]) -> ResultSet:
169 """
170 Concatenate list of result sets to single result set
171 """
172 df_list = []
173 for res in results:
174 df, col_names = res.to_df_cols()
175 if len(df) > 0:
176 df_list.append(df)
178 data = ResultSet()
179 if len(df_list) > 0:
180 data = ResultSet.from_df_cols(pd.concat(df_list), col_names)
182 return data
184 def exec_sub_steps(self, df: pd.DataFrame) -> ResultSet:
185 """
186 FetchDataframeStepPartition has substeps defined
187 Every batch of data have to be used to execute these substeps
188 - batch of data is put as result of FetchDataframeStepPartition
189 - substep are executed using result of previos step (like it is all fetched data is available)
190 - the final result is returned and used outside to concatenate with results of other's batches
191 """
192 input_data = ResultSet.from_df(
193 df, table_name=self.table_alias[1], table_alias=self.table_alias[2], database=self.table_alias[0]
194 )
196 if len(self.substeps) == 0:
197 return input_data
199 # execute with modified previous results
200 steps_data2 = self.steps_data.copy()
201 steps_data2[self.current_step_num] = input_data
203 sub_data = None
204 for substep in self.substeps:
205 sub_data = self.sql_query.execute_step(substep, steps_data=steps_data2)
206 steps_data2[substep.step_num] = sub_data
207 return sub_data
209 def fetch_threads(
210 self, run_query: RunningQuery, query: ASTNode, thread_count: int = None, on_error: str = None
211 ) -> ResultSet:
212 """
213 Process batches in threads
214 - spawn required count of threads
215 - create in/out queue to communicate with threads
216 - send task to threads and receive results
217 """
219 # create communication queues
221 if thread_count is None:
222 thread_count = get_max_thread_count()
224 # 3 tasks per worker during 1 batch
225 partition_size = int(run_query.batch_size / thread_count)
226 # min partition size
227 if partition_size < 10:
228 partition_size = 10
230 results = []
232 with ContextThreadPoolExecutor(max_workers=thread_count) as executor:
233 for df in run_query.get_partitions(self.dn, self, query):
234 # split into chunks and send to workers
235 futures = []
236 for df2 in split_data_frame(df, partition_size):
237 futures.append([executor.submit(self.exec_sub_steps, df2), len(df2)])
239 error = None
240 for future, rows_count in futures:
241 try:
242 results.append(future.result())
243 run_query.set_progress(processed_rows=rows_count)
244 except Exception as e:
245 if on_error == "skip":
246 logger.error(e)
247 else:
248 executor.shutdown()
249 error = e
251 if error:
252 raise error
253 if self.sql_query.stop_event is not None and self.sql_query.stop_event.is_set():
254 executor.shutdown()
255 raise RuntimeError("Query is interrupted")
257 return self.concat_results(results)