Coverage for mindsdb / api / common / middleware.py: 22%
74 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import os
2import hmac
3import secrets
4import hashlib
5from http import HTTPStatus
6from typing import Optional
8from starlette.middleware.base import BaseHTTPMiddleware
9from starlette.responses import JSONResponse
10from starlette.requests import Request
12from mindsdb.utilities import log
13from mindsdb.utilities.config import config
15logger = log.getLogger(__name__)
17SECRET_KEY = os.environ.get("AUTH_SECRET_KEY") or secrets.token_urlsafe(32)
18# We store token (fingerprints) in memory, which means everyone is logged out if the process restarts
19TOKENS = []
22def get_pat_fingerprint(token: str) -> str:
23 """Hash the token with HMAC-SHA256 using secret_key as pepper."""
24 return hmac.new(SECRET_KEY.encode(), token.encode(), hashlib.sha256).hexdigest()
27def generate_pat() -> str:
28 logger.debug("Generating new auth token")
29 token = "pat_" + secrets.token_urlsafe(32)
30 TOKENS.append(get_pat_fingerprint(token))
31 return token
34def verify_pat(raw_token: str) -> bool:
35 """Verify if the raw_token matches a stored fingerprint.
36 Returns token_id if valid, None if not.
37 """
38 if not raw_token:
39 return False
40 fp = get_pat_fingerprint(raw_token)
41 for stored_fp in TOKENS:
42 if hmac.compare_digest(fp, stored_fp):
43 return True
44 return False
47def revoke_pat(raw_token: str) -> bool:
48 """Revoke raw_token from active tokens"""
49 if not raw_token:
50 return False
51 fp = get_pat_fingerprint(raw_token)
52 for stored_fp in TOKENS:
53 if hmac.compare_digest(fp, stored_fp):
54 TOKENS.remove(stored_fp)
55 return True
56 return False
59class PATAuthMiddleware(BaseHTTPMiddleware):
60 def _extract_bearer(self, request: Request) -> Optional[str]:
61 h = request.headers.get("Authorization")
62 if not h or not h.startswith("Bearer "):
63 return None
64 return h.split(" ", 1)[1].strip() or None
66 async def dispatch(self, request: Request, call_next):
67 if config.get("auth", {}).get("http_auth_enabled", False) is False:
68 return await call_next(request)
70 token = self._extract_bearer(request)
71 if not token or not verify_pat(token):
72 return JSONResponse({"detail": "Unauthorized"}, status_code=HTTPStatus.UNAUTHORIZED)
74 request.state.user = config["auth"].get("username")
75 return await call_next(request)
78# Used by mysql protocol
79def check_auth(username, password, scramble_func, salt, company_id, config):
80 try:
81 hardcoded_user = config["auth"].get("username")
82 hardcoded_password = config["auth"].get("password")
83 if hardcoded_password is None:
84 hardcoded_password = ""
85 hardcoded_password_hash = scramble_func(hardcoded_password, salt)
86 hardcoded_password = hardcoded_password.encode()
88 if password is None:
89 password = ""
90 if isinstance(password, str):
91 password = password.encode()
93 if username != hardcoded_user:
94 logger.warning(f"Check auth, user={username}: user mismatch")
95 return {"success": False}
97 if password != hardcoded_password and password != hardcoded_password_hash:
98 logger.warning(f"check auth, user={username}: password mismatch")
99 return {"success": False}
101 logger.info(f"Check auth, user={username}: Ok")
102 return {"success": True, "username": username}
103 except Exception:
104 logger.exception(f"Check auth, user={username}: ERROR")