Coverage for mindsdb / integrations / handlers / elasticsearch_handler / elasticsearch_handler.py: 0%

123 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import Text, Dict, Optional 

2 

3from elasticsearch import Elasticsearch 

4from elasticsearch.exceptions import ( 

5 ConnectionError, 

6 AuthenticationException, 

7 TransportError, 

8 RequestError, 

9) 

10from es.elastic.sqlalchemy import ESDialect 

11from pandas import DataFrame 

12from mindsdb_sql_parser.ast.base import ASTNode 

13from mindsdb.utilities.render.sqlalchemy_render import SqlalchemyRender 

14 

15from mindsdb.integrations.libs.base import DatabaseHandler 

16from mindsdb.integrations.libs.response import ( 

17 HandlerResponse as Response, 

18 HandlerStatusResponse as StatusResponse, 

19 RESPONSE_TYPE, 

20) 

21from mindsdb.utilities import log 

22 

23 

24logger = log.getLogger(__name__) 

25 

26 

27class ElasticsearchHandler(DatabaseHandler): 

28 """ 

29 This handler handles the connection and execution of SQL statements on Elasticsearch. 

30 """ 

31 

32 name = "elasticsearch" 

33 

34 def __init__(self, name: Text, connection_data: Optional[Dict], **kwargs) -> None: 

35 """ 

36 Initializes the handler. 

37 

38 Args: 

39 name (Text): The name of the handler instance. 

40 connection_data (Dict): The connection data required to connect to the AWS (S3) account. 

41 kwargs: Arbitrary keyword arguments. 

42 """ 

43 super().__init__(name) 

44 self.connection_data = connection_data 

45 self.kwargs = kwargs 

46 

47 self.connection = None 

48 self.is_connected = False 

49 

50 def __del__(self) -> None: 

51 """ 

52 Closes the connection when the handler instance is deleted. 

53 """ 

54 if self.is_connected: 

55 self.disconnect() 

56 

57 def connect(self) -> Elasticsearch: 

58 """ 

59 Establishes a connection to the Elasticsearch host. 

60 

61 Raises: 

62 ValueError: If the expected connection parameters are not provided. 

63 

64 Returns: 

65 elasticsearch.Elasticsearch: A connection object to the Elasticsearch host. 

66 """ 

67 if self.is_connected is True: 

68 return self.connection 

69 

70 config = {} 

71 

72 # Mandatory connection parameters. 

73 if ("hosts" not in self.connection_data) and ( 

74 "cloud_id" not in self.connection_data 

75 ): 

76 raise ValueError( 

77 "Either the hosts or cloud_id parameter should be provided!" 

78 ) 

79 

80 # Optional/Additional connection parameters. 

81 optional_parameters = ["hosts", "cloud_id", "api_key"] 

82 for parameter in optional_parameters: 

83 if parameter in self.connection_data: 

84 if parameter == "hosts": 

85 config["hosts"] = self.connection_data[parameter].split(",") 

86 else: 

87 config[parameter] = self.connection_data[parameter] 

88 

89 # Ensure that if either user or password is provided, both are provided. 

90 if ("user" in self.connection_data) != ("password" in self.connection_data): 

91 raise ValueError( 

92 "Both user and password should be provided if one of them is provided!" 

93 ) 

94 

95 if "user" in self.connection_data: 

96 config["http_auth"] = ( 

97 self.connection_data["user"], 

98 self.connection_data["password"], 

99 ) 

100 

101 try: 

102 self.connection = Elasticsearch( 

103 **config, 

104 ) 

105 self.is_connected = True 

106 return self.connection 

107 except ConnectionError as conn_error: 

108 logger.error( 

109 f"Connection error when connecting to Elasticsearch: {conn_error}" 

110 ) 

111 raise 

112 except AuthenticationException as auth_error: 

113 logger.error( 

114 f"Authentication error when connecting to Elasticsearch: {auth_error}" 

115 ) 

116 raise 

117 except Exception as unknown_error: 

118 logger.error( 

119 f"Unknown error when connecting to Elasticsearch: {unknown_error}" 

120 ) 

121 raise 

122 

123 def disconnect(self) -> None: 

124 """ 

125 Closes the connection to the Elasticsearch host if it's currently open. 

126 """ 

127 if self.is_connected is False: 

128 return 

129 

130 self.connection.close() 

131 self.is_connected = False 

132 

133 def check_connection(self) -> StatusResponse: 

134 """ 

135 Checks the status of the connection to the Elasticsearch host. 

136 

137 Returns: 

138 StatusResponse: An object containing the success status and an error message if an error occurs. 

139 """ 

140 response = StatusResponse(False) 

141 need_to_close = self.is_connected is False 

142 

143 try: 

144 connection = self.connect() 

145 

146 # Execute a simple query to test the connection. 

147 connection.sql.query(body={"query": "SELECT 1"}) 

148 response.success = True 

149 # All exceptions are caught here to ensure that the connection is closed if an error occurs. 

150 except Exception as error: 

151 logger.error(f"Error connecting to Elasticsearch, {error}!") 

152 response.error_message = str(error) 

153 

154 if response.success and need_to_close: 

155 self.disconnect() 

156 

157 elif not response.success and self.is_connected: 

158 self.is_connected = False 

159 

160 return response 

161 

162 def native_query(self, query: Text) -> Response: 

163 """ 

164 Executes a native SQL query on the Elasticsearch host and returns the result. 

165 

166 Args: 

167 query (str): The SQL query to be executed. 

168 

169 Returns: 

170 Response: A response object containing the result of the query or an error message. 

171 """ 

172 need_to_close = self.is_connected is False 

173 

174 connection = self.connect() 

175 try: 

176 response = connection.sql.query(body={"query": query}) 

177 records = response["rows"] 

178 columns = response["columns"] 

179 

180 new_records = True 

181 while new_records: 

182 try: 

183 if response["cursor"]: 

184 response = connection.sql.query( 

185 body={"query": query, "cursor": response["cursor"]} 

186 ) 

187 

188 new_records = response["rows"] 

189 records = records + new_records 

190 except KeyError: 

191 new_records = False 

192 

193 column_names = [column["name"] for column in columns] 

194 if not records: 

195 null_record = [None] * len(column_names) 

196 records = [null_record] 

197 

198 response = Response( 

199 RESPONSE_TYPE.TABLE, 

200 data_frame=DataFrame(records, columns=column_names), 

201 ) 

202 

203 except (TransportError, RequestError) as transport_or_request_error: 

204 logger.error( 

205 f"Error running query: {query} on Elasticsearch, {transport_or_request_error}!" 

206 ) 

207 response = Response( 

208 RESPONSE_TYPE.ERROR, error_message=str(transport_or_request_error) 

209 ) 

210 except Exception as unknown_error: 

211 logger.error( 

212 f"Unknown error running query: {query} on Elasticsearch, {unknown_error}!" 

213 ) 

214 response = Response(RESPONSE_TYPE.ERROR, error_message=str(unknown_error)) 

215 

216 if need_to_close is True: 

217 self.disconnect() 

218 

219 return response 

220 

221 def query(self, query: ASTNode) -> Response: 

222 """ 

223 Executes a SQL query represented by an ASTNode on the Elasticsearch host and retrieves the data. 

224 

225 Args: 

226 query (ASTNode): An ASTNode representing the SQL query to be executed. 

227 

228 Returns: 

229 Response: The response from the `native_query` method, containing the result of the SQL query execution. 

230 """ 

231 # TODO: Add support for other query types. 

232 renderer = SqlalchemyRender(ESDialect) 

233 query_str = renderer.get_string(query, with_failback=True) 

234 logger.debug(f"Executing SQL query: {query_str}") 

235 return self.native_query(query_str) 

236 

237 def get_tables(self) -> Response: 

238 """ 

239 Retrieves a list of all non-system tables (indexes) in the Elasticsearch host. 

240 

241 Returns: 

242 Response: A response object containing a list of tables (indexes) in the Elasticsearch host. 

243 """ 

244 query = """ 

245 SHOW TABLES 

246 """ 

247 result = self.native_query(query) 

248 

249 df = result.data_frame 

250 

251 # Remove indices that are system indices: These are indices that start with a period. 

252 df = df[~df["name"].str.startswith(".")] 

253 

254 df = df.drop(["catalog", "kind"], axis=1) 

255 result.data_frame = df.rename( 

256 columns={"name": "table_name", "type": "table_type"} 

257 ) 

258 

259 return result 

260 

261 def get_columns(self, table_name: Text) -> Response: 

262 """ 

263 Retrieves column (field) details for a specified table (index) in the Elasticsearch host. 

264 

265 Args: 

266 table_name (str): The name of the table for which to retrieve column information. 

267 

268 Raises: 

269 ValueError: If the 'table_name' is not a valid string. 

270 

271 Returns: 

272 Response: A response object containing the column details. 

273 """ 

274 if not table_name or not isinstance(table_name, str): 

275 raise ValueError("Invalid table name provided.") 

276 

277 query = f""" 

278 DESCRIBE {table_name} 

279 """ 

280 result = self.native_query(query) 

281 

282 df = result.data_frame 

283 df = df.drop("mapping", axis=1) 

284 result.data_frame = df.rename( 

285 columns={"column": "column_name", "type": "data_type"} 

286 ) 

287 

288 return result