Skip to content

Commit

Permalink
Change AuthCache logic to request per user+resource keypair. (#684)
Browse files Browse the repository at this point in the history
Killed off all logic where we add a user to the auth cache. Instead, we should function like a true cache and just abstract away the reading with a cache.

Now don't cache all resources, make a new request to Den for every resource + user pairing. Simplified auth calls all around
  • Loading branch information
rohinb2 committed Apr 1, 2024
1 parent d2edbf9 commit 218776f
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 86 deletions.
6 changes: 0 additions & 6 deletions runhouse/servers/cluster_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,6 @@ async def aset_cluster_config_value(self, key: str, value: Any):
##############################################
# Auth cache internal functions
##############################################
async def aadd_user_to_auth_cache(self, token: str, refresh_cache: bool = True):
self._auth_cache.add_user(token, refresh_cache)

async def aresource_access_level(
self, token: str, resource_uri: str
) -> Union[str, None]:
Expand All @@ -98,9 +95,6 @@ async def aresource_access_level(
return ResourceAccess.WRITE
return self._auth_cache.lookup_access_level(token, resource_uri)

async def auser_resources(self, token: str) -> dict:
return self._auth_cache.get_user_resources(token)

async def aget_username(self, token: str) -> str:
return self._auth_cache.get_username(token)

Expand Down
71 changes: 30 additions & 41 deletions runhouse/servers/http/auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import logging
from typing import Optional, Union

Expand All @@ -15,50 +14,51 @@ def __init__(self):
self.CACHE = {}
self.USERNAMES = {}

def get_user_resources(self, token: str) -> dict:
"""Get resources associated with a particular token"""
return self.CACHE.get(token, {})

def get_username(self, token: str) -> Optional[str]:
"""Get username associated with a particular token"""
return self.USERNAMES.get(token)
if token not in self.USERNAMES:
username = username_from_token(token)
if username is not None:
self.USERNAMES[token] = username

def lookup_access_level(self, token: str, resource_uri: str) -> Union[str, None]:
resources: dict = self.get_user_resources(token)
return resources.get(resource_uri)
return self.USERNAMES.get(token)

def add_user(self, token, refresh_cache=True):
"""Refresh the server cache with the latest resources and access levels for a particular token"""
def lookup_access_level(
self, token: str, resource_uri: str, refresh_cache=True
) -> Union[str, None]:
"""Get the access level of a particular resource for a user"""
if token is None:
return

if not refresh_cache and token in self.CACHE:
return
# Also add this user to the username cache
self.get_username(token)

if (token, resource_uri) in self.CACHE and not refresh_cache:
return self.CACHE[(token, resource_uri)]

if resource_uri.startswith("/"):
resource_uri_to_send = resource_uri[1:].replace("/", ":")
else:
resource_uri_to_send = resource_uri.replace("/", ":")

resp = rns_client.session.get(
f"{rns_client.api_server_url}/resource",
f"{rns_client.api_server_url}/resource/{resource_uri_to_send}",
headers={"Authorization": f"Bearer {token}"},
)

if resp.status_code == 404:
logger.error(f"Resource not found: {load_resp_content(resp)}")
return

if resp.status_code != 200:
logger.error(
f"Failed to load resources for user: {load_resp_content(resp)}"
f"Failed to load access level for resource: {load_resp_content(resp)}"
)
return

username = username_from_token(token)
if username is None:
raise ValueError("Failed to find Runhouse user from provided token.")
self.USERNAMES[token] = username

resp_data = json.loads(resp.content)
# Support access_level and access_type for BC
all_resources: dict = {
resource["name"]: resource.get("access_level")
or resource.get("access_type")
for resource in resp_data["data"]
}
# Update server cache with a user's resources and access type
self.CACHE[token] = all_resources
self.CACHE[(token, resource_uri)] = resp.json()["data"]["access_level"]

return self.CACHE[(token, resource_uri)]

def clear_cache(self, token: str = None):
"""Clear the server cache, If a token is specified, clear the cache for that particular user only"""
Expand Down Expand Up @@ -90,17 +90,6 @@ async def averify_cluster_access(
):
return True

# Check if user already has saved resources in cache
cached_resources: dict = await obj_store.auser_resources(token)

# e.g. {"/jlewitt1/bert-preproc": "read"}
cluster_access_level = cached_resources.get(cluster_uri)

if cluster_access_level is None:
# Reload from cache and check again
await obj_store.aadd_user_to_auth_cache(token)

cached_resources: dict = await obj_store.auser_resources(token)
cluster_access_level = cached_resources.get(cluster_uri)
cluster_access_level = await obj_store.aresource_access_level(token, cluster_uri)

return cluster_access_level in [ResourceAccess.WRITE, ResourceAccess.READ]
3 changes: 0 additions & 3 deletions runhouse/servers/http/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,6 @@ async def wrapper(*args, **kwargs):
ctx_token = obj_store.set_ctx(request_id=request_id, token=token)

try:
if func_call and token:
await obj_store.aadd_user_to_auth_cache(token, refresh_cache=False)

if den_auth_enabled and not func_call:
if token is None:
raise HTTPException(
Expand Down
29 changes: 3 additions & 26 deletions runhouse/servers/obj_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,10 +337,7 @@ def set_ctx(**ctx_args):
async def apopulate_ctx_locally(self):
from runhouse.globals import configs

den_auth_enabled = (await self.aget_cluster_config()).get("den_auth")
token = configs.token
if den_auth_enabled and token:
await self.aadd_user_to_auth_cache(token, refresh_cache=False)
return self.set_ctx(request_id=str(uuid.uuid4()), token=token)

@staticmethod
Expand Down Expand Up @@ -378,14 +375,6 @@ def set_cluster_config_value(self, key: str, value: Any):
##############################################
# Auth cache internal functions
##############################################
async def aadd_user_to_auth_cache(self, token: str, refresh_cache: bool = True):
return await self.acall_actor_method(
self.cluster_servlet, "aadd_user_to_auth_cache", token, refresh_cache
)

def add_user_to_auth_cache(self, token: str, refresh_cache: bool = True):
return sync_function(self.aadd_user_to_auth_cache)(token, refresh_cache)

async def aresource_access_level(self, token: str, resource_uri: str):
return await self.acall_actor_method(
self.cluster_servlet,
Expand All @@ -397,14 +386,6 @@ async def aresource_access_level(self, token: str, resource_uri: str):
def resource_access_level(self, token: str, resource_uri: str):
return sync_function(self.aresource_access_level)(token, resource_uri)

async def auser_resources(self, token: str):
return await self.acall_actor_method(
self.cluster_servlet, "auser_resources", token
)

def user_resources(self, token: str):
return sync_function(self.auser_resources)(token)

async def aget_username(self, token: str):
return await self.acall_actor_method(
self.cluster_servlet, "aget_username", token
Expand All @@ -422,6 +403,8 @@ async def ahas_resource_access(self, token: str, resource_uri=None) -> bool:

# The logged-in user always has full access to the cluster and its resources. This is especially
# important if they flip on Den Auth without saving the cluster.

# configs.token is the token stored on the cluster itself
if configs.token:
if configs.token == token:
return True
Expand All @@ -445,10 +428,7 @@ async def ahas_resource_access(self, token: str, resource_uri=None) -> bool:
return False

resource_access_level = await self.aresource_access_level(token, resource_uri)
if resource_access_level not in [ResourceAccess.WRITE, ResourceAccess.READ]:
return False

return True
return resource_access_level in [ResourceAccess.WRITE, ResourceAccess.READ]

async def aclear_auth_cache(self, token: str = None):
return await self.acall_actor_method(
Expand Down Expand Up @@ -1032,9 +1012,6 @@ async def acall_local(
if (await self.aget_cluster_config()).get("den_auth"):
if not isinstance(obj, Resource) or obj.visibility not in [
ResourceVisibility.UNLISTED,
ResourceVisibility.PUBLIC,
"unlisted",
"public",
]:
ctx = req_ctx.get()
if not ctx or not ctx.token:
Expand Down
10 changes: 0 additions & 10 deletions tests/test_servers/test_server_obj_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,21 +355,11 @@ def test_save_resources_to_obj_store_cache(self, obj_store):
token = test_account_dict["token"]

# Add test account resources to the local cache
obj_store.add_user_to_auth_cache(token)
resources = obj_store.user_resources(token)
assert resources

resource_uri = f"/{test_account_dict['username']}/summer"
access_level = obj_store.resource_access_level(token, resource_uri)

assert access_level == "write"

@pytest.mark.level("unit")
def test_no_resources_for_invalid_token(self, obj_store):
token = "abc"
resources = obj_store.user_resources(token)
assert not resources

@pytest.mark.level("unit")
def test_no_resource_access_for_invalid_token(self, obj_store):
with friend_account() as test_account_dict:
Expand Down

0 comments on commit 218776f

Please sign in to comment.