Coverage for mindsdb / integrations / libs / api_handler_generator.py: 0%
327 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from dataclasses import dataclass
2import re
3from io import StringIO
4import json
5from typing import Dict, List, Any
6import yaml
7try:
8 from yaml import CLoader as Loader
9except ImportError:
10 from yaml import Loader
13import pandas as pd
14import requests
15from requests.auth import HTTPBasicAuth
17from mindsdb.integrations.utilities.sql_utils import (
18 FilterCondition, FilterOperator, SortColumn
19)
20from mindsdb.integrations.libs.api_handler import APIResource
23class ApiRequestException(Exception):
24 pass
27class ApiResponseException(Exception):
28 pass
31@dataclass
32class APIInfo:
33 """
34 A class to store the information about the API.
35 """
36 base_url: str = None
37 auth: dict = None
40@dataclass
41class APIEndpoint:
42 url: str
43 method: str
44 params: dict
45 response: dict
48@dataclass
49class APIResourceType:
50 type_name: str
51 sub_type: str = None
52 properties: dict[str, str] = None
55@dataclass
56class APIEndpointParam:
57 name: str
58 type: APIResourceType
59 where: str = None
60 default: Any = None
63def find_common_url_prefix(urls):
64 if len(urls) == 0:
65 return ''
66 urls = [
67 url.split('/')
68 for url in urls
69 ]
71 min_len = min(len(s) for s in urls)
73 for i in range(min_len):
74 for j in range(1, len(urls)):
75 if urls[j][i] != urls[0][i]:
76 return '/'.join(urls[0][:i])
78 return '/'.join(urls[0][:min_len])
81class OpenAPISpecParser:
82 """
83 A class to parse the OpenAPI specification.
84 """
85 def __init__(self, openapi_spec_path: str) -> None:
86 if openapi_spec_path.startswith('http://') or openapi_spec_path.startswith('https://'):
87 response = requests.get(openapi_spec_path)
88 response.raise_for_status()
90 if openapi_spec_path.endswith('.json'):
91 self.openapi_spec = response.json()
92 else:
93 stream = StringIO(response.text)
94 self.openapi_spec = yaml.load(stream, Loader=Loader)
95 else:
96 raise ApiRequestException('URL is required')
98 def get_security_schemes(self) -> dict:
99 """
100 Returns the security schemes defined in the OpenAPI specification.
102 Returns:
103 dict: A dictionary containing the security schemes defined in the OpenAPI specification.
104 """
105 return self.openapi_spec.get('components', {}).get('securitySchemes', {})
107 def get_schemas(self) -> dict:
108 """
109 Returns the schemas defined in the OpenAPI specification.
111 Returns:
112 dict: A dictionary containing the schemas defined in the OpenAPI specification.
113 """
114 return self.openapi_spec.get('components', {}).get('schemas', {})
116 def get_paths(self) -> dict:
117 """
118 Returns the paths defined in the OpenAPI specification.
120 Returns:
121 dict: A dictionary containing the paths defined in the OpenAPI specification.
122 """
123 return self.openapi_spec.get('paths', {})
125 def get_specs(self) -> dict:
126 return self.openapi_spec
129class APIResourceGenerator:
130 """
131 A class to generate API resources based on the OpenAPI specification.
132 """
133 def __init__(self, url, connection_data, url_base=None, options=None) -> None:
134 self.openapi_spec_parser = OpenAPISpecParser(url)
135 self.connection_data = connection_data
136 self.url_base = url_base
137 self.options = options or {}
138 self.resources = {}
140 def check_connection(self):
141 if 'check_connection_table' in self.options:
142 table = self.resources.get(self.options['check_connection_table'])
143 if table:
144 table.list(targets=[], limit=1, conditions=[])
146 def generate_api_resources(self, handler, table_name_format='{url}') -> Dict[str, APIResource]:
147 """
148 Generates an API resource based on the OpenAPI specification.
150 Returns:
151 Type[APIResource]: The generated API resource class.
152 """
153 paths = self.openapi_spec_parser.get_paths()
154 schemas = self.openapi_spec_parser.get_schemas()
155 self.security_schemes = self.openapi_spec_parser.get_security_schemes()
157 self.resource_types = self.process_resource_types(schemas)
158 endpoints = self.process_endpoints(paths)
160 prefix_len = len(find_common_url_prefix([i.url for i in endpoints]))
162 for endpoint in endpoints:
163 url = endpoint.url[prefix_len:]
164 # replace placehoders with x
165 url = re.sub(r"{(\w+)}", 'x', url)
166 url = url.replace('/', '_').strip('_')
167 table_name = table_name_format.format(url=url, method=endpoint.method).lower()
168 self.resources[table_name] = RestApiTable(handler, endpoint=endpoint, resource_gen=self)
170 return self.resources
172 def process_resource_types(self, schemas: dict) -> dict:
173 resource_types = {}
174 for name, schema_info in schemas.items():
175 resource_types[name] = self._convert_to_resource_type(schema_info)
177 return resource_types
179 def process_endpoints(self, paths: dict) -> List[APIEndpoint]:
180 """
181 Processes the endpoints defined in the OpenAPI specification.
183 Args:
184 endpoints (Dict): A dictionary containing the endpoints defined in the OpenAPI specification.
186 Returns:
187 Dict: A dictionary containing the processed endpoints.
188 """
189 endpoints = []
190 for path, path_info in paths.items():
191 # filter endpoints by url base
192 if self.url_base is not None and (not path.startswith(self.url_base) or path == self.url_base):
193 continue
195 for http_method, method_info in path_info.items():
196 if http_method != 'get':
197 continue
199 parameters = self._process_endpoint_parameters(method_info['parameters']) if 'parameters' in method_info else {}
201 response = self._process_endpoint_response(method_info['responses'])
202 if response['type'] is None:
203 continue
205 endpoint = APIEndpoint(
206 url=path,
207 method=http_method,
208 params=parameters,
209 response=response
210 )
212 endpoints.append(endpoint)
214 return endpoints
216 def get_ref_object(self, ref):
217 # get object by $ref link
218 el = self.openapi_spec_parser.get_specs()
219 for path in ref.lstrip('#').split('/'):
220 if path:
221 el = el[path]
222 return el
224 def _process_endpoint_parameters(self, parameters: list) -> Dict[str, APIEndpointParam]:
225 """
226 Processes the parameters defined in the OpenAPI specification.
228 Args:
229 parameters (Dict): A dictionary containing the parameters defined in the OpenAPI specification.
231 Returns:
232 Dict: A dictionary containing the processed parameters.
233 """
234 endpoint_parameters = {}
235 for parameter in parameters:
236 if '$ref' in parameter:
237 parameter = self.get_ref_object(parameter['$ref'])
239 type_name = self.get_resource_type(parameter['schema'])
241 endpoint_parameters[parameter['name']] = APIEndpointParam(
242 name=parameter['name'],
243 type=type_name,
244 default=parameter['schema'].get('default'),
245 where=parameter['in'],
246 )
248 return endpoint_parameters
250 def _process_endpoint_response(self, responses: dict):
251 response = None
252 response_path = [] # used to find list in response
254 if '200' not in responses:
255 return {'type': None}
257 view = 'table'
259 resp_success = responses['200']
260 if '$ref' in resp_success:
261 resp_success = self.get_ref_object(responses['200']['$ref'])
263 for content_type, resp_info in resp_success['content'].items():
264 if content_type != 'application/json':
265 continue
267 # type_name=get_type(resp_info['schema'])
268 if 'schema' not in resp_info:
269 continue
271 resource_type = self._convert_to_resource_type(resp_info['schema'])
273 # resolve type
274 type_name = None
275 if resource_type.type_name in self.resource_types:
276 type_name = resource_type.type_name
277 resource_type = self.resource_types[resource_type.type_name]
279 if resource_type.type_name == 'array':
280 response = resource_type.sub_type
281 elif resource_type.type_name == 'object':
282 if resource_type.properties is None:
283 raise NotImplementedError
285 # if it is a table find property with list
286 is_table = False
287 if 'total_column' in self.options:
288 for col in self.options['total_column']:
289 if col in resource_type.properties:
290 is_table = True
292 if is_table:
293 for k, v in resource_type.properties.items():
294 if v.type_name == 'array':
296 response = v.sub_type
297 response_path.append(k)
298 break
299 else:
300 response = type_name
301 view = 'record'
302 break
304 return {
305 'type': response,
306 'path': response_path,
307 'view': view
308 }
310 def _convert_to_resource_type(self, schema: dict) -> APIResourceType:
311 """
312 Converts the schema information to a resource type.
314 Args:
315 schema (Dict): A dictionary containing the schema information.
317 Returns:
318 APIResourceType: An object containing the resource type information.
319 """
320 type_name = self.get_resource_type(schema)
321 # type_name= info['type']
323 kwargs = {
324 # 'name': name,
325 'type_name': type_name,
326 }
328 if type_name == 'object':
329 properties = {}
330 if 'properties' in schema:
331 for k, v in schema['properties'].items():
332 # type_name2 = get_type(v)
333 properties[k] = self._convert_to_resource_type(v)
334 elif 'additionalProperties' in schema:
335 if isinstance(schema['additionalProperties'], dict) and 'type' in schema['additionalProperties']:
336 type_name = schema['additionalProperties']['type']
337 else:
338 type_name = 'string'
340 kwargs['properties'] = properties
341 if type_name == 'array' and 'items' in schema:
342 kwargs['sub_type'] = self.get_resource_type(schema['items'])
344 return APIResourceType(**kwargs)
346 def get_resource_type(self, schema: dict) -> str:
347 if 'type' in schema:
348 return schema['type']
350 elif '$ref' in schema:
351 return schema['$ref'].split('/')[-1]
353 elif 'allOf' in schema:
354 # TODO Get only the first type.
355 return self.get_resource_type(schema['allOf'][0])
358class RestApiTable(APIResource):
359 def __init__(self, *args, endpoint: APIEndpoint = None, resource_gen=None, **kwargs):
360 self.endpoint = endpoint
361 resource_types = resource_gen.resource_types
362 self.connection_data = resource_gen.connection_data
363 self.security_schemes = resource_gen.security_schemes
364 self.options = resource_gen.options
366 self.output_columns = {}
367 response_type = endpoint.response['type']
368 if response_type in resource_types:
369 self.output_columns = resource_types[response_type].properties
370 else:
371 # let it be single column with this type
372 self.output_columns = {'value': response_type}
374 # check params:
375 self.params, self.list_params = [], []
376 for name, param in endpoint.params.items():
377 self.params.append(name)
378 if param.type == 'array':
379 self.list_params.append(name)
381 super().__init__(*args, **kwargs)
383 def repr_value(self, value):
384 # convert dict and lists to strings to show it response table
386 if isinstance(value, dict):
387 # remove empty keys
388 value = {
389 k: v
390 for k, v in value.items()
391 if v is not None
392 }
393 value = json.dumps(value)
394 elif isinstance(value, list):
395 value = ",".join([str(i) for i in value])
396 return value
398 def _handle_auth(self) -> dict:
399 """
400 Processes the authentication mechanism defined in the OpenAPI specification.
401 Args:
402 security_schemes (Dict): A dictionary containing the security schemes defined in the OpenAPI specification.
403 Returns:
404 Dict: A dictionary containing the authentication information required to connect to the API.
405 """
406 # API key authentication will be given preference over other mechanisms.
407 # NOTE: If the API supports multiple authentication mechanisms, should they be supported? Which one should be given preference?
409 security_schemes = self.security_schemes
411 if 'token' in self.connection_data:
412 headers = {'Authorization': f'Bearer {self.connection_data["token"]}'}
414 return {
415 "headers": headers
416 }
418 elif 'basicAuth' in security_schemes:
419 # For basic authentication, the username and password are required.
420 if not all(
421 key in self.connection_data
422 for key in ["username", "password"]
423 ):
424 raise ApiRequestException(
425 "The username and password are required for basic authentication."
426 )
427 return {
428 "auth": HTTPBasicAuth(
429 self.connection_data["username"],
430 self.connection_data["password"],
431 ),
432 }
433 return {}
435 def get_columns(self) -> List[str]:
436 return list(self.output_columns.keys())
438 def get_setting_param(self, setting_name: str) -> str:
439 # find input param name for specific setting
441 if setting_name in self.options:
442 for col in self.options[setting_name]:
443 if col in self.endpoint.params:
444 return col
446 def get_user_params(self):
447 params = {}
448 for k, v in self.connection_data.items():
449 if k not in ('username', 'password', 'token', 'api_base'):
450 params[k] = v
451 return params
453 def _api_request(self, filters):
454 query, body, path_vars = {}, {}, {}
455 for name, value in filters.items():
456 param = self.endpoint.params[name]
457 if param.where == 'query':
458 query[name] = value
459 elif param.where == 'path':
460 path_vars[name] = value
461 else:
462 body[name] = value
464 url = self.connection_data['api_base'] + self.endpoint.url
465 if path_vars:
466 url = url.format(**path_vars)
467 # check empty placeholders
468 placeholders = re.findall(r"{(\w+)}", url)
469 if placeholders:
470 raise ApiRequestException('Parameters are required: ' + ', '.join(placeholders))
472 kwargs = self._handle_auth()
473 req = requests.request(self.endpoint.method, url, params=query, data=body, **kwargs)
475 if req.status_code != 200:
476 raise ApiResponseException(req.text)
477 resp = req.json()
479 total = None
480 if 'total_column' in self.options and isinstance(resp, dict):
481 for col in self.options['total_column']:
482 if col in resp:
483 total = resp[col]
484 break
486 for item in self.endpoint.response['path']:
487 resp = resp[item]
489 if self.endpoint.response['view'] == 'record':
490 # response is one record, make table
491 resp = [resp]
492 return resp, total
494 def list(
495 self,
496 conditions: List[FilterCondition] = None,
497 limit: int = None,
498 sort: List[SortColumn] = None,
499 targets: List[str] = None,
500 **kwargs
501 ) -> pd.DataFrame:
503 if limit is None:
504 limit = 20
506 filters = {}
507 if conditions:
508 for condition in conditions:
509 if condition.column not in self.params:
510 continue
512 if condition.column in self.list_params:
513 if condition.op == FilterOperator.IN:
514 filters[condition.column] = condition.value
515 elif condition.op == FilterOperator.EQUAL:
516 filters[condition.column] = [condition]
517 condition.applied = True
518 else:
519 filters[condition.column] = condition.value
520 condition.applied = True
522 # user params
523 params = self.get_user_params()
524 if params:
525 filters.update(params)
527 page_size_param = self.get_setting_param('page_size_param')
528 page_size = None
529 if page_size_param is not None:
530 # use default value for page size
531 page_size = self.endpoint.params[page_size_param].default
532 if page_size:
533 filters[page_size_param] = page_size
534 resp, total = self._api_request(filters)
536 # pagination
537 offset_param = self.get_setting_param('offset_param')
538 page_num_param = self.get_setting_param('page_num_param')
539 if offset_param is not None or page_num_param is not None:
540 page_num = 1
541 while True:
542 count = len(resp)
543 if limit <= count:
544 break
546 if total is not None and total <= count:
547 # total is reached
548 break
550 if page_size is not None and page_size > count:
551 # number of results are more than page, don't go to next page
552 break
554 # download more pages
555 if offset_param:
556 filters[offset_param] = count
557 else:
558 page_num += 1
559 filters[page_num_param] = page_num
560 resp2, total = self._api_request(filters)
561 if len(resp2) == 0:
562 # no results from next page
563 break
564 resp.extend(resp2)
566 resp = resp[:limit]
568 data = []
570 columns = self.get_columns()
571 for record in resp:
572 item = {}
574 if isinstance(record, dict):
575 for name, value in record.items():
576 item[name] = self.repr_value(value)
578 data.append(item)
579 elif len(columns) > 0:
580 # response is value
581 item[columns[0]] = self.repr_value(record)
583 return pd.DataFrame(data, columns=columns)