Coverage for mindsdb / utilities / functions.py: 49%

98 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import os 

2import base64 

3import hashlib 

4import json 

5import datetime 

6import textwrap 

7from contextlib import ContextDecorator 

8 

9from cryptography.fernet import Fernet 

10from mindsdb_sql_parser.ast import Identifier 

11 

12from mindsdb.utilities.fs import create_process_mark, delete_process_mark, set_process_mark 

13from mindsdb.utilities import log 

14from mindsdb.utilities.config import Config 

15 

16 

17logger = log.getLogger(__name__) 

18 

19 

20def get_handler_install_message(handler_name): 

21 if Config().use_docker_env: 

22 container_id = os.environ.get("HOSTNAME", "<container_id>") 

23 return textwrap.dedent(f"""\ 

24 To install the {handler_name} handler, run the following in your terminal outside the docker container 

25 ({container_id} is the ID of this container): 

26 

27 docker exec {container_id} pip install 'mindsdb[{handler_name}]'""") 

28 else: 

29 return textwrap.dedent(f"""\ 

30 To install the {handler_name} handler, run the following in your terminal: 

31 

32 pip install 'mindsdb[{handler_name}]' # If you installed mindsdb via pip 

33 pip install '.[{handler_name}]' # If you installed mindsdb from source""") 

34 

35 

36def cast_row_types(row, field_types): 

37 """ """ 

38 keys = [x for x in row.keys() if x in field_types] 

39 for key in keys: 

40 t = field_types[key] 

41 if t == "Timestamp" and isinstance(row[key], (int, float)): 

42 timestamp = datetime.datetime.fromtimestamp(row[key], datetime.timezone.utc) 

43 row[key] = timestamp.strftime("%Y-%m-%d %H:%M:%S") 

44 elif t == "Date" and isinstance(row[key], (int, float)): 

45 timestamp = datetime.datetime.fromtimestamp(row[key], datetime.timezone.utc) 

46 row[key] = timestamp.strftime("%Y-%m-%d") 

47 elif t == "Int" and isinstance(row[key], (int, float, str)): 

48 try: 

49 logger.debug(f"cast {row[key]} to {int(row[key])}") 

50 row[key] = int(row[key]) 

51 except Exception: 

52 pass 

53 

54 

55class mark_process(ContextDecorator): 

56 def __init__(self, name: str, custom_mark: str = None): 

57 self.name = name 

58 self.custom_mark = custom_mark 

59 self.mark = None 

60 

61 def __enter__(self): 

62 if self.custom_mark is None: 62 ↛ 65line 62 didn't jump to line 65 because the condition on line 62 was always true

63 self.mark = create_process_mark(self.name) 

64 else: 

65 self.mark = set_process_mark(self.name, self.custom_mark) 

66 

67 def __exit__(self, exc_type, exc, tb): 

68 delete_process_mark(self.name, self.mark) 

69 return False 

70 

71 

72def init_lexer_parsers(): 

73 from mindsdb_sql_parser.lexer import MindsDBLexer 

74 from mindsdb_sql_parser.parser import MindsDBParser 

75 

76 return MindsDBLexer(), MindsDBParser() 

77 

78 

79def resolve_table_identifier(identifier: Identifier, default_database: str = None) -> tuple: 

80 parts = identifier.parts 

81 

82 parts_count = len(parts) 

83 if parts_count == 1: 83 ↛ 85line 83 didn't jump to line 85 because the condition on line 83 was always true

84 return (None, parts[0]) 

85 elif parts_count == 2: 

86 return (parts[0], parts[1]) 

87 else: 

88 raise Exception(f"Table identifier must contain max 2 parts: {parts}") 

89 

90 

91def resolve_model_identifier(identifier: Identifier) -> tuple: 

92 """ 

93 Splits a model identifier into its database, model name, and version components. 

94 

95 The identifier may contain one, two, or three parts. 

96 The function supports both quoted and unquoted identifiers, and normalizes names to lowercase if unquoted. 

97 

98 Examples: 

99 >>> resolve_model_identifier(Identifier(parts=['a', 'b'])) 

100 ('a', 'b', None) 

101 >>> resolve_model_identifier(Identifier(parts=['a', '1'])) 

102 (None, 'a', 1) 

103 >>> resolve_model_identifier(Identifier(parts=['a'])) 

104 (None, 'a', None) 

105 >>> resolve_model_identifier(Identifier(parts=['a', 'b', 'c'])) 

106 (None, None, None) # not found 

107 

108 Args: 

109 identifier (Identifier): The identifier object containing parts and is_quoted attributes. 

110 

111 Returns: 

112 tuple: (database_name, model_name, model_version) 

113 - database_name (str or None): The name of the database/project, or None if not specified. 

114 - model_name (str or None): The name of the model, or None if not found. 

115 - model_version (int or None): The model version as an integer, or None if not specified. 

116 """ 

117 model_name = None 

118 db_name = None 

119 version = None 

120 model_name_quoted = None 

121 db_name_quoted = None 

122 

123 match identifier.parts, identifier.is_quoted: 

124 case [model_name], [model_name_quoted]: 

125 ... 

126 case [model_name, str(version)], [model_name_quoted, _] if version.isdigit(): 

127 ... 

128 case [model_name, int(version)], [model_name_quoted, _]: 

129 ... 

130 case [db_name, model_name], [db_name_quoted, model_name_quoted]: 

131 ... 

132 case [db_name, model_name, str(version)], [db_name_quoted, model_name_quoted, _] if version.isdigit(): 

133 ... 

134 case [db_name, model_name, int(version)], [db_name_quoted, model_name_quoted, _]: 

135 ... 

136 case [db_name, model_name, str(version)], [db_name_quoted, model_name_quoted, _]: 

137 # for back compatibility. May be delete? 

138 return (None, None, None) 

139 case _: 

140 ... # may be raise ValueError? 

141 

142 if model_name_quoted is False: 142 ↛ 145line 142 didn't jump to line 145 because the condition on line 142 was always true

143 model_name = model_name.lower() 

144 

145 if db_name_quoted is False: 

146 db_name = db_name.lower() 

147 

148 if isinstance(version, int) or isinstance(version, str) and version.isdigit(): 148 ↛ 149line 148 didn't jump to line 149 because the condition on line 148 was never true

149 version = int(version) 

150 else: 

151 version = None 

152 

153 return db_name, model_name, version 

154 

155 

156def encrypt(string: bytes, key: str) -> bytes: 

157 hashed_string = hashlib.sha256(key.encode()).digest() 

158 

159 fernet_key = base64.urlsafe_b64encode(hashed_string) 

160 

161 cipher = Fernet(fernet_key) 

162 return cipher.encrypt(string) 

163 

164 

165def decrypt(encripted: bytes, key: str) -> bytes: 

166 hashed_string = hashlib.sha256(key.encode()).digest() 

167 

168 fernet_key = base64.urlsafe_b64encode(hashed_string) 

169 

170 cipher = Fernet(fernet_key) 

171 return cipher.decrypt(encripted) 

172 

173 

174def encrypt_json(data: dict, key: str) -> bytes: 

175 json_str = json.dumps(data) 

176 return encrypt(json_str.encode(), key) 

177 

178 

179def decrypt_json(encrypted_data: bytes, key: str) -> dict: 

180 decrypted = decrypt(encrypted_data, key) 

181 return json.loads(decrypted)