Coverage for mindsdb / integrations / handlers / snowflake_handler / auth_types.py: 71%
44 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from abc import ABC, abstractmethod
2from typing import Dict, Any, Union
3from pathlib import Path
5from cryptography.hazmat.primitives import serialization
6from cryptography.hazmat.backends import default_backend
9class SnowflakeAuthType(ABC):
10 @abstractmethod
11 def get_config(self, **kwargs) -> Dict[str, Any]:
12 pass
15class PasswordAuthType(SnowflakeAuthType):
16 def get_config(self, **kwargs) -> Dict[str, Any]:
17 required_keys = ["account", "user", "database"]
18 if not all(kwargs.get(key) for key in required_keys):
19 raise ValueError("Required parameters (account, user, database) must be provided.")
21 if not kwargs.get("password"):
22 raise ValueError("Password must be provided when auth_type is 'password'.")
23 return {
24 "account": kwargs.get("account"),
25 "user": kwargs.get("user"),
26 "password": kwargs.get("password"),
27 "database": kwargs.get("database"),
28 "schema": kwargs.get("schema"),
29 "role": kwargs.get("role"),
30 "warehouse": kwargs.get("warehouse"),
31 "auth_type": "password",
32 }
35class KeyPairAuthType(SnowflakeAuthType):
36 def get_config(self, **kwargs) -> Dict[str, Any]:
37 if not all(kwargs.get(key) for key in ["account", "user", "database"]): 37 ↛ 38line 37 didn't jump to line 38 because the condition on line 37 was never true
38 raise ValueError("Required parameters (account, user, database) must be provided.")
40 private_key_value = kwargs.get("private_key")
41 private_key_path = kwargs.get("private_key_path")
43 if not private_key_value and not private_key_path: 43 ↛ 44line 43 didn't jump to line 44 because the condition on line 43 was never true
44 raise ValueError("Either private_key or private_key_path must be provided when auth_type is 'key_pair'.")
46 config = {
47 "account": kwargs.get("account"),
48 "user": kwargs.get("user"),
49 "database": kwargs.get("database"),
50 "schema": kwargs.get("schema"),
51 "role": kwargs.get("role"),
52 "warehouse": kwargs.get("warehouse"),
53 "authenticator": "SNOWFLAKE_JWT",
54 "auth_type": "key_pair",
55 }
57 if private_key_value:
58 config["private_key"] = self._load_private_key(private_key_value, kwargs.get("private_key_passphrase"))
59 else:
60 if not Path(private_key_path).exists():
61 raise ValueError(f"Private key file not found: {private_key_path}")
62 config["private_key_file"] = private_key_path
63 if kwargs.get("private_key_passphrase"):
64 config["private_key_file_pwd"] = kwargs.get("private_key_passphrase")
65 return config
67 def _load_private_key(self, private_key: Union[str, bytes], passphrase: str = None):
68 if isinstance(private_key, str):
69 private_key = private_key.replace("\\n", "\n").encode()
70 elif isinstance(private_key, bytes) is False:
71 raise ValueError("private_key must be a string or bytes.")
73 password = passphrase.encode() if passphrase else None
74 try:
75 return serialization.load_pem_private_key(private_key, password=password, backend=default_backend())
76 except Exception as exc:
77 raise ValueError(
78 "Failed to load private_key. Ensure it is a valid PEM-encoded key and the passphrase is correct."
79 ) from exc