Coverage for mindsdb / integrations / handlers / plaid_handler / plaid_handler.py: 0%

96 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import pandas as pd 

2from mindsdb.utilities import log 

3from mindsdb.integrations.libs.api_handler import APIHandler, FuncParser 

4from mindsdb.integrations.libs.response import ( 

5 HandlerStatusResponse as StatusResponse, 

6 HandlerResponse as Response, 

7 RESPONSE_TYPE 

8) 

9 

10from datetime import datetime 

11import plaid 

12from plaid.api import plaid_api 

13from plaid.model.accounts_balance_get_request import AccountsBalanceGetRequest 

14from plaid.model.accounts_balance_get_request_options import AccountsBalanceGetRequestOptions 

15from plaid.model.transactions_get_request import TransactionsGetRequest 

16from plaid.model.transactions_get_request_options import TransactionsGetRequestOptions 

17from .plaid_tables import BalanceTable, TransactionTable 

18from .utils import parse_transaction 

19 

20 

21PLAID_ENV = { 

22 'production': plaid.Environment.Production, 

23 'development': plaid.Environment.Development, 

24 'sandbox': plaid.Environment.Sandbox, 

25} 

26 

27logger = log.getLogger(__name__) 

28 

29 

30class PlaidHandler(APIHandler): 

31 '''A class for handling connections and interactions with the Plaid API. 

32 

33 Attributes: 

34 plaid_env (str): Enviroment used by user [ 'sandbox'(default) OR 'development' OR 'production' ]. 

35 client_id (str): Your Plaid API client_id. 

36 secret (str): Your Plaid API secret 

37 access_token (str): The access token for the Plaid account. 

38 ''' 

39 

40 def __init__(self, name=None, **kwargs): 

41 super().__init__(name) 

42 

43 args = kwargs.get('connection_data', {}) 

44 

45 self.plaid_config = plaid.Configuration( 

46 host=PLAID_ENV[args.get('plaid_env', 'sandbox')], 

47 api_key={ 

48 'clientId': args.get('client_id'), 

49 'secret': args.get('secret') 

50 } 

51 ) 

52 

53 self.access_token = args.get('access_token') 

54 

55 self.api = None 

56 self.is_connected = False 

57 

58 balance = BalanceTable(self) 

59 transactions = TransactionTable(self) 

60 self._register_table('balance', balance) 

61 self._register_table('transactions', transactions) 

62 

63 def connect(self): 

64 '''Authenticate with the Plaid API using the API keys and secrets stored in the `plaid_env`, `client_id`, `secret` , and `access_token` attributes.''' # noqa 

65 

66 if self.is_connected is True: 

67 return self.api 

68 

69 api_client = plaid.ApiClient(self.plaid_config) 

70 self.api = plaid_api.PlaidApi(api_client) 

71 self.is_connected = True 

72 return self.api 

73 

74 def check_connection(self) -> StatusResponse: 

75 '''It evaluates if the connection with Plaid API is alive and healthy. 

76 Returns: 

77 HandlerStatusResponse 

78 ''' 

79 

80 response = StatusResponse(False) 

81 

82 try: 

83 api = self.connect() 

84 api.accounts_balance_get(AccountsBalanceGetRequest( 

85 access_token=self.access_token) 

86 ) 

87 response.success = True 

88 

89 except Exception as e: 

90 response.error_message = f'Error connecting to Plaid api: {e}. ' 

91 logger.error(response.error_message) 

92 

93 if response.success is False and self.is_connected is True: 

94 self.is_connected = False 

95 

96 return response 

97 

98 def native_query(self, query_string: str = None): 

99 '''It parses any native statement string and acts upon it (for example, raw syntax commands). 

100 Args: 

101 query (Any): query in native format (str for sql databases, 

102 dict for mongo, api's json etc) 

103 Returns: 

104 HandlerResponse 

105 ''' 

106 

107 method_name, params = FuncParser().from_string(query_string) 

108 df = self.call_plaid_api(method_name, params) 

109 return Response( 

110 RESPONSE_TYPE.TABLE, 

111 data_frame=df 

112 ) 

113 

114 def call_plaid_api(self, method_name: str = None, params: dict = {}): 

115 '''Receive query as AST (abstract syntax tree) and act upon it somehow. 

116 Args: 

117 query (ASTNode): sql query represented as AST. May be any kind of query: SELECT, INSERT, DELETE, etc 

118 Returns: 

119 DataFrame 

120 ''' 

121 

122 result = pd.DataFrame() 

123 if method_name == 'get_balance': 

124 result = self.get_balance(params=params) 

125 result = BalanceTable(self).filter_columns(result=result) 

126 

127 elif method_name == 'get_transactions': 

128 result = self.get_transactions(params=params) 

129 result = TransactionTable(self).filter_columns(result=result) 

130 

131 return result 

132 

133 def get_balance(self, params=None): 

134 '''Filters data from Plaid API's balance endpoint and returns a DataFrame with the required information. 

135 

136 Args: 

137 params (dict, optional): A dictionary of options to be passed to the Plaid API. 

138 

139 Returns: 

140 pandas.DataFrame: A DataFrame containing the filtered data. 

141 ''' 

142 

143 self.connect() 

144 if params.get('last_updated_datetime') is not None: 

145 options = AccountsBalanceGetRequestOptions( 

146 min_last_updated_datetime=datetime.strptime( 

147 params.get('last_updated_datetime') 

148 ) 

149 ) 

150 

151 response = self.api.accounts_balance_get( 

152 AccountsBalanceGetRequest( 

153 access_token=self.access_token, 

154 options=options 

155 ) 

156 ) 

157 else: 

158 response = self.api.accounts_balance_get( 

159 AccountsBalanceGetRequest(access_token=self.access_token) 

160 ) 

161 

162 messages = [] 

163 for obj in response['accounts']: 

164 message_dict = {} 

165 for i in obj.to_dict().keys(): 

166 if i.startswith('account_'): 

167 message_dict[i] = obj[i] 

168 elif i == 'balances': 

169 dict_obj = obj[i].to_dict() 

170 for j in dict_obj.keys(): 

171 message_dict[f'balance_{j}'] = dict_obj[j] 

172 else: 

173 message_dict[f'account_{i}'] = obj[i] 

174 messages.append(message_dict) 

175 df = pd.DataFrame(messages) 

176 

177 return df 

178 

179 def get_transactions(self, params={}): 

180 ''' 

181 Filters data from Plaid API's transaction endpoint and returns a DataFrame with the required information. 

182 Args: 

183 params (dict, optional): A dictionary of options to be passed to the Plaid API. 

184 

185 Returns: 

186 pandas.DataFrame: A DataFrame containing the filtered data. 

187 ''' 

188 

189 self.connect() 

190 if params.get('start_date') and params.get('end_date'): 

191 start_date = datetime.strptime(params.get('start_date'), '%Y-%m-%d').date() 

192 end_date = datetime.strptime(params.get('end_date'), '%Y-%m-%d').date() 

193 else: 

194 raise Exception('start_date and end_date is required in format YYYY-MM-DD ') 

195 

196 request = TransactionsGetRequest( 

197 access_token=self.access_token, 

198 start_date=start_date, 

199 end_date=end_date, 

200 options=TransactionsGetRequestOptions() 

201 ) 

202 

203 response = self.api.transactions_get(request) 

204 transactions = parse_transaction(response['transactions']) 

205 

206 # Manipulate the count and offset parameters to paginate 

207 # transactions and retrieve all available data 

208 while len(transactions) < response['total_transactions']: 

209 request = TransactionsGetRequest( 

210 access_token=self.access_token, 

211 start_date=start_date, 

212 end_date=end_date, 

213 options=TransactionsGetRequestOptions( 

214 offset=len(transactions) 

215 ) 

216 ) 

217 response = self.api.transactions_get(request) 

218 transactions.extend(parse_transaction(response['transactions'])) 

219 

220 # Converting date column from str 

221 df = pd.DataFrame(transactions) 

222 for i in ['date', 'authorized_date']: 

223 df[i] = pd.to_datetime(df[i]).dt.date 

224 

225 return df