From f18c1520e4f1716ec160c2337527ff1b97589f0a Mon Sep 17 00:00:00 2001 From: Rohin Bhasin Date: Tue, 16 Apr 2024 11:34:17 -0400 Subject: [PATCH] Make `env_servlets` cache an instance field instead of global. --- runhouse/globals.py | 1 - runhouse/servers/cluster_servlet.py | 12 +-- runhouse/servers/http/http_server.py | 5 +- runhouse/servers/obj_store.py | 94 +++++++++++---------- tests/test_servers/test_server_obj_store.py | 6 +- tests/utils.py | 2 +- 6 files changed, 65 insertions(+), 55 deletions(-) diff --git a/runhouse/globals.py b/runhouse/globals.py index 138250f24..89c6b23fc 100644 --- a/runhouse/globals.py +++ b/runhouse/globals.py @@ -32,4 +32,3 @@ def clean_up_ssh_connections(): # Note: this initalizes a dummy global object. The obj_store must # be properly initialized by a servlet via initialize. obj_store = ObjStore() -env_servlets = {} diff --git a/runhouse/servers/cluster_servlet.py b/runhouse/servers/cluster_servlet.py index 4cae0682b..4ab808484 100644 --- a/runhouse/servers/cluster_servlet.py +++ b/runhouse/servers/cluster_servlet.py @@ -4,7 +4,7 @@ import time from typing import Any, Dict, List, Optional, Set, Union -from runhouse.globals import configs, ObjStore, rns_client +from runhouse.globals import configs, obj_store, rns_client from runhouse.resources.hardware import load_cluster_config_from_file from runhouse.rns.utils.api import ResourceAccess from runhouse.servers.http.auth import AuthCache @@ -76,10 +76,11 @@ async def aset_cluster_config(self, cluster_config: Dict[str, Any]): # Propagate the changes to all other process's obj_stores await asyncio.gather( *[ - ObjStore.acall_actor_method( - ObjStore.get_env_servlet(env_servlet_name), + obj_store.acall_env_servlet_method( + env_servlet_name, "aset_cluster_config", cluster_config, + use_env_servlet_cache=False, ) for env_servlet_name in await self.aget_all_initialized_env_servlet_names() ] @@ -98,11 +99,12 @@ async def aset_cluster_config_value(self, key: str, value: Any): # Propagate the changes to all other process's obj_stores await asyncio.gather( *[ - ObjStore.acall_actor_method( - ObjStore.get_env_servlet(env_servlet_name), + obj_store.acall_env_servlet_method( + env_servlet_name, "aset_cluster_config_value", key, value, + use_env_servlet_cache=False, ) for env_servlet_name in await self.aget_all_initialized_env_servlet_names() ] diff --git a/runhouse/servers/http/http_server.py b/runhouse/servers/http/http_server.py index 0e4d5fcb6..342a35dee 100644 --- a/runhouse/servers/http/http_server.py +++ b/runhouse/servers/http/http_server.py @@ -46,7 +46,6 @@ ) from runhouse.servers.obj_store import ( ClusterServletSetupOption, - ObjStore, ObjStoreError, RaySetupOption, ) @@ -199,7 +198,7 @@ async def _add_username_to_span(request: Request, call_next): # TODO: We aren't sure _exactly_ where this is or isn't used. # There are a few spots where we do `env_name or "base"`, and # this allows that base env to be pre-initialized. - _ = ObjStore.get_env_servlet( + _ = obj_store.get_env_servlet( env_name="base", create=True, runtime_env=runtime_env, @@ -679,7 +678,7 @@ async def get_keys(request: Request, env_name: Optional[str] = None): if not env_name: output = await obj_store.akeys() else: - output = await ObjStore.akeys_for_env_servlet_name(env_name) + output = await obj_store.akeys_for_env_servlet_name(env_name) # Expicitly tell the client not to attempt to deserialize the output return Response( diff --git a/runhouse/servers/obj_store.py b/runhouse/servers/obj_store.py index f3d31ca98..0690ec49b 100644 --- a/runhouse/servers/obj_store.py +++ b/runhouse/servers/obj_store.py @@ -134,6 +134,7 @@ def __init__(self): self.imported_modules = {} self.installed_envs = {} # TODO: consider deleting it? self._kv_store: Dict[Any, Any] = None + self.env_servlet_cache = {} # Ray defaults to setting OMP_NUM_THREADS to 1, which unexpectedly limit parallelism in user programs. # We delete it by default, but if we find that the user explicitly set it to another value, we respect that. @@ -261,9 +262,15 @@ def initialize( ############################################## # Generic helpers ############################################## - @staticmethod - async def acall_env_servlet_method(servlet_name: str, method: str, *args, **kwargs): - env_servlet = ObjStore.get_env_servlet(servlet_name) + async def acall_env_servlet_method( + self, + servlet_name: str, + method: str, + *args, + use_env_servlet_cache: bool = True, + **kwargs, + ): + env_servlet = self.get_env_servlet(servlet_name, use_env_servlet_cache) return await ObjStore.acall_actor_method(env_servlet, method, *args, **kwargs) @staticmethod @@ -281,25 +288,26 @@ def call_actor_method(actor: ray.actor.ActorHandle, method: str, *args, **kwargs return ray.get(getattr(actor, method).remote(*args, **kwargs)) - @staticmethod def get_env_servlet( + self, env_name: str, create: bool = False, raise_ex_if_not_found: bool = False, resources: Optional[Dict[str, Any]] = None, + use_env_servlet_cache: bool = True, **kwargs, ): # Need to import these here to avoid circular imports - from runhouse.globals import env_servlets from runhouse.servers.env_servlet import EnvServlet - if env_name in env_servlets: - return env_servlets[env_name] + if use_env_servlet_cache and env_name in self.env_servlet_cache: + return self.env_servlet_cache[env_name] # It may not have been cached, but does exist try: existing_actor = ray.get_actor(env_name, namespace="runhouse") - env_servlets[env_name] = existing_actor + if use_env_servlet_cache: + self.env_servlet_cache[env_name] = existing_actor return existing_actor except ValueError: # ValueError: Failed to look up actor with name ... @@ -339,8 +347,8 @@ def get_env_servlet( # Make sure env_servlet is actually initialized # ray.get(new_env_actor.register_activity.remote()) - - env_servlets[env_name] = new_env_actor + if use_env_servlet_cache: + self.env_servlet_cache[env_name] = new_env_actor return new_env_actor else: @@ -542,9 +550,8 @@ async def aremove_env_servlet_name(self, env_servlet_name: str): ############################################## # KV Store: Keys ############################################## - @staticmethod - async def akeys_for_env_servlet_name(env_servlet_name: str) -> List[Any]: - return await ObjStore.acall_env_servlet_method(env_servlet_name, "akeys_local") + async def akeys_for_env_servlet_name(self, env_servlet_name: str) -> List[Any]: + return await self.acall_env_servlet_method(env_servlet_name, "akeys_local") def keys_for_env_servlet_name(self, env_servlet_name: str) -> List[Any]: return sync_function(self.akeys_for_env_servlet_name)(env_servlet_name) @@ -567,11 +574,14 @@ def keys(self) -> List[Any]: ############################################## # KV Store: Put ############################################## - @staticmethod async def aput_for_env_servlet_name( - env_servlet_name: str, key: Any, data: Any, serialization: Optional[str] = None + self, + env_servlet_name: str, + key: Any, + data: Any, + serialization: Optional[str] = None, ): - return await ObjStore.acall_env_servlet_method( + return await self.acall_env_servlet_method( env_servlet_name, "aput_local", key, @@ -638,8 +648,8 @@ def put( ############################################## # KV Store: Get ############################################## - @staticmethod async def aget_from_env_servlet_name( + self, env_servlet_name: str, key: Any, default: Optional[Any] = None, @@ -647,7 +657,7 @@ async def aget_from_env_servlet_name( remote: bool = False, ): logger.info(f"Getting {key} from servlet {env_servlet_name}") - return await ObjStore.acall_env_servlet_method( + return await self.acall_env_servlet_method( env_servlet_name, "aget_local", key, @@ -656,15 +666,15 @@ async def aget_from_env_servlet_name( remote=remote, ) - @staticmethod def get_from_env_servlet_name( + self, env_servlet_name: str, key: Any, default: Optional[Any] = None, serialization: Optional[str] = None, remote: bool = False, ): - return sync_function(ObjStore.aget_from_env_servlet_name)( + return sync_function(self.aget_from_env_servlet_name)( env_servlet_name, key, default, serialization, remote ) @@ -758,9 +768,8 @@ def get( ############################################## # KV Store: Contains ############################################## - @staticmethod - async def acontains_for_env_servlet_name(env_servlet_name: str, key: Any): - return await ObjStore.acall_env_servlet_method( + async def acontains_for_env_servlet_name(self, env_servlet_name: str, key: Any): + return await self.acall_env_servlet_method( env_servlet_name, "acontains_local", key ) @@ -791,11 +800,14 @@ def contains(self, key: Any): ############################################## # KV Store: Pop ############################################## - @staticmethod async def apop_from_env_servlet_name( - env_servlet_name: str, key: Any, serialization: Optional[str] = "pickle", *args + self, + env_servlet_name: str, + key: Any, + serialization: Optional[str] = "pickle", + *args, ) -> Any: - return await ObjStore.acall_env_servlet_method( + return await self.acall_env_servlet_method( env_servlet_name, "apop_local", key, @@ -863,9 +875,8 @@ def pop(self, key: Any, serialization: Optional[str] = "pickle", *args) -> Any: ############################################## # KV Store: Delete ############################################## - @staticmethod - async def adelete_for_env_servlet_name(env_servlet_name: str, key: Any): - return await ObjStore.acall_env_servlet_method( + async def adelete_for_env_servlet_name(self, env_servlet_name: str, key: Any): + return await self.acall_env_servlet_method( env_servlet_name, "adelete_local", key ) @@ -873,18 +884,17 @@ async def adelete_local(self, key: Any): await self.apop_local(key) async def adelete_env_contents(self, env_name: Any): - from runhouse.globals import env_servlets # clear keys in the env servlet deleted_keys = await self.akeys_for_env_servlet_name(env_name) await self.aclear_for_env_servlet_name(env_name) # delete the env servlet actor and remove its references - if env_name in env_servlets: - actor = env_servlets[env_name] + if env_name in self.env_servlet_cache: + actor = self.env_servlet_cache[env_name] ray.kill(actor) - del env_servlets[env_name] + del self.env_servlet_cache[env_name] await self.aremove_env_servlet_name(env_name) return deleted_keys @@ -926,9 +936,8 @@ def delete(self, key: Union[Any, List[Any]]): ############################################## # KV Store: Clear ############################################## - @staticmethod - async def aclear_for_env_servlet_name(env_servlet_name: str): - return await ObjStore.acall_env_servlet_method(env_servlet_name, "aclear_local") + async def aclear_for_env_servlet_name(self, env_servlet_name: str): + return await self.acall_env_servlet_method(env_servlet_name, "aclear_local") async def aclear_local(self): if self.has_local_storage: @@ -951,11 +960,10 @@ def clear(self): ############################################## # KV Store: Rename ############################################## - @staticmethod async def arename_for_env_servlet_name( - env_servlet_name: str, old_key: Any, new_key: Any + self, env_servlet_name: str, old_key: Any, new_key: Any ): - return await ObjStore.acall_env_servlet_method( + return await self.acall_env_servlet_method( env_servlet_name, "arename_local", old_key, @@ -998,8 +1006,8 @@ def rename(self, old_key: Any, new_key: Any): ############################################## # KV Store: Call ############################################## - @staticmethod async def acall_for_env_servlet_name( + self, env_servlet_name: str, key: Any, method_name: str, @@ -1009,7 +1017,7 @@ async def acall_for_env_servlet_name( stream_logs: bool = False, remote: bool = False, ): - return await ObjStore.acall_env_servlet_method( + return await self.acall_env_servlet_method( env_servlet_name, "acall_local", key, @@ -1409,14 +1417,14 @@ async def aput_resource( else {} ) - _ = ObjStore.get_env_servlet( + _ = self.get_env_servlet( env_name=env_name, create=True, runtime_env=runtime_env, resources=resource_config.get("compute", None), ) - return await ObjStore.acall_env_servlet_method( + return await self.acall_env_servlet_method( env_name, "aput_resource_local", data=serialized_data, diff --git a/tests/test_servers/test_server_obj_store.py b/tests/test_servers/test_server_obj_store.py index ea4fab681..f80e675dd 100644 --- a/tests/test_servers/test_server_obj_store.py +++ b/tests/test_servers/test_server_obj_store.py @@ -1,6 +1,6 @@ import pytest -from runhouse.servers.obj_store import ObjStore, ObjStoreError +from runhouse.servers.obj_store import ObjStoreError from tests.utils import friend_account, get_ray_servlet_and_obj_store @@ -341,7 +341,9 @@ def test_delete_env_servlet(self, obj_store): # check that corresponding Ray actor is killed with pytest.raises(ObjStoreError): - ObjStore.get_env_servlet(env_name=env_to_delete, raise_ex_if_not_found=True) + obj_store.get_env_servlet( + env_name=env_to_delete, raise_ex_if_not_found=True + ) @pytest.mark.servertest diff --git a/tests/utils.py b/tests/utils.py index 316be9cf6..3a22d6cd1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -18,7 +18,7 @@ def get_ray_servlet_and_obj_store(env_name): test_obj_store = ObjStore() test_obj_store.initialize(env_name, setup_ray=RaySetupOption.GET_OR_FAIL) - servlet = ObjStore.get_env_servlet( + servlet = test_obj_store.get_env_servlet( env_name=env_name, create=True, )