Coverage for mindsdb / integrations / handlers / scylla_handler / scylla_handler.py: 0%

114 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import tempfile 

2 

3import pandas as pd 

4import requests 

5 

6from cassandra.cluster import Cluster 

7from cassandra.auth import PlainTextAuthProvider 

8from cassandra.util import Date 

9 

10from mindsdb_sql_parser import parse_sql 

11from mindsdb_sql_parser.ast.base import ASTNode 

12from mindsdb_sql_parser import ast 

13from mindsdb.utilities.render.sqlalchemy_render import SqlalchemyRender 

14 

15from mindsdb.integrations.libs.base import DatabaseHandler 

16from mindsdb.integrations.libs.response import ( 

17 HandlerStatusResponse as StatusResponse, 

18 HandlerResponse as Response, 

19 RESPONSE_TYPE 

20) 

21from mindsdb.utilities import log 

22 

23logger = log.getLogger(__name__) 

24 

25 

26class ScyllaHandler(DatabaseHandler): 

27 """ 

28 This handler handles connection and execution of the Scylla statements. 

29 """ 

30 name = 'scylla' 

31 

32 def __init__(self, name=None, **kwargs): 

33 super().__init__(name) 

34 self.parser = parse_sql 

35 self.connection_args = kwargs.get('connection_data') 

36 self.session = None 

37 self.is_connected = False 

38 

39 def download_secure_bundle(self, url, max_size=10 * 1024 * 1024): 

40 """ 

41 Downloads the secure bundle from a given URL and stores it in a temporary file. 

42 

43 :param url: URL of the secure bundle to be downloaded. 

44 :param max_size: Maximum allowable size of the bundle in bytes. Defaults to 10MB. 

45 :return: Path to the downloaded secure bundle saved as a temporary file. 

46 :raises ValueError: If the secure bundle size exceeds the allowed `max_size`. 

47 

48 TODO: 

49 - Find a way to periodically clean up or delete the temporary files 

50 after they have been used to prevent filling up storage over time. 

51 """ 

52 response = requests.get(url, stream=True, timeout=10) 

53 response.raise_for_status() 

54 

55 content_length = int(response.headers.get('content-length', 0)) 

56 if content_length > max_size: 

57 raise ValueError("Secure bundle is larger than the allowed size!") 

58 

59 with tempfile.NamedTemporaryFile(delete=False) as temp_file: 

60 size_downloaded = 0 

61 for chunk in response.iter_content(chunk_size=8192): 

62 size_downloaded += len(chunk) 

63 if size_downloaded > max_size: 

64 raise ValueError("Secure bundle is larger than the allowed size!") 

65 temp_file.write(chunk) 

66 return temp_file.name 

67 

68 def connect(self): 

69 """ 

70 Handles the connection to a Scylla keystore. 

71 """ 

72 if self.is_connected is True: 

73 return self.session 

74 

75 auth_provider = None 

76 if any(key in self.connection_args for key in ('user', 'password')): 

77 if all(key in self.connection_args for key in ('user', 'password')): 

78 auth_provider = PlainTextAuthProvider( 

79 username=self.connection_args['user'], password=self.connection_args['password'] 

80 ) 

81 else: 

82 raise ValueError("If authentication is required, both 'user' and 'password' must be provided!") 

83 

84 connection_props = { 

85 'auth_provider': auth_provider 

86 } 

87 connection_props['protocol_version'] = self.connection_args.get('protocol_version', 4) 

88 secure_connect_bundle = self.connection_args.get('secure_connect_bundle') 

89 

90 if secure_connect_bundle: 

91 # Check if the secure bundle is a URL 

92 if secure_connect_bundle.startswith(('http://', 'https://')): 

93 secure_connect_bundle = self.download_secure_bundle(secure_connect_bundle) 

94 connection_props['cloud'] = { 

95 'secure_connect_bundle': secure_connect_bundle 

96 } 

97 else: 

98 connection_props['contact_points'] = [self.connection_args['host']] 

99 connection_props['port'] = int(self.connection_args['port']) 

100 

101 cluster = Cluster(**connection_props) 

102 session = cluster.connect(self.connection_args.get('keyspace')) 

103 

104 self.is_connected = True 

105 self.session = session 

106 return self.session 

107 

108 def check_connection(self) -> StatusResponse: 

109 """ 

110 Check the connection of the Scylla database 

111 :return: success status and error message if error occurs 

112 """ 

113 response = StatusResponse(False) 

114 

115 try: 

116 session = self.connect() 

117 # TODO: change the healthcheck 

118 session.execute('SELECT release_version FROM system.local').one() 

119 response.success = True 

120 except Exception as e: 

121 logger.error(f'Error connecting to Scylla {self.connection_args["keyspace"]}, {e}!') 

122 response.error_message = e 

123 

124 if response.success is False and self.is_connected is True: 

125 self.is_connected = False 

126 

127 return response 

128 

129 def prepare_response(self, resp): 

130 # replace cassandra types 

131 data = [] 

132 for row in resp: 

133 row2 = {} 

134 for k, v in row._asdict().items(): 

135 if isinstance(v, Date): 

136 v = v.date() 

137 row2[k] = v 

138 data.append(row2) 

139 return data 

140 

141 def native_query(self, query: str) -> Response: 

142 """ 

143 Receive SQL query and runs it 

144 :param query: The SQL query to run in MySQL 

145 :return: returns the records from the current recordset 

146 """ 

147 session = self.connect() 

148 try: 

149 resp = session.execute(query).all() 

150 resp = self.prepare_response(resp) 

151 if resp: 

152 response = Response( 

153 RESPONSE_TYPE.TABLE, 

154 pd.DataFrame( 

155 resp 

156 ) 

157 ) 

158 else: 

159 response = Response(RESPONSE_TYPE.OK) 

160 except Exception as e: 

161 logger.error(f'Error running query: {query} on {self.connection_args["keyspace"]}!') 

162 response = Response( 

163 RESPONSE_TYPE.ERROR, 

164 error_message=str(e) 

165 ) 

166 return response 

167 

168 def query(self, query: ASTNode) -> Response: 

169 """ 

170 Retrieve the data from the SQL statement. 

171 """ 

172 

173 # remove table alias because Cassandra Query Language doesn't support it 

174 if isinstance(query, ast.Select): 

175 if isinstance(query.from_table, ast.Identifier) and query.from_table.alias is not None: 

176 query.from_table.alias = None 

177 

178 # remove table name from fields 

179 table_name = query.from_table.parts[-1] 

180 

181 for target in query.targets: 

182 if isinstance(target, ast.Identifier): 

183 if target.parts[0] == table_name: 

184 target.parts.pop(0) 

185 

186 renderer = SqlalchemyRender('mysql') 

187 query_str = renderer.get_string(query, with_failback=True) 

188 return self.native_query(query_str) 

189 

190 def get_tables(self) -> Response: 

191 """ 

192 Get a list with all of the tabels in MySQL 

193 """ 

194 q = "DESCRIBE TABLES;" 

195 result = self.native_query(q) 

196 df = result.data_frame 

197 result.data_frame = df.rename(columns={df.columns[0]: 'table_name'}) 

198 return result 

199 

200 def get_columns(self, table_name) -> Response: 

201 """ 

202 Show details about the table 

203 """ 

204 q = f"DESCRIBE {table_name};" 

205 result = self.native_query(q) 

206 return result