Coverage for mindsdb / integrations / utilities / handlers / auth_utilities / snowflake / snowflake_jwt_gen.py: 31%

63 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1# Based on https://docs.snowflake.com/en/developer-guide/sql-api/authenticating 

2 

3import time 

4import base64 

5import hashlib 

6import logging 

7from datetime import timedelta, timezone, datetime 

8 

9from cryptography.hazmat.primitives.serialization import load_pem_private_key 

10from cryptography.hazmat.primitives.serialization import Encoding 

11from cryptography.hazmat.primitives.serialization import PublicFormat 

12from cryptography.hazmat.backends import default_backend 

13import jwt 

14 

15logger = logging.getLogger(__name__) 

16 

17ISSUER = "iss" 

18EXPIRE_TIME = "exp" 

19ISSUE_TIME = "iat" 

20SUBJECT = "sub" 

21 

22 

23class JWTGenerator(object): 

24 """ 

25 Creates and signs a JWT with the specified private key file, username, and account identifier. The JWTGenerator keeps the 

26 generated token and only regenerates the token if a specified period of time has passed. 

27 """ 

28 

29 LIFETIME = timedelta(minutes=60) # The tokens will have a 59 minute lifetime 

30 ALGORITHM = "RS256" # Tokens will be generated using RSA with SHA256 

31 

32 def __init__(self, account: str, user: str, private_key: str, lifetime: timedelta = LIFETIME): 

33 """ 

34 __init__ creates an object that generates JWTs for the specified user, account identifier, and private key. 

35 :param account: Your Snowflake account identifier. See https://docs.snowflake.com/en/user-guide/admin-account-identifier.html. Note that if you are using the account locator, exclude any region information from the account locator. 

36 :param user: The Snowflake username. 

37 :param private_key: The private key file used for signing the JWTs. 

38 :param lifetime: The number of minutes (as a timedelta) during which the key will be valid. 

39 """ 

40 

41 logger.info( 

42 """Creating JWTGenerator with arguments 

43 account : %s, user : %s, lifetime : %s""", 

44 account, 

45 user, 

46 lifetime, 

47 ) 

48 

49 # Construct the fully qualified name of the user in uppercase. 

50 self.account = self.prepare_account_name_for_jwt(account) 

51 self.user = user.upper() 

52 self.qualified_username = self.account + "." + self.user 

53 

54 self.lifetime = lifetime 

55 self.renew_time = datetime.now(timezone.utc) 

56 self.token = None 

57 

58 self.private_key = load_pem_private_key(private_key.encode(), None, default_backend()) 

59 

60 def prepare_account_name_for_jwt(self, raw_account: str) -> str: 

61 """ 

62 Prepare the account identifier for use in the JWT. 

63 For the JWT, the account identifier must not include the subdomain or any region or cloud provider information. 

64 :param raw_account: The specified account identifier. 

65 :return: The account identifier in a form that can be used to generate JWT. 

66 """ 

67 account = raw_account 

68 if ".global" not in account: 

69 # Handle the general case. 

70 idx = account.find(".") 

71 if idx > 0: 

72 account = account[0:idx] 

73 else: 

74 # Handle the replication case. 

75 idx = account.find("-") 

76 if idx > 0: 

77 account = account[0:idx] 

78 # Use uppercase for the account identifier. 

79 return account.upper() 

80 

81 def get_token(self) -> str: 

82 """ 

83 Generates a new JWT. 

84 :return: the new token 

85 """ 

86 now = datetime.now(timezone.utc) # Fetch the current time 

87 

88 # Prepare the fields for the payload. 

89 # Generate the public key fingerprint for the issuer in the payload. 

90 public_key_fp = self.calculate_public_key_fingerprint(self.private_key) 

91 

92 # Create our payload 

93 payload = { 

94 # Set the issuer to the fully qualified username concatenated with the public key fingerprint. 

95 ISSUER: self.qualified_username + "." + public_key_fp, 

96 # Set the subject to the fully qualified username. 

97 SUBJECT: self.qualified_username, 

98 # Set the issue time to now. 

99 ISSUE_TIME: now, 

100 # Set the expiration time, based on the lifetime specified for this object. 

101 EXPIRE_TIME: now + self.lifetime, 

102 } 

103 

104 # Regenerate the actual token 

105 token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM) 

106 # If you are using a version of PyJWT prior to 2.0, jwt.encode returns a byte string, rather than a string. 

107 # If the token is a byte string, convert it to a string. 

108 if isinstance(token, bytes): 

109 token = token.decode("utf-8") 

110 self.token = token 

111 

112 return self.token 

113 

114 def calculate_public_key_fingerprint(self, private_key: str) -> str: 

115 """ 

116 Given a private key in PEM format, return the public key fingerprint. 

117 :param private_key: private key string 

118 :return: public key fingerprint 

119 """ 

120 # Get the raw bytes of public key. 

121 public_key_raw = private_key.public_key().public_bytes(Encoding.DER, PublicFormat.SubjectPublicKeyInfo) 

122 

123 # Get the sha256 hash of the raw bytes. 

124 sha256hash = hashlib.sha256() 

125 sha256hash.update(public_key_raw) 

126 

127 # Base64-encode the value and prepend the prefix 'SHA256:'. 

128 public_key_fp = "SHA256:" + base64.b64encode(sha256hash.digest()).decode("utf-8") 

129 logger.info("Public key fingerprint is %s", public_key_fp) 

130 

131 return public_key_fp 

132 

133 

134def get_validated_jwt(token: str, account: str, user: str, private_key: str) -> str: 

135 try: 

136 content = jwt.decode(token, algorithms=[JWTGenerator.ALGORITHM], options={"verify_signature": False}) 

137 

138 expired = content.get("exp", 0) 

139 # add 5 seconds before limit 

140 if expired - 5 > time.time(): 

141 # keep the same 

142 return token 

143 

144 except jwt.DecodeError: 

145 # wrong key 

146 ... 

147 

148 # generate new token 

149 if private_key is None: 

150 raise ValueError("Private key is missing") 

151 return JWTGenerator(account, user, private_key).get_token()