Coverage for mindsdb / integrations / libs / api_handler_generator.py: 0%

327 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from dataclasses import dataclass 

2import re 

3from io import StringIO 

4import json 

5from typing import Dict, List, Any 

6import yaml 

7try: 

8 from yaml import CLoader as Loader 

9except ImportError: 

10 from yaml import Loader 

11 

12 

13import pandas as pd 

14import requests 

15from requests.auth import HTTPBasicAuth 

16 

17from mindsdb.integrations.utilities.sql_utils import ( 

18 FilterCondition, FilterOperator, SortColumn 

19) 

20from mindsdb.integrations.libs.api_handler import APIResource 

21 

22 

23class ApiRequestException(Exception): 

24 pass 

25 

26 

27class ApiResponseException(Exception): 

28 pass 

29 

30 

31@dataclass 

32class APIInfo: 

33 """ 

34 A class to store the information about the API. 

35 """ 

36 base_url: str = None 

37 auth: dict = None 

38 

39 

40@dataclass 

41class APIEndpoint: 

42 url: str 

43 method: str 

44 params: dict 

45 response: dict 

46 

47 

48@dataclass 

49class APIResourceType: 

50 type_name: str 

51 sub_type: str = None 

52 properties: dict[str, str] = None 

53 

54 

55@dataclass 

56class APIEndpointParam: 

57 name: str 

58 type: APIResourceType 

59 where: str = None 

60 default: Any = None 

61 

62 

63def find_common_url_prefix(urls): 

64 if len(urls) == 0: 

65 return '' 

66 urls = [ 

67 url.split('/') 

68 for url in urls 

69 ] 

70 

71 min_len = min(len(s) for s in urls) 

72 

73 for i in range(min_len): 

74 for j in range(1, len(urls)): 

75 if urls[j][i] != urls[0][i]: 

76 return '/'.join(urls[0][:i]) 

77 

78 return '/'.join(urls[0][:min_len]) 

79 

80 

81class OpenAPISpecParser: 

82 """ 

83 A class to parse the OpenAPI specification. 

84 """ 

85 def __init__(self, openapi_spec_path: str) -> None: 

86 if openapi_spec_path.startswith('http://') or openapi_spec_path.startswith('https://'): 

87 response = requests.get(openapi_spec_path) 

88 response.raise_for_status() 

89 

90 if openapi_spec_path.endswith('.json'): 

91 self.openapi_spec = response.json() 

92 else: 

93 stream = StringIO(response.text) 

94 self.openapi_spec = yaml.load(stream, Loader=Loader) 

95 else: 

96 raise ApiRequestException('URL is required') 

97 

98 def get_security_schemes(self) -> dict: 

99 """ 

100 Returns the security schemes defined in the OpenAPI specification. 

101 

102 Returns: 

103 dict: A dictionary containing the security schemes defined in the OpenAPI specification. 

104 """ 

105 return self.openapi_spec.get('components', {}).get('securitySchemes', {}) 

106 

107 def get_schemas(self) -> dict: 

108 """ 

109 Returns the schemas defined in the OpenAPI specification. 

110 

111 Returns: 

112 dict: A dictionary containing the schemas defined in the OpenAPI specification. 

113 """ 

114 return self.openapi_spec.get('components', {}).get('schemas', {}) 

115 

116 def get_paths(self) -> dict: 

117 """ 

118 Returns the paths defined in the OpenAPI specification. 

119 

120 Returns: 

121 dict: A dictionary containing the paths defined in the OpenAPI specification. 

122 """ 

123 return self.openapi_spec.get('paths', {}) 

124 

125 def get_specs(self) -> dict: 

126 return self.openapi_spec 

127 

128 

129class APIResourceGenerator: 

130 """ 

131 A class to generate API resources based on the OpenAPI specification. 

132 """ 

133 def __init__(self, url, connection_data, url_base=None, options=None) -> None: 

134 self.openapi_spec_parser = OpenAPISpecParser(url) 

135 self.connection_data = connection_data 

136 self.url_base = url_base 

137 self.options = options or {} 

138 self.resources = {} 

139 

140 def check_connection(self): 

141 if 'check_connection_table' in self.options: 

142 table = self.resources.get(self.options['check_connection_table']) 

143 if table: 

144 table.list(targets=[], limit=1, conditions=[]) 

145 

146 def generate_api_resources(self, handler, table_name_format='{url}') -> Dict[str, APIResource]: 

147 """ 

148 Generates an API resource based on the OpenAPI specification. 

149 

150 Returns: 

151 Type[APIResource]: The generated API resource class. 

152 """ 

153 paths = self.openapi_spec_parser.get_paths() 

154 schemas = self.openapi_spec_parser.get_schemas() 

155 self.security_schemes = self.openapi_spec_parser.get_security_schemes() 

156 

157 self.resource_types = self.process_resource_types(schemas) 

158 endpoints = self.process_endpoints(paths) 

159 

160 prefix_len = len(find_common_url_prefix([i.url for i in endpoints])) 

161 

162 for endpoint in endpoints: 

163 url = endpoint.url[prefix_len:] 

164 # replace placehoders with x 

165 url = re.sub(r"{(\w+)}", 'x', url) 

166 url = url.replace('/', '_').strip('_') 

167 table_name = table_name_format.format(url=url, method=endpoint.method).lower() 

168 self.resources[table_name] = RestApiTable(handler, endpoint=endpoint, resource_gen=self) 

169 

170 return self.resources 

171 

172 def process_resource_types(self, schemas: dict) -> dict: 

173 resource_types = {} 

174 for name, schema_info in schemas.items(): 

175 resource_types[name] = self._convert_to_resource_type(schema_info) 

176 

177 return resource_types 

178 

179 def process_endpoints(self, paths: dict) -> List[APIEndpoint]: 

180 """ 

181 Processes the endpoints defined in the OpenAPI specification. 

182 

183 Args: 

184 endpoints (Dict): A dictionary containing the endpoints defined in the OpenAPI specification. 

185 

186 Returns: 

187 Dict: A dictionary containing the processed endpoints. 

188 """ 

189 endpoints = [] 

190 for path, path_info in paths.items(): 

191 # filter endpoints by url base 

192 if self.url_base is not None and (not path.startswith(self.url_base) or path == self.url_base): 

193 continue 

194 

195 for http_method, method_info in path_info.items(): 

196 if http_method != 'get': 

197 continue 

198 

199 parameters = self._process_endpoint_parameters(method_info['parameters']) if 'parameters' in method_info else {} 

200 

201 response = self._process_endpoint_response(method_info['responses']) 

202 if response['type'] is None: 

203 continue 

204 

205 endpoint = APIEndpoint( 

206 url=path, 

207 method=http_method, 

208 params=parameters, 

209 response=response 

210 ) 

211 

212 endpoints.append(endpoint) 

213 

214 return endpoints 

215 

216 def get_ref_object(self, ref): 

217 # get object by $ref link 

218 el = self.openapi_spec_parser.get_specs() 

219 for path in ref.lstrip('#').split('/'): 

220 if path: 

221 el = el[path] 

222 return el 

223 

224 def _process_endpoint_parameters(self, parameters: list) -> Dict[str, APIEndpointParam]: 

225 """ 

226 Processes the parameters defined in the OpenAPI specification. 

227 

228 Args: 

229 parameters (Dict): A dictionary containing the parameters defined in the OpenAPI specification. 

230 

231 Returns: 

232 Dict: A dictionary containing the processed parameters. 

233 """ 

234 endpoint_parameters = {} 

235 for parameter in parameters: 

236 if '$ref' in parameter: 

237 parameter = self.get_ref_object(parameter['$ref']) 

238 

239 type_name = self.get_resource_type(parameter['schema']) 

240 

241 endpoint_parameters[parameter['name']] = APIEndpointParam( 

242 name=parameter['name'], 

243 type=type_name, 

244 default=parameter['schema'].get('default'), 

245 where=parameter['in'], 

246 ) 

247 

248 return endpoint_parameters 

249 

250 def _process_endpoint_response(self, responses: dict): 

251 response = None 

252 response_path = [] # used to find list in response 

253 

254 if '200' not in responses: 

255 return {'type': None} 

256 

257 view = 'table' 

258 

259 resp_success = responses['200'] 

260 if '$ref' in resp_success: 

261 resp_success = self.get_ref_object(responses['200']['$ref']) 

262 

263 for content_type, resp_info in resp_success['content'].items(): 

264 if content_type != 'application/json': 

265 continue 

266 

267 # type_name=get_type(resp_info['schema']) 

268 if 'schema' not in resp_info: 

269 continue 

270 

271 resource_type = self._convert_to_resource_type(resp_info['schema']) 

272 

273 # resolve type 

274 type_name = None 

275 if resource_type.type_name in self.resource_types: 

276 type_name = resource_type.type_name 

277 resource_type = self.resource_types[resource_type.type_name] 

278 

279 if resource_type.type_name == 'array': 

280 response = resource_type.sub_type 

281 elif resource_type.type_name == 'object': 

282 if resource_type.properties is None: 

283 raise NotImplementedError 

284 

285 # if it is a table find property with list 

286 is_table = False 

287 if 'total_column' in self.options: 

288 for col in self.options['total_column']: 

289 if col in resource_type.properties: 

290 is_table = True 

291 

292 if is_table: 

293 for k, v in resource_type.properties.items(): 

294 if v.type_name == 'array': 

295 

296 response = v.sub_type 

297 response_path.append(k) 

298 break 

299 else: 

300 response = type_name 

301 view = 'record' 

302 break 

303 

304 return { 

305 'type': response, 

306 'path': response_path, 

307 'view': view 

308 } 

309 

310 def _convert_to_resource_type(self, schema: dict) -> APIResourceType: 

311 """ 

312 Converts the schema information to a resource type. 

313 

314 Args: 

315 schema (Dict): A dictionary containing the schema information. 

316 

317 Returns: 

318 APIResourceType: An object containing the resource type information. 

319 """ 

320 type_name = self.get_resource_type(schema) 

321 # type_name= info['type'] 

322 

323 kwargs = { 

324 # 'name': name, 

325 'type_name': type_name, 

326 } 

327 

328 if type_name == 'object': 

329 properties = {} 

330 if 'properties' in schema: 

331 for k, v in schema['properties'].items(): 

332 # type_name2 = get_type(v) 

333 properties[k] = self._convert_to_resource_type(v) 

334 elif 'additionalProperties' in schema: 

335 if isinstance(schema['additionalProperties'], dict) and 'type' in schema['additionalProperties']: 

336 type_name = schema['additionalProperties']['type'] 

337 else: 

338 type_name = 'string' 

339 

340 kwargs['properties'] = properties 

341 if type_name == 'array' and 'items' in schema: 

342 kwargs['sub_type'] = self.get_resource_type(schema['items']) 

343 

344 return APIResourceType(**kwargs) 

345 

346 def get_resource_type(self, schema: dict) -> str: 

347 if 'type' in schema: 

348 return schema['type'] 

349 

350 elif '$ref' in schema: 

351 return schema['$ref'].split('/')[-1] 

352 

353 elif 'allOf' in schema: 

354 # TODO Get only the first type. 

355 return self.get_resource_type(schema['allOf'][0]) 

356 

357 

358class RestApiTable(APIResource): 

359 def __init__(self, *args, endpoint: APIEndpoint = None, resource_gen=None, **kwargs): 

360 self.endpoint = endpoint 

361 resource_types = resource_gen.resource_types 

362 self.connection_data = resource_gen.connection_data 

363 self.security_schemes = resource_gen.security_schemes 

364 self.options = resource_gen.options 

365 

366 self.output_columns = {} 

367 response_type = endpoint.response['type'] 

368 if response_type in resource_types: 

369 self.output_columns = resource_types[response_type].properties 

370 else: 

371 # let it be single column with this type 

372 self.output_columns = {'value': response_type} 

373 

374 # check params: 

375 self.params, self.list_params = [], [] 

376 for name, param in endpoint.params.items(): 

377 self.params.append(name) 

378 if param.type == 'array': 

379 self.list_params.append(name) 

380 

381 super().__init__(*args, **kwargs) 

382 

383 def repr_value(self, value): 

384 # convert dict and lists to strings to show it response table 

385 

386 if isinstance(value, dict): 

387 # remove empty keys 

388 value = { 

389 k: v 

390 for k, v in value.items() 

391 if v is not None 

392 } 

393 value = json.dumps(value) 

394 elif isinstance(value, list): 

395 value = ",".join([str(i) for i in value]) 

396 return value 

397 

398 def _handle_auth(self) -> dict: 

399 """ 

400 Processes the authentication mechanism defined in the OpenAPI specification. 

401 Args: 

402 security_schemes (Dict): A dictionary containing the security schemes defined in the OpenAPI specification. 

403 Returns: 

404 Dict: A dictionary containing the authentication information required to connect to the API. 

405 """ 

406 # API key authentication will be given preference over other mechanisms. 

407 # NOTE: If the API supports multiple authentication mechanisms, should they be supported? Which one should be given preference? 

408 

409 security_schemes = self.security_schemes 

410 

411 if 'token' in self.connection_data: 

412 headers = {'Authorization': f'Bearer {self.connection_data["token"]}'} 

413 

414 return { 

415 "headers": headers 

416 } 

417 

418 elif 'basicAuth' in security_schemes: 

419 # For basic authentication, the username and password are required. 

420 if not all( 

421 key in self.connection_data 

422 for key in ["username", "password"] 

423 ): 

424 raise ApiRequestException( 

425 "The username and password are required for basic authentication." 

426 ) 

427 return { 

428 "auth": HTTPBasicAuth( 

429 self.connection_data["username"], 

430 self.connection_data["password"], 

431 ), 

432 } 

433 return {} 

434 

435 def get_columns(self) -> List[str]: 

436 return list(self.output_columns.keys()) 

437 

438 def get_setting_param(self, setting_name: str) -> str: 

439 # find input param name for specific setting 

440 

441 if setting_name in self.options: 

442 for col in self.options[setting_name]: 

443 if col in self.endpoint.params: 

444 return col 

445 

446 def get_user_params(self): 

447 params = {} 

448 for k, v in self.connection_data.items(): 

449 if k not in ('username', 'password', 'token', 'api_base'): 

450 params[k] = v 

451 return params 

452 

453 def _api_request(self, filters): 

454 query, body, path_vars = {}, {}, {} 

455 for name, value in filters.items(): 

456 param = self.endpoint.params[name] 

457 if param.where == 'query': 

458 query[name] = value 

459 elif param.where == 'path': 

460 path_vars[name] = value 

461 else: 

462 body[name] = value 

463 

464 url = self.connection_data['api_base'] + self.endpoint.url 

465 if path_vars: 

466 url = url.format(**path_vars) 

467 # check empty placeholders 

468 placeholders = re.findall(r"{(\w+)}", url) 

469 if placeholders: 

470 raise ApiRequestException('Parameters are required: ' + ', '.join(placeholders)) 

471 

472 kwargs = self._handle_auth() 

473 req = requests.request(self.endpoint.method, url, params=query, data=body, **kwargs) 

474 

475 if req.status_code != 200: 

476 raise ApiResponseException(req.text) 

477 resp = req.json() 

478 

479 total = None 

480 if 'total_column' in self.options and isinstance(resp, dict): 

481 for col in self.options['total_column']: 

482 if col in resp: 

483 total = resp[col] 

484 break 

485 

486 for item in self.endpoint.response['path']: 

487 resp = resp[item] 

488 

489 if self.endpoint.response['view'] == 'record': 

490 # response is one record, make table 

491 resp = [resp] 

492 return resp, total 

493 

494 def list( 

495 self, 

496 conditions: List[FilterCondition] = None, 

497 limit: int = None, 

498 sort: List[SortColumn] = None, 

499 targets: List[str] = None, 

500 **kwargs 

501 ) -> pd.DataFrame: 

502 

503 if limit is None: 

504 limit = 20 

505 

506 filters = {} 

507 if conditions: 

508 for condition in conditions: 

509 if condition.column not in self.params: 

510 continue 

511 

512 if condition.column in self.list_params: 

513 if condition.op == FilterOperator.IN: 

514 filters[condition.column] = condition.value 

515 elif condition.op == FilterOperator.EQUAL: 

516 filters[condition.column] = [condition] 

517 condition.applied = True 

518 else: 

519 filters[condition.column] = condition.value 

520 condition.applied = True 

521 

522 # user params 

523 params = self.get_user_params() 

524 if params: 

525 filters.update(params) 

526 

527 page_size_param = self.get_setting_param('page_size_param') 

528 page_size = None 

529 if page_size_param is not None: 

530 # use default value for page size 

531 page_size = self.endpoint.params[page_size_param].default 

532 if page_size: 

533 filters[page_size_param] = page_size 

534 resp, total = self._api_request(filters) 

535 

536 # pagination 

537 offset_param = self.get_setting_param('offset_param') 

538 page_num_param = self.get_setting_param('page_num_param') 

539 if offset_param is not None or page_num_param is not None: 

540 page_num = 1 

541 while True: 

542 count = len(resp) 

543 if limit <= count: 

544 break 

545 

546 if total is not None and total <= count: 

547 # total is reached 

548 break 

549 

550 if page_size is not None and page_size > count: 

551 # number of results are more than page, don't go to next page 

552 break 

553 

554 # download more pages 

555 if offset_param: 

556 filters[offset_param] = count 

557 else: 

558 page_num += 1 

559 filters[page_num_param] = page_num 

560 resp2, total = self._api_request(filters) 

561 if len(resp2) == 0: 

562 # no results from next page 

563 break 

564 resp.extend(resp2) 

565 

566 resp = resp[:limit] 

567 

568 data = [] 

569 

570 columns = self.get_columns() 

571 for record in resp: 

572 item = {} 

573 

574 if isinstance(record, dict): 

575 for name, value in record.items(): 

576 item[name] = self.repr_value(value) 

577 

578 data.append(item) 

579 elif len(columns) > 0: 

580 # response is value 

581 item[columns[0]] = self.repr_value(record) 

582 

583 return pd.DataFrame(data, columns=columns)