From 56097698c22779027b1a86e88eeb318f6fe594f9 Mon Sep 17 00:00:00 2001 From: jlewitt1 Date: Tue, 3 Sep 2024 19:31:29 +0300 Subject: [PATCH] load cluster token from Den instead of generating client side --- runhouse/main.py | 4 +- runhouse/resources/hardware/cluster.py | 7 +-- runhouse/rns/rns_client.py | 46 ++++++++++++++++--- runhouse/servers/cluster_servlet.py | 15 ++++-- runhouse/servers/http/auth.py | 9 ++-- runhouse/servers/http/http_client.py | 29 ++++-------- runhouse/servers/obj_store.py | 13 +++--- .../test_clusters/test_cluster.py | 6 ++- tests/test_resources/test_resource_sharing.py | 26 +++++------ tests/test_servers/test_http_client.py | 27 ++++++----- tests/test_servers/test_http_server.py | 3 +- 11 files changed, 110 insertions(+), 75 deletions(-) diff --git a/runhouse/main.py b/runhouse/main.py index 6ff581300..3e5738ad4 100644 --- a/runhouse/main.py +++ b/runhouse/main.py @@ -562,9 +562,7 @@ def status( current_cluster = cluster_or_local # cluster_or_local = rh.here try: - cluster_status = current_cluster.status( - resource_address=current_cluster.rns_address, send_to_den=send_to_den - ) + cluster_status = current_cluster.status(send_to_den=send_to_den) except ValueError: console.print("Failed to load status for cluster.") diff --git a/runhouse/resources/hardware/cluster.py b/runhouse/resources/hardware/cluster.py index b83ecc008..7debb8cf8 100644 --- a/runhouse/resources/hardware/cluster.py +++ b/runhouse/resources/hardware/cluster.py @@ -832,11 +832,10 @@ def connect_server_client(self, force_reconnect=False): system=self, ) - def status(self, resource_address: str = None, send_to_den: bool = False): + def status(self, send_to_den: bool = False): """Load the status of the Runhouse daemon running on a cluster. Args: - resource_address (str, optional): send_to_den (bool, optional): Whether to send and update the status in Den. Only applies to clusters that are saved to Den. (Default: ``False``) """ @@ -848,7 +847,6 @@ def status(self, resource_address: str = None, send_to_den: bool = False): else: status, den_resp_status_code = self.call_client_method( "status", - resource_address=resource_address or self.rns_address, send_to_den=send_to_den, ) @@ -1041,7 +1039,7 @@ def restart_server( user_config = yaml.safe_dump( { "token": rns_client.cluster_token( - rns_client.token, rns_client.username + resource_address=rns_client.username ), "username": rns_client.username, "default_folder": rns_client.default_folder, @@ -1174,7 +1172,6 @@ def call( method_to_call, module_name, method_name, - resource_address=self.rns_address, stream_logs=stream_logs, data={"args": args, "kwargs": kwargs}, run_name=run_name, diff --git a/runhouse/rns/rns_client.py b/runhouse/rns/rns_client.py index 721030623..f50652b0f 100644 --- a/runhouse/rns/rns_client.py +++ b/runhouse/rns/rns_client.py @@ -1,4 +1,3 @@ -import hashlib import importlib import json import os @@ -197,7 +196,6 @@ def request_headers( headers: dict = self._configs.request_headers if not headers: - # TODO: allow this? means we failed to load token from configs return None if "Authorization" not in headers: @@ -220,18 +218,52 @@ def request_headers( "Failed to extract token from request auth header. Expected in format: Bearer " ) - hashed_token = self.cluster_token(den_token, resource_address) + hashed_token = self.cluster_token(resource_address) return {"Authorization": f"Bearer {hashed_token}"} - def cluster_token(self, den_token: str, resource_address: str): + def cluster_token( + self, resource_address: str, username: str = None, den_token: str = None + ): + """Load the hashed token as generated in Den. Cache the token value in-memory for a given resource address. + Optionally provide a username and den token instead of using the default values stored in local configs.""" if resource_address and "/" in resource_address: # If provided as a full rns address, extract the top level directory resource_address = self.base_folder(resource_address) - hash_input = (den_token + resource_address).encode("utf-8") - hash_hex = hashlib.sha256(hash_input).hexdigest() - return f"{hash_hex}+{resource_address}+{self._configs.username}" + uri = f"{self.api_server_url}/auth/token/cluster" + token_payload = { + "resource_address": resource_address, + "username": username or self._configs.username, + } + + headers = ( + {"Authorization": f"Bearer {den_token}"} + if den_token + else self._configs.request_headers + ) + resp = self.session.post( + uri, + data=json.dumps(token_payload), + headers=headers, + ) + if resp.status_code != 200: + raise Exception( + f"Received [{resp.status_code}] from Den POST '{uri}': Failed to load cluster token: {load_resp_content(resp)}" + ) + + resp_data = read_resp_data(resp) + return resp_data.get("token") + + def validate_cluster_token(self, cluster_token: str, cluster_uri: str) -> bool: + """Checks whether a particular cluster token is valid for the given cluster URI""" + request_uri = self.resource_uri(cluster_uri) + uri = f"{self.api_server_url}/auth/token/cluster/{request_uri}" + resp = self.session.get( + uri, + headers={"Authorization": f"Bearer {cluster_token}"}, + ) + return resp.status_code == 200 def resource_request_payload(self, payload) -> dict: payload = remove_null_values_from_dict(payload) diff --git a/runhouse/servers/cluster_servlet.py b/runhouse/servers/cluster_servlet.py index 6def0eb42..1ac95dd3a 100644 --- a/runhouse/servers/cluster_servlet.py +++ b/runhouse/servers/cluster_servlet.py @@ -155,11 +155,16 @@ async def aresource_access_level( ) -> Union[str, None]: # If the token in this request matches that of the owner of the cluster, # they have access to everything - if configs.token and ( - configs.token == token - or rns_client.cluster_token(configs.token, resource_uri) == token - ): - return ResourceAccess.WRITE + config_token = configs.token + if config_token: + if config_token == token: + return ResourceAccess.WRITE + + if resource_uri and rns_client.validate_cluster_token( + cluster_token=token, cluster_uri=resource_uri + ): + return ResourceAccess.WRITE + return self._auth_cache.lookup_access_level(token, resource_uri) async def aget_username(self, token: str) -> str: diff --git a/runhouse/servers/http/auth.py b/runhouse/servers/http/auth.py index 501e86be2..7ca1a7f65 100644 --- a/runhouse/servers/http/auth.py +++ b/runhouse/servers/http/auth.py @@ -85,13 +85,14 @@ async def averify_cluster_access( from runhouse.globals import configs, obj_store # The logged-in user always has full access to the cluster. This is especially important if they flip on - # Den Auth without saving the cluster. We may need to generate a subtoken here to check. + # Den Auth without saving the cluster. Note: The token saved in the cluster config is a hashed cluster token, + # which may match the token provided in the request headers. if configs.token: if configs.token == token: return True - if ( - cluster_uri - and rns_client.cluster_token(configs.token, cluster_uri) == token + + if cluster_uri and rns_client.validate_cluster_token( + cluster_token=token, cluster_uri=cluster_uri ): return True diff --git a/runhouse/servers/http/http_client.py b/runhouse/servers/http/http_client.py index 75dab47d3..6f141a15d 100644 --- a/runhouse/servers/http/http_client.py +++ b/runhouse/servers/http/http_client.py @@ -99,6 +99,7 @@ def __init__( ) self.log_formatter = ClusterLogsFormatter(self.system) + self._request_headers = rns_client.request_headers(self.resource_address) def _certs_are_self_signed(self) -> bool: """Checks whether the cert provided is self-signed. If it is, all client requests will include the path @@ -165,7 +166,6 @@ def request( self, endpoint, req_type="post", - resource_address=None, data=None, env=None, stream_logs=True, @@ -175,7 +175,7 @@ def request( timeout=None, headers: Union[Dict, None] = None, ): - headers = rns_client.request_headers(resource_address, headers) + headers = headers or self._request_headers json_dict = { "data": data, "env": env, @@ -202,11 +202,7 @@ def request_json( headers: Union[Dict, None] = None, ): # Support use case where we explicitly do not want to provide headers (e.g. requesting a cert) - headers = ( - rns_client.request_headers(self.resource_address) - if headers != {} - else headers - ) + headers = self._request_headers if headers != {} else headers req_fn = ( session.get if req_type == "get" @@ -276,13 +272,11 @@ def check_server(self): f"but local Runhouse version is ({runhouse.__version__})" ) - def status(self, resource_address: str, send_to_den: bool = False): + def status(self, send_to_den: bool = False): """Load the remote cluster's status.""" - # Note: Resource address must be specified in order to construct the cluster subtoken return self.request( f"status?send_to_den={send_to_den}", req_type="get", - resource_address=resource_address, ) def folder_ls(self, path: Union[str, Path], full_paths: bool, sort: bool): @@ -390,11 +384,11 @@ def call( method_name: str, data: Any = None, serialization: Optional[str] = None, - resource_address=None, run_name: Optional[str] = None, stream_logs: bool = True, remote: bool = False, save=False, + headers=None, ): """wrapper to temporarily support cluster's call signature""" return self.call_module_method( @@ -402,12 +396,12 @@ def call( method_name, data=data, serialization=serialization, - resource_address=resource_address or self.resource_address, run_name=run_name, stream_logs=stream_logs, remote=remote, save=save, system=self.system, + headers=headers, ) def call_module_method( @@ -416,12 +410,12 @@ def call_module_method( method_name: str, data: Any = None, serialization: Optional[str] = None, - resource_address=None, run_name: Optional[str] = None, stream_logs: bool = True, remote: bool = False, save=False, system=None, + headers=None, ): """ Client function to call the rpc for call_module_method @@ -451,7 +445,7 @@ def call_module_method( remote=remote, ).model_dump(), stream=True, - headers=rns_client.request_headers(resource_address), + headers=headers or self._request_headers, auth=self.auth, verify=self.verify, ) @@ -504,7 +498,6 @@ async def acall( method_name: str, data: Any = None, serialization: Optional[str] = None, - resource_address=None, run_name: Optional[str] = None, stream_logs: bool = True, remote: bool = False, @@ -517,7 +510,6 @@ async def acall( method_name, data=data, serialization=serialization, - resource_address=resource_address or self.resource_address, run_name=run_name, stream_logs=stream_logs, remote=remote, @@ -532,7 +524,6 @@ async def acall_module_method( method_name: str, data: Any = None, serialization: Optional[str] = None, - resource_address=None, run_name: Optional[str] = None, stream_logs: bool = True, remote: bool = False, @@ -569,7 +560,7 @@ async def acall_module_method( remote=remote, run_async=run_async, ).model_dump(), - headers=rns_client.request_headers(resource_address), + headers=self._request_headers, ) as res: if res.status_code != 200: raise ValueError( @@ -675,7 +666,7 @@ def set_settings(self, new_settings: Dict[str, Any]): res = retry_with_exponential_backoff(session.post)( self._formatted_url("settings"), json=new_settings, - headers=rns_client.request_headers(self.resource_address), + headers=self._request_headers, auth=self.auth, verify=self.verify, ) diff --git a/runhouse/servers/obj_store.py b/runhouse/servers/obj_store.py index 2016c2236..bc3e5046f 100644 --- a/runhouse/servers/obj_store.py +++ b/runhouse/servers/obj_store.py @@ -522,13 +522,14 @@ async def ahas_resource_access(self, token: str, resource_uri=None) -> bool: # The logged-in user always has full access to the cluster and its resources. This is especially # important if they flip on Den Auth without saving the cluster. - # configs.token is the token stored on the cluster itself - if configs.token: - if configs.token == token: + # configs.token is the token stored on the cluster itself, which is itself a hashed subtoken + config_token = configs.token + if config_token: + if config_token == token: return True - if ( - resource_uri - and rns_client.cluster_token(configs.token, resource_uri) == token + + if resource_uri and rns_client.validate_cluster_token( + cluster_token=token, cluster_uri=resource_uri ): return True diff --git a/tests/test_resources/test_clusters/test_cluster.py b/tests/test_resources/test_clusters/test_cluster.py index cee1ecec9..dd89221ef 100644 --- a/tests/test_resources/test_clusters/test_cluster.py +++ b/tests/test_resources/test_clusters/test_cluster.py @@ -425,13 +425,17 @@ def test_caller_token_propagated(self, cluster): remote_assume_caller_and_get_token.share( users=["info@run.house"], notify_users=False ) + current_username = rh.configs.username + current_den_token = rh.configs.token with friend_account(): unassumed_token, assumed_token = remote_assume_caller_and_get_token() # "Local token" is the token the cluster accesses in rh.configs.token; this is what will be used # in subsequent rns_client calls assert assumed_token == rh.globals.rns_client.cluster_token( - rh.configs.token, cluster.rns_address + cluster.rns_address, + username=current_username, + den_token=current_den_token, ) assert unassumed_token != rh.configs.token diff --git a/tests/test_resources/test_resource_sharing.py b/tests/test_resources/test_resource_sharing.py index 960bd4f10..0a1515db1 100644 --- a/tests/test_resources/test_resource_sharing.py +++ b/tests/test_resources/test_resource_sharing.py @@ -63,7 +63,6 @@ def call_cluster_methods(cluster, valid_token): @pytest.mark.level("local") def test_calling_shared_resource(self, resource): - current_token = rh.configs.token cluster = resource.system # Run commands on cluster with current token @@ -71,7 +70,7 @@ def test_calling_shared_resource(self, resource): assert return_codes[0][0] == 0 # Call function with current token via CURL - cluster_token = rns_client.cluster_token(current_token, cluster.rns_address) + cluster_token = rns_client.cluster_token(cluster.rns_address) res = self.call_func_with_curl( cluster, resource.name, cluster_token, **{"a": 1, "b": 2} ) @@ -119,14 +118,18 @@ def test_use_resource_apis(self, resource): # Use invalid token to confirm no function access rh.configs.token = "abc123" - try: - resource(2, 2) == 4 - except Exception as e: - assert "Unauthorized access to resource summer." in str(e) + with pytest.raises(Exception): + # cluster will throw error since the cluster token is invalid + # note: use the "call" method directly in order to pass new request headers with invalid token + cluster._http_client.call( + key=reloaded_func.name, + method_name="call", + headers=rns_client.request_headers(resource.rns_address), + ) # Reset back to valid token and confirm we can call function again rh.configs.token = current_token - cluster_token = rns_client.cluster_token(current_token, cluster.rns_address) + cluster_token = rns_client.cluster_token(cluster.rns_address) res = self.call_func_with_curl( cluster, resource.name, cluster_token, **{"a": 1, "b": 2} @@ -138,7 +141,6 @@ def test_calling_resource_with_cluster_write_access(self, resource): """Check that a user with write access to a cluster can call a function on that cluster, even without having explicit access to the function.""" current_username = rh.configs.username - current_token = rh.configs.token cluster = resource.system cluster_uri = rns_client.resource_uri(cluster.rns_address) @@ -170,7 +172,7 @@ def test_calling_resource_with_cluster_write_access(self, resource): ) # Confirm user can still call the function with write access to the cluster - cluster_token = rns_client.cluster_token(current_token, cluster.rns_address) + cluster_token = rns_client.cluster_token(cluster.rns_address) res = self.call_func_with_curl( cluster, resource.name, @@ -188,7 +190,6 @@ def test_calling_resource_with_no_cluster_access(self, resource): """Check that a user with no access to the cluster can still call a function on that cluster if they were given explicit access to the function.""" current_username = rh.configs.username - current_token = rh.configs.token cluster = resource.system cluster_uri = rns_client.resource_uri(cluster.rns_address) @@ -205,7 +206,7 @@ def test_calling_resource_with_no_cluster_access(self, resource): ), f"Failed to remove access to the cluster for user: {current_username}: {resp.text}" # Confirm current user can still call the function (which they still have explicit access to) - cluster_token = rns_client.cluster_token(current_token, cluster.rns_address) + cluster_token = rns_client.cluster_token(cluster.rns_address) res = self.call_func_with_curl( cluster, resource.name, cluster_token, **{"a": 1, "b": 2} ) @@ -219,7 +220,6 @@ def test_calling_resource_with_cluster_read_access(self, resource): """Check that a user with read only access to the cluster cannot call a function on that cluster if they do not explicitly have access to the function itself.""" current_username = rh.configs.username - current_token = rh.configs.token cluster = resource.system cluster_uri = rns_client.resource_uri(cluster.rns_address) @@ -259,7 +259,7 @@ def test_calling_resource_with_cluster_read_access(self, resource): cluster.enable_den_auth(flush=True) # Confirm user can no longer call the function with read only access to the cluster and no function access - cluster_token = rns_client.cluster_token(current_token, cluster.rns_address) + cluster_token = rns_client.cluster_token(cluster.rns_address) res = self.call_func_with_curl( cluster, resource.name, cluster_token, **{"a": 1, "b": 2} ) diff --git a/tests/test_servers/test_http_client.py b/tests/test_servers/test_http_client.py index 304a95c6f..a48b4094d 100644 --- a/tests/test_servers/test_http_client.py +++ b/tests/test_servers/test_http_client.py @@ -1,4 +1,5 @@ import json +from unittest.mock import patch import pytest @@ -82,7 +83,11 @@ def test_get_certificate(self, mocker): mock_file_open().write.assert_called_once_with(b"certificate_content") @pytest.mark.level("unit") - def test_use_cert_verification(self, mocker): + @patch("runhouse.globals.rns_client.request_headers") + def test_use_cert_verification(self, mock_request_headers, mocker): + # Mock the request_headers to avoid actual HTTP requests in the test for loading the cluster token + mock_request_headers.return_value = {"Authorization": "Bearer mock_token"} + # Mock a certificate where the issuer is different from the subject mock_cert = mocker.MagicMock() mock_cert.issuer = "issuer" @@ -136,6 +141,9 @@ def test_use_cert_verification(self, mocker): @pytest.mark.level("unit") def test_call_module_method(self, mocker): + expected_headers = rns_client.request_headers( + resource_address=self.local_cluster.rns_address + ) response_sequence = [ json.dumps({"output_type": "stdout", "data": "Log message"}), json.dumps( @@ -161,7 +169,6 @@ def test_call_module_method(self, mocker): result = self.client.call( module_name, method_name, - resource_address=self.local_cluster.rns_address, run_name="test_run_name", ) @@ -177,9 +184,7 @@ def test_call_module_method(self, mocker): "save": False, "remote": False, } - expected_headers = rns_client.request_headers( - resource_address=self.local_cluster.rns_address - ) + expected_verify = self.client.verify mock_post.assert_called_once_with( @@ -193,6 +198,8 @@ def test_call_module_method(self, mocker): @pytest.mark.level("unit") def test_call_module_method_with_args_kwargs(self, mocker): + expected_headers = rns_client.request_headers(self.local_cluster.rns_address) + mock_response = mocker.MagicMock() mock_response.status_code = 200 # Set up iter_lines to return an iterator @@ -212,7 +219,6 @@ def test_call_module_method_with_args_kwargs(self, mocker): module_name, method_name, data=data, - resource_address=self.local_cluster.rns_address, run_name="test_run_name", ) @@ -226,7 +232,6 @@ def test_call_module_method_with_args_kwargs(self, mocker): "remote": False, } expected_url = f"http://localhost:32300/{module_name}/{method_name}" - expected_headers = rns_client.request_headers(self.local_cluster.rns_address) expected_verify = self.client.verify mock_post.assert_called_with( @@ -246,12 +251,12 @@ def test_call_module_method_error_handling(self, mocker, local_cluster): mocker.patch("requests.Session.post", return_value=mock_response) with pytest.raises(ValueError): - self.client.call( - "module", "method", resource_address=local_cluster.rns_address - ) + self.client.call("module", "method") @pytest.mark.level("unit") def test_call_module_method_config(self, mocker, local_cluster): + request_headers = rns_client.request_headers(local_cluster.rns_address) + test_data = self.local_cluster.config() mock_response = mocker.Mock() mock_response.status_code = 200 @@ -265,7 +270,7 @@ def test_call_module_method_config(self, mocker, local_cluster): cluster = self.client.call( EMPTY_DEFAULT_ENV_NAME, "install", - resource_address=local_cluster.rns_address, + headers=request_headers, ) assert cluster.config() == test_data diff --git a/tests/test_servers/test_http_server.py b/tests/test_servers/test_http_server.py index 94298442c..16be8636a 100644 --- a/tests/test_servers/test_http_server.py +++ b/tests/test_servers/test_http_server.py @@ -611,7 +611,8 @@ def test_no_access_to_cluster(self, http_client, cluster): import requests - with friend_account(): # Test accounts with Den auth are created under test_account + with friend_account(): + # Test accounts with Den auth are created under test_account res = requests.get( f"{rns_client.api_server_url}/resource", headers=rns_client.request_headers(cluster.rns_address),