Coverage for mindsdb / api / executor / sql_query / steps / map_reduce_step.py: 12%
95 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import copy
3from mindsdb_sql_parser.ast import (
4 BinaryOperation,
5 UnaryOperation,
6 Constant,
7)
8from mindsdb.api.executor.planner.steps import (
9 MapReduceStep,
10 FetchDataframeStep,
11 MultipleSteps,
12 SubSelectStep,
13)
15from mindsdb.api.executor.sql_query.result_set import ResultSet
16from mindsdb.api.executor.exceptions import LogicError
17from mindsdb.utilities.partitioning import process_dataframe_in_partitions
19from .base import BaseStepCall
22def markQueryVar(where):
23 if isinstance(where, BinaryOperation):
24 markQueryVar(where.args[0])
25 markQueryVar(where.args[1])
26 elif isinstance(where, UnaryOperation):
27 markQueryVar(where.args[0])
28 elif isinstance(where, Constant):
29 if str(where.value).startswith('$var['):
30 where.is_var = True
31 where.var_name = where.value
34def replaceQueryVar(where, var_value, var_name):
35 if isinstance(where, BinaryOperation):
36 replaceQueryVar(where.args[0], var_value, var_name)
37 replaceQueryVar(where.args[1], var_value, var_name)
38 elif isinstance(where, UnaryOperation):
39 replaceQueryVar(where.args[0], var_value, var_name)
40 elif isinstance(where, Constant):
41 if hasattr(where, 'is_var') and where.is_var is True and where.value == f'$var[{var_name}]':
42 where.value = var_value
45def join_query_data(target, source):
46 if len(target.columns) == 0:
47 target = source
48 else:
49 target.add_from_result_set(source)
50 return target
53class MapReduceStepCall(BaseStepCall):
55 bind = MapReduceStep
57 def call(self, step: MultipleSteps):
58 if step.reduce != 'union':
59 raise LogicError(f'Unknown MapReduceStep type: {step.reduce}')
61 partition = getattr(step, 'partition', None)
63 if partition is not None:
64 data = self._reduce_partition(step, partition)
66 else:
67 data = self._reduce_vars(step)
69 return data
71 def _reduce_partition(self, step, partition):
72 if not isinstance(partition, int):
73 raise ValueError('Only integers are supported in partition definition.')
74 if partition <= 0:
75 raise ValueError('Partition must be a positive number')
77 input_idx = step.values.step_num
78 input_data = self.steps_data[input_idx]
79 input_columns = list(input_data.columns)
81 substeps = step.step
82 if not isinstance(substeps, list):
83 substeps = [substeps]
85 data = ResultSet()
87 df = input_data.get_raw_df()
89 def callback(chunk):
90 return self._exec_partition(chunk, substeps, input_idx, input_columns)
92 for result in process_dataframe_in_partitions(df, callback, partition):
93 if result:
94 data = join_query_data(data, result)
96 return data
98 def _exec_partition(self, df, substeps, input_idx, input_columns):
100 input_data2 = ResultSet(columns=input_columns.copy())
101 input_data2.add_raw_df(df)
103 # execute with modified previous results
104 steps_data2 = self.steps_data.copy()
105 steps_data2[input_idx] = input_data2
107 sub_data = None
108 for substep in substeps:
109 sub_data = self.sql_query.execute_step(substep, steps_data=steps_data2)
110 steps_data2[substep.step_num] = sub_data
112 return sub_data
114 def _reduce_vars(self, step):
115 # extract vars
116 step_data = self.steps_data[step.values.step_num]
117 vars = []
118 for row in step_data.get_records():
119 var_group = {}
120 vars.append(var_group)
121 for name, value in row.items():
122 if name != '__mindsdb_row_id':
123 var_group[name] = value
125 substep = step.step
127 data = ResultSet()
129 for var_group in vars:
130 steps2 = copy.deepcopy(substep)
132 self._fill_vars(steps2, var_group)
134 sub_data = self.sql_query.execute_step(steps2)
135 data = join_query_data(data, sub_data)
137 return data
139 def _fill_vars(self, step, var_group):
140 if isinstance(step, MultipleSteps):
141 for substep in step.steps:
142 self._fill_vars(substep, var_group)
143 elif isinstance(step, (FetchDataframeStep, SubSelectStep)):
144 markQueryVar(step.query.where)
145 for name, value in var_group.items():
146 replaceQueryVar(step.query.where, value, name)