Coverage for mindsdb / integrations / handlers / ibm_cos_handler / ibm_cos_handler.py: 0%

197 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from contextlib import contextmanager 

2from typing import Text, Dict, Optional, List 

3 

4import ibm_boto3 

5from ibm_botocore.client import ClientError 

6import pandas as pd 

7import duckdb 

8 

9from mindsdb_sql_parser.ast.base import ASTNode 

10from mindsdb_sql_parser.ast import Select, Identifier, Insert, Star, Constant 

11 

12from mindsdb.utilities import log 

13from mindsdb.integrations.libs.response import ( 

14 HandlerStatusResponse as StatusResponse, 

15 HandlerResponse as Response, 

16 RESPONSE_TYPE, 

17) 

18 

19from mindsdb.integrations.libs.api_handler import APIResource, APIHandler 

20from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator 

21 

22logger = log.getLogger(__name__) 

23 

24 

25class ListFilesTable(APIResource): 

26 def list( 

27 self, 

28 targets: List[str] = None, 

29 conditions: List[FilterCondition] = None, 

30 limit: int = None, 

31 *args, 

32 **kwargs, 

33 ) -> pd.DataFrame: 

34 buckets = None 

35 for condition in conditions: 

36 if condition.column == "bucket": 

37 if condition.op == FilterOperator.IN: 

38 buckets = condition.value 

39 elif condition.op == FilterOperator.EQUAL: 

40 buckets = [condition.value] 

41 condition.applied = True 

42 

43 data = [] 

44 for obj in self.handler.get_objects(limit=limit, buckets=buckets): 

45 path = obj["Key"] 

46 if obj["Filename"].split(".")[1] in self.handler.supported_file_formats: 

47 item = { 

48 "path": path, 

49 "bucket": obj["Bucket"], 

50 "name": path[path.rfind("/") + 1 :], 

51 "extension": path[path.rfind(".") + 1 :], 

52 } 

53 

54 data.append(item) 

55 

56 return pd.DataFrame(data=data, columns=self.get_columns()) 

57 

58 def get_columns(self) -> List[str]: 

59 return ["path", "name", "extension", "bucket", "content"] 

60 

61 

62class FileTable(APIResource): 

63 def list(self, targets: List[str] = None, table_name=None, *args, **kwargs) -> pd.DataFrame: 

64 return self.handler.read_as_table(table_name) 

65 

66 def add(self, data, table_name=None): 

67 df = pd.DataFrame(data) 

68 return self.handler.add_data_to_table(table_name, df) 

69 

70 

71class IBMCloudObjectStorageHandler(APIHandler): 

72 name = "ibm_cos" 

73 supported_file_formats = ["csv", "tsv", "json", "parquet"] 

74 

75 def __init__(self, name: Text, connection_data: Optional[Dict] = None, **kwargs): 

76 super().__init__(name) 

77 self.connection_data = connection_data or {} 

78 self.kwargs = kwargs 

79 

80 self.connection = None 

81 self.is_connected = False 

82 self.cache_thread_safe = True 

83 self._regions = {} 

84 

85 self.bucket = self.connection_data.get("bucket") 

86 self._files_table = ListFilesTable(self) 

87 

88 def __del__(self): 

89 if self.is_connected is True: 

90 self.disconnect() 

91 

92 def connect(self): 

93 if self.is_connected is True: 

94 return self.connection 

95 

96 required_params = [ 

97 "cos_hmac_access_key_id", 

98 "cos_hmac_secret_access_key", 

99 "cos_endpoint_url", 

100 ] 

101 if not all(key in self.connection_data for key in required_params): 

102 raise ValueError( 

103 "Required parameters (cos_hmac_access_key_id, cos_hmac_secret_access_key, cos_endpoint_url) must be provided." 

104 ) 

105 

106 self.connection = self._connect_ibm_boto3() 

107 self.is_connected = True 

108 

109 return self.connection 

110 

111 def _connect_ibm_boto3(self) -> ibm_boto3.client: 

112 config = { 

113 "aws_access_key_id": self.connection_data["cos_hmac_access_key_id"], 

114 "aws_secret_access_key": self.connection_data["cos_hmac_secret_access_key"], 

115 "endpoint_url": self.connection_data["cos_endpoint_url"], 

116 } 

117 

118 client = ibm_boto3.client("s3", **config) 

119 

120 if self.bucket is not None: 

121 client.head_bucket(Bucket=self.bucket) 

122 else: 

123 client.list_buckets() 

124 

125 return client 

126 

127 def disconnect(self): 

128 if not self.is_connected: 

129 return 

130 self.connection = None 

131 self.is_connected = False 

132 

133 def check_connection(self) -> StatusResponse: 

134 response = StatusResponse(False) 

135 need_to_close = self.is_connected is False 

136 

137 try: 

138 self._connect_ibm_boto3() 

139 response.success = True 

140 except (ClientError, ValueError) as e: 

141 logger.error(f"Error connecting to IBM COS with the given credentials, {e}!") 

142 response.error_message = str(e) 

143 

144 if response.success and need_to_close: 

145 self.disconnect() 

146 

147 elif not response.success and self.is_connected: 

148 self.is_connected = False 

149 

150 return response 

151 

152 @contextmanager 

153 def _connect_duckdb(self): 

154 duckdb_conn = duckdb.connect(":memory:") 

155 duckdb_conn.execute("INSTALL httpfs") 

156 duckdb_conn.execute("LOAD httpfs") 

157 

158 duckdb_conn.execute(f"SET s3_access_key_id='{self.connection_data['cos_hmac_access_key_id']}'") 

159 duckdb_conn.execute(f"SET s3_secret_access_key='{self.connection_data['cos_hmac_secret_access_key']}'") 

160 

161 endpoint_url = self.connection_data["cos_endpoint_url"] 

162 if endpoint_url.startswith("https://"): 

163 endpoint_url = endpoint_url[len("https://") :] 

164 elif endpoint_url.startswith("http://"): 

165 endpoint_url = endpoint_url[len("http://") :] 

166 

167 duckdb_conn.execute(f"SET s3_endpoint='{endpoint_url}'") 

168 duckdb_conn.execute("SET s3_url_style='path'") 

169 duckdb_conn.execute("SET s3_use_ssl=true") 

170 

171 try: 

172 yield duckdb_conn 

173 finally: 

174 duckdb_conn.close() 

175 

176 def _get_bucket(self, key): 

177 if self.bucket is not None: 

178 return self.bucket, key 

179 

180 ar = key.split("/") 

181 return ar[0], "/".join(ar[1:]) 

182 

183 def read_as_table(self, key) -> pd.DataFrame: 

184 bucket, key = self._get_bucket(key) 

185 

186 with self._connect_duckdb() as connection: 

187 cursor = connection.execute(f"SELECT * FROM 's3://{bucket}/{key}'") 

188 

189 return cursor.fetchdf() 

190 

191 def _read_as_content(self, key) -> None: 

192 bucket, key = self._get_bucket(key) 

193 

194 client = self.connect() 

195 

196 obj = client.get_object(Bucket=bucket, Key=key) 

197 content = obj["Body"].read() 

198 return content 

199 

200 def add_data_to_table(self, key, df) -> None: 

201 bucket, key = self._get_bucket(key) 

202 

203 try: 

204 client = self.connect() 

205 client.head_object(Bucket=bucket, Key=key) 

206 except ClientError as e: 

207 logger.error(f"Error querying the file {key} in the bucket {bucket}, {e}!") 

208 raise e 

209 

210 with self._connect_duckdb() as connection: 

211 connection.execute(f"CREATE TABLE tmp_table AS SELECT * FROM 's3://{bucket}/{key}'") 

212 

213 connection.execute("INSERT INTO tmp_table BY NAME SELECT * FROM df") 

214 

215 connection.execute(f"COPY tmp_table TO 's3://{bucket}/{key}'") 

216 

217 def query(self, query: ASTNode) -> Response: 

218 self.connect() 

219 if isinstance(query, Select): 

220 table_name = query.from_table.parts[-1] 

221 

222 if table_name == "files": 

223 table = self._files_table 

224 df = table.select(query) 

225 

226 has_content = False 

227 for target in query.targets: 

228 if isinstance(target, Identifier) and target.parts[-1].lower() == "content": 

229 has_content = True 

230 break 

231 if has_content: 

232 df["content"] = df["path"].apply(self._read_as_content) 

233 else: 

234 extension = table_name.split(".")[-1] 

235 if extension not in self.supported_file_formats: 

236 logger.error(f"The file format {extension} is not supported!") 

237 raise ValueError(f"The file format {extension} is not supported!") 

238 

239 table = FileTable(self, table_name=table_name) 

240 df = table.select(query) 

241 

242 response = Response(RESPONSE_TYPE.TABLE, data_frame=df) 

243 elif isinstance(query, Insert): 

244 table_name = query.table.parts[-1] 

245 table = FileTable(self, table_name=table_name) 

246 table.insert(query) 

247 response = Response(RESPONSE_TYPE.OK) 

248 else: 

249 raise NotImplementedError 

250 

251 return response 

252 

253 def get_objects(self, limit=None, buckets=None) -> List[dict]: 

254 client = self.connect() 

255 if self.bucket is not None: 

256 add_bucket_to_name = False 

257 scan_buckets = [self.bucket] 

258 else: 

259 add_bucket_to_name = True 

260 resp = client.list_buckets() 

261 scan_buckets = [b["Name"] for b in resp["Buckets"]] 

262 

263 objects = [] 

264 for bucket in scan_buckets: 

265 if buckets is not None and bucket not in buckets: 

266 continue 

267 

268 resp = client.list_objects_v2(Bucket=bucket) 

269 if "Contents" not in resp: 

270 continue 

271 

272 for obj in resp["Contents"]: 

273 obj["Bucket"] = bucket 

274 obj["Filename"] = obj["Key"] 

275 if add_bucket_to_name: 

276 obj["Key"] = f"{bucket}/{obj['Key']}" 

277 objects.append(obj) 

278 if limit is not None and len(objects) >= limit: 

279 break 

280 

281 return objects 

282 

283 def get_tables(self) -> Response: 

284 supported_names = [ 

285 f"{obj['Key']}" for obj in self.get_objects() if obj["Key"].split(".")[-1] in self.supported_file_formats 

286 ] 

287 

288 supported_names.insert(0, "files") 

289 

290 response = Response( 

291 RESPONSE_TYPE.TABLE, 

292 data_frame=pd.DataFrame(supported_names, columns=["table_name"]), 

293 ) 

294 

295 return response 

296 

297 def get_columns(self, table_name: str) -> Response: 

298 query = Select( 

299 targets=[Star()], 

300 from_table=Identifier(parts=[table_name]), 

301 limit=Constant(1), 

302 ) 

303 

304 result = self.query(query) 

305 

306 response = Response( 

307 RESPONSE_TYPE.TABLE, 

308 data_frame=pd.DataFrame( 

309 { 

310 "column_name": result.data_frame.columns, 

311 "data_type": [ 

312 str(dtype) if str(dtype) != "object" else "string" for dtype in result.data_frame.dtypes 

313 ], 

314 } 

315 ), 

316 ) 

317 

318 return response