Coverage for mindsdb / integrations / handlers / plaid_handler / plaid_handler.py: 0%
96 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import pandas as pd
2from mindsdb.utilities import log
3from mindsdb.integrations.libs.api_handler import APIHandler, FuncParser
4from mindsdb.integrations.libs.response import (
5 HandlerStatusResponse as StatusResponse,
6 HandlerResponse as Response,
7 RESPONSE_TYPE
8)
10from datetime import datetime
11import plaid
12from plaid.api import plaid_api
13from plaid.model.accounts_balance_get_request import AccountsBalanceGetRequest
14from plaid.model.accounts_balance_get_request_options import AccountsBalanceGetRequestOptions
15from plaid.model.transactions_get_request import TransactionsGetRequest
16from plaid.model.transactions_get_request_options import TransactionsGetRequestOptions
17from .plaid_tables import BalanceTable, TransactionTable
18from .utils import parse_transaction
21PLAID_ENV = {
22 'production': plaid.Environment.Production,
23 'development': plaid.Environment.Development,
24 'sandbox': plaid.Environment.Sandbox,
25}
27logger = log.getLogger(__name__)
30class PlaidHandler(APIHandler):
31 '''A class for handling connections and interactions with the Plaid API.
33 Attributes:
34 plaid_env (str): Enviroment used by user [ 'sandbox'(default) OR 'development' OR 'production' ].
35 client_id (str): Your Plaid API client_id.
36 secret (str): Your Plaid API secret
37 access_token (str): The access token for the Plaid account.
38 '''
40 def __init__(self, name=None, **kwargs):
41 super().__init__(name)
43 args = kwargs.get('connection_data', {})
45 self.plaid_config = plaid.Configuration(
46 host=PLAID_ENV[args.get('plaid_env', 'sandbox')],
47 api_key={
48 'clientId': args.get('client_id'),
49 'secret': args.get('secret')
50 }
51 )
53 self.access_token = args.get('access_token')
55 self.api = None
56 self.is_connected = False
58 balance = BalanceTable(self)
59 transactions = TransactionTable(self)
60 self._register_table('balance', balance)
61 self._register_table('transactions', transactions)
63 def connect(self):
64 '''Authenticate with the Plaid API using the API keys and secrets stored in the `plaid_env`, `client_id`, `secret` , and `access_token` attributes.''' # noqa
66 if self.is_connected is True:
67 return self.api
69 api_client = plaid.ApiClient(self.plaid_config)
70 self.api = plaid_api.PlaidApi(api_client)
71 self.is_connected = True
72 return self.api
74 def check_connection(self) -> StatusResponse:
75 '''It evaluates if the connection with Plaid API is alive and healthy.
76 Returns:
77 HandlerStatusResponse
78 '''
80 response = StatusResponse(False)
82 try:
83 api = self.connect()
84 api.accounts_balance_get(AccountsBalanceGetRequest(
85 access_token=self.access_token)
86 )
87 response.success = True
89 except Exception as e:
90 response.error_message = f'Error connecting to Plaid api: {e}. '
91 logger.error(response.error_message)
93 if response.success is False and self.is_connected is True:
94 self.is_connected = False
96 return response
98 def native_query(self, query_string: str = None):
99 '''It parses any native statement string and acts upon it (for example, raw syntax commands).
100 Args:
101 query (Any): query in native format (str for sql databases,
102 dict for mongo, api's json etc)
103 Returns:
104 HandlerResponse
105 '''
107 method_name, params = FuncParser().from_string(query_string)
108 df = self.call_plaid_api(method_name, params)
109 return Response(
110 RESPONSE_TYPE.TABLE,
111 data_frame=df
112 )
114 def call_plaid_api(self, method_name: str = None, params: dict = {}):
115 '''Receive query as AST (abstract syntax tree) and act upon it somehow.
116 Args:
117 query (ASTNode): sql query represented as AST. May be any kind of query: SELECT, INSERT, DELETE, etc
118 Returns:
119 DataFrame
120 '''
122 result = pd.DataFrame()
123 if method_name == 'get_balance':
124 result = self.get_balance(params=params)
125 result = BalanceTable(self).filter_columns(result=result)
127 elif method_name == 'get_transactions':
128 result = self.get_transactions(params=params)
129 result = TransactionTable(self).filter_columns(result=result)
131 return result
133 def get_balance(self, params=None):
134 '''Filters data from Plaid API's balance endpoint and returns a DataFrame with the required information.
136 Args:
137 params (dict, optional): A dictionary of options to be passed to the Plaid API.
139 Returns:
140 pandas.DataFrame: A DataFrame containing the filtered data.
141 '''
143 self.connect()
144 if params.get('last_updated_datetime') is not None:
145 options = AccountsBalanceGetRequestOptions(
146 min_last_updated_datetime=datetime.strptime(
147 params.get('last_updated_datetime')
148 )
149 )
151 response = self.api.accounts_balance_get(
152 AccountsBalanceGetRequest(
153 access_token=self.access_token,
154 options=options
155 )
156 )
157 else:
158 response = self.api.accounts_balance_get(
159 AccountsBalanceGetRequest(access_token=self.access_token)
160 )
162 messages = []
163 for obj in response['accounts']:
164 message_dict = {}
165 for i in obj.to_dict().keys():
166 if i.startswith('account_'):
167 message_dict[i] = obj[i]
168 elif i == 'balances':
169 dict_obj = obj[i].to_dict()
170 for j in dict_obj.keys():
171 message_dict[f'balance_{j}'] = dict_obj[j]
172 else:
173 message_dict[f'account_{i}'] = obj[i]
174 messages.append(message_dict)
175 df = pd.DataFrame(messages)
177 return df
179 def get_transactions(self, params={}):
180 '''
181 Filters data from Plaid API's transaction endpoint and returns a DataFrame with the required information.
182 Args:
183 params (dict, optional): A dictionary of options to be passed to the Plaid API.
185 Returns:
186 pandas.DataFrame: A DataFrame containing the filtered data.
187 '''
189 self.connect()
190 if params.get('start_date') and params.get('end_date'):
191 start_date = datetime.strptime(params.get('start_date'), '%Y-%m-%d').date()
192 end_date = datetime.strptime(params.get('end_date'), '%Y-%m-%d').date()
193 else:
194 raise Exception('start_date and end_date is required in format YYYY-MM-DD ')
196 request = TransactionsGetRequest(
197 access_token=self.access_token,
198 start_date=start_date,
199 end_date=end_date,
200 options=TransactionsGetRequestOptions()
201 )
203 response = self.api.transactions_get(request)
204 transactions = parse_transaction(response['transactions'])
206 # Manipulate the count and offset parameters to paginate
207 # transactions and retrieve all available data
208 while len(transactions) < response['total_transactions']:
209 request = TransactionsGetRequest(
210 access_token=self.access_token,
211 start_date=start_date,
212 end_date=end_date,
213 options=TransactionsGetRequestOptions(
214 offset=len(transactions)
215 )
216 )
217 response = self.api.transactions_get(request)
218 transactions.extend(parse_transaction(response['transactions']))
220 # Converting date column from str
221 df = pd.DataFrame(transactions)
222 for i in ['date', 'authorized_date']:
223 df[i] = pd.to_datetime(df[i]).dt.date
225 return df