Coverage for mindsdb / api / executor / sql_query / result_set.py: 56%

242 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import copy 

2from array import array 

3from typing import Any 

4from dataclasses import dataclass, field, MISSING 

5 

6import numpy as np 

7import pandas as pd 

8from pandas.api import types as pd_types 

9import sqlalchemy.types as sqlalchemy_types 

10 

11from mindsdb_sql_parser.ast import TableColumn 

12 

13from mindsdb.utilities import log 

14from mindsdb.api.executor.exceptions import WrongArgumentError 

15from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE 

16 

17 

18logger = log.getLogger(__name__) 

19 

20 

21def get_mysql_data_type_from_series(series: pd.Series, do_infer: bool = False) -> MYSQL_DATA_TYPE: 

22 """Maps pandas Series data type to corresponding MySQL data type. 

23 

24 This function examines the dtype of a pandas Series and returns the appropriate 

25 MySQL data type enum value. For object dtypes, it can optionally attempt to infer 

26 a more specific type. 

27 

28 Args: 

29 series (pd.Series): The pandas Series to determine the MySQL type for 

30 do_infer (bool): If True and series has object dtype, attempt to infer a more specific type 

31 

32 Returns: 

33 MYSQL_DATA_TYPE: The corresponding MySQL data type enum value 

34 """ 

35 dtype = series.dtype 

36 if pd_types.is_object_dtype(dtype) and do_infer is True: 36 ↛ 37line 36 didn't jump to line 37 because the condition on line 36 was never true

37 dtype = series.infer_objects().dtype 

38 

39 if pd_types.is_object_dtype(dtype): 

40 return MYSQL_DATA_TYPE.TEXT 

41 if pd_types.is_datetime64_dtype(dtype): 

42 return MYSQL_DATA_TYPE.DATETIME 

43 if pd_types.is_string_dtype(dtype): 

44 return MYSQL_DATA_TYPE.TEXT 

45 if pd_types.is_bool_dtype(dtype): 

46 return MYSQL_DATA_TYPE.BOOL 

47 if pd_types.is_integer_dtype(dtype): 

48 return MYSQL_DATA_TYPE.INT 

49 if pd_types.is_numeric_dtype(dtype): 49 ↛ 51line 49 didn't jump to line 51 because the condition on line 49 was always true

50 return MYSQL_DATA_TYPE.FLOAT 

51 return MYSQL_DATA_TYPE.TEXT 

52 

53 

54def _dump_vector(value: Any) -> Any: 

55 if isinstance(value, array): 

56 return value.tolist() 

57 return value 

58 

59 

60@dataclass(kw_only=True, slots=True) 

61class Column: 

62 name: str = field(default=MISSING) 

63 alias: str | None = None 

64 table_name: str | None = None 

65 table_alias: str | None = None 

66 type: MYSQL_DATA_TYPE | None = None 

67 database: str | None = None 

68 flags: dict = None 

69 charset: str | None = None 

70 

71 def __post_init__(self): 

72 if self.alias is None: 72 ↛ 74line 72 didn't jump to line 74 because the condition on line 72 was always true

73 self.alias = self.name 

74 if self.table_alias is None: 

75 self.table_alias = self.table_name 

76 

77 def get_hash_name(self, prefix): 

78 table_name = self.table_name if self.table_alias is None else self.table_alias 

79 name = self.name if self.alias is None else self.alias 

80 

81 name = f"{prefix}_{table_name}_{name}" 

82 return name 

83 

84 

85def rename_df_columns(df: pd.DataFrame, names: list | None = None) -> None: 

86 """Inplace rename of dataframe columns 

87 

88 Args: 

89 df (pd.DataFrame): dataframe 

90 names (Optional[List]): columns names to set 

91 """ 

92 if names is not None: 

93 df.columns = names 

94 else: 

95 df.columns = list(range(len(df.columns))) 

96 

97 

98class ResultSet: 

99 def __init__( 

100 self, 

101 columns: list[Column] | None = None, 

102 values: list[list] | None = None, 

103 df: pd.DataFrame | None = None, 

104 affected_rows: int | None = None, 

105 is_prediction: bool = False, 

106 mysql_types: list[MYSQL_DATA_TYPE] | None = None, 

107 ): 

108 """ 

109 Args: 

110 columns: list of Columns 

111 values (List[List]): data of resultSet, have to be list of lists with length equal to column 

112 df (pd.DataFrame): injected dataframe, have to have enumerated columns and length equal to columns 

113 affected_rows (int): number of affected rows 

114 """ 

115 if columns is None: 

116 columns = [] 

117 self._columns = columns 

118 

119 if df is None: 

120 if values is None: 120 ↛ 123line 120 didn't jump to line 123 because the condition on line 120 was always true

121 df = None 

122 else: 

123 df = pd.DataFrame(values) 

124 self._df = df 

125 

126 self.affected_rows = affected_rows 

127 

128 self.is_prediction = is_prediction 

129 

130 self.mysql_types = mysql_types 

131 

132 def __repr__(self): 

133 col_names = ", ".join([col.name for col in self._columns]) 

134 

135 return f"{self.__class__.__name__}({self.length()} rows, cols: {col_names})" 

136 

137 def __len__(self) -> int: 

138 if self._df is None: 

139 return 0 

140 return len(self._df) 

141 

142 def __getitem__(self, slice_val): 

143 # return resultSet with sliced dataframe 

144 df = self._df[slice_val] 

145 return ResultSet(columns=self.columns, df=df) 

146 

147 # --- converters --- 

148 

149 @classmethod 

150 def from_df( 

151 cls, 

152 df: pd.DataFrame, 

153 database=None, 

154 table_name=None, 

155 table_alias=None, 

156 is_prediction: bool = False, 

157 mysql_types: list[MYSQL_DATA_TYPE] | None = None, 

158 ): 

159 match mysql_types: 

160 case None: 

161 mysql_types = [None] * len(df.columns) 

162 case list() if len(mysql_types) != len(df.columns): 162 ↛ 163line 162 didn't jump to line 163 because the pattern on line 162 never matched

163 raise WrongArgumentError(f"Mysql types length mismatch: {len(mysql_types)} != {len(df.columns)}") 

164 

165 columns = [ 

166 Column(name=column_name, table_name=table_name, table_alias=table_alias, database=database, type=mysql_type) 

167 for column_name, mysql_type in zip(df.columns, mysql_types) 

168 ] 

169 

170 rename_df_columns(df) 

171 return cls(df=df, columns=columns, is_prediction=is_prediction, mysql_types=mysql_types) 

172 

173 @classmethod 

174 def from_df_cols(cls, df: pd.DataFrame, columns_dict: dict[str, Column], strict: bool = True) -> "ResultSet": 

175 """Create ResultSet from dataframe and dictionary of columns 

176 

177 Args: 

178 df (pd.DataFrame): dataframe 

179 columns_dict (dict[str, Column]): dictionary of columns 

180 strict (bool): if True, raise an error if a column is not found in columns_dict 

181 

182 Returns: 

183 ResultSet: result set 

184 

185 Raises: 

186 ValueError: if a column is not found in columns_dict and strict is True 

187 """ 

188 alias_idx = {column.alias: column for column in columns_dict.values() if column.alias is not None} 

189 

190 columns = [] 

191 for column_name in df.columns: 

192 if strict and column_name not in columns_dict: 

193 raise ValueError(f"Column {column_name} not found in columns_dict") 

194 column = columns_dict.get(column_name) or alias_idx.get(column_name) or Column(name=column_name) 

195 columns.append(column) 

196 

197 rename_df_columns(df) 

198 

199 return cls(columns=columns, df=df) 

200 

201 def to_df(self): 

202 columns_names = self.get_column_names() 

203 df = self.get_raw_df() 

204 rename_df_columns(df, columns_names) 

205 return df 

206 

207 def to_df_cols(self, prefix: str = "") -> tuple[pd.DataFrame, dict[str, Column]]: 

208 # returns dataframe and dict of columns 

209 # can be restored to ResultSet by from_df_cols method 

210 

211 columns = [] 

212 col_names = {} 

213 for col in self._columns: 

214 name = col.get_hash_name(prefix) 

215 columns.append(name) 

216 col_names[name] = col 

217 

218 df = self.get_raw_df() 

219 rename_df_columns(df, columns) 

220 return df, col_names 

221 

222 # --- tables --- 

223 

224 def get_tables(self): 

225 tables_idx = [] 

226 tables = [] 

227 cols = ["database", "table_name", "table_alias"] 

228 for col in self._columns: 

229 table = (col.database, col.table_name, col.table_alias) 

230 if table not in tables_idx: 

231 tables_idx.append(table) 

232 tables.append(dict(zip(cols, table))) 

233 return tables 

234 

235 # --- columns --- 

236 

237 def get_col_index(self, col): 

238 """ 

239 Get column index 

240 :param col: column object 

241 :return: index of column 

242 """ 

243 

244 col_idx = None 

245 for i, col0 in enumerate(self._columns): 

246 if col0 is col: 

247 col_idx = i 

248 break 

249 if col_idx is None: 

250 raise WrongArgumentError(f"Column is not found: {col}") 

251 return col_idx 

252 

253 def add_column(self, col, values=None): 

254 self._columns.append(col) 

255 

256 col_idx = len(self._columns) - 1 

257 if self._df is not None: 257 ↛ 258line 257 didn't jump to line 258 because the condition on line 257 was never true

258 self._df[col_idx] = values 

259 return col_idx 

260 

261 def del_column(self, col): 

262 idx = self.get_col_index(col) 

263 self._columns.pop(idx) 

264 

265 self._df.drop(idx, axis=1, inplace=True) 

266 rename_df_columns(self._df) 

267 

268 @property 

269 def columns(self): 

270 return self._columns 

271 

272 def get_column_names(self): 

273 columns = [col.name if col.alias is None else col.alias for col in self._columns] 

274 return columns 

275 

276 def find_columns(self, alias=None, table_alias=None): 

277 col_list = [] 

278 for col in self.columns: 

279 if alias is not None and col.alias.lower() != alias.lower(): 279 ↛ 281line 279 didn't jump to line 281 because the condition on line 279 was always true

280 continue 

281 if table_alias is not None and col.table_alias.lower() != table_alias.lower(): 

282 continue 

283 col_list.append(col) 

284 

285 return col_list 

286 

287 def copy_column_to(self, col, result_set2): 

288 # copy with values 

289 idx = self.get_col_index(col) 

290 

291 values = [row[idx] for row in self._records] 

292 

293 col2 = copy.deepcopy(col) 

294 

295 result_set2.add_column(col2, values) 

296 return col2 

297 

298 def set_col_type(self, col_idx, type_name): 

299 self.columns[col_idx].type = type_name 

300 if self._df is not None: 

301 self._df[col_idx] = self._df[col_idx].astype(type_name) 

302 

303 # --- records --- 

304 

305 def get_raw_df(self): 

306 if self._df is None: 306 ↛ 307line 306 didn't jump to line 307 because the condition on line 306 was never true

307 names = range(len(self._columns)) 

308 return pd.DataFrame([], columns=names) 

309 return self._df 

310 

311 def add_raw_df(self, df): 

312 if len(df.columns) != len(self._columns): 

313 raise WrongArgumentError(f"Record length mismatch columns length: {len(df.columns)} != {len(self.columns)}") 

314 

315 rename_df_columns(df) 

316 

317 if self._df is None: 

318 self._df = df 

319 else: 

320 self._df = pd.concat([self._df, df], ignore_index=True) 

321 

322 def add_raw_values(self, values): 

323 # If some values are None, the DataFrame could have incorrect integer types, since 'NaN' is technically a float, so it will convert ints to floats automatically. 

324 df = pd.DataFrame(values).convert_dtypes( 

325 convert_integer=True, 

326 convert_floating=True, 

327 infer_objects=False, 

328 convert_string=False, 

329 convert_boolean=False, 

330 ) 

331 self.add_raw_df(df) 

332 

333 def get_ast_columns(self) -> list[TableColumn]: 

334 """Converts ResultSet columns to a list of TableColumn objects with SQLAlchemy types. 

335 

336 This method processes each column in the ResultSet, determines its MySQL data type 

337 (inferring it if necessary), and maps it to the appropriate SQLAlchemy type. 

338 The resulting TableColumn objects most likely will be used in CREATE TABLE statement. 

339 

340 Returns: 

341 list[TableColumn]: A list of TableColumn objects with properly mapped SQLAlchemy types 

342 """ 

343 columns: list[TableColumn] = [] 

344 

345 type_mapping = { 

346 MYSQL_DATA_TYPE.TINYINT: sqlalchemy_types.INTEGER, 

347 MYSQL_DATA_TYPE.SMALLINT: sqlalchemy_types.INTEGER, 

348 MYSQL_DATA_TYPE.MEDIUMINT: sqlalchemy_types.INTEGER, 

349 MYSQL_DATA_TYPE.INT: sqlalchemy_types.INTEGER, 

350 MYSQL_DATA_TYPE.BIGINT: sqlalchemy_types.INTEGER, 

351 MYSQL_DATA_TYPE.YEAR: sqlalchemy_types.INTEGER, 

352 MYSQL_DATA_TYPE.BOOL: sqlalchemy_types.BOOLEAN, 

353 MYSQL_DATA_TYPE.BOOLEAN: sqlalchemy_types.BOOLEAN, 

354 MYSQL_DATA_TYPE.FLOAT: sqlalchemy_types.FLOAT, 

355 MYSQL_DATA_TYPE.DOUBLE: sqlalchemy_types.FLOAT, 

356 MYSQL_DATA_TYPE.TIME: sqlalchemy_types.Time, 

357 MYSQL_DATA_TYPE.DATE: sqlalchemy_types.Date, 

358 MYSQL_DATA_TYPE.DATETIME: sqlalchemy_types.DateTime, 

359 MYSQL_DATA_TYPE.TIMESTAMP: sqlalchemy_types.TIMESTAMP, 

360 } 

361 

362 for i, column in enumerate(self._columns): 

363 column_type: MYSQL_DATA_TYPE | None = column.type 

364 

365 # infer MYSQL_DATA_TYPE if not set 

366 if isinstance(column_type, MYSQL_DATA_TYPE) is False: 366 ↛ 367line 366 didn't jump to line 367 because the condition on line 366 was never true

367 if column_type is not None: 

368 logger.warning(f"Unexpected column type: {column_type}") 

369 if self._df is None: 

370 column_type = MYSQL_DATA_TYPE.TEXT 

371 else: 

372 column_type = get_mysql_data_type_from_series(self._df.iloc[:, i]) 

373 

374 sqlalchemy_type = type_mapping.get(column_type, sqlalchemy_types.TEXT) 

375 

376 columns.append(TableColumn(name=column.alias, type=sqlalchemy_type)) 

377 return columns 

378 

379 def to_lists(self, json_types=False): 

380 """ 

381 :param type_cast: cast numpy types 

382 array->list, datetime64->str 

383 :return: list of lists 

384 """ 

385 

386 if len(self.get_raw_df()) == 0: 386 ↛ 387line 386 didn't jump to line 387 because the condition on line 386 was never true

387 return [] 

388 # output for APIs. simplify types 

389 if json_types: 

390 df = self.get_raw_df().copy() 

391 for name, dtype in df.dtypes.to_dict().items(): 

392 if pd.api.types.is_datetime64_any_dtype(dtype): 392 ↛ 393line 392 didn't jump to line 393 because the condition on line 392 was never true

393 df[name] = df[name].dt.strftime("%Y-%m-%d %H:%M:%S.%f") 

394 for i, column in enumerate(self.columns): 

395 if column.type == MYSQL_DATA_TYPE.VECTOR: 395 ↛ 396line 395 didn't jump to line 396 because the condition on line 395 was never true

396 df[i] = df[i].apply(_dump_vector) 

397 df.replace({np.nan: None}, inplace=True) 

398 return df.to_records(index=False).tolist() 

399 

400 # slower but keep timestamp type 

401 df = self._df.replace({np.nan: None}) # TODO rework 

402 return df.to_dict("split")["data"] 

403 

404 def get_column_values(self, col_idx): 

405 # get by column index 

406 df = self.get_raw_df() 

407 return list(df[df.columns[col_idx]]) 

408 

409 def set_column_values(self, col_name, values): 

410 # values is one value or list of values 

411 cols = self.find_columns(col_name) 

412 if len(cols) == 0: 

413 col_idx = self.add_column(Column(name=col_name)) 

414 else: 

415 col_idx = self.get_col_index(cols[0]) 

416 

417 if self._df is not None: 

418 self._df[col_idx] = values 

419 

420 def add_from_result_set(self, rs): 

421 source_names = rs.get_column_names() 

422 

423 col_sequence = [] 

424 for name in self.get_column_names(): 

425 col_sequence.append(source_names.index(name)) 

426 

427 raw_df = rs.get_raw_df()[col_sequence] 

428 

429 self.add_raw_df(raw_df) 

430 

431 @property 

432 def records(self): 

433 return list(self.get_records()) 

434 

435 def get_records(self): 

436 # get records as dicts. 

437 # !!! Attention: !!! 

438 # if resultSet contents duplicate column name: only one of them will be in output 

439 names = self.get_column_names() 

440 for row in self.to_lists(): 

441 yield dict(zip(names, row)) 

442 

443 def length(self): 

444 return len(self)