Coverage for mindsdb / api / executor / sql_query / steps / update_step.py: 9%
64 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from mindsdb_sql_parser.ast import (
2 BinaryOperation,
3 Identifier,
4 Constant,
5 Update,
6)
7from mindsdb.api.executor.planner.steps import UpdateToTable
8from mindsdb.integrations.utilities.query_traversal import query_traversal
10from mindsdb.api.executor.sql_query.result_set import ResultSet
11from mindsdb.api.executor.exceptions import WrongArgumentError
13from .base import BaseStepCall
16class UpdateToTableCall(BaseStepCall):
18 bind = UpdateToTable
20 def call(self, step):
21 if len(step.table.parts) > 1:
22 integration_name = step.table.parts[0]
23 table_name_parts = step.table.parts[1:]
24 else:
25 integration_name = self.context['database']
26 table_name_parts = step.table.parts
28 dn = self.session.datahub.get(integration_name)
30 result_step = step.dataframe
32 params_map_index = []
34 if step.update_command.keys is not None:
35 result_data = self.steps_data[result_step.result.step_num]
37 where = None
38 update_columns = {}
40 key_columns = [i.to_string() for i in step.update_command.keys]
41 if len(key_columns) == 0:
42 raise WrongArgumentError('No key columns in update statement')
43 for col in result_data.columns:
44 name = col.name
45 value = Constant(None)
47 if name in key_columns:
48 # put it to where
50 condition = BinaryOperation(
51 op='=',
52 args=[Identifier(name), value]
53 )
54 if where is None:
55 where = condition
56 else:
57 where = BinaryOperation(
58 op='and',
59 args=[where, condition]
60 )
61 else:
62 # put to update
63 update_columns[name] = value
65 params_map_index.append([name, value])
67 if len(update_columns) is None:
68 raise WrongArgumentError(f'No columns for update found in: {result_data.columns}')
70 update_query = Update(
71 table=Identifier(parts=table_name_parts),
72 update_columns=update_columns,
73 where=where
74 )
76 else:
77 # make command
78 update_query = Update(
79 table=Identifier(parts=table_name_parts),
80 update_columns=step.update_command.update_columns,
81 where=step.update_command.where
82 )
84 if result_step is None:
85 # run as is
86 response = dn.query(query=update_query, session=self.session)
87 return ResultSet(affected_rows=response.affected_rows)
88 result_data = self.steps_data[result_step.result.step_num]
90 # link nodes with parameters for fast replacing with values
91 input_table_alias = step.update_command.from_select_alias
92 if input_table_alias is None:
93 raise WrongArgumentError('Subselect in update requires alias')
95 def prepare_map_index(node, is_table, **kwargs):
96 if isinstance(node, Identifier) and not is_table:
97 # is input table field
98 if node.parts[0] == input_table_alias.parts[0]:
99 node2 = Constant(None)
100 param_name = node.parts[-1]
101 params_map_index.append([param_name, node2])
102 # replace node with constant
103 return node2
104 elif node.parts[0] == table_name_parts[0]:
105 # remove updated table alias
106 node.parts = node.parts[1:]
108 # do mapping
109 query_traversal(update_query, prepare_map_index)
111 # check all params is input data:
112 data_header = [col.alias for col in result_data.columns]
114 for param_name, _ in params_map_index:
115 if param_name not in data_header:
116 raise WrongArgumentError(f'Field {param_name} not found in input data. Input fields: {data_header}')
118 # perform update
119 for row in result_data.get_records():
120 # run update from every row from input data
122 # fill params:
123 for param_name, param in params_map_index:
124 param.value = row[param_name]
126 response = dn.query(query=update_query, session=self.session)
127 return ResultSet(affected_rows=response.affected_rows)