Coverage for mindsdb / api / executor / sql_query / steps / map_reduce_step.py: 12%

95 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import copy 

2 

3from mindsdb_sql_parser.ast import ( 

4 BinaryOperation, 

5 UnaryOperation, 

6 Constant, 

7) 

8from mindsdb.api.executor.planner.steps import ( 

9 MapReduceStep, 

10 FetchDataframeStep, 

11 MultipleSteps, 

12 SubSelectStep, 

13) 

14 

15from mindsdb.api.executor.sql_query.result_set import ResultSet 

16from mindsdb.api.executor.exceptions import LogicError 

17from mindsdb.utilities.partitioning import process_dataframe_in_partitions 

18 

19from .base import BaseStepCall 

20 

21 

22def markQueryVar(where): 

23 if isinstance(where, BinaryOperation): 

24 markQueryVar(where.args[0]) 

25 markQueryVar(where.args[1]) 

26 elif isinstance(where, UnaryOperation): 

27 markQueryVar(where.args[0]) 

28 elif isinstance(where, Constant): 

29 if str(where.value).startswith('$var['): 

30 where.is_var = True 

31 where.var_name = where.value 

32 

33 

34def replaceQueryVar(where, var_value, var_name): 

35 if isinstance(where, BinaryOperation): 

36 replaceQueryVar(where.args[0], var_value, var_name) 

37 replaceQueryVar(where.args[1], var_value, var_name) 

38 elif isinstance(where, UnaryOperation): 

39 replaceQueryVar(where.args[0], var_value, var_name) 

40 elif isinstance(where, Constant): 

41 if hasattr(where, 'is_var') and where.is_var is True and where.value == f'$var[{var_name}]': 

42 where.value = var_value 

43 

44 

45def join_query_data(target, source): 

46 if len(target.columns) == 0: 

47 target = source 

48 else: 

49 target.add_from_result_set(source) 

50 return target 

51 

52 

53class MapReduceStepCall(BaseStepCall): 

54 

55 bind = MapReduceStep 

56 

57 def call(self, step: MultipleSteps): 

58 if step.reduce != 'union': 

59 raise LogicError(f'Unknown MapReduceStep type: {step.reduce}') 

60 

61 partition = getattr(step, 'partition', None) 

62 

63 if partition is not None: 

64 data = self._reduce_partition(step, partition) 

65 

66 else: 

67 data = self._reduce_vars(step) 

68 

69 return data 

70 

71 def _reduce_partition(self, step, partition): 

72 if not isinstance(partition, int): 

73 raise ValueError('Only integers are supported in partition definition.') 

74 if partition <= 0: 

75 raise ValueError('Partition must be a positive number') 

76 

77 input_idx = step.values.step_num 

78 input_data = self.steps_data[input_idx] 

79 input_columns = list(input_data.columns) 

80 

81 substeps = step.step 

82 if not isinstance(substeps, list): 

83 substeps = [substeps] 

84 

85 data = ResultSet() 

86 

87 df = input_data.get_raw_df() 

88 

89 def callback(chunk): 

90 return self._exec_partition(chunk, substeps, input_idx, input_columns) 

91 

92 for result in process_dataframe_in_partitions(df, callback, partition): 

93 if result: 

94 data = join_query_data(data, result) 

95 

96 return data 

97 

98 def _exec_partition(self, df, substeps, input_idx, input_columns): 

99 

100 input_data2 = ResultSet(columns=input_columns.copy()) 

101 input_data2.add_raw_df(df) 

102 

103 # execute with modified previous results 

104 steps_data2 = self.steps_data.copy() 

105 steps_data2[input_idx] = input_data2 

106 

107 sub_data = None 

108 for substep in substeps: 

109 sub_data = self.sql_query.execute_step(substep, steps_data=steps_data2) 

110 steps_data2[substep.step_num] = sub_data 

111 

112 return sub_data 

113 

114 def _reduce_vars(self, step): 

115 # extract vars 

116 step_data = self.steps_data[step.values.step_num] 

117 vars = [] 

118 for row in step_data.get_records(): 

119 var_group = {} 

120 vars.append(var_group) 

121 for name, value in row.items(): 

122 if name != '__mindsdb_row_id': 

123 var_group[name] = value 

124 

125 substep = step.step 

126 

127 data = ResultSet() 

128 

129 for var_group in vars: 

130 steps2 = copy.deepcopy(substep) 

131 

132 self._fill_vars(steps2, var_group) 

133 

134 sub_data = self.sql_query.execute_step(steps2) 

135 data = join_query_data(data, sub_data) 

136 

137 return data 

138 

139 def _fill_vars(self, step, var_group): 

140 if isinstance(step, MultipleSteps): 

141 for substep in step.steps: 

142 self._fill_vars(substep, var_group) 

143 elif isinstance(step, (FetchDataframeStep, SubSelectStep)): 

144 markQueryVar(step.query.where) 

145 for name, value in var_group.items(): 

146 replaceQueryVar(step.query.where, value, name)