Skip to content

Commit

Permalink
Cache OAuth access token per host and user pair
Browse files Browse the repository at this point in the history
  • Loading branch information
Przemek Denkiewicz committed Dec 29, 2023
1 parent 0517c65 commit 292bd0f
Showing 1 changed file with 37 additions and 21 deletions.
58 changes: 37 additions & 21 deletions trino/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import re
import threading
import webbrowser
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
from urllib.parse import urlparse

from requests import PreparedRequest, Request, Response, Session
Expand All @@ -26,6 +26,7 @@

import trino.logging
from trino.client import exceptions
from trino.constants import HEADER_USER

logger = trino.logging.get_logger(__name__)

Expand Down Expand Up @@ -208,32 +209,33 @@ class _OAuth2TokenCache(metaclass=abc.ABCMeta):
"""

@abc.abstractmethod
def get_token_from_cache(self, host: Optional[str]) -> Optional[str]:
def get_token_from_cache(self, key: Optional[str]) -> Optional[str]:
pass

@abc.abstractmethod
def store_token_to_cache(self, host: Optional[str], token: str) -> None:
def store_token_to_cache(self, key: Optional[str], token: str) -> None:
pass


class _OAuth2TokenInMemoryCache(_OAuth2TokenCache):
"""
In-memory token cache implementation. The token is stored per host, so multiple clients can share the same cache.
In-memory token cache implementation. The token is stored per host and user pair,
so multiple clients can share the same cache.
"""

def __init__(self) -> None:
self._cache: Dict[Optional[str], str] = {}

def get_token_from_cache(self, host: Optional[str]) -> Optional[str]:
return self._cache.get(host)
def get_token_from_cache(self, key: Optional[str]) -> Optional[str]:
return self._cache.get(key)

def store_token_to_cache(self, host: Optional[str], token: str) -> None:
self._cache[host] = token
def store_token_to_cache(self, key: Optional[str], token: str) -> None:
self._cache[key] = token


class _OAuth2KeyRingTokenCache(_OAuth2TokenCache):
"""
Keyring Token Cache implementation
Keyring token cache implementation
"""

def __init__(self) -> None:
Expand All @@ -248,18 +250,18 @@ def is_keyring_available(self) -> bool:
return self._keyring is not None \
and not isinstance(self._keyring.get_keyring(), self._keyring.backends.fail.Keyring)

def get_token_from_cache(self, host: Optional[str]) -> Optional[str]:
def get_token_from_cache(self, key: Optional[str]) -> Optional[str]:
try:
return self._keyring.get_password(host, "token")
return self._keyring.get_password(key, "token")
except self._keyring.errors.NoKeyringError as e:
raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been "
"detected, check https://pypi.org/project/keyring/ for more "
"information.") from e

def store_token_to_cache(self, host: Optional[str], token: str) -> None:
def store_token_to_cache(self, key: Optional[str], token: str) -> None:
try:
# keyring is installed, so we can store the token for reuse within multiple threads
self._keyring.set_password(host, "token", token)
self._keyring.set_password(key, "token", token)
except self._keyring.errors.NoKeyringError as e:
raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been "
"detected, check https://pypi.org/project/keyring/ for more "
Expand All @@ -268,7 +270,7 @@ def store_token_to_cache(self, host: Optional[str], token: str) -> None:

class _OAuth2TokenBearer(AuthBase):
"""
Custom implementation of Trino Oauth2 based authorization to get the token
Custom implementation of Trino OAuth2 based authentication to get the token
"""
MAX_OAUTH_ATTEMPTS = 5
_BEARER_PREFIX = re.compile(r"bearer", flags=re.IGNORECASE)
Expand All @@ -283,7 +285,9 @@ def __init__(self, redirect_auth_url_handler: Callable[[str], None]):

def __call__(self, r: PreparedRequest) -> PreparedRequest:
host = self._determine_host(r.url)
token = self._get_token_from_cache(host)
user = self._determine_user(r.headers)
key = self._construct_cache_key(host, user)
token = self._get_token_from_cache(key)

if token is not None:
r.headers['Authorization'] = "Bearer " + token
Expand Down Expand Up @@ -341,15 +345,19 @@ def _attempt_oauth(self, response: Response, **kwargs: Any) -> None:

request = response.request
host = self._determine_host(request.url)
self._store_token_to_cache(host, token)
user = self._determine_user(request.headers)
key = self._construct_cache_key(host, user)
self._store_token_to_cache(key, token)

def _retry_request(self, response: Response, **kwargs: Any) -> Optional[Response]:
request = response.request.copy()
extract_cookies_to_jar(request._cookies, response.request, response.raw) # type: ignore
request.prepare_cookies(request._cookies) # type: ignore

host = self._determine_host(response.request.url)
token = self._get_token_from_cache(host)
user = self._determine_user(request.headers)
key = self._construct_cache_key(host, user)
token = self._get_token_from_cache(key)
if token is not None:
request.headers['Authorization'] = "Bearer " + token
retry_response = response.connection.send(request, **kwargs) # type: ignore
Expand Down Expand Up @@ -382,18 +390,26 @@ def _get_token(self, token_server: str, response: Response, **kwargs: Any) -> st

raise exceptions.TrinoAuthError("Exceeded max attempts while getting the token")

def _get_token_from_cache(self, host: Optional[str]) -> Optional[str]:
def _get_token_from_cache(self, key: Optional[str]) -> Optional[str]:
with self._token_lock:
return self._token_cache.get_token_from_cache(host)
return self._token_cache.get_token_from_cache(key)

def _store_token_to_cache(self, host: Optional[str], token: str) -> None:
def _store_token_to_cache(self, key: Optional[str], token: str) -> None:
with self._token_lock:
self._token_cache.store_token_to_cache(host, token)
self._token_cache.store_token_to_cache(key, token)

@staticmethod
def _determine_host(url: Optional[str]) -> Any:
return urlparse(url).hostname

@staticmethod
def _determine_user(headers: Mapping[Any, Any]) -> Optional[Any]:
return headers.get(HEADER_USER)

@staticmethod
def _construct_cache_key(host: Optional[str], user: Optional[str]) -> str:
return f"{host}@{user}"


class OAuth2Authentication(Authentication):
def __init__(self, redirect_auth_url_handler: CompositeRedirectHandler = CompositeRedirectHandler([
Expand Down

0 comments on commit 292bd0f

Please sign in to comment.