Coverage for mindsdb / integrations / handlers / salesforce_handler / salesforce_handler.py: 95%

122 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import Any, Dict, List, Optional, Text 

2 

3import pandas as pd 

4import salesforce_api 

5from salesforce_api.exceptions import AuthenticationError, RestRequestCouldNotBeUnderstoodError 

6 

7from mindsdb.integrations.libs.api_handler import MetaAPIHandler 

8from mindsdb.integrations.libs.response import ( 

9 HandlerResponse as Response, 

10 HandlerStatusResponse as StatusResponse, 

11 RESPONSE_TYPE, 

12) 

13from mindsdb.integrations.handlers.salesforce_handler.salesforce_tables import create_table_class 

14from mindsdb.integrations.handlers.salesforce_handler.constants import get_soql_instructions 

15from mindsdb.utilities import log 

16 

17 

18logger = log.getLogger(__name__) 

19 

20 

21class SalesforceHandler(MetaAPIHandler): 

22 """ 

23 This handler handles the connection and execution of SQL statements on Salesforce. 

24 """ 

25 

26 name = "salesforce" 

27 

28 def __init__(self, name: Text, connection_data: Dict, **kwargs: Any) -> None: 

29 """ 

30 Initializes the handler. 

31 

32 Args: 

33 name (Text): The name of the handler instance. 

34 connection_data (Dict): The connection data required to connect to the Salesforce API. 

35 kwargs: Arbitrary keyword arguments. 

36 """ 

37 super().__init__(name) 

38 self.connection_data = connection_data 

39 self.kwargs = kwargs 

40 

41 self.connection = None 

42 self.is_connected = False 

43 self.cache_thread_safe = True 

44 self.resource_names = [] 

45 

46 def connect(self) -> salesforce_api.client.Client: 

47 """ 

48 Establishes a connection to the Salesforce API. 

49 

50 Raises: 

51 ValueError: If the required connection parameters are not provided. 

52 AuthenticationError: If an authentication error occurs while connecting to the Salesforce API. 

53 

54 Returns: 

55 salesforce_api.client.Client: A connection object to the Salesforce API. 

56 """ 

57 if self.is_connected is True: 

58 return self.connection 

59 

60 # Mandatory connection parameters. 

61 if not all(key in self.connection_data for key in ["username", "password", "client_id", "client_secret"]): 

62 raise ValueError("Required parameters (username, password, client_id, client_secret) must be provided.") 

63 

64 try: 

65 self.connection = salesforce_api.Salesforce( 

66 username=self.connection_data["username"], 

67 password=self.connection_data["password"], 

68 client_id=self.connection_data["client_id"], 

69 client_secret=self.connection_data["client_secret"], 

70 is_sandbox=self.connection_data.get("is_sandbox", False), 

71 ) 

72 self.is_connected = True 

73 

74 resource_tables = self._get_resource_names() 

75 for resource_name in resource_tables: 

76 table_class = create_table_class(resource_name.lower()) 

77 self._register_table(resource_name, table_class(self)) 

78 

79 return self.connection 

80 except AuthenticationError as auth_error: 

81 logger.error(f"Authentication error connecting to Salesforce, {auth_error}!") 

82 raise 

83 except Exception as unknown_error: 

84 logger.error(f"Unknwn error connecting to Salesforce, {unknown_error}!") 

85 raise 

86 

87 def check_connection(self) -> StatusResponse: 

88 """ 

89 Checks the status of the connection to the Salesforce API. 

90 

91 Returns: 

92 StatusResponse: An object containing the success status and an error message if an error occurs. 

93 """ 

94 response = StatusResponse(False) 

95 

96 try: 

97 self.connect() 

98 response.success = True 

99 except (AuthenticationError, ValueError) as known_error: 

100 logger.error(f"Connection check to Salesforce failed, {known_error}!") 

101 response.error_message = str(known_error) 

102 except Exception as unknown_error: 

103 logger.error(f"Connection check to Salesforce failed due to an unknown error, {unknown_error}!") 

104 response.error_message = str(unknown_error) 

105 

106 self.is_connected = response.success 

107 

108 return response 

109 

110 def native_query(self, query: Text) -> Response: 

111 """ 

112 Executes a native SOQL query on Salesforce and returns the result. 

113 

114 Args: 

115 query (Text): The SQL query to be executed. 

116 

117 Returns: 

118 Response: A response object containing the result of the query or an error message. 

119 """ 

120 connection = self.connect() 

121 

122 try: 

123 results = connection.sobjects.query(query) 

124 

125 parsed_results = [] 

126 for result in results: 

127 del result["attributes"] 

128 

129 # Check if the result contains any of the other Salesforce resources. 

130 if any(key in self.resource_names for key in result.keys()): 130 ↛ 146line 130 didn't jump to line 146 because the condition on line 130 was always true

131 # Parse the result to extract the nested resources. 

132 parsed_result = {} 

133 for key, value in result.items(): 

134 if key in self.resource_names: 

135 del value["attributes"] 

136 parsed_result.update( 

137 {f"{key}_{sub_key}": sub_value for sub_key, sub_value in value.items()} 

138 ) 

139 

140 else: 

141 parsed_result[key] = value 

142 

143 parsed_results.append(parsed_result) 

144 

145 else: 

146 parsed_results.append(result) 

147 

148 response = Response(RESPONSE_TYPE.TABLE, pd.DataFrame(parsed_results)) 

149 except RestRequestCouldNotBeUnderstoodError as rest_error: 

150 logger.error(f"Error running query: {query} on Salesforce, {rest_error}!") 

151 response = Response(RESPONSE_TYPE.ERROR, error_code=0, error_message=str(rest_error)) 

152 except Exception as unknown_error: 

153 logger.error(f"Error running query: {query} on Salesforce, {unknown_error}!") 

154 response = Response(RESPONSE_TYPE.ERROR, error_code=0, error_message=str(unknown_error)) 

155 

156 return response 

157 

158 def _get_resource_names(self) -> List[str]: 

159 """ 

160 Retrieves the names of the Salesforce resources with optimized pre-filtering. 

161 Returns: 

162 List[str]: A list of filtered resource names. 

163 """ 

164 if not self.resource_names: 164 ↛ 179line 164 didn't jump to line 179 because the condition on line 164 was always true

165 # Check for user-specified table filtering first 

166 include_tables = self.connection_data.get("include_tables") or self.connection_data.get("tables") 

167 exclude_tables = self.connection_data.get("exclude_tables", []) 

168 

169 if include_tables: 

170 # OPTIMIZATION: Skip expensive global describe() call 

171 # Only validate the specified tables 

172 logger.info(f"Using pre-filtered table list: {include_tables}") 

173 self.resource_names = self._validate_specified_tables(include_tables, exclude_tables) 

174 else: 

175 # Fallback to full discovery with hard-coded filtering 

176 logger.info("No table filter specified, performing full discovery...") 

177 self.resource_names = self._discover_all_tables_with_filtering(exclude_tables) 

178 

179 return self.resource_names 

180 

181 def _validate_specified_tables(self, include_tables: List[str], exclude_tables: List[str]) -> List[str]: 

182 """ 

183 Validate user-specified tables without expensive global describe() call. 

184 

185 Args: 

186 include_tables: List of table names to include 

187 exclude_tables: List of table names to exclude 

188 

189 Returns: 

190 List[str]: Validated and filtered table names 

191 """ 

192 validated_tables = [] 

193 

194 for table_name in include_tables: 

195 # Skip if explicitly excluded 

196 if table_name in exclude_tables: 

197 logger.info(f"Skipping excluded table: {table_name}") 

198 continue 

199 

200 try: 

201 # Quick validation: check if table exists and is queryable 

202 # This is much faster than global describe() 

203 metadata = getattr(self.connection.sobjects, table_name).describe() 

204 if metadata.get("queryable", False): 

205 validated_tables.append(table_name) 

206 logger.debug(f"Validated table: {table_name}") 

207 else: 

208 logger.warning(f"Table {table_name} is not queryable, skipping") 

209 except Exception as e: 

210 logger.warning(f"Table {table_name} not found or accessible: {e}") 

211 

212 logger.info(f"Validated {len(validated_tables)} tables from include_tables") 

213 return validated_tables 

214 

215 def _discover_all_tables_with_filtering(self, exclude_tables: List[str]) -> List[str]: 

216 """ 

217 Fallback method: discover all tables with hard-coded filtering. 

218 

219 Args: 

220 exclude_tables: List of table names to exclude 

221 

222 Returns: 

223 List[str]: Filtered table names 

224 """ 

225 # This is the original expensive approach - only used when no include_tables specified 

226 all_resources = [ 

227 resource["name"] 

228 for resource in self.connection.sobjects.describe()["sobjects"] 

229 if resource.get("queryable", False) 

230 ] 

231 

232 # Apply hard-coded filtering (existing logic) 

233 ignore_suffixes = ("Share", "History", "Feed", "ChangeEvent", "Tag", "Permission", "Setup", "Consent") 

234 ignore_prefixes = ( 

235 "Apex", 

236 "CommPlatform", 

237 "Lightning", 

238 "Flow", 

239 "Transaction", 

240 "AI", 

241 "Aura", 

242 "ContentWorkspace", 

243 "Collaboration", 

244 "Datacloud", 

245 ) 

246 ignore_exact = { 

247 "EntityDefinition", 

248 "FieldDefinition", 

249 "RecordType", 

250 "CaseStatus", 

251 "UserRole", 

252 "UserLicense", 

253 "UserPermissionAccess", 

254 "UserRecordAccess", 

255 "Folder", 

256 "Group", 

257 "Note", 

258 "ProcessDefinition", 

259 "ProcessInstance", 

260 "ContentFolder", 

261 "ContentDocumentSubscription", 

262 "DashboardComponent", 

263 "Report", 

264 "Dashboard", 

265 "Topic", 

266 "TopicAssignment", 

267 "Period", 

268 "Partner", 

269 "PackageLicense", 

270 "ColorDefinition", 

271 "DataUsePurpose", 

272 "DataUseLegalBasis", 

273 } 

274 

275 ignore_substrings = ( 

276 "CleanInfo", 

277 "Template", 

278 "Rule", 

279 "Definition", 

280 "Status", 

281 "Policy", 

282 "Setting", 

283 "Access", 

284 "Config", 

285 "Subscription", 

286 "DataType", 

287 "MilestoneType", 

288 "Entitlement", 

289 "Auth", 

290 ) 

291 

292 # Apply hard-coded filtering 

293 filtered = [] 

294 for r in all_resources: 

295 if ( 

296 not r.endswith(ignore_suffixes) 

297 and not r.startswith(ignore_prefixes) 

298 and not any(sub in r for sub in ignore_substrings) 

299 and r not in ignore_exact 

300 and r not in exclude_tables # Apply user exclusions 

301 ): 

302 filtered.append(r) 

303 

304 return filtered 

305 

306 def meta_get_handler_info(self, **kwargs) -> str: 

307 """ 

308 Retrieves information about the design and implementation of the API handler. 

309 This should include, but not be limited to, the following: 

310 - The type of SQL queries and operations that the handler supports. 

311 - etc. 

312 

313 Args: 

314 kwargs: Additional keyword arguments that may be used in generating the handler information. 

315 

316 Returns: 

317 str: A string containing information about the API handler's design and implementation. 

318 """ 

319 return get_soql_instructions(self.name) 

320 

321 def meta_get_tables(self, table_names: Optional[List[str]] = None) -> Response: 

322 """ 

323 Retrieves metadata for the specified tables (or all tables if no list is provided). 

324 

325 Args: 

326 table_names (List): A list of table names for which to retrieve metadata. 

327 

328 Returns: 

329 Response: A response object containing the table metadata. 

330 """ 

331 connection = self.connect() 

332 

333 # Retrieve the metadata for all Salesforce resources. 

334 main_metadata = connection.sobjects.describe() 

335 if table_names: 335 ↛ 341line 335 didn't jump to line 341 because the condition on line 335 was always true

336 # Filter the metadata for the specified tables. 

337 main_metadata = [ 

338 resource for resource in main_metadata["sobjects"] if resource["name"].lower() in table_names 

339 ] 

340 else: 

341 main_metadata = main_metadata["sobjects"] 

342 

343 return super().meta_get_tables(table_names=table_names, main_metadata=main_metadata)