Skip to content

Commit

Permalink
load cluster token from Den instead of generating client side (#1228)
Browse files Browse the repository at this point in the history
  • Loading branch information
jlewitt1 committed Sep 12, 2024
1 parent 2893b29 commit 192b5d0
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 75 deletions.
4 changes: 1 addition & 3 deletions runhouse/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,9 +562,7 @@ def status(
current_cluster = cluster_or_local # cluster_or_local = rh.here

try:
cluster_status = current_cluster.status(
resource_address=current_cluster.rns_address, send_to_den=send_to_den
)
cluster_status = current_cluster.status(send_to_den=send_to_den)

except ValueError:
console.print("Failed to load status for cluster.")
Expand Down
7 changes: 2 additions & 5 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,11 +832,10 @@ def connect_server_client(self, force_reconnect=False):
system=self,
)

def status(self, resource_address: str = None, send_to_den: bool = False):
def status(self, send_to_den: bool = False):
"""Load the status of the Runhouse daemon running on a cluster.
Args:
resource_address (str, optional):
send_to_den (bool, optional): Whether to send and update the status in Den. Only applies to
clusters that are saved to Den. (Default: ``False``)
"""
Expand All @@ -848,7 +847,6 @@ def status(self, resource_address: str = None, send_to_den: bool = False):
else:
status, den_resp_status_code = self.call_client_method(
"status",
resource_address=resource_address or self.rns_address,
send_to_den=send_to_den,
)

Expand Down Expand Up @@ -1041,7 +1039,7 @@ def restart_server(
user_config = yaml.safe_dump(
{
"token": rns_client.cluster_token(
rns_client.token, rns_client.username
resource_address=rns_client.username
),
"username": rns_client.username,
"default_folder": rns_client.default_folder,
Expand Down Expand Up @@ -1174,7 +1172,6 @@ def call(
method_to_call,
module_name,
method_name,
resource_address=self.rns_address,
stream_logs=stream_logs,
data={"args": args, "kwargs": kwargs},
run_name=run_name,
Expand Down
46 changes: 39 additions & 7 deletions runhouse/rns/rns_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import hashlib
import importlib
import json
import os
Expand Down Expand Up @@ -197,7 +196,6 @@ def request_headers(
headers: dict = self._configs.request_headers

if not headers:
# TODO: allow this? means we failed to load token from configs
return None

if "Authorization" not in headers:
Expand All @@ -220,18 +218,52 @@ def request_headers(
"Failed to extract token from request auth header. Expected in format: Bearer <token>"
)

hashed_token = self.cluster_token(den_token, resource_address)
hashed_token = self.cluster_token(resource_address)

return {"Authorization": f"Bearer {hashed_token}"}

def cluster_token(self, den_token: str, resource_address: str):
def cluster_token(
self, resource_address: str, username: str = None, den_token: str = None
):
"""Load the hashed token as generated in Den. Cache the token value in-memory for a given resource address.
Optionally provide a username and den token instead of using the default values stored in local configs."""
if resource_address and "/" in resource_address:
# If provided as a full rns address, extract the top level directory
resource_address = self.base_folder(resource_address)

hash_input = (den_token + resource_address).encode("utf-8")
hash_hex = hashlib.sha256(hash_input).hexdigest()
return f"{hash_hex}+{resource_address}+{self._configs.username}"
uri = f"{self.api_server_url}/auth/token/cluster"
token_payload = {
"resource_address": resource_address,
"username": username or self._configs.username,
}

headers = (
{"Authorization": f"Bearer {den_token}"}
if den_token
else self._configs.request_headers
)
resp = self.session.post(
uri,
data=json.dumps(token_payload),
headers=headers,
)
if resp.status_code != 200:
raise Exception(
f"Received [{resp.status_code}] from Den POST '{uri}': Failed to load cluster token: {load_resp_content(resp)}"
)

resp_data = read_resp_data(resp)
return resp_data.get("token")

def validate_cluster_token(self, cluster_token: str, cluster_uri: str) -> bool:
"""Checks whether a particular cluster token is valid for the given cluster URI"""
request_uri = self.resource_uri(cluster_uri)
uri = f"{self.api_server_url}/auth/token/cluster/{request_uri}"
resp = self.session.get(
uri,
headers={"Authorization": f"Bearer {cluster_token}"},
)
return resp.status_code == 200

def resource_request_payload(self, payload) -> dict:
payload = remove_null_values_from_dict(payload)
Expand Down
15 changes: 10 additions & 5 deletions runhouse/servers/cluster_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,16 @@ async def aresource_access_level(
) -> Union[str, None]:
# If the token in this request matches that of the owner of the cluster,
# they have access to everything
if configs.token and (
configs.token == token
or rns_client.cluster_token(configs.token, resource_uri) == token
):
return ResourceAccess.WRITE
config_token = configs.token
if config_token:
if config_token == token:
return ResourceAccess.WRITE

if resource_uri and rns_client.validate_cluster_token(
cluster_token=token, cluster_uri=resource_uri
):
return ResourceAccess.WRITE

return self._auth_cache.lookup_access_level(token, resource_uri)

async def aget_username(self, token: str) -> str:
Expand Down
9 changes: 5 additions & 4 deletions runhouse/servers/http/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,14 @@ async def averify_cluster_access(
from runhouse.globals import configs, obj_store

# The logged-in user always has full access to the cluster. This is especially important if they flip on
# Den Auth without saving the cluster. We may need to generate a subtoken here to check.
# Den Auth without saving the cluster. Note: The token saved in the cluster config is a hashed cluster token,
# which may match the token provided in the request headers.
if configs.token:
if configs.token == token:
return True
if (
cluster_uri
and rns_client.cluster_token(configs.token, cluster_uri) == token

if cluster_uri and rns_client.validate_cluster_token(
cluster_token=token, cluster_uri=cluster_uri
):
return True

Expand Down
29 changes: 10 additions & 19 deletions runhouse/servers/http/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
)

self.log_formatter = ClusterLogsFormatter(self.system)
self._request_headers = rns_client.request_headers(self.resource_address)

def _certs_are_self_signed(self) -> bool:
"""Checks whether the cert provided is self-signed. If it is, all client requests will include the path
Expand Down Expand Up @@ -165,7 +166,6 @@ def request(
self,
endpoint,
req_type="post",
resource_address=None,
data=None,
env=None,
stream_logs=True,
Expand All @@ -175,7 +175,7 @@ def request(
timeout=None,
headers: Union[Dict, None] = None,
):
headers = rns_client.request_headers(resource_address, headers)
headers = headers or self._request_headers
json_dict = {
"data": data,
"env": env,
Expand All @@ -202,11 +202,7 @@ def request_json(
headers: Union[Dict, None] = None,
):
# Support use case where we explicitly do not want to provide headers (e.g. requesting a cert)
headers = (
rns_client.request_headers(self.resource_address)
if headers != {}
else headers
)
headers = self._request_headers if headers != {} else headers
req_fn = (
session.get
if req_type == "get"
Expand Down Expand Up @@ -276,13 +272,11 @@ def check_server(self):
f"but local Runhouse version is ({runhouse.__version__})"
)

def status(self, resource_address: str, send_to_den: bool = False):
def status(self, send_to_den: bool = False):
"""Load the remote cluster's status."""
# Note: Resource address must be specified in order to construct the cluster subtoken
return self.request(
f"status?send_to_den={send_to_den}",
req_type="get",
resource_address=resource_address,
)

def folder_ls(self, path: Union[str, Path], full_paths: bool, sort: bool):
Expand Down Expand Up @@ -390,24 +384,24 @@ def call(
method_name: str,
data: Any = None,
serialization: Optional[str] = None,
resource_address=None,
run_name: Optional[str] = None,
stream_logs: bool = True,
remote: bool = False,
save=False,
headers=None,
):
"""wrapper to temporarily support cluster's call signature"""
return self.call_module_method(
key,
method_name,
data=data,
serialization=serialization,
resource_address=resource_address or self.resource_address,
run_name=run_name,
stream_logs=stream_logs,
remote=remote,
save=save,
system=self.system,
headers=headers,
)

def call_module_method(
Expand All @@ -416,12 +410,12 @@ def call_module_method(
method_name: str,
data: Any = None,
serialization: Optional[str] = None,
resource_address=None,
run_name: Optional[str] = None,
stream_logs: bool = True,
remote: bool = False,
save=False,
system=None,
headers=None,
):
"""
Client function to call the rpc for call_module_method
Expand Down Expand Up @@ -451,7 +445,7 @@ def call_module_method(
remote=remote,
).model_dump(),
stream=True,
headers=rns_client.request_headers(resource_address),
headers=headers or self._request_headers,
auth=self.auth,
verify=self.verify,
)
Expand Down Expand Up @@ -504,7 +498,6 @@ async def acall(
method_name: str,
data: Any = None,
serialization: Optional[str] = None,
resource_address=None,
run_name: Optional[str] = None,
stream_logs: bool = True,
remote: bool = False,
Expand All @@ -517,7 +510,6 @@ async def acall(
method_name,
data=data,
serialization=serialization,
resource_address=resource_address or self.resource_address,
run_name=run_name,
stream_logs=stream_logs,
remote=remote,
Expand All @@ -532,7 +524,6 @@ async def acall_module_method(
method_name: str,
data: Any = None,
serialization: Optional[str] = None,
resource_address=None,
run_name: Optional[str] = None,
stream_logs: bool = True,
remote: bool = False,
Expand Down Expand Up @@ -569,7 +560,7 @@ async def acall_module_method(
remote=remote,
run_async=run_async,
).model_dump(),
headers=rns_client.request_headers(resource_address),
headers=self._request_headers,
) as res:
if res.status_code != 200:
raise ValueError(
Expand Down Expand Up @@ -675,7 +666,7 @@ def set_settings(self, new_settings: Dict[str, Any]):
res = retry_with_exponential_backoff(session.post)(
self._formatted_url("settings"),
json=new_settings,
headers=rns_client.request_headers(self.resource_address),
headers=self._request_headers,
auth=self.auth,
verify=self.verify,
)
Expand Down
13 changes: 7 additions & 6 deletions runhouse/servers/obj_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,13 +522,14 @@ async def ahas_resource_access(self, token: str, resource_uri=None) -> bool:
# The logged-in user always has full access to the cluster and its resources. This is especially
# important if they flip on Den Auth without saving the cluster.

# configs.token is the token stored on the cluster itself
if configs.token:
if configs.token == token:
# configs.token is the token stored on the cluster itself, which is itself a hashed subtoken
config_token = configs.token
if config_token:
if config_token == token:
return True
if (
resource_uri
and rns_client.cluster_token(configs.token, resource_uri) == token

if resource_uri and rns_client.validate_cluster_token(
cluster_token=token, cluster_uri=resource_uri
):
return True

Expand Down
6 changes: 5 additions & 1 deletion tests/test_resources/test_clusters/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,13 +425,17 @@ def test_caller_token_propagated(self, cluster):
remote_assume_caller_and_get_token.share(
users=["info@run.house"], notify_users=False
)
current_username = rh.configs.username
current_den_token = rh.configs.token

with friend_account():
unassumed_token, assumed_token = remote_assume_caller_and_get_token()
# "Local token" is the token the cluster accesses in rh.configs.token; this is what will be used
# in subsequent rns_client calls
assert assumed_token == rh.globals.rns_client.cluster_token(
rh.configs.token, cluster.rns_address
cluster.rns_address,
username=current_username,
den_token=current_den_token,
)
assert unassumed_token != rh.configs.token

Expand Down
Loading

0 comments on commit 192b5d0

Please sign in to comment.