From 39536fdb032b66f789af3f4b5f382f082aa63ff9 Mon Sep 17 00:00:00 2001 From: Rohin Bhasin Date: Tue, 16 Apr 2024 14:12:31 -0400 Subject: [PATCH] Make `env_servlets` cache an instance field instead of global. (#736) We don't need this to be global, it's only used in the object store. In general, minimizing the amount of global stuff we have referencing important objects is better to avoid weird memory issues. --- runhouse/globals.py | 1 - runhouse/servers/cluster_servlet.py | 12 +-- runhouse/servers/http/http_server.py | 5 +- runhouse/servers/obj_store.py | 96 ++++++++++++--------- tests/test_servers/test_server_obj_store.py | 6 +- tests/utils.py | 2 +- 6 files changed, 67 insertions(+), 55 deletions(-) diff --git a/runhouse/globals.py b/runhouse/globals.py index 138250f24..89c6b23fc 100644 --- a/runhouse/globals.py +++ b/runhouse/globals.py @@ -32,4 +32,3 @@ def clean_up_ssh_connections(): # Note: this initalizes a dummy global object. The obj_store must # be properly initialized by a servlet via initialize. obj_store = ObjStore() -env_servlets = {} diff --git a/runhouse/servers/cluster_servlet.py b/runhouse/servers/cluster_servlet.py index 4cae0682b..4ab808484 100644 --- a/runhouse/servers/cluster_servlet.py +++ b/runhouse/servers/cluster_servlet.py @@ -4,7 +4,7 @@ import time from typing import Any, Dict, List, Optional, Set, Union -from runhouse.globals import configs, ObjStore, rns_client +from runhouse.globals import configs, obj_store, rns_client from runhouse.resources.hardware import load_cluster_config_from_file from runhouse.rns.utils.api import ResourceAccess from runhouse.servers.http.auth import AuthCache @@ -76,10 +76,11 @@ async def aset_cluster_config(self, cluster_config: Dict[str, Any]): # Propagate the changes to all other process's obj_stores await asyncio.gather( *[ - ObjStore.acall_actor_method( - ObjStore.get_env_servlet(env_servlet_name), + obj_store.acall_env_servlet_method( + env_servlet_name, "aset_cluster_config", cluster_config, + use_env_servlet_cache=False, ) for env_servlet_name in await self.aget_all_initialized_env_servlet_names() ] @@ -98,11 +99,12 @@ async def aset_cluster_config_value(self, key: str, value: Any): # Propagate the changes to all other process's obj_stores await asyncio.gather( *[ - ObjStore.acall_actor_method( - ObjStore.get_env_servlet(env_servlet_name), + obj_store.acall_env_servlet_method( + env_servlet_name, "aset_cluster_config_value", key, value, + use_env_servlet_cache=False, ) for env_servlet_name in await self.aget_all_initialized_env_servlet_names() ] diff --git a/runhouse/servers/http/http_server.py b/runhouse/servers/http/http_server.py index 0e4d5fcb6..342a35dee 100644 --- a/runhouse/servers/http/http_server.py +++ b/runhouse/servers/http/http_server.py @@ -46,7 +46,6 @@ ) from runhouse.servers.obj_store import ( ClusterServletSetupOption, - ObjStore, ObjStoreError, RaySetupOption, ) @@ -199,7 +198,7 @@ async def _add_username_to_span(request: Request, call_next): # TODO: We aren't sure _exactly_ where this is or isn't used. # There are a few spots where we do `env_name or "base"`, and # this allows that base env to be pre-initialized. - _ = ObjStore.get_env_servlet( + _ = obj_store.get_env_servlet( env_name="base", create=True, runtime_env=runtime_env, @@ -679,7 +678,7 @@ async def get_keys(request: Request, env_name: Optional[str] = None): if not env_name: output = await obj_store.akeys() else: - output = await ObjStore.akeys_for_env_servlet_name(env_name) + output = await obj_store.akeys_for_env_servlet_name(env_name) # Expicitly tell the client not to attempt to deserialize the output return Response( diff --git a/runhouse/servers/obj_store.py b/runhouse/servers/obj_store.py index f3d31ca98..acbb994c1 100644 --- a/runhouse/servers/obj_store.py +++ b/runhouse/servers/obj_store.py @@ -134,6 +134,7 @@ def __init__(self): self.imported_modules = {} self.installed_envs = {} # TODO: consider deleting it? self._kv_store: Dict[Any, Any] = None + self.env_servlet_cache = {} # Ray defaults to setting OMP_NUM_THREADS to 1, which unexpectedly limit parallelism in user programs. # We delete it by default, but if we find that the user explicitly set it to another value, we respect that. @@ -261,9 +262,17 @@ def initialize( ############################################## # Generic helpers ############################################## - @staticmethod - async def acall_env_servlet_method(servlet_name: str, method: str, *args, **kwargs): - env_servlet = ObjStore.get_env_servlet(servlet_name) + async def acall_env_servlet_method( + self, + servlet_name: str, + method: str, + *args, + use_env_servlet_cache: bool = True, + **kwargs, + ): + env_servlet = self.get_env_servlet( + servlet_name, use_env_servlet_cache=use_env_servlet_cache + ) return await ObjStore.acall_actor_method(env_servlet, method, *args, **kwargs) @staticmethod @@ -281,25 +290,26 @@ def call_actor_method(actor: ray.actor.ActorHandle, method: str, *args, **kwargs return ray.get(getattr(actor, method).remote(*args, **kwargs)) - @staticmethod def get_env_servlet( + self, env_name: str, create: bool = False, raise_ex_if_not_found: bool = False, resources: Optional[Dict[str, Any]] = None, + use_env_servlet_cache: bool = True, **kwargs, ): # Need to import these here to avoid circular imports - from runhouse.globals import env_servlets from runhouse.servers.env_servlet import EnvServlet - if env_name in env_servlets: - return env_servlets[env_name] + if use_env_servlet_cache and env_name in self.env_servlet_cache: + return self.env_servlet_cache[env_name] # It may not have been cached, but does exist try: existing_actor = ray.get_actor(env_name, namespace="runhouse") - env_servlets[env_name] = existing_actor + if use_env_servlet_cache: + self.env_servlet_cache[env_name] = existing_actor return existing_actor except ValueError: # ValueError: Failed to look up actor with name ... @@ -339,8 +349,8 @@ def get_env_servlet( # Make sure env_servlet is actually initialized # ray.get(new_env_actor.register_activity.remote()) - - env_servlets[env_name] = new_env_actor + if use_env_servlet_cache: + self.env_servlet_cache[env_name] = new_env_actor return new_env_actor else: @@ -542,9 +552,8 @@ async def aremove_env_servlet_name(self, env_servlet_name: str): ############################################## # KV Store: Keys ############################################## - @staticmethod - async def akeys_for_env_servlet_name(env_servlet_name: str) -> List[Any]: - return await ObjStore.acall_env_servlet_method(env_servlet_name, "akeys_local") + async def akeys_for_env_servlet_name(self, env_servlet_name: str) -> List[Any]: + return await self.acall_env_servlet_method(env_servlet_name, "akeys_local") def keys_for_env_servlet_name(self, env_servlet_name: str) -> List[Any]: return sync_function(self.akeys_for_env_servlet_name)(env_servlet_name) @@ -567,11 +576,14 @@ def keys(self) -> List[Any]: ############################################## # KV Store: Put ############################################## - @staticmethod async def aput_for_env_servlet_name( - env_servlet_name: str, key: Any, data: Any, serialization: Optional[str] = None + self, + env_servlet_name: str, + key: Any, + data: Any, + serialization: Optional[str] = None, ): - return await ObjStore.acall_env_servlet_method( + return await self.acall_env_servlet_method( env_servlet_name, "aput_local", key, @@ -638,8 +650,8 @@ def put( ############################################## # KV Store: Get ############################################## - @staticmethod async def aget_from_env_servlet_name( + self, env_servlet_name: str, key: Any, default: Optional[Any] = None, @@ -647,7 +659,7 @@ async def aget_from_env_servlet_name( remote: bool = False, ): logger.info(f"Getting {key} from servlet {env_servlet_name}") - return await ObjStore.acall_env_servlet_method( + return await self.acall_env_servlet_method( env_servlet_name, "aget_local", key, @@ -656,15 +668,15 @@ async def aget_from_env_servlet_name( remote=remote, ) - @staticmethod def get_from_env_servlet_name( + self, env_servlet_name: str, key: Any, default: Optional[Any] = None, serialization: Optional[str] = None, remote: bool = False, ): - return sync_function(ObjStore.aget_from_env_servlet_name)( + return sync_function(self.aget_from_env_servlet_name)( env_servlet_name, key, default, serialization, remote ) @@ -758,9 +770,8 @@ def get( ############################################## # KV Store: Contains ############################################## - @staticmethod - async def acontains_for_env_servlet_name(env_servlet_name: str, key: Any): - return await ObjStore.acall_env_servlet_method( + async def acontains_for_env_servlet_name(self, env_servlet_name: str, key: Any): + return await self.acall_env_servlet_method( env_servlet_name, "acontains_local", key ) @@ -791,11 +802,14 @@ def contains(self, key: Any): ############################################## # KV Store: Pop ############################################## - @staticmethod async def apop_from_env_servlet_name( - env_servlet_name: str, key: Any, serialization: Optional[str] = "pickle", *args + self, + env_servlet_name: str, + key: Any, + serialization: Optional[str] = "pickle", + *args, ) -> Any: - return await ObjStore.acall_env_servlet_method( + return await self.acall_env_servlet_method( env_servlet_name, "apop_local", key, @@ -863,9 +877,8 @@ def pop(self, key: Any, serialization: Optional[str] = "pickle", *args) -> Any: ############################################## # KV Store: Delete ############################################## - @staticmethod - async def adelete_for_env_servlet_name(env_servlet_name: str, key: Any): - return await ObjStore.acall_env_servlet_method( + async def adelete_for_env_servlet_name(self, env_servlet_name: str, key: Any): + return await self.acall_env_servlet_method( env_servlet_name, "adelete_local", key ) @@ -873,18 +886,17 @@ async def adelete_local(self, key: Any): await self.apop_local(key) async def adelete_env_contents(self, env_name: Any): - from runhouse.globals import env_servlets # clear keys in the env servlet deleted_keys = await self.akeys_for_env_servlet_name(env_name) await self.aclear_for_env_servlet_name(env_name) # delete the env servlet actor and remove its references - if env_name in env_servlets: - actor = env_servlets[env_name] + if env_name in self.env_servlet_cache: + actor = self.env_servlet_cache[env_name] ray.kill(actor) - del env_servlets[env_name] + del self.env_servlet_cache[env_name] await self.aremove_env_servlet_name(env_name) return deleted_keys @@ -926,9 +938,8 @@ def delete(self, key: Union[Any, List[Any]]): ############################################## # KV Store: Clear ############################################## - @staticmethod - async def aclear_for_env_servlet_name(env_servlet_name: str): - return await ObjStore.acall_env_servlet_method(env_servlet_name, "aclear_local") + async def aclear_for_env_servlet_name(self, env_servlet_name: str): + return await self.acall_env_servlet_method(env_servlet_name, "aclear_local") async def aclear_local(self): if self.has_local_storage: @@ -951,11 +962,10 @@ def clear(self): ############################################## # KV Store: Rename ############################################## - @staticmethod async def arename_for_env_servlet_name( - env_servlet_name: str, old_key: Any, new_key: Any + self, env_servlet_name: str, old_key: Any, new_key: Any ): - return await ObjStore.acall_env_servlet_method( + return await self.acall_env_servlet_method( env_servlet_name, "arename_local", old_key, @@ -998,8 +1008,8 @@ def rename(self, old_key: Any, new_key: Any): ############################################## # KV Store: Call ############################################## - @staticmethod async def acall_for_env_servlet_name( + self, env_servlet_name: str, key: Any, method_name: str, @@ -1009,7 +1019,7 @@ async def acall_for_env_servlet_name( stream_logs: bool = False, remote: bool = False, ): - return await ObjStore.acall_env_servlet_method( + return await self.acall_env_servlet_method( env_servlet_name, "acall_local", key, @@ -1409,14 +1419,14 @@ async def aput_resource( else {} ) - _ = ObjStore.get_env_servlet( + _ = self.get_env_servlet( env_name=env_name, create=True, runtime_env=runtime_env, resources=resource_config.get("compute", None), ) - return await ObjStore.acall_env_servlet_method( + return await self.acall_env_servlet_method( env_name, "aput_resource_local", data=serialized_data, diff --git a/tests/test_servers/test_server_obj_store.py b/tests/test_servers/test_server_obj_store.py index ea4fab681..f80e675dd 100644 --- a/tests/test_servers/test_server_obj_store.py +++ b/tests/test_servers/test_server_obj_store.py @@ -1,6 +1,6 @@ import pytest -from runhouse.servers.obj_store import ObjStore, ObjStoreError +from runhouse.servers.obj_store import ObjStoreError from tests.utils import friend_account, get_ray_servlet_and_obj_store @@ -341,7 +341,9 @@ def test_delete_env_servlet(self, obj_store): # check that corresponding Ray actor is killed with pytest.raises(ObjStoreError): - ObjStore.get_env_servlet(env_name=env_to_delete, raise_ex_if_not_found=True) + obj_store.get_env_servlet( + env_name=env_to_delete, raise_ex_if_not_found=True + ) @pytest.mark.servertest diff --git a/tests/utils.py b/tests/utils.py index b04907628..c8396b678 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -18,7 +18,7 @@ def get_ray_servlet_and_obj_store(env_name): test_obj_store = ObjStore() test_obj_store.initialize(env_name, setup_ray=RaySetupOption.GET_OR_FAIL) - servlet = ObjStore.get_env_servlet( + servlet = test_obj_store.get_env_servlet( env_name=env_name, create=True, )