Coverage for mindsdb / integrations / handlers / scylla_handler / scylla_handler.py: 0%
114 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import tempfile
3import pandas as pd
4import requests
6from cassandra.cluster import Cluster
7from cassandra.auth import PlainTextAuthProvider
8from cassandra.util import Date
10from mindsdb_sql_parser import parse_sql
11from mindsdb_sql_parser.ast.base import ASTNode
12from mindsdb_sql_parser import ast
13from mindsdb.utilities.render.sqlalchemy_render import SqlalchemyRender
15from mindsdb.integrations.libs.base import DatabaseHandler
16from mindsdb.integrations.libs.response import (
17 HandlerStatusResponse as StatusResponse,
18 HandlerResponse as Response,
19 RESPONSE_TYPE
20)
21from mindsdb.utilities import log
23logger = log.getLogger(__name__)
26class ScyllaHandler(DatabaseHandler):
27 """
28 This handler handles connection and execution of the Scylla statements.
29 """
30 name = 'scylla'
32 def __init__(self, name=None, **kwargs):
33 super().__init__(name)
34 self.parser = parse_sql
35 self.connection_args = kwargs.get('connection_data')
36 self.session = None
37 self.is_connected = False
39 def download_secure_bundle(self, url, max_size=10 * 1024 * 1024):
40 """
41 Downloads the secure bundle from a given URL and stores it in a temporary file.
43 :param url: URL of the secure bundle to be downloaded.
44 :param max_size: Maximum allowable size of the bundle in bytes. Defaults to 10MB.
45 :return: Path to the downloaded secure bundle saved as a temporary file.
46 :raises ValueError: If the secure bundle size exceeds the allowed `max_size`.
48 TODO:
49 - Find a way to periodically clean up or delete the temporary files
50 after they have been used to prevent filling up storage over time.
51 """
52 response = requests.get(url, stream=True, timeout=10)
53 response.raise_for_status()
55 content_length = int(response.headers.get('content-length', 0))
56 if content_length > max_size:
57 raise ValueError("Secure bundle is larger than the allowed size!")
59 with tempfile.NamedTemporaryFile(delete=False) as temp_file:
60 size_downloaded = 0
61 for chunk in response.iter_content(chunk_size=8192):
62 size_downloaded += len(chunk)
63 if size_downloaded > max_size:
64 raise ValueError("Secure bundle is larger than the allowed size!")
65 temp_file.write(chunk)
66 return temp_file.name
68 def connect(self):
69 """
70 Handles the connection to a Scylla keystore.
71 """
72 if self.is_connected is True:
73 return self.session
75 auth_provider = None
76 if any(key in self.connection_args for key in ('user', 'password')):
77 if all(key in self.connection_args for key in ('user', 'password')):
78 auth_provider = PlainTextAuthProvider(
79 username=self.connection_args['user'], password=self.connection_args['password']
80 )
81 else:
82 raise ValueError("If authentication is required, both 'user' and 'password' must be provided!")
84 connection_props = {
85 'auth_provider': auth_provider
86 }
87 connection_props['protocol_version'] = self.connection_args.get('protocol_version', 4)
88 secure_connect_bundle = self.connection_args.get('secure_connect_bundle')
90 if secure_connect_bundle:
91 # Check if the secure bundle is a URL
92 if secure_connect_bundle.startswith(('http://', 'https://')):
93 secure_connect_bundle = self.download_secure_bundle(secure_connect_bundle)
94 connection_props['cloud'] = {
95 'secure_connect_bundle': secure_connect_bundle
96 }
97 else:
98 connection_props['contact_points'] = [self.connection_args['host']]
99 connection_props['port'] = int(self.connection_args['port'])
101 cluster = Cluster(**connection_props)
102 session = cluster.connect(self.connection_args.get('keyspace'))
104 self.is_connected = True
105 self.session = session
106 return self.session
108 def check_connection(self) -> StatusResponse:
109 """
110 Check the connection of the Scylla database
111 :return: success status and error message if error occurs
112 """
113 response = StatusResponse(False)
115 try:
116 session = self.connect()
117 # TODO: change the healthcheck
118 session.execute('SELECT release_version FROM system.local').one()
119 response.success = True
120 except Exception as e:
121 logger.error(f'Error connecting to Scylla {self.connection_args["keyspace"]}, {e}!')
122 response.error_message = e
124 if response.success is False and self.is_connected is True:
125 self.is_connected = False
127 return response
129 def prepare_response(self, resp):
130 # replace cassandra types
131 data = []
132 for row in resp:
133 row2 = {}
134 for k, v in row._asdict().items():
135 if isinstance(v, Date):
136 v = v.date()
137 row2[k] = v
138 data.append(row2)
139 return data
141 def native_query(self, query: str) -> Response:
142 """
143 Receive SQL query and runs it
144 :param query: The SQL query to run in MySQL
145 :return: returns the records from the current recordset
146 """
147 session = self.connect()
148 try:
149 resp = session.execute(query).all()
150 resp = self.prepare_response(resp)
151 if resp:
152 response = Response(
153 RESPONSE_TYPE.TABLE,
154 pd.DataFrame(
155 resp
156 )
157 )
158 else:
159 response = Response(RESPONSE_TYPE.OK)
160 except Exception as e:
161 logger.error(f'Error running query: {query} on {self.connection_args["keyspace"]}!')
162 response = Response(
163 RESPONSE_TYPE.ERROR,
164 error_message=str(e)
165 )
166 return response
168 def query(self, query: ASTNode) -> Response:
169 """
170 Retrieve the data from the SQL statement.
171 """
173 # remove table alias because Cassandra Query Language doesn't support it
174 if isinstance(query, ast.Select):
175 if isinstance(query.from_table, ast.Identifier) and query.from_table.alias is not None:
176 query.from_table.alias = None
178 # remove table name from fields
179 table_name = query.from_table.parts[-1]
181 for target in query.targets:
182 if isinstance(target, ast.Identifier):
183 if target.parts[0] == table_name:
184 target.parts.pop(0)
186 renderer = SqlalchemyRender('mysql')
187 query_str = renderer.get_string(query, with_failback=True)
188 return self.native_query(query_str)
190 def get_tables(self) -> Response:
191 """
192 Get a list with all of the tabels in MySQL
193 """
194 q = "DESCRIBE TABLES;"
195 result = self.native_query(q)
196 df = result.data_frame
197 result.data_frame = df.rename(columns={df.columns[0]: 'table_name'})
198 return result
200 def get_columns(self, table_name) -> Response:
201 """
202 Show details about the table
203 """
204 q = f"DESCRIBE {table_name};"
205 result = self.native_query(q)
206 return result