Coverage for mindsdb / integrations / handlers / snowflake_handler / auth_types.py: 71%

44 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from abc import ABC, abstractmethod 

2from typing import Dict, Any, Union 

3from pathlib import Path 

4 

5from cryptography.hazmat.primitives import serialization 

6from cryptography.hazmat.backends import default_backend 

7 

8 

9class SnowflakeAuthType(ABC): 

10 @abstractmethod 

11 def get_config(self, **kwargs) -> Dict[str, Any]: 

12 pass 

13 

14 

15class PasswordAuthType(SnowflakeAuthType): 

16 def get_config(self, **kwargs) -> Dict[str, Any]: 

17 required_keys = ["account", "user", "database"] 

18 if not all(kwargs.get(key) for key in required_keys): 

19 raise ValueError("Required parameters (account, user, database) must be provided.") 

20 

21 if not kwargs.get("password"): 

22 raise ValueError("Password must be provided when auth_type is 'password'.") 

23 return { 

24 "account": kwargs.get("account"), 

25 "user": kwargs.get("user"), 

26 "password": kwargs.get("password"), 

27 "database": kwargs.get("database"), 

28 "schema": kwargs.get("schema"), 

29 "role": kwargs.get("role"), 

30 "warehouse": kwargs.get("warehouse"), 

31 "auth_type": "password", 

32 } 

33 

34 

35class KeyPairAuthType(SnowflakeAuthType): 

36 def get_config(self, **kwargs) -> Dict[str, Any]: 

37 if not all(kwargs.get(key) for key in ["account", "user", "database"]): 37 ↛ 38line 37 didn't jump to line 38 because the condition on line 37 was never true

38 raise ValueError("Required parameters (account, user, database) must be provided.") 

39 

40 private_key_value = kwargs.get("private_key") 

41 private_key_path = kwargs.get("private_key_path") 

42 

43 if not private_key_value and not private_key_path: 43 ↛ 44line 43 didn't jump to line 44 because the condition on line 43 was never true

44 raise ValueError("Either private_key or private_key_path must be provided when auth_type is 'key_pair'.") 

45 

46 config = { 

47 "account": kwargs.get("account"), 

48 "user": kwargs.get("user"), 

49 "database": kwargs.get("database"), 

50 "schema": kwargs.get("schema"), 

51 "role": kwargs.get("role"), 

52 "warehouse": kwargs.get("warehouse"), 

53 "authenticator": "SNOWFLAKE_JWT", 

54 "auth_type": "key_pair", 

55 } 

56 

57 if private_key_value: 

58 config["private_key"] = self._load_private_key(private_key_value, kwargs.get("private_key_passphrase")) 

59 else: 

60 if not Path(private_key_path).exists(): 

61 raise ValueError(f"Private key file not found: {private_key_path}") 

62 config["private_key_file"] = private_key_path 

63 if kwargs.get("private_key_passphrase"): 

64 config["private_key_file_pwd"] = kwargs.get("private_key_passphrase") 

65 return config 

66 

67 def _load_private_key(self, private_key: Union[str, bytes], passphrase: str = None): 

68 if isinstance(private_key, str): 

69 private_key = private_key.replace("\\n", "\n").encode() 

70 elif isinstance(private_key, bytes) is False: 

71 raise ValueError("private_key must be a string or bytes.") 

72 

73 password = passphrase.encode() if passphrase else None 

74 try: 

75 return serialization.load_pem_private_key(private_key, password=password, backend=default_backend()) 

76 except Exception as exc: 

77 raise ValueError( 

78 "Failed to load private_key. Ensure it is a valid PEM-encoded key and the passphrase is correct." 

79 ) from exc