Coverage for mindsdb / api / common / middleware.py: 22%

74 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import os 

2import hmac 

3import secrets 

4import hashlib 

5from http import HTTPStatus 

6from typing import Optional 

7 

8from starlette.middleware.base import BaseHTTPMiddleware 

9from starlette.responses import JSONResponse 

10from starlette.requests import Request 

11 

12from mindsdb.utilities import log 

13from mindsdb.utilities.config import config 

14 

15logger = log.getLogger(__name__) 

16 

17SECRET_KEY = os.environ.get("AUTH_SECRET_KEY") or secrets.token_urlsafe(32) 

18# We store token (fingerprints) in memory, which means everyone is logged out if the process restarts 

19TOKENS = [] 

20 

21 

22def get_pat_fingerprint(token: str) -> str: 

23 """Hash the token with HMAC-SHA256 using secret_key as pepper.""" 

24 return hmac.new(SECRET_KEY.encode(), token.encode(), hashlib.sha256).hexdigest() 

25 

26 

27def generate_pat() -> str: 

28 logger.debug("Generating new auth token") 

29 token = "pat_" + secrets.token_urlsafe(32) 

30 TOKENS.append(get_pat_fingerprint(token)) 

31 return token 

32 

33 

34def verify_pat(raw_token: str) -> bool: 

35 """Verify if the raw_token matches a stored fingerprint. 

36 Returns token_id if valid, None if not. 

37 """ 

38 if not raw_token: 

39 return False 

40 fp = get_pat_fingerprint(raw_token) 

41 for stored_fp in TOKENS: 

42 if hmac.compare_digest(fp, stored_fp): 

43 return True 

44 return False 

45 

46 

47def revoke_pat(raw_token: str) -> bool: 

48 """Revoke raw_token from active tokens""" 

49 if not raw_token: 

50 return False 

51 fp = get_pat_fingerprint(raw_token) 

52 for stored_fp in TOKENS: 

53 if hmac.compare_digest(fp, stored_fp): 

54 TOKENS.remove(stored_fp) 

55 return True 

56 return False 

57 

58 

59class PATAuthMiddleware(BaseHTTPMiddleware): 

60 def _extract_bearer(self, request: Request) -> Optional[str]: 

61 h = request.headers.get("Authorization") 

62 if not h or not h.startswith("Bearer "): 

63 return None 

64 return h.split(" ", 1)[1].strip() or None 

65 

66 async def dispatch(self, request: Request, call_next): 

67 if config.get("auth", {}).get("http_auth_enabled", False) is False: 

68 return await call_next(request) 

69 

70 token = self._extract_bearer(request) 

71 if not token or not verify_pat(token): 

72 return JSONResponse({"detail": "Unauthorized"}, status_code=HTTPStatus.UNAUTHORIZED) 

73 

74 request.state.user = config["auth"].get("username") 

75 return await call_next(request) 

76 

77 

78# Used by mysql protocol 

79def check_auth(username, password, scramble_func, salt, company_id, config): 

80 try: 

81 hardcoded_user = config["auth"].get("username") 

82 hardcoded_password = config["auth"].get("password") 

83 if hardcoded_password is None: 

84 hardcoded_password = "" 

85 hardcoded_password_hash = scramble_func(hardcoded_password, salt) 

86 hardcoded_password = hardcoded_password.encode() 

87 

88 if password is None: 

89 password = "" 

90 if isinstance(password, str): 

91 password = password.encode() 

92 

93 if username != hardcoded_user: 

94 logger.warning(f"Check auth, user={username}: user mismatch") 

95 return {"success": False} 

96 

97 if password != hardcoded_password and password != hardcoded_password_hash: 

98 logger.warning(f"check auth, user={username}: password mismatch") 

99 return {"success": False} 

100 

101 logger.info(f"Check auth, user={username}: Ok") 

102 return {"success": True, "username": username} 

103 except Exception: 

104 logger.exception(f"Check auth, user={username}: ERROR")