diff --git a/runhouse/servers/cluster_servlet.py b/runhouse/servers/cluster_servlet.py index 16ff3b120..a2f33af16 100644 --- a/runhouse/servers/cluster_servlet.py +++ b/runhouse/servers/cluster_servlet.py @@ -83,9 +83,6 @@ async def aset_cluster_config_value(self, key: str, value: Any): ############################################## # Auth cache internal functions ############################################## - async def aadd_user_to_auth_cache(self, token: str, refresh_cache: bool = True): - self._auth_cache.add_user(token, refresh_cache) - async def aresource_access_level( self, token: str, resource_uri: str ) -> Union[str, None]: @@ -98,9 +95,6 @@ async def aresource_access_level( return ResourceAccess.WRITE return self._auth_cache.lookup_access_level(token, resource_uri) - async def auser_resources(self, token: str) -> dict: - return self._auth_cache.get_user_resources(token) - async def aget_username(self, token: str) -> str: return self._auth_cache.get_username(token) diff --git a/runhouse/servers/http/auth.py b/runhouse/servers/http/auth.py index 0b2ef3677..3fa5afe84 100644 --- a/runhouse/servers/http/auth.py +++ b/runhouse/servers/http/auth.py @@ -1,4 +1,3 @@ -import json import logging from typing import Optional, Union @@ -15,50 +14,51 @@ def __init__(self): self.CACHE = {} self.USERNAMES = {} - def get_user_resources(self, token: str) -> dict: - """Get resources associated with a particular token""" - return self.CACHE.get(token, {}) - def get_username(self, token: str) -> Optional[str]: """Get username associated with a particular token""" - return self.USERNAMES.get(token) + if token not in self.USERNAMES: + username = username_from_token(token) + if username is not None: + self.USERNAMES[token] = username - def lookup_access_level(self, token: str, resource_uri: str) -> Union[str, None]: - resources: dict = self.get_user_resources(token) - return resources.get(resource_uri) + return self.USERNAMES.get(token) - def add_user(self, token, refresh_cache=True): - """Refresh the server cache with the latest resources and access levels for a particular token""" + def lookup_access_level( + self, token: str, resource_uri: str, refresh_cache=True + ) -> Union[str, None]: + """Get the access level of a particular resource for a user""" if token is None: return - if not refresh_cache and token in self.CACHE: - return + # Also add this user to the username cache + self.get_username(token) + + if (token, resource_uri) in self.CACHE and not refresh_cache: + return self.CACHE[(token, resource_uri)] + + if resource_uri.startswith("/"): + resource_uri_to_send = resource_uri[1:].replace("/", ":") + else: + resource_uri_to_send = resource_uri.replace("/", ":") resp = rns_client.session.get( - f"{rns_client.api_server_url}/resource", + f"{rns_client.api_server_url}/resource/{resource_uri_to_send}", headers={"Authorization": f"Bearer {token}"}, ) + + if resp.status_code == 404: + logger.error(f"Resource not found: {load_resp_content(resp)}") + return + if resp.status_code != 200: logger.error( - f"Failed to load resources for user: {load_resp_content(resp)}" + f"Failed to load access level for resource: {load_resp_content(resp)}" ) return - username = username_from_token(token) - if username is None: - raise ValueError("Failed to find Runhouse user from provided token.") - self.USERNAMES[token] = username - - resp_data = json.loads(resp.content) - # Support access_level and access_type for BC - all_resources: dict = { - resource["name"]: resource.get("access_level") - or resource.get("access_type") - for resource in resp_data["data"] - } - # Update server cache with a user's resources and access type - self.CACHE[token] = all_resources + self.CACHE[(token, resource_uri)] = resp.json()["data"]["access_level"] + + return self.CACHE[(token, resource_uri)] def clear_cache(self, token: str = None): """Clear the server cache, If a token is specified, clear the cache for that particular user only""" @@ -90,17 +90,6 @@ async def averify_cluster_access( ): return True - # Check if user already has saved resources in cache - cached_resources: dict = await obj_store.auser_resources(token) - - # e.g. {"/jlewitt1/bert-preproc": "read"} - cluster_access_level = cached_resources.get(cluster_uri) - - if cluster_access_level is None: - # Reload from cache and check again - await obj_store.aadd_user_to_auth_cache(token) - - cached_resources: dict = await obj_store.auser_resources(token) - cluster_access_level = cached_resources.get(cluster_uri) + cluster_access_level = await obj_store.aresource_access_level(token, cluster_uri) return cluster_access_level in [ResourceAccess.WRITE, ResourceAccess.READ] diff --git a/runhouse/servers/http/http_server.py b/runhouse/servers/http/http_server.py index a051df2f5..09d0eede2 100644 --- a/runhouse/servers/http/http_server.py +++ b/runhouse/servers/http/http_server.py @@ -73,9 +73,6 @@ async def wrapper(*args, **kwargs): ctx_token = obj_store.set_ctx(request_id=request_id, token=token) try: - if func_call and token: - await obj_store.aadd_user_to_auth_cache(token, refresh_cache=False) - if den_auth_enabled and not func_call: if token is None: raise HTTPException( diff --git a/runhouse/servers/obj_store.py b/runhouse/servers/obj_store.py index 8a7de31e9..5ab5efdb7 100644 --- a/runhouse/servers/obj_store.py +++ b/runhouse/servers/obj_store.py @@ -337,10 +337,7 @@ def set_ctx(**ctx_args): async def apopulate_ctx_locally(self): from runhouse.globals import configs - den_auth_enabled = (await self.aget_cluster_config()).get("den_auth") token = configs.token - if den_auth_enabled and token: - await self.aadd_user_to_auth_cache(token, refresh_cache=False) return self.set_ctx(request_id=str(uuid.uuid4()), token=token) @staticmethod @@ -378,14 +375,6 @@ def set_cluster_config_value(self, key: str, value: Any): ############################################## # Auth cache internal functions ############################################## - async def aadd_user_to_auth_cache(self, token: str, refresh_cache: bool = True): - return await self.acall_actor_method( - self.cluster_servlet, "aadd_user_to_auth_cache", token, refresh_cache - ) - - def add_user_to_auth_cache(self, token: str, refresh_cache: bool = True): - return sync_function(self.aadd_user_to_auth_cache)(token, refresh_cache) - async def aresource_access_level(self, token: str, resource_uri: str): return await self.acall_actor_method( self.cluster_servlet, @@ -397,14 +386,6 @@ async def aresource_access_level(self, token: str, resource_uri: str): def resource_access_level(self, token: str, resource_uri: str): return sync_function(self.aresource_access_level)(token, resource_uri) - async def auser_resources(self, token: str): - return await self.acall_actor_method( - self.cluster_servlet, "auser_resources", token - ) - - def user_resources(self, token: str): - return sync_function(self.auser_resources)(token) - async def aget_username(self, token: str): return await self.acall_actor_method( self.cluster_servlet, "aget_username", token @@ -422,6 +403,8 @@ async def ahas_resource_access(self, token: str, resource_uri=None) -> bool: # The logged-in user always has full access to the cluster and its resources. This is especially # important if they flip on Den Auth without saving the cluster. + + # configs.token is the token stored on the cluster itself if configs.token: if configs.token == token: return True @@ -445,10 +428,7 @@ async def ahas_resource_access(self, token: str, resource_uri=None) -> bool: return False resource_access_level = await self.aresource_access_level(token, resource_uri) - if resource_access_level not in [ResourceAccess.WRITE, ResourceAccess.READ]: - return False - - return True + return resource_access_level in [ResourceAccess.WRITE, ResourceAccess.READ] async def aclear_auth_cache(self, token: str = None): return await self.acall_actor_method( @@ -1032,9 +1012,6 @@ async def acall_local( if (await self.aget_cluster_config()).get("den_auth"): if not isinstance(obj, Resource) or obj.visibility not in [ ResourceVisibility.UNLISTED, - ResourceVisibility.PUBLIC, - "unlisted", - "public", ]: ctx = req_ctx.get() if not ctx or not ctx.token: diff --git a/tests/test_servers/test_server_obj_store.py b/tests/test_servers/test_server_obj_store.py index 33840b9d0..ea4fab681 100644 --- a/tests/test_servers/test_server_obj_store.py +++ b/tests/test_servers/test_server_obj_store.py @@ -355,21 +355,11 @@ def test_save_resources_to_obj_store_cache(self, obj_store): token = test_account_dict["token"] # Add test account resources to the local cache - obj_store.add_user_to_auth_cache(token) - resources = obj_store.user_resources(token) - assert resources - resource_uri = f"/{test_account_dict['username']}/summer" access_level = obj_store.resource_access_level(token, resource_uri) assert access_level == "write" - @pytest.mark.level("unit") - def test_no_resources_for_invalid_token(self, obj_store): - token = "abc" - resources = obj_store.user_resources(token) - assert not resources - @pytest.mark.level("unit") def test_no_resource_access_for_invalid_token(self, obj_store): with friend_account() as test_account_dict: