Coverage for mindsdb / integrations / utilities / handlers / query_utilities / insert_query_utilities.py: 0%

35 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from mindsdb_sql_parser import ast 

2from typing import Text, List, Dict, Any, Optional 

3 

4from .exceptions import UnsupportedColumnException, MandatoryColumnException, ColumnCountMismatchException 

5 

6 

7class INSERTQueryParser: 

8 """ 

9 Parses a INSERT query into its component parts. 

10 

11 Parameters 

12 ---------- 

13 query : ast.Insert 

14 Given SQL INSERT query. 

15 supported_columns : List[Text], Optional 

16 List of columns supported by the table for inserting. 

17 mandatory_columns : List[Text], Optional 

18 List of columns that must be present in the query for inserting. 

19 all_mandatory : Optional[Any], Optional (default=True) 

20 Whether all mandatory columns must be present in the query. If False, only one of the mandatory columns must be present. 

21 """ 

22 def __init__(self, query: ast.Insert, supported_columns: Optional[List[Text]] = None, mandatory_columns: Optional[List[Text]] = None, all_mandatory: Optional[Any] = True): 

23 self.query = query 

24 self.supported_columns = supported_columns 

25 self.mandatory_columns = mandatory_columns 

26 self.all_mandatory = all_mandatory 

27 

28 def parse_query(self) -> List[Dict[Text, Any]]: 

29 """ 

30 Parses a SQL INSERT statement into its components: columns, values and returns a list of dictionaries with the values to insert. 

31 """ 

32 columns = self.parse_columns() 

33 values = self.parse_values() 

34 

35 values_to_insert = [] 

36 for value in values: 

37 if len(columns) != len(value): 

38 raise ColumnCountMismatchException("Number of columns does not match the number of values") 

39 else: 

40 values_to_insert.append(dict(zip(columns, value))) 

41 

42 return values_to_insert 

43 

44 def parse_columns(self): 

45 """ 

46 Parses the columns in the query. Raises an exception if the columns are not supported or if mandatory columns are missing. 

47 """ 

48 columns = [col.name for col in self.query.columns] 

49 

50 if self.supported_columns: 

51 if not set(columns).issubset(self.supported_columns): 

52 unsupported_columns = set(columns).difference(self.supported_columns) 

53 raise UnsupportedColumnException(f"Unsupported columns: {', '.join(unsupported_columns)}") 

54 

55 if self.mandatory_columns: 

56 if self.all_mandatory: 

57 if not set(self.mandatory_columns).issubset(columns): 

58 missing_mandatory_columns = set(self.mandatory_columns).difference(columns) 

59 raise MandatoryColumnException(f"Mandatory columns missing: {', '.join(missing_mandatory_columns)}") 

60 else: 

61 if not set(self.mandatory_columns).intersection(columns): 

62 missing_mandatory_columns = set(self.mandatory_columns).difference(columns) 

63 raise MandatoryColumnException(f"Mandatory columns missing: {', '.join(missing_mandatory_columns)}") 

64 

65 return columns 

66 

67 def parse_values(self): 

68 """ 

69 Parses the values in the query. 

70 """ 

71 return self.query.values