Skip to content

Commit

Permalink
Change AutostopServlet into AutostopHelper, and test properly (#897)
Browse files Browse the repository at this point in the history
AutostopServlet is no longer in its own process because it's easier to just run
python commands in SkyPilot's venv through subprocess than start an actor inside
that interpreter. Here we're changing the AutostopServlet to AutostopHelper (a regular class,
not an Actor), fleshing out the available autostop commands it can make,  and testing it thoroughly.

Previously keep_warm or the cluster.autostop_mins setter wouldn't work if called on the cluster,
now they do.
  • Loading branch information
dongreenberg committed Jun 16, 2024
1 parent 9a61f2c commit c4c65cc
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 61 deletions.
12 changes: 6 additions & 6 deletions runhouse/resources/hardware/on_demand_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
LOCAL_HOSTS,
)

from runhouse.globals import configs, rns_client
from runhouse.globals import configs, obj_store, rns_client
from runhouse.resources.hardware.utils import ServerConnectionType

from .cluster import Cluster
Expand Down Expand Up @@ -112,16 +112,16 @@ def autostop_mins(self):
def autostop_mins(self, mins):
self.check_server()

self._autostop_mins = mins
if self.on_this_cluster():
raise ValueError("Cannot set autostop_mins live on the cluster.")
obj_store.set_cluster_config_value("autostop_mins", mins)
else:
if self.run_python(["import skypilot"])[0] != 0:
raise ImportError(
"Skypilot must be installed on the cluster in order to set autostop."
)
self.client.set_settings({"autostop_mins": mins})
sky.autostop(self.name, mins, down=True)
self._autostop_mins = mins

def config(self, condensed=True):
config = super().config(condensed)
Expand Down Expand Up @@ -448,14 +448,14 @@ def up(self):

return self

def keep_warm(self, autostop_mins: int = -1):
def keep_warm(self, mins: int = -1):
"""Keep the cluster warm for given number of minutes after inactivity.
Args:
autostop_mins (int): Amount of time (in min) to keep the cluster warm after inactivity.
mins (int): Amount of time (in min) to keep the cluster warm after inactivity.
If set to -1, keep cluster warm indefinitely. (Default: `-1`)
"""
self.autostop_mins = autostop_mins
self.autostop_mins = mins
return self

def teardown(self):
Expand Down
72 changes: 72 additions & 0 deletions runhouse/servers/autostop_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import logging
import shlex
import subprocess

logger = logging.getLogger(__name__)


class AutostopHelper:
"""A helper class strictly to run SkyPilot methods on OnDemandClusters inside SkyPilot's conda env."""

SKY_VENV = "~/skypilot-runtime"

def __init__(self):
self._activity_registered = False

async def set_last_active_time_to_now(self):
self._activity_registered = True

def _run_python_in_sky_venv(self, cmd: str):
sky_python_cmd = f"{self.SKY_VENV}/bin/python -c {cmd}"

logger.debug(f"Running command in SkyPilot's venv: {sky_python_cmd}")
# run with subprocess and return the output
return subprocess.run(
sky_python_cmd,
shell=True,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
).stdout.decode("utf-8")

async def get_autostop(self):
sky_get_autostop_cmd = shlex.quote(
"from sky.skylet.autostop_lib import get_autostop_config; "
"print(get_autostop_config().autostop_idle_minutes)"
)

return int(self._run_python_in_sky_venv(sky_get_autostop_cmd))

async def get_last_active_time(self):
sky_get_last_active_time_cmd = shlex.quote(
"from sky.skylet.autostop_lib import get_last_active_time; "
"print(get_last_active_time())"
)

return float(self._run_python_in_sky_venv(sky_get_last_active_time_cmd))

async def set_autostop(self, idle_minutes: int):
# Filling in "cloudvmray" as the backend because it's the only backend supported by SkyPilot right now,
# if needed we can grab the backend from the autostop config with:
# `from sky.skylet.autostop_lib import get_autostop_config; get_autostop_config().backend`
sky_set_autostop_cmd = shlex.quote(
f"from sky.skylet.autostop_lib import set_autostop; "
f'set_autostop({idle_minutes}, "cloudvmray", True)'
)

self._run_python_in_sky_venv(sky_set_autostop_cmd)

async def register_activity_if_needed(self):
sky_register_activity_cmd = shlex.quote(
"from sky.skylet.autostop_lib import set_last_active_time_to_now; "
"set_last_active_time_to_now()"
)

if self._activity_registered:
logger.debug("Activity registered, updating last active time in SkyConfig")
self._run_python_in_sky_venv(sky_register_activity_cmd)
self._activity_registered = False
else:
logger.debug(
"No activity registered, not updating last active time in SkyConfig"
)
30 changes: 0 additions & 30 deletions runhouse/servers/autostop_servlet.py

This file was deleted.

39 changes: 14 additions & 25 deletions runhouse/servers/cluster_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from runhouse.resources.hardware import load_cluster_config_from_file
from runhouse.rns.rns_client import ResourceStatusData
from runhouse.rns.utils.api import ResourceAccess
from runhouse.servers.autostop_servlet import AutostopServlet
from runhouse.servers.autostop_helper import AutostopHelper
from runhouse.servers.http.auth import AuthCache

from runhouse.utils import sync_function
Expand Down Expand Up @@ -48,25 +48,10 @@ async def __init__(
self._initialized_env_servlet_names: Set[str] = set()
self._key_to_env_servlet_name: Dict[Any, str] = {}
self._auth_cache: AuthCache = AuthCache(cluster_config)
self.autostop_servlet = None
self.autostop_helper = None

if cluster_config.get("resource_subtype", None) == "OnDemandCluster":
import ray

current_ip = ray.get_runtime_context().worker.node_ip_address
self.autostop_servlet = (
ray.remote(AutostopServlet)
.options(
name="autostop_servlet",
get_if_exists=True,
lifetime="detached",
namespace="runhouse",
max_concurrency=1000,
resources={f"node:{current_ip}": 0.001},
num_cpus=0,
)
.remote()
)
self.autostop_helper = AutostopHelper()

logger.info("Creating periodic_status_check thread.")
post_status_thread = threading.Thread(
Expand Down Expand Up @@ -119,8 +104,8 @@ async def aset_cluster_config(self, cluster_config: Dict[str, Any]):
return self.cluster_config

async def aset_cluster_config_value(self, key: str, value: Any):
if self.autostop_servlet and key == "autostop_mins" and value > -1:
await self.autostop_servlet.set_auto_stop.remote(value)
if self.autostop_helper and key == "autostop_mins":
await self.autostop_helper.set_autostop(value)
self.cluster_config[key] = value

# Propagate the changes to all other process's obj_stores
Expand Down Expand Up @@ -206,8 +191,8 @@ async def aget_key_to_env_servlet_name_dict(self) -> Dict[Any, str]:
return self._key_to_env_servlet_name

async def aget_env_servlet_name_for_key(self, key: Any) -> str:
if self.autostop_servlet:
await self.autostop_servlet.set_last_active_time_to_now.remote()
if self.autostop_helper:
await self.autostop_helper.set_last_active_time_to_now()
return self._key_to_env_servlet_name.get(key, None)

async def aput_env_servlet_name_for_key(self, key: Any, env_servlet_name: str):
Expand Down Expand Up @@ -257,7 +242,7 @@ async def aperiodic_status_check(self):

# Only if one of these is true, do we actually need to get the status from each EnvServlet
should_send_status_to_den = den_auth and interval_size != -1
should_update_autostop = self.autostop_servlet is not None
should_update_autostop = self.autostop_helper is not None
if should_send_status_to_den or should_update_autostop:
logger.info(
"Performing cluster status check: potentially sending to Den or updating autostop."
Expand All @@ -273,8 +258,12 @@ async def aperiodic_status_check(self):
for resources in status.env_resource_mapping.values()
)
if function_running:
await self.autostop_servlet.set_last_active_time_to_now.remote()
await self.autostop_servlet.update_autostop_in_sky_config.remote()
await self.autostop_helper.set_last_active_time_to_now()
# We do this separately from the set_last_active_time_to_now call above because
# function_running will only reflect activity from functions which happen to be running during
# the status check. We still need to attempt to register activity for functions which have
# been called and completed.
await self.autostop_helper.register_activity_if_needed()

if should_send_status_to_den:
cluster_rns_address = cluster_config.get("name")
Expand Down
96 changes: 96 additions & 0 deletions tests/test_resources/test_clusters/test_on_demand_cluster.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,47 @@
import asyncio
import time

import pytest

import runhouse as rh

import tests.test_resources.test_clusters.test_cluster
from tests.utils import friend_account


def set_autostop_from_on_cluster_via_ah(mins):
ah = rh.servers.autostop_helper.AutostopHelper()

asyncio.run(ah.set_autostop(mins))


def get_auotstop_from_on_cluster():
ah = rh.servers.autostop_helper.AutostopHelper()

return asyncio.run(ah.get_autostop())


def get_last_active_time_from_on_cluster():
ah = rh.servers.autostop_helper.AutostopHelper()

return asyncio.run(ah.get_last_active_time())


def register_activity_from_on_cluster():
ah = rh.servers.autostop_helper.AutostopHelper()

asyncio.run(ah.set_last_active_time_to_now())
asyncio.run(ah.register_activity_if_needed())


def set_autostop_from_on_cluster_via_cluster_obj(mins):
rh.here.autostop_mins = mins


def set_autostop_from_on_cluster_via_cluster_keep_warm():
rh.here.keep_warm()


class TestOnDemandCluster(tests.test_resources.test_clusters.test_cluster.TestCluster):

MAP_FIXTURES = {"resource": "cluster"}
Expand Down Expand Up @@ -59,3 +97,61 @@ def test_restart_does_not_change_config_yaml(self, cluster):
assert config_yaml_res_after_restart[0][0] == 0
config_yaml_content_after_restart = config_yaml_res[0][1]
assert config_yaml_content_after_restart == config_yaml_content

@pytest.mark.level("minimal")
def test_autostop(self, cluster):
rh.env(
working_dir="local:./", reqs=["pytest", "pandas"], name="autostop_env"
).to(cluster)
get_autostop = rh.fn(get_auotstop_from_on_cluster).to(
cluster, env="autostop_env"
)
# First check that the autostop is set to whatever the cluster set it to
assert get_autostop() == cluster.autostop_mins
original_autostop = cluster.autostop_mins

set_autostop = rh.fn(set_autostop_from_on_cluster_via_ah).to(
cluster, env="autostop_env"
)
set_autostop(5)
assert get_autostop() == 5

register_activity = rh.fn(register_activity_from_on_cluster).to(
cluster, env="autostop_env"
)
get_last_active = rh.fn(get_last_active_time_from_on_cluster).to(
cluster, env="autostop_env"
)

register_activity()
# Check that last active is within the last 2 seconds
assert get_last_active() > time.time() - 2

set_autostop_via_cluster_keep_warm = rh.fn(
set_autostop_from_on_cluster_via_cluster_keep_warm
).to(cluster, env="autostop_env")
set_autostop_via_cluster_keep_warm()
assert get_autostop() == -1

set_autostop_via_cluster_obj = rh.fn(
set_autostop_from_on_cluster_via_cluster_obj
).to(cluster, env="autostop_env")
# reset the autostop to the original value
set_autostop_via_cluster_obj(original_autostop)
assert get_autostop() == original_autostop

# TODO add a way to manually trigger the status loop to check that activity
# is actually registered after a call
# cluster.call("autostop_env", "config")
# cluster.status()
# assert get_last_active() > time.time() - 2

# TODO add a way to manually trigger the status loop to check that activity
# is actually registered during a long running function
# from .test_cluster import sleep_fn
# sleep_remote = rh.fn(sleep_fn).to(cluster, env="autostop_env")
# Thread(target=sleep_remote, args=(3,)).start()
# time.sleep(2)
# cluster.status()
# # Check that last active is within the last second, so we know the activity wasn't just from the call itself
# assert get_last_active() > time.time() - 1

0 comments on commit c4c65cc

Please sign in to comment.