diff --git a/trino/auth.py b/trino/auth.py index d71c373e..8b0f6ec4 100644 --- a/trino/auth.py +++ b/trino/auth.py @@ -17,7 +17,7 @@ import re import threading import webbrowser -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple from urllib.parse import urlparse from requests import PreparedRequest, Request, Response, Session @@ -26,6 +26,7 @@ import trino.logging from trino.client import exceptions +from trino.constants import HEADER_USER logger = trino.logging.get_logger(__name__) @@ -208,32 +209,33 @@ class _OAuth2TokenCache(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def get_token_from_cache(self, host: Optional[str]) -> Optional[str]: + def get_token_from_cache(self, key: Optional[str]) -> Optional[str]: pass @abc.abstractmethod - def store_token_to_cache(self, host: Optional[str], token: str) -> None: + def store_token_to_cache(self, key: Optional[str], token: str) -> None: pass class _OAuth2TokenInMemoryCache(_OAuth2TokenCache): """ - In-memory token cache implementation. The token is stored per host, so multiple clients can share the same cache. + In-memory token cache implementation. The token is stored per host and user pair, + so multiple clients can share the same cache. """ def __init__(self) -> None: self._cache: Dict[Optional[str], str] = {} - def get_token_from_cache(self, host: Optional[str]) -> Optional[str]: - return self._cache.get(host) + def get_token_from_cache(self, key: Optional[str]) -> Optional[str]: + return self._cache.get(key) - def store_token_to_cache(self, host: Optional[str], token: str) -> None: - self._cache[host] = token + def store_token_to_cache(self, key: Optional[str], token: str) -> None: + self._cache[key] = token class _OAuth2KeyRingTokenCache(_OAuth2TokenCache): """ - Keyring Token Cache implementation + Keyring token cache implementation """ def __init__(self) -> None: @@ -248,18 +250,18 @@ def is_keyring_available(self) -> bool: return self._keyring is not None \ and not isinstance(self._keyring.get_keyring(), self._keyring.backends.fail.Keyring) - def get_token_from_cache(self, host: Optional[str]) -> Optional[str]: + def get_token_from_cache(self, key: Optional[str]) -> Optional[str]: try: - return self._keyring.get_password(host, "token") + return self._keyring.get_password(key, "token") except self._keyring.errors.NoKeyringError as e: raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been " "detected, check https://pypi.org/project/keyring/ for more " "information.") from e - def store_token_to_cache(self, host: Optional[str], token: str) -> None: + def store_token_to_cache(self, key: Optional[str], token: str) -> None: try: # keyring is installed, so we can store the token for reuse within multiple threads - self._keyring.set_password(host, "token", token) + self._keyring.set_password(key, "token", token) except self._keyring.errors.NoKeyringError as e: raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been " "detected, check https://pypi.org/project/keyring/ for more " @@ -268,7 +270,7 @@ def store_token_to_cache(self, host: Optional[str], token: str) -> None: class _OAuth2TokenBearer(AuthBase): """ - Custom implementation of Trino Oauth2 based authorization to get the token + Custom implementation of Trino OAuth2 based authentication to get the token """ MAX_OAUTH_ATTEMPTS = 5 _BEARER_PREFIX = re.compile(r"bearer", flags=re.IGNORECASE) @@ -283,7 +285,9 @@ def __init__(self, redirect_auth_url_handler: Callable[[str], None]): def __call__(self, r: PreparedRequest) -> PreparedRequest: host = self._determine_host(r.url) - token = self._get_token_from_cache(host) + user = self._determine_user(r.headers) + key = self._construct_cache_key(host, user) + token = self._get_token_from_cache(key) if token is not None: r.headers['Authorization'] = "Bearer " + token @@ -341,7 +345,9 @@ def _attempt_oauth(self, response: Response, **kwargs: Any) -> None: request = response.request host = self._determine_host(request.url) - self._store_token_to_cache(host, token) + user = self._determine_user(request.headers) + key = self._construct_cache_key(host, user) + self._store_token_to_cache(key, token) def _retry_request(self, response: Response, **kwargs: Any) -> Optional[Response]: request = response.request.copy() @@ -349,7 +355,9 @@ def _retry_request(self, response: Response, **kwargs: Any) -> Optional[Response request.prepare_cookies(request._cookies) # type: ignore host = self._determine_host(response.request.url) - token = self._get_token_from_cache(host) + user = self._determine_user(request.headers) + key = self._construct_cache_key(host, user) + token = self._get_token_from_cache(key) if token is not None: request.headers['Authorization'] = "Bearer " + token retry_response = response.connection.send(request, **kwargs) # type: ignore @@ -382,18 +390,26 @@ def _get_token(self, token_server: str, response: Response, **kwargs: Any) -> st raise exceptions.TrinoAuthError("Exceeded max attempts while getting the token") - def _get_token_from_cache(self, host: Optional[str]) -> Optional[str]: + def _get_token_from_cache(self, key: Optional[str]) -> Optional[str]: with self._token_lock: - return self._token_cache.get_token_from_cache(host) + return self._token_cache.get_token_from_cache(key) - def _store_token_to_cache(self, host: Optional[str], token: str) -> None: + def _store_token_to_cache(self, key: Optional[str], token: str) -> None: with self._token_lock: - self._token_cache.store_token_to_cache(host, token) + self._token_cache.store_token_to_cache(key, token) @staticmethod def _determine_host(url: Optional[str]) -> Any: return urlparse(url).hostname + @staticmethod + def _determine_user(headers: Mapping[Any, Any]) -> Optional[Any]: + return headers.get(HEADER_USER) + + @staticmethod + def _construct_cache_key(host: Optional[str], user: Optional[str]) -> str: + return f"{host}@{user}" + class OAuth2Authentication(Authentication): def __init__(self, redirect_auth_url_handler: CompositeRedirectHandler = CompositeRedirectHandler([