Coverage for mindsdb / integrations / handlers / trino_handler / trino_handler.py: 0%

88 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import re 

2from typing import Dict 

3import pandas as pd 

4from pyhive import sqlalchemy_trino 

5from mindsdb_sql_parser import parse_sql, ASTNode 

6from trino.auth import BasicAuthentication 

7from trino.dbapi import connect 

8from mindsdb.utilities.render.sqlalchemy_render import SqlalchemyRender 

9from mindsdb.integrations.libs.base import DatabaseHandler 

10from mindsdb.utilities import log 

11from mindsdb.integrations.libs.response import ( 

12 HandlerStatusResponse as StatusResponse, 

13 HandlerResponse as Response, 

14 RESPONSE_TYPE 

15) 

16 

17logger = log.getLogger(__name__) 

18 

19 

20class TrinoHandler(DatabaseHandler): 

21 """ 

22 This handler handles connection and execution of the Trino statements 

23 

24 kerberos is not implemented yet 

25 """ 

26 

27 name = 'trino' 

28 

29 def __init__(self, name, connection_data, **kwargs): 

30 super().__init__(name) 

31 self.parser = parse_sql 

32 self.connection_data = connection_data 

33 ''' 

34 service_name = kwargs.get('service_name') 

35 self.config_file_name = kwargs.get('config_file_name') 

36 self.trino_config_provider = TrinoConfigProvider(config_file_name=self.config_file_name) 

37 self.kerberos_config = self.trino_config_provider.get_trino_kerberos_config() 

38 self.http_scheme = self.kerberos_config['http_scheme'] 

39 self.dialect = self.kerberos_config['dialect'] 

40 config = self.kerberos_config['config'] 

41 hostname_override = self.kerberos_config['hostname_override'] 

42 principal = f"{kwargs.get('user')}@{hostname_override}" 

43 ca_bundle = self.kerberos_config['ca_bundle'] 

44 self.auth_config = KerberosAuthentication(config=config, 

45 service_name=service_name, 

46 principal=principal, 

47 ca_bundle=ca_bundle, 

48 hostname_override=hostname_override) 

49 ''' 

50 self.connection = None 

51 self.is_connected = False 

52 self.with_clause = "" 

53 

54 def connect(self): 

55 """" 

56 Handles the connection to a Trino instance. 

57 """ 

58 if self.is_connected is True: 

59 return self.connection 

60 

61 # option configuration 

62 http_scheme = 'http' 

63 auth = None 

64 auth_config = None 

65 password = None 

66 

67 if 'auth' in self.connection_data: 

68 auth = self.connection_data['auth'] 

69 if 'password' in self.connection_data: 

70 password = self.connection_data['password'] 

71 if 'http_scheme' in self.connection_data: 

72 http_scheme = self.connection_data['http_scheme'] 

73 if 'with' in self.connection_data: 

74 self.with_clause = self.connection_data['with'] 

75 

76 if password and auth == 'kerberos': 

77 raise Exception("Kerberos authorization doesn't support password.") 

78 elif password: 

79 auth_config = BasicAuthentication(self.connection_data['user'], password) 

80 

81 if auth: 

82 conn = connect( 

83 host=self.connection_data['host'], 

84 port=self.connection_data['port'], 

85 user=self.connection_data['user'], 

86 catalog=self.connection_data['catalog'], 

87 schema=self.connection_data['schema'], 

88 http_scheme=http_scheme, 

89 auth=auth_config) 

90 else: 

91 conn = connect( 

92 host=self.connection_data['host'], 

93 port=self.connection_data['port'], 

94 user=self.connection_data['user'], 

95 catalog=self.connection_data['catalog'], 

96 schema=self.connection_data['schema']) 

97 

98 self.is_connected = True 

99 self.connection = conn 

100 return conn 

101 

102 def check_connection(self) -> StatusResponse: 

103 """ 

104 Check the connection of the Trino instance 

105 :return: success status and error message if error occurs 

106 """ 

107 response = StatusResponse(False) 

108 

109 try: 

110 connection = self.connect() 

111 cur = connection.cursor() 

112 cur.execute("SELECT 1") 

113 response.success = True 

114 except Exception as e: 

115 logger.error(f'Error connecting to Trino {self.connection_data["schema"]}, {e}!') 

116 response.error_message = str(e) 

117 

118 if response.success is False and self.is_connected is True: 

119 self.is_connected = False 

120 

121 return response 

122 

123 def native_query(self, query: str) -> Response: 

124 """ 

125 Receive SQL query and runs it 

126 :param query: The SQL query to run in Trino 

127 :return: returns the records from the current recordset 

128 """ 

129 try: 

130 connection = self.connect() 

131 cur = connection.cursor() 

132 result = cur.execute(query) 

133 if result and cur.description: 

134 response = Response( 

135 RESPONSE_TYPE.TABLE, 

136 data_frame=pd.DataFrame( 

137 result, 

138 columns=[x[0] for x in cur.description] 

139 ) 

140 ) 

141 else: 

142 response = Response(RESPONSE_TYPE.OK) 

143 connection.commit() 

144 except Exception as e: 

145 logger.error(f'Error connecting to Trino {self.connection_data["schema"]}, {e}!') 

146 response = Response( 

147 RESPONSE_TYPE.ERROR, 

148 error_message=str(e) 

149 ) 

150 return response 

151 

152 def query(self, query: ASTNode) -> Response: 

153 # Utilize trino dialect from sqlalchemy 

154 # implement WITH clause as default for all table 

155 # in future, this behavior should be changed to support more detail 

156 # level 

157 # also, for simple the current implement is using rendered query string 

158 # another method that directly manipulate ASTNOde is prefered 

159 renderer = SqlalchemyRender(sqlalchemy_trino.TrinoDialect) 

160 query_str = renderer.get_string(query, with_failback=True) 

161 modified_query_str = re.sub( 

162 r"(?is)(CREATE.+TABLE.+\(.*\))", 

163 f"\\1 {self.with_clause}", 

164 query_str 

165 ) 

166 return self.native_query(modified_query_str) 

167 

168 def get_tables(self) -> Response: 

169 """ 

170 List all tables in Trino 

171 :return: list of all tables 

172 """ 

173 query = "SHOW TABLES" 

174 response = self.native_query(query) 

175 df = response.data_frame 

176 response.data_frame = df.rename(columns={df.columns[0]: 'table_name'}) 

177 return response 

178 

179 def get_columns(self, table_name: str) -> Dict: 

180 query = f'DESCRIBE "{table_name}"' 

181 response = self.native_query(query) 

182 return response