Coverage for mindsdb / integrations / utilities / handlers / auth_utilities / snowflake / snowflake_jwt_gen.py: 31%
63 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1# Based on https://docs.snowflake.com/en/developer-guide/sql-api/authenticating
3import time
4import base64
5import hashlib
6import logging
7from datetime import timedelta, timezone, datetime
9from cryptography.hazmat.primitives.serialization import load_pem_private_key
10from cryptography.hazmat.primitives.serialization import Encoding
11from cryptography.hazmat.primitives.serialization import PublicFormat
12from cryptography.hazmat.backends import default_backend
13import jwt
15logger = logging.getLogger(__name__)
17ISSUER = "iss"
18EXPIRE_TIME = "exp"
19ISSUE_TIME = "iat"
20SUBJECT = "sub"
23class JWTGenerator(object):
24 """
25 Creates and signs a JWT with the specified private key file, username, and account identifier. The JWTGenerator keeps the
26 generated token and only regenerates the token if a specified period of time has passed.
27 """
29 LIFETIME = timedelta(minutes=60) # The tokens will have a 59 minute lifetime
30 ALGORITHM = "RS256" # Tokens will be generated using RSA with SHA256
32 def __init__(self, account: str, user: str, private_key: str, lifetime: timedelta = LIFETIME):
33 """
34 __init__ creates an object that generates JWTs for the specified user, account identifier, and private key.
35 :param account: Your Snowflake account identifier. See https://docs.snowflake.com/en/user-guide/admin-account-identifier.html. Note that if you are using the account locator, exclude any region information from the account locator.
36 :param user: The Snowflake username.
37 :param private_key: The private key file used for signing the JWTs.
38 :param lifetime: The number of minutes (as a timedelta) during which the key will be valid.
39 """
41 logger.info(
42 """Creating JWTGenerator with arguments
43 account : %s, user : %s, lifetime : %s""",
44 account,
45 user,
46 lifetime,
47 )
49 # Construct the fully qualified name of the user in uppercase.
50 self.account = self.prepare_account_name_for_jwt(account)
51 self.user = user.upper()
52 self.qualified_username = self.account + "." + self.user
54 self.lifetime = lifetime
55 self.renew_time = datetime.now(timezone.utc)
56 self.token = None
58 self.private_key = load_pem_private_key(private_key.encode(), None, default_backend())
60 def prepare_account_name_for_jwt(self, raw_account: str) -> str:
61 """
62 Prepare the account identifier for use in the JWT.
63 For the JWT, the account identifier must not include the subdomain or any region or cloud provider information.
64 :param raw_account: The specified account identifier.
65 :return: The account identifier in a form that can be used to generate JWT.
66 """
67 account = raw_account
68 if ".global" not in account:
69 # Handle the general case.
70 idx = account.find(".")
71 if idx > 0:
72 account = account[0:idx]
73 else:
74 # Handle the replication case.
75 idx = account.find("-")
76 if idx > 0:
77 account = account[0:idx]
78 # Use uppercase for the account identifier.
79 return account.upper()
81 def get_token(self) -> str:
82 """
83 Generates a new JWT.
84 :return: the new token
85 """
86 now = datetime.now(timezone.utc) # Fetch the current time
88 # Prepare the fields for the payload.
89 # Generate the public key fingerprint for the issuer in the payload.
90 public_key_fp = self.calculate_public_key_fingerprint(self.private_key)
92 # Create our payload
93 payload = {
94 # Set the issuer to the fully qualified username concatenated with the public key fingerprint.
95 ISSUER: self.qualified_username + "." + public_key_fp,
96 # Set the subject to the fully qualified username.
97 SUBJECT: self.qualified_username,
98 # Set the issue time to now.
99 ISSUE_TIME: now,
100 # Set the expiration time, based on the lifetime specified for this object.
101 EXPIRE_TIME: now + self.lifetime,
102 }
104 # Regenerate the actual token
105 token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM)
106 # If you are using a version of PyJWT prior to 2.0, jwt.encode returns a byte string, rather than a string.
107 # If the token is a byte string, convert it to a string.
108 if isinstance(token, bytes):
109 token = token.decode("utf-8")
110 self.token = token
112 return self.token
114 def calculate_public_key_fingerprint(self, private_key: str) -> str:
115 """
116 Given a private key in PEM format, return the public key fingerprint.
117 :param private_key: private key string
118 :return: public key fingerprint
119 """
120 # Get the raw bytes of public key.
121 public_key_raw = private_key.public_key().public_bytes(Encoding.DER, PublicFormat.SubjectPublicKeyInfo)
123 # Get the sha256 hash of the raw bytes.
124 sha256hash = hashlib.sha256()
125 sha256hash.update(public_key_raw)
127 # Base64-encode the value and prepend the prefix 'SHA256:'.
128 public_key_fp = "SHA256:" + base64.b64encode(sha256hash.digest()).decode("utf-8")
129 logger.info("Public key fingerprint is %s", public_key_fp)
131 return public_key_fp
134def get_validated_jwt(token: str, account: str, user: str, private_key: str) -> str:
135 try:
136 content = jwt.decode(token, algorithms=[JWTGenerator.ALGORITHM], options={"verify_signature": False})
138 expired = content.get("exp", 0)
139 # add 5 seconds before limit
140 if expired - 5 > time.time():
141 # keep the same
142 return token
144 except jwt.DecodeError:
145 # wrong key
146 ...
148 # generate new token
149 if private_key is None:
150 raise ValueError("Private key is missing")
151 return JWTGenerator(account, user, private_key).get_token()