diff --git a/runhouse/resources/hardware/on_demand_cluster.py b/runhouse/resources/hardware/on_demand_cluster.py index 9a31f0651..60f2f7759 100644 --- a/runhouse/resources/hardware/on_demand_cluster.py +++ b/runhouse/resources/hardware/on_demand_cluster.py @@ -22,7 +22,7 @@ LOCAL_HOSTS, ) -from runhouse.globals import configs, rns_client +from runhouse.globals import configs, obj_store, rns_client from runhouse.resources.hardware.utils import ServerConnectionType from .cluster import Cluster @@ -112,8 +112,9 @@ def autostop_mins(self): def autostop_mins(self, mins): self.check_server() + self._autostop_mins = mins if self.on_this_cluster(): - raise ValueError("Cannot set autostop_mins live on the cluster.") + obj_store.set_cluster_config_value("autostop_mins", mins) else: if self.run_python(["import skypilot"])[0] != 0: raise ImportError( @@ -121,7 +122,6 @@ def autostop_mins(self, mins): ) self.client.set_settings({"autostop_mins": mins}) sky.autostop(self.name, mins, down=True) - self._autostop_mins = mins def config(self, condensed=True): config = super().config(condensed) @@ -448,14 +448,14 @@ def up(self): return self - def keep_warm(self, autostop_mins: int = -1): + def keep_warm(self, mins: int = -1): """Keep the cluster warm for given number of minutes after inactivity. Args: - autostop_mins (int): Amount of time (in min) to keep the cluster warm after inactivity. + mins (int): Amount of time (in min) to keep the cluster warm after inactivity. If set to -1, keep cluster warm indefinitely. (Default: `-1`) """ - self.autostop_mins = autostop_mins + self.autostop_mins = mins return self def teardown(self): diff --git a/runhouse/servers/autostop_helper.py b/runhouse/servers/autostop_helper.py new file mode 100644 index 000000000..f72dce67b --- /dev/null +++ b/runhouse/servers/autostop_helper.py @@ -0,0 +1,72 @@ +import logging +import shlex +import subprocess + +logger = logging.getLogger(__name__) + + +class AutostopHelper: + """A helper class strictly to run SkyPilot methods on OnDemandClusters inside SkyPilot's conda env.""" + + SKY_VENV = "~/skypilot-runtime" + + def __init__(self): + self._activity_registered = False + + async def set_last_active_time_to_now(self): + self._activity_registered = True + + def _run_python_in_sky_venv(self, cmd: str): + sky_python_cmd = f"{self.SKY_VENV}/bin/python -c {cmd}" + + logger.debug(f"Running command in SkyPilot's venv: {sky_python_cmd}") + # run with subprocess and return the output + return subprocess.run( + sky_python_cmd, + shell=True, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ).stdout.decode("utf-8") + + async def get_autostop(self): + sky_get_autostop_cmd = shlex.quote( + "from sky.skylet.autostop_lib import get_autostop_config; " + "print(get_autostop_config().autostop_idle_minutes)" + ) + + return int(self._run_python_in_sky_venv(sky_get_autostop_cmd)) + + async def get_last_active_time(self): + sky_get_last_active_time_cmd = shlex.quote( + "from sky.skylet.autostop_lib import get_last_active_time; " + "print(get_last_active_time())" + ) + + return float(self._run_python_in_sky_venv(sky_get_last_active_time_cmd)) + + async def set_autostop(self, idle_minutes: int): + # Filling in "cloudvmray" as the backend because it's the only backend supported by SkyPilot right now, + # if needed we can grab the backend from the autostop config with: + # `from sky.skylet.autostop_lib import get_autostop_config; get_autostop_config().backend` + sky_set_autostop_cmd = shlex.quote( + f"from sky.skylet.autostop_lib import set_autostop; " + f'set_autostop({idle_minutes}, "cloudvmray", True)' + ) + + self._run_python_in_sky_venv(sky_set_autostop_cmd) + + async def register_activity_if_needed(self): + sky_register_activity_cmd = shlex.quote( + "from sky.skylet.autostop_lib import set_last_active_time_to_now; " + "set_last_active_time_to_now()" + ) + + if self._activity_registered: + logger.debug("Activity registered, updating last active time in SkyConfig") + self._run_python_in_sky_venv(sky_register_activity_cmd) + self._activity_registered = False + else: + logger.debug( + "No activity registered, not updating last active time in SkyConfig" + ) diff --git a/runhouse/servers/autostop_servlet.py b/runhouse/servers/autostop_servlet.py deleted file mode 100644 index 7de705840..000000000 --- a/runhouse/servers/autostop_servlet.py +++ /dev/null @@ -1,30 +0,0 @@ -import logging -import subprocess - -logger = logging.getLogger(__name__) - - -class AutostopServlet: - """A helper class strictly to run SkyPilot methods on OnDemandClusters inside SkyPilot's conda env.""" - - def __init__(self): - self._activity_registered = False - - def set_last_active_time_to_now(self): - self._activity_registered = True - - def update_autostop_in_sky_config(self): - SKY_VENV = "~/skypilot-runtime" - SKY_AUTOSTOP_CMD = "from sky.skylet.autostop_lib import set_last_active_time_to_now; set_last_active_time_to_now()" - SKY_CMD = f"{SKY_VENV}/bin/python -c '{SKY_AUTOSTOP_CMD}'" - - if self._activity_registered: - logger.debug( - "Activity registered, updating last active time in SkyConfig with command: {SKY_CMD}" - ) - subprocess.run(SKY_CMD, shell=True, check=True) - self._activity_registered = False - else: - logger.debug( - "No activity registered, not updating last active time in SkyConfig" - ) diff --git a/runhouse/servers/cluster_servlet.py b/runhouse/servers/cluster_servlet.py index 3c1d3df6a..a1e359e4f 100644 --- a/runhouse/servers/cluster_servlet.py +++ b/runhouse/servers/cluster_servlet.py @@ -19,7 +19,7 @@ from runhouse.resources.hardware import load_cluster_config_from_file from runhouse.rns.rns_client import ResourceStatusData from runhouse.rns.utils.api import ResourceAccess -from runhouse.servers.autostop_servlet import AutostopServlet +from runhouse.servers.autostop_helper import AutostopHelper from runhouse.servers.http.auth import AuthCache from runhouse.utils import sync_function @@ -48,25 +48,10 @@ async def __init__( self._initialized_env_servlet_names: Set[str] = set() self._key_to_env_servlet_name: Dict[Any, str] = {} self._auth_cache: AuthCache = AuthCache(cluster_config) - self.autostop_servlet = None + self.autostop_helper = None if cluster_config.get("resource_subtype", None) == "OnDemandCluster": - import ray - - current_ip = ray.get_runtime_context().worker.node_ip_address - self.autostop_servlet = ( - ray.remote(AutostopServlet) - .options( - name="autostop_servlet", - get_if_exists=True, - lifetime="detached", - namespace="runhouse", - max_concurrency=1000, - resources={f"node:{current_ip}": 0.001}, - num_cpus=0, - ) - .remote() - ) + self.autostop_helper = AutostopHelper() logger.info("Creating periodic_status_check thread.") post_status_thread = threading.Thread( @@ -119,8 +104,8 @@ async def aset_cluster_config(self, cluster_config: Dict[str, Any]): return self.cluster_config async def aset_cluster_config_value(self, key: str, value: Any): - if self.autostop_servlet and key == "autostop_mins" and value > -1: - await self.autostop_servlet.set_auto_stop.remote(value) + if self.autostop_helper and key == "autostop_mins": + await self.autostop_helper.set_autostop(value) self.cluster_config[key] = value # Propagate the changes to all other process's obj_stores @@ -206,8 +191,8 @@ async def aget_key_to_env_servlet_name_dict(self) -> Dict[Any, str]: return self._key_to_env_servlet_name async def aget_env_servlet_name_for_key(self, key: Any) -> str: - if self.autostop_servlet: - await self.autostop_servlet.set_last_active_time_to_now.remote() + if self.autostop_helper: + await self.autostop_helper.set_last_active_time_to_now() return self._key_to_env_servlet_name.get(key, None) async def aput_env_servlet_name_for_key(self, key: Any, env_servlet_name: str): @@ -257,7 +242,7 @@ async def aperiodic_status_check(self): # Only if one of these is true, do we actually need to get the status from each EnvServlet should_send_status_to_den = den_auth and interval_size != -1 - should_update_autostop = self.autostop_servlet is not None + should_update_autostop = self.autostop_helper is not None if should_send_status_to_den or should_update_autostop: logger.info( "Performing cluster status check: potentially sending to Den or updating autostop." @@ -273,8 +258,12 @@ async def aperiodic_status_check(self): for resources in status.env_resource_mapping.values() ) if function_running: - await self.autostop_servlet.set_last_active_time_to_now.remote() - await self.autostop_servlet.update_autostop_in_sky_config.remote() + await self.autostop_helper.set_last_active_time_to_now() + # We do this separately from the set_last_active_time_to_now call above because + # function_running will only reflect activity from functions which happen to be running during + # the status check. We still need to attempt to register activity for functions which have + # been called and completed. + await self.autostop_helper.register_activity_if_needed() if should_send_status_to_den: cluster_rns_address = cluster_config.get("name") diff --git a/tests/test_resources/test_clusters/test_on_demand_cluster.py b/tests/test_resources/test_clusters/test_on_demand_cluster.py index aed1e6d73..3578ac36b 100644 --- a/tests/test_resources/test_clusters/test_on_demand_cluster.py +++ b/tests/test_resources/test_clusters/test_on_demand_cluster.py @@ -1,9 +1,47 @@ +import asyncio +import time + import pytest +import runhouse as rh + import tests.test_resources.test_clusters.test_cluster from tests.utils import friend_account +def set_autostop_from_on_cluster_via_ah(mins): + ah = rh.servers.autostop_helper.AutostopHelper() + + asyncio.run(ah.set_autostop(mins)) + + +def get_auotstop_from_on_cluster(): + ah = rh.servers.autostop_helper.AutostopHelper() + + return asyncio.run(ah.get_autostop()) + + +def get_last_active_time_from_on_cluster(): + ah = rh.servers.autostop_helper.AutostopHelper() + + return asyncio.run(ah.get_last_active_time()) + + +def register_activity_from_on_cluster(): + ah = rh.servers.autostop_helper.AutostopHelper() + + asyncio.run(ah.set_last_active_time_to_now()) + asyncio.run(ah.register_activity_if_needed()) + + +def set_autostop_from_on_cluster_via_cluster_obj(mins): + rh.here.autostop_mins = mins + + +def set_autostop_from_on_cluster_via_cluster_keep_warm(): + rh.here.keep_warm() + + class TestOnDemandCluster(tests.test_resources.test_clusters.test_cluster.TestCluster): MAP_FIXTURES = {"resource": "cluster"} @@ -59,3 +97,61 @@ def test_restart_does_not_change_config_yaml(self, cluster): assert config_yaml_res_after_restart[0][0] == 0 config_yaml_content_after_restart = config_yaml_res[0][1] assert config_yaml_content_after_restart == config_yaml_content + + @pytest.mark.level("minimal") + def test_autostop(self, cluster): + rh.env( + working_dir="local:./", reqs=["pytest", "pandas"], name="autostop_env" + ).to(cluster) + get_autostop = rh.fn(get_auotstop_from_on_cluster).to( + cluster, env="autostop_env" + ) + # First check that the autostop is set to whatever the cluster set it to + assert get_autostop() == cluster.autostop_mins + original_autostop = cluster.autostop_mins + + set_autostop = rh.fn(set_autostop_from_on_cluster_via_ah).to( + cluster, env="autostop_env" + ) + set_autostop(5) + assert get_autostop() == 5 + + register_activity = rh.fn(register_activity_from_on_cluster).to( + cluster, env="autostop_env" + ) + get_last_active = rh.fn(get_last_active_time_from_on_cluster).to( + cluster, env="autostop_env" + ) + + register_activity() + # Check that last active is within the last 2 seconds + assert get_last_active() > time.time() - 2 + + set_autostop_via_cluster_keep_warm = rh.fn( + set_autostop_from_on_cluster_via_cluster_keep_warm + ).to(cluster, env="autostop_env") + set_autostop_via_cluster_keep_warm() + assert get_autostop() == -1 + + set_autostop_via_cluster_obj = rh.fn( + set_autostop_from_on_cluster_via_cluster_obj + ).to(cluster, env="autostop_env") + # reset the autostop to the original value + set_autostop_via_cluster_obj(original_autostop) + assert get_autostop() == original_autostop + + # TODO add a way to manually trigger the status loop to check that activity + # is actually registered after a call + # cluster.call("autostop_env", "config") + # cluster.status() + # assert get_last_active() > time.time() - 2 + + # TODO add a way to manually trigger the status loop to check that activity + # is actually registered during a long running function + # from .test_cluster import sleep_fn + # sleep_remote = rh.fn(sleep_fn).to(cluster, env="autostop_env") + # Thread(target=sleep_remote, args=(3,)).start() + # time.sleep(2) + # cluster.status() + # # Check that last active is within the last second, so we know the activity wasn't just from the call itself + # assert get_last_active() > time.time() - 1