diff --git a/runhouse/servers/cluster_servlet.py b/runhouse/servers/cluster_servlet.py index 4ab808484..4b2a1490a 100644 --- a/runhouse/servers/cluster_servlet.py +++ b/runhouse/servers/cluster_servlet.py @@ -199,5 +199,13 @@ async def aclear_key_to_env_servlet_name_dict(self): ############################################## # Remove Env Servlet ############################################## - async def aremove_env_servlet_name(self, env_servlet_name: str): + async def aclear_all_references_to_env_servlet_name(self, env_servlet_name: str): self._initialized_env_servlet_names.remove(env_servlet_name) + deleted_keys = [ + key + for key, env in self._key_to_env_servlet_name.items() + if env == env_servlet_name + ] + for key in deleted_keys: + self._key_to_env_servlet_name.pop(key) + return deleted_keys diff --git a/runhouse/servers/obj_store.py b/runhouse/servers/obj_store.py index 0690ec49b..e130bf583 100644 --- a/runhouse/servers/obj_store.py +++ b/runhouse/servers/obj_store.py @@ -263,15 +263,20 @@ def initialize( # Generic helpers ############################################## async def acall_env_servlet_method( - self, - servlet_name: str, - method: str, - *args, - use_env_servlet_cache: bool = True, - **kwargs, + self, servlet_name: str, method: str, *args, **kwargs ): - env_servlet = self.get_env_servlet(servlet_name, use_env_servlet_cache) - return await ObjStore.acall_actor_method(env_servlet, method, *args, **kwargs) + env_servlet = self.get_env_servlet(servlet_name) + try: + return await ObjStore.acall_actor_method( + env_servlet, method, *args, **kwargs + ) + except (ray.exceptions.RayActorError, ray.exceptions.OutOfMemoryError) as e: + if isinstance( + e, ray.exceptions.OutOfMemoryError + ) or "died unexpectedly before finishing this task" in str(e): + await self.adelete_env_contents(servlet_name) + + raise e @staticmethod async def acall_actor_method( @@ -542,9 +547,11 @@ async def _apop_env_servlet_name_for_key(self, key: Any, *args) -> str: ############################################## # Remove Env Servlet ############################################## - async def aremove_env_servlet_name(self, env_servlet_name: str): + async def aclear_all_references_to_env_servlet_name(self, env_servlet_name: str): return await self.acall_actor_method( - self.cluster_servlet, "aremove_env_servlet_name", env_servlet_name + self.cluster_servlet, + "aclear_all_references_to_env_servlet_name", + env_servlet_name, ) ############################################## @@ -885,18 +892,13 @@ async def adelete_local(self, key: Any): async def adelete_env_contents(self, env_name: Any): - # clear keys in the env servlet - deleted_keys = await self.akeys_for_env_servlet_name(env_name) - await self.aclear_for_env_servlet_name(env_name) - # delete the env servlet actor and remove its references if env_name in self.env_servlet_cache: actor = self.env_servlet_cache[env_name] ray.kill(actor) - del self.env_servlet_cache[env_name] - await self.aremove_env_servlet_name(env_name) + deleted_keys = await self.aclear_all_references_to_env_servlet_name(env_name) return deleted_keys def delete_env_contents(self, env_name: Any):