Coverage for mindsdb / api / http / namespaces / auth.py: 29%

78 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import base64 

2import secrets 

3import time 

4import urllib 

5 

6import requests 

7from flask import redirect, request, url_for 

8from flask_restx import Resource 

9 

10from mindsdb.api.http.namespaces.configs.auth import ns_conf 

11from mindsdb.metrics.metrics import api_endpoint_metrics 

12from mindsdb.utilities.config import Config 

13from mindsdb.utilities import log 

14 

15logger = log.getLogger(__name__) 

16 

17 

18def get_access_token() -> str: 

19 """return current access token 

20 

21 Returns: 

22 str: token 

23 """ 

24 return Config().get("auth", {}).get("oauth", {}).get("tokens", {}).get("access_token") 

25 

26 

27def request_user_info(access_token: str = None) -> dict: 

28 """request user info from cloud 

29 

30 Args: 

31 access_token (str, optional): token that used to get user data 

32 

33 Returns: 

34 dict: user data 

35 """ 

36 if access_token is None: 

37 access_token = get_access_token() 

38 if access_token is None: 

39 raise KeyError() 

40 

41 auth_server = Config()["auth"]["oauth"]["server_host"] 

42 

43 response = requests.get( 

44 f"https://{auth_server}/auth/userinfo", 

45 headers={"Authorization": f"Bearer {access_token}"}, 

46 timeout=5, 

47 ) 

48 if response.status_code != 200: 

49 raise Exception(f"Wrong response: {response.status_code}, {response.text}") 

50 

51 return response.json() 

52 

53 

54@ns_conf.route("/callback", methods=["GET"]) 

55@ns_conf.route("/callback/cloud_home", methods=["GET"]) 

56@ns_conf.hide 

57class Auth(Resource): 

58 @ns_conf.doc(params={"code": "authentification code"}) 

59 @api_endpoint_metrics("GET", "/auth/code") 

60 def get(self): 

61 """callback from auth server if authentification is successful""" 

62 config = Config() 

63 code = request.args.get("code") 

64 

65 aws_meta_data = config["aws_meta_data"] 

66 public_hostname = aws_meta_data["public-hostname"] 

67 instance_id = aws_meta_data["instance-id"] 

68 

69 oauth_meta = config["auth"]["oauth"] 

70 client_id = oauth_meta["client_id"] 

71 client_secret = oauth_meta["client_secret"] 

72 auth_server = oauth_meta["server_host"] 

73 client_basic = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode() 

74 

75 redirect_uri = f"https://{public_hostname}{request.path}" 

76 response = requests.post( 

77 f"https://{auth_server}/auth/token", 

78 data={ 

79 "code": code, 

80 "grant_type": "authorization_code", 

81 "redirect_uri": redirect_uri, 

82 }, 

83 headers={"Authorization": f"Basic {client_basic}"}, 

84 ) 

85 tokens = response.json() 

86 if "expires_in" in tokens: 

87 tokens["expires_at"] = round(time.time() + tokens["expires_in"] - 1) 

88 del tokens["expires_in"] 

89 

90 user_data = request_user_info(tokens["access_token"]) 

91 

92 previous_username = config["auth"]["oauth"].get("username") 

93 new_username = user_data["name"] 

94 if previous_username is not None and new_username != previous_username: 

95 return redirect("/forbidden") 

96 

97 config.update( 

98 { 

99 "auth": { 

100 "provider": "cloud", 

101 "oauth": {"username": new_username, "tokens": tokens}, 

102 } 

103 } 

104 ) 

105 

106 try: 

107 resp = requests.put( 

108 f"https://{auth_server}/cloud/instance", 

109 json={ 

110 "instance_id": instance_id, 

111 "public_hostname": public_hostname, 

112 "ami_id": aws_meta_data.get("ami-id"), 

113 }, 

114 headers={"Authorization": f"Bearer {tokens['access_token']}"}, 

115 timeout=5, 

116 ) 

117 if resp.status_code != 200: 

118 logger.warning(f"Wrong response from cloud server: {resp.status_code}") 

119 except Exception as e: 

120 logger.warning(f"Cant't send request to cloud server: {e}", exc_info=True) 

121 

122 if request.path.endswith("/auth/callback/cloud_home"): 

123 return redirect(f"https://{auth_server}") 

124 else: 

125 return redirect(url_for("root_index")) 

126 

127 

128@ns_conf.route("/cloud_login", methods=["GET"]) 

129@ns_conf.hide 

130class CloudLoginRoute(Resource): 

131 @ns_conf.doc( 

132 responses={302: "Redirect to auth server"}, 

133 params={"location": "final redirection should lead to that location"}, 

134 ) 

135 @api_endpoint_metrics("GET", "/auth/cloud_login") 

136 def get(self): 

137 """redirect to cloud login form""" 

138 location = request.args.get("location") 

139 config = Config() 

140 

141 aws_meta_data = config["aws_meta_data"] 

142 public_hostname = aws_meta_data["public-hostname"] 

143 auth_server = config["auth"]["oauth"]["server_host"] 

144 

145 if location == "cloud_home": 

146 redirect_uri = f"https://{public_hostname}/api/auth/callback/cloud_home" 

147 else: 

148 redirect_uri = f"https://{public_hostname}/api/auth/callback" 

149 

150 args = urllib.parse.urlencode( 

151 { 

152 "client_id": config["auth"]["oauth"]["client_id"], 

153 "scope": "openid profile aws_marketplace", 

154 "response_type": "code", 

155 "nonce": secrets.token_urlsafe(), 

156 "redirect_uri": redirect_uri, 

157 } 

158 ) 

159 return redirect(f"https://{auth_server}/auth/authorize?{args}")