From e555c555a0666a6b7213337e9f8e81300407df16 Mon Sep 17 00:00:00 2001 From: Rohin Bhasin Date: Mon, 15 Apr 2024 14:39:57 -0400 Subject: [PATCH] Clean up an `EnvServlet` from global Runhouse daemon state if it dies unexpectedly. --- runhouse/servers/cluster_servlet.py | 10 +++++++++- runhouse/servers/obj_store.py | 25 ++++++++++++++++--------- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/runhouse/servers/cluster_servlet.py b/runhouse/servers/cluster_servlet.py index 4ab808484..4b2a1490a 100644 --- a/runhouse/servers/cluster_servlet.py +++ b/runhouse/servers/cluster_servlet.py @@ -199,5 +199,13 @@ async def aclear_key_to_env_servlet_name_dict(self): ############################################## # Remove Env Servlet ############################################## - async def aremove_env_servlet_name(self, env_servlet_name: str): + async def aclear_all_references_to_env_servlet_name(self, env_servlet_name: str): self._initialized_env_servlet_names.remove(env_servlet_name) + deleted_keys = [ + key + for key, env in self._key_to_env_servlet_name.items() + if env == env_servlet_name + ] + for key in deleted_keys: + self._key_to_env_servlet_name.pop(key) + return deleted_keys diff --git a/runhouse/servers/obj_store.py b/runhouse/servers/obj_store.py index acbb994c1..672aa9d35 100644 --- a/runhouse/servers/obj_store.py +++ b/runhouse/servers/obj_store.py @@ -273,7 +273,17 @@ async def acall_env_servlet_method( env_servlet = self.get_env_servlet( servlet_name, use_env_servlet_cache=use_env_servlet_cache ) - return await ObjStore.acall_actor_method(env_servlet, method, *args, **kwargs) + try: + return await ObjStore.acall_actor_method( + env_servlet, method, *args, **kwargs + ) + except (ray.exceptions.RayActorError, ray.exceptions.OutOfMemoryError) as e: + if isinstance( + e, ray.exceptions.OutOfMemoryError + ) or "died unexpectedly before finishing this task" in str(e): + await self.adelete_env_contents(servlet_name) + + raise e @staticmethod async def acall_actor_method( @@ -544,9 +554,11 @@ async def _apop_env_servlet_name_for_key(self, key: Any, *args) -> str: ############################################## # Remove Env Servlet ############################################## - async def aremove_env_servlet_name(self, env_servlet_name: str): + async def aclear_all_references_to_env_servlet_name(self, env_servlet_name: str): return await self.acall_actor_method( - self.cluster_servlet, "aremove_env_servlet_name", env_servlet_name + self.cluster_servlet, + "aclear_all_references_to_env_servlet_name", + env_servlet_name, ) ############################################## @@ -887,18 +899,13 @@ async def adelete_local(self, key: Any): async def adelete_env_contents(self, env_name: Any): - # clear keys in the env servlet - deleted_keys = await self.akeys_for_env_servlet_name(env_name) - await self.aclear_for_env_servlet_name(env_name) - # delete the env servlet actor and remove its references if env_name in self.env_servlet_cache: actor = self.env_servlet_cache[env_name] ray.kill(actor) - del self.env_servlet_cache[env_name] - await self.aremove_env_servlet_name(env_name) + deleted_keys = await self.aclear_all_references_to_env_servlet_name(env_name) return deleted_keys def delete_env_contents(self, env_name: Any):