Coverage for mindsdb / api / executor / sql_query / steps / update_step.py: 9%

64 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from mindsdb_sql_parser.ast import ( 

2 BinaryOperation, 

3 Identifier, 

4 Constant, 

5 Update, 

6) 

7from mindsdb.api.executor.planner.steps import UpdateToTable 

8from mindsdb.integrations.utilities.query_traversal import query_traversal 

9 

10from mindsdb.api.executor.sql_query.result_set import ResultSet 

11from mindsdb.api.executor.exceptions import WrongArgumentError 

12 

13from .base import BaseStepCall 

14 

15 

16class UpdateToTableCall(BaseStepCall): 

17 

18 bind = UpdateToTable 

19 

20 def call(self, step): 

21 if len(step.table.parts) > 1: 

22 integration_name = step.table.parts[0] 

23 table_name_parts = step.table.parts[1:] 

24 else: 

25 integration_name = self.context['database'] 

26 table_name_parts = step.table.parts 

27 

28 dn = self.session.datahub.get(integration_name) 

29 

30 result_step = step.dataframe 

31 

32 params_map_index = [] 

33 

34 if step.update_command.keys is not None: 

35 result_data = self.steps_data[result_step.result.step_num] 

36 

37 where = None 

38 update_columns = {} 

39 

40 key_columns = [i.to_string() for i in step.update_command.keys] 

41 if len(key_columns) == 0: 

42 raise WrongArgumentError('No key columns in update statement') 

43 for col in result_data.columns: 

44 name = col.name 

45 value = Constant(None) 

46 

47 if name in key_columns: 

48 # put it to where 

49 

50 condition = BinaryOperation( 

51 op='=', 

52 args=[Identifier(name), value] 

53 ) 

54 if where is None: 

55 where = condition 

56 else: 

57 where = BinaryOperation( 

58 op='and', 

59 args=[where, condition] 

60 ) 

61 else: 

62 # put to update 

63 update_columns[name] = value 

64 

65 params_map_index.append([name, value]) 

66 

67 if len(update_columns) is None: 

68 raise WrongArgumentError(f'No columns for update found in: {result_data.columns}') 

69 

70 update_query = Update( 

71 table=Identifier(parts=table_name_parts), 

72 update_columns=update_columns, 

73 where=where 

74 ) 

75 

76 else: 

77 # make command 

78 update_query = Update( 

79 table=Identifier(parts=table_name_parts), 

80 update_columns=step.update_command.update_columns, 

81 where=step.update_command.where 

82 ) 

83 

84 if result_step is None: 

85 # run as is 

86 response = dn.query(query=update_query, session=self.session) 

87 return ResultSet(affected_rows=response.affected_rows) 

88 result_data = self.steps_data[result_step.result.step_num] 

89 

90 # link nodes with parameters for fast replacing with values 

91 input_table_alias = step.update_command.from_select_alias 

92 if input_table_alias is None: 

93 raise WrongArgumentError('Subselect in update requires alias') 

94 

95 def prepare_map_index(node, is_table, **kwargs): 

96 if isinstance(node, Identifier) and not is_table: 

97 # is input table field 

98 if node.parts[0] == input_table_alias.parts[0]: 

99 node2 = Constant(None) 

100 param_name = node.parts[-1] 

101 params_map_index.append([param_name, node2]) 

102 # replace node with constant 

103 return node2 

104 elif node.parts[0] == table_name_parts[0]: 

105 # remove updated table alias 

106 node.parts = node.parts[1:] 

107 

108 # do mapping 

109 query_traversal(update_query, prepare_map_index) 

110 

111 # check all params is input data: 

112 data_header = [col.alias for col in result_data.columns] 

113 

114 for param_name, _ in params_map_index: 

115 if param_name not in data_header: 

116 raise WrongArgumentError(f'Field {param_name} not found in input data. Input fields: {data_header}') 

117 

118 # perform update 

119 for row in result_data.get_records(): 

120 # run update from every row from input data 

121 

122 # fill params: 

123 for param_name, param in params_map_index: 

124 param.value = row[param_name] 

125 

126 response = dn.query(query=update_query, session=self.session) 

127 return ResultSet(affected_rows=response.affected_rows)