Coverage for mindsdb / api / executor / sql_query / result_set.py: 56%
242 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import copy
2from array import array
3from typing import Any
4from dataclasses import dataclass, field, MISSING
6import numpy as np
7import pandas as pd
8from pandas.api import types as pd_types
9import sqlalchemy.types as sqlalchemy_types
11from mindsdb_sql_parser.ast import TableColumn
13from mindsdb.utilities import log
14from mindsdb.api.executor.exceptions import WrongArgumentError
15from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE
18logger = log.getLogger(__name__)
21def get_mysql_data_type_from_series(series: pd.Series, do_infer: bool = False) -> MYSQL_DATA_TYPE:
22 """Maps pandas Series data type to corresponding MySQL data type.
24 This function examines the dtype of a pandas Series and returns the appropriate
25 MySQL data type enum value. For object dtypes, it can optionally attempt to infer
26 a more specific type.
28 Args:
29 series (pd.Series): The pandas Series to determine the MySQL type for
30 do_infer (bool): If True and series has object dtype, attempt to infer a more specific type
32 Returns:
33 MYSQL_DATA_TYPE: The corresponding MySQL data type enum value
34 """
35 dtype = series.dtype
36 if pd_types.is_object_dtype(dtype) and do_infer is True: 36 ↛ 37line 36 didn't jump to line 37 because the condition on line 36 was never true
37 dtype = series.infer_objects().dtype
39 if pd_types.is_object_dtype(dtype):
40 return MYSQL_DATA_TYPE.TEXT
41 if pd_types.is_datetime64_dtype(dtype):
42 return MYSQL_DATA_TYPE.DATETIME
43 if pd_types.is_string_dtype(dtype):
44 return MYSQL_DATA_TYPE.TEXT
45 if pd_types.is_bool_dtype(dtype):
46 return MYSQL_DATA_TYPE.BOOL
47 if pd_types.is_integer_dtype(dtype):
48 return MYSQL_DATA_TYPE.INT
49 if pd_types.is_numeric_dtype(dtype): 49 ↛ 51line 49 didn't jump to line 51 because the condition on line 49 was always true
50 return MYSQL_DATA_TYPE.FLOAT
51 return MYSQL_DATA_TYPE.TEXT
54def _dump_vector(value: Any) -> Any:
55 if isinstance(value, array):
56 return value.tolist()
57 return value
60@dataclass(kw_only=True, slots=True)
61class Column:
62 name: str = field(default=MISSING)
63 alias: str | None = None
64 table_name: str | None = None
65 table_alias: str | None = None
66 type: MYSQL_DATA_TYPE | None = None
67 database: str | None = None
68 flags: dict = None
69 charset: str | None = None
71 def __post_init__(self):
72 if self.alias is None: 72 ↛ 74line 72 didn't jump to line 74 because the condition on line 72 was always true
73 self.alias = self.name
74 if self.table_alias is None:
75 self.table_alias = self.table_name
77 def get_hash_name(self, prefix):
78 table_name = self.table_name if self.table_alias is None else self.table_alias
79 name = self.name if self.alias is None else self.alias
81 name = f"{prefix}_{table_name}_{name}"
82 return name
85def rename_df_columns(df: pd.DataFrame, names: list | None = None) -> None:
86 """Inplace rename of dataframe columns
88 Args:
89 df (pd.DataFrame): dataframe
90 names (Optional[List]): columns names to set
91 """
92 if names is not None:
93 df.columns = names
94 else:
95 df.columns = list(range(len(df.columns)))
98class ResultSet:
99 def __init__(
100 self,
101 columns: list[Column] | None = None,
102 values: list[list] | None = None,
103 df: pd.DataFrame | None = None,
104 affected_rows: int | None = None,
105 is_prediction: bool = False,
106 mysql_types: list[MYSQL_DATA_TYPE] | None = None,
107 ):
108 """
109 Args:
110 columns: list of Columns
111 values (List[List]): data of resultSet, have to be list of lists with length equal to column
112 df (pd.DataFrame): injected dataframe, have to have enumerated columns and length equal to columns
113 affected_rows (int): number of affected rows
114 """
115 if columns is None:
116 columns = []
117 self._columns = columns
119 if df is None:
120 if values is None: 120 ↛ 123line 120 didn't jump to line 123 because the condition on line 120 was always true
121 df = None
122 else:
123 df = pd.DataFrame(values)
124 self._df = df
126 self.affected_rows = affected_rows
128 self.is_prediction = is_prediction
130 self.mysql_types = mysql_types
132 def __repr__(self):
133 col_names = ", ".join([col.name for col in self._columns])
135 return f"{self.__class__.__name__}({self.length()} rows, cols: {col_names})"
137 def __len__(self) -> int:
138 if self._df is None:
139 return 0
140 return len(self._df)
142 def __getitem__(self, slice_val):
143 # return resultSet with sliced dataframe
144 df = self._df[slice_val]
145 return ResultSet(columns=self.columns, df=df)
147 # --- converters ---
149 @classmethod
150 def from_df(
151 cls,
152 df: pd.DataFrame,
153 database=None,
154 table_name=None,
155 table_alias=None,
156 is_prediction: bool = False,
157 mysql_types: list[MYSQL_DATA_TYPE] | None = None,
158 ):
159 match mysql_types:
160 case None:
161 mysql_types = [None] * len(df.columns)
162 case list() if len(mysql_types) != len(df.columns): 162 ↛ 163line 162 didn't jump to line 163 because the pattern on line 162 never matched
163 raise WrongArgumentError(f"Mysql types length mismatch: {len(mysql_types)} != {len(df.columns)}")
165 columns = [
166 Column(name=column_name, table_name=table_name, table_alias=table_alias, database=database, type=mysql_type)
167 for column_name, mysql_type in zip(df.columns, mysql_types)
168 ]
170 rename_df_columns(df)
171 return cls(df=df, columns=columns, is_prediction=is_prediction, mysql_types=mysql_types)
173 @classmethod
174 def from_df_cols(cls, df: pd.DataFrame, columns_dict: dict[str, Column], strict: bool = True) -> "ResultSet":
175 """Create ResultSet from dataframe and dictionary of columns
177 Args:
178 df (pd.DataFrame): dataframe
179 columns_dict (dict[str, Column]): dictionary of columns
180 strict (bool): if True, raise an error if a column is not found in columns_dict
182 Returns:
183 ResultSet: result set
185 Raises:
186 ValueError: if a column is not found in columns_dict and strict is True
187 """
188 alias_idx = {column.alias: column for column in columns_dict.values() if column.alias is not None}
190 columns = []
191 for column_name in df.columns:
192 if strict and column_name not in columns_dict:
193 raise ValueError(f"Column {column_name} not found in columns_dict")
194 column = columns_dict.get(column_name) or alias_idx.get(column_name) or Column(name=column_name)
195 columns.append(column)
197 rename_df_columns(df)
199 return cls(columns=columns, df=df)
201 def to_df(self):
202 columns_names = self.get_column_names()
203 df = self.get_raw_df()
204 rename_df_columns(df, columns_names)
205 return df
207 def to_df_cols(self, prefix: str = "") -> tuple[pd.DataFrame, dict[str, Column]]:
208 # returns dataframe and dict of columns
209 # can be restored to ResultSet by from_df_cols method
211 columns = []
212 col_names = {}
213 for col in self._columns:
214 name = col.get_hash_name(prefix)
215 columns.append(name)
216 col_names[name] = col
218 df = self.get_raw_df()
219 rename_df_columns(df, columns)
220 return df, col_names
222 # --- tables ---
224 def get_tables(self):
225 tables_idx = []
226 tables = []
227 cols = ["database", "table_name", "table_alias"]
228 for col in self._columns:
229 table = (col.database, col.table_name, col.table_alias)
230 if table not in tables_idx:
231 tables_idx.append(table)
232 tables.append(dict(zip(cols, table)))
233 return tables
235 # --- columns ---
237 def get_col_index(self, col):
238 """
239 Get column index
240 :param col: column object
241 :return: index of column
242 """
244 col_idx = None
245 for i, col0 in enumerate(self._columns):
246 if col0 is col:
247 col_idx = i
248 break
249 if col_idx is None:
250 raise WrongArgumentError(f"Column is not found: {col}")
251 return col_idx
253 def add_column(self, col, values=None):
254 self._columns.append(col)
256 col_idx = len(self._columns) - 1
257 if self._df is not None: 257 ↛ 258line 257 didn't jump to line 258 because the condition on line 257 was never true
258 self._df[col_idx] = values
259 return col_idx
261 def del_column(self, col):
262 idx = self.get_col_index(col)
263 self._columns.pop(idx)
265 self._df.drop(idx, axis=1, inplace=True)
266 rename_df_columns(self._df)
268 @property
269 def columns(self):
270 return self._columns
272 def get_column_names(self):
273 columns = [col.name if col.alias is None else col.alias for col in self._columns]
274 return columns
276 def find_columns(self, alias=None, table_alias=None):
277 col_list = []
278 for col in self.columns:
279 if alias is not None and col.alias.lower() != alias.lower(): 279 ↛ 281line 279 didn't jump to line 281 because the condition on line 279 was always true
280 continue
281 if table_alias is not None and col.table_alias.lower() != table_alias.lower():
282 continue
283 col_list.append(col)
285 return col_list
287 def copy_column_to(self, col, result_set2):
288 # copy with values
289 idx = self.get_col_index(col)
291 values = [row[idx] for row in self._records]
293 col2 = copy.deepcopy(col)
295 result_set2.add_column(col2, values)
296 return col2
298 def set_col_type(self, col_idx, type_name):
299 self.columns[col_idx].type = type_name
300 if self._df is not None:
301 self._df[col_idx] = self._df[col_idx].astype(type_name)
303 # --- records ---
305 def get_raw_df(self):
306 if self._df is None: 306 ↛ 307line 306 didn't jump to line 307 because the condition on line 306 was never true
307 names = range(len(self._columns))
308 return pd.DataFrame([], columns=names)
309 return self._df
311 def add_raw_df(self, df):
312 if len(df.columns) != len(self._columns):
313 raise WrongArgumentError(f"Record length mismatch columns length: {len(df.columns)} != {len(self.columns)}")
315 rename_df_columns(df)
317 if self._df is None:
318 self._df = df
319 else:
320 self._df = pd.concat([self._df, df], ignore_index=True)
322 def add_raw_values(self, values):
323 # If some values are None, the DataFrame could have incorrect integer types, since 'NaN' is technically a float, so it will convert ints to floats automatically.
324 df = pd.DataFrame(values).convert_dtypes(
325 convert_integer=True,
326 convert_floating=True,
327 infer_objects=False,
328 convert_string=False,
329 convert_boolean=False,
330 )
331 self.add_raw_df(df)
333 def get_ast_columns(self) -> list[TableColumn]:
334 """Converts ResultSet columns to a list of TableColumn objects with SQLAlchemy types.
336 This method processes each column in the ResultSet, determines its MySQL data type
337 (inferring it if necessary), and maps it to the appropriate SQLAlchemy type.
338 The resulting TableColumn objects most likely will be used in CREATE TABLE statement.
340 Returns:
341 list[TableColumn]: A list of TableColumn objects with properly mapped SQLAlchemy types
342 """
343 columns: list[TableColumn] = []
345 type_mapping = {
346 MYSQL_DATA_TYPE.TINYINT: sqlalchemy_types.INTEGER,
347 MYSQL_DATA_TYPE.SMALLINT: sqlalchemy_types.INTEGER,
348 MYSQL_DATA_TYPE.MEDIUMINT: sqlalchemy_types.INTEGER,
349 MYSQL_DATA_TYPE.INT: sqlalchemy_types.INTEGER,
350 MYSQL_DATA_TYPE.BIGINT: sqlalchemy_types.INTEGER,
351 MYSQL_DATA_TYPE.YEAR: sqlalchemy_types.INTEGER,
352 MYSQL_DATA_TYPE.BOOL: sqlalchemy_types.BOOLEAN,
353 MYSQL_DATA_TYPE.BOOLEAN: sqlalchemy_types.BOOLEAN,
354 MYSQL_DATA_TYPE.FLOAT: sqlalchemy_types.FLOAT,
355 MYSQL_DATA_TYPE.DOUBLE: sqlalchemy_types.FLOAT,
356 MYSQL_DATA_TYPE.TIME: sqlalchemy_types.Time,
357 MYSQL_DATA_TYPE.DATE: sqlalchemy_types.Date,
358 MYSQL_DATA_TYPE.DATETIME: sqlalchemy_types.DateTime,
359 MYSQL_DATA_TYPE.TIMESTAMP: sqlalchemy_types.TIMESTAMP,
360 }
362 for i, column in enumerate(self._columns):
363 column_type: MYSQL_DATA_TYPE | None = column.type
365 # infer MYSQL_DATA_TYPE if not set
366 if isinstance(column_type, MYSQL_DATA_TYPE) is False: 366 ↛ 367line 366 didn't jump to line 367 because the condition on line 366 was never true
367 if column_type is not None:
368 logger.warning(f"Unexpected column type: {column_type}")
369 if self._df is None:
370 column_type = MYSQL_DATA_TYPE.TEXT
371 else:
372 column_type = get_mysql_data_type_from_series(self._df.iloc[:, i])
374 sqlalchemy_type = type_mapping.get(column_type, sqlalchemy_types.TEXT)
376 columns.append(TableColumn(name=column.alias, type=sqlalchemy_type))
377 return columns
379 def to_lists(self, json_types=False):
380 """
381 :param type_cast: cast numpy types
382 array->list, datetime64->str
383 :return: list of lists
384 """
386 if len(self.get_raw_df()) == 0: 386 ↛ 387line 386 didn't jump to line 387 because the condition on line 386 was never true
387 return []
388 # output for APIs. simplify types
389 if json_types:
390 df = self.get_raw_df().copy()
391 for name, dtype in df.dtypes.to_dict().items():
392 if pd.api.types.is_datetime64_any_dtype(dtype): 392 ↛ 393line 392 didn't jump to line 393 because the condition on line 392 was never true
393 df[name] = df[name].dt.strftime("%Y-%m-%d %H:%M:%S.%f")
394 for i, column in enumerate(self.columns):
395 if column.type == MYSQL_DATA_TYPE.VECTOR: 395 ↛ 396line 395 didn't jump to line 396 because the condition on line 395 was never true
396 df[i] = df[i].apply(_dump_vector)
397 df.replace({np.nan: None}, inplace=True)
398 return df.to_records(index=False).tolist()
400 # slower but keep timestamp type
401 df = self._df.replace({np.nan: None}) # TODO rework
402 return df.to_dict("split")["data"]
404 def get_column_values(self, col_idx):
405 # get by column index
406 df = self.get_raw_df()
407 return list(df[df.columns[col_idx]])
409 def set_column_values(self, col_name, values):
410 # values is one value or list of values
411 cols = self.find_columns(col_name)
412 if len(cols) == 0:
413 col_idx = self.add_column(Column(name=col_name))
414 else:
415 col_idx = self.get_col_index(cols[0])
417 if self._df is not None:
418 self._df[col_idx] = values
420 def add_from_result_set(self, rs):
421 source_names = rs.get_column_names()
423 col_sequence = []
424 for name in self.get_column_names():
425 col_sequence.append(source_names.index(name))
427 raw_df = rs.get_raw_df()[col_sequence]
429 self.add_raw_df(raw_df)
431 @property
432 def records(self):
433 return list(self.get_records())
435 def get_records(self):
436 # get records as dicts.
437 # !!! Attention: !!!
438 # if resultSet contents duplicate column name: only one of them will be in output
439 names = self.get_column_names()
440 for row in self.to_lists():
441 yield dict(zip(names, row))
443 def length(self):
444 return len(self)