diff --git a/runhouse/servers/obj_store.py b/runhouse/servers/obj_store.py index f0cb81ef2..a2fe3ff6d 100644 --- a/runhouse/servers/obj_store.py +++ b/runhouse/servers/obj_store.py @@ -261,6 +261,11 @@ def initialize( ############################################## # Generic helpers ############################################## + @staticmethod + async def acall_env_servlet_method(servlet_name: str, method: str, *args, **kwargs): + env_servlet = ObjStore.get_env_servlet(servlet_name) + return await ObjStore.acall_actor_method(env_servlet, method, *args, **kwargs) + @staticmethod async def acall_actor_method( actor: ray.actor.ActorHandle, method: str, *args, **kwargs @@ -539,9 +544,7 @@ async def aremove_env_servlet_name(self, env_servlet_name: str): ############################################## @staticmethod async def akeys_for_env_servlet_name(env_servlet_name: str) -> List[Any]: - return await ObjStore.acall_actor_method( - ObjStore.get_env_servlet(env_servlet_name), "akeys_local" - ) + return await ObjStore.acall_env_servlet_method(env_servlet_name, "akeys_local") def keys_for_env_servlet_name(self, env_servlet_name: str) -> List[Any]: return sync_function(self.akeys_for_env_servlet_name)(env_servlet_name) @@ -568,8 +571,8 @@ def keys(self) -> List[Any]: async def aput_for_env_servlet_name( env_servlet_name: str, key: Any, data: Any, serialization: Optional[str] = None ): - return await ObjStore.acall_actor_method( - ObjStore.get_env_servlet(env_servlet_name), + return await ObjStore.acall_env_servlet_method( + env_servlet_name, "aput_local", key, data=data, @@ -644,8 +647,8 @@ async def aget_from_env_servlet_name( remote: bool = False, ): logger.info(f"Getting {key} from servlet {env_servlet_name}") - return await ObjStore.acall_actor_method( - ObjStore.get_env_servlet(env_servlet_name), + return await ObjStore.acall_env_servlet_method( + env_servlet_name, "aget_local", key, default=default, @@ -757,8 +760,8 @@ def get( ############################################## @staticmethod async def acontains_for_env_servlet_name(env_servlet_name: str, key: Any): - return await ObjStore.acall_actor_method( - ObjStore.get_env_servlet(env_servlet_name), "acontains_local", key + return await ObjStore.acall_env_servlet_method( + env_servlet_name, "acontains_local", key ) def contains_local(self, key: Any): @@ -792,8 +795,8 @@ def contains(self, key: Any): async def apop_from_env_servlet_name( env_servlet_name: str, key: Any, serialization: Optional[str] = "pickle", *args ) -> Any: - return await ObjStore.acall_actor_method( - ObjStore.get_env_servlet(env_servlet_name), + return await ObjStore.acall_env_servlet_method( + env_servlet_name, "apop_local", key, serialization, @@ -862,8 +865,8 @@ def pop(self, key: Any, serialization: Optional[str] = "pickle", *args) -> Any: ############################################## @staticmethod async def adelete_for_env_servlet_name(env_servlet_name: str, key: Any): - return await ObjStore.acall_actor_method( - ObjStore.get_env_servlet(env_servlet_name), "adelete_local", key + return await ObjStore.acall_env_servlet_method( + env_servlet_name, "adelete_local", key ) async def adelete_local(self, key: Any): @@ -925,9 +928,7 @@ def delete(self, key: Union[Any, List[Any]]): ############################################## @staticmethod async def aclear_for_env_servlet_name(env_servlet_name: str): - return await ObjStore.acall_actor_method( - ObjStore.get_env_servlet(env_servlet_name), "aclear_local" - ) + return await ObjStore.acall_env_servlet_method(env_servlet_name, "aclear_local") async def aclear_local(self): if self.has_local_storage: @@ -953,8 +954,8 @@ def clear(self): async def arename_for_env_servlet_name( env_servlet_name: str, old_key: Any, new_key: Any ): - return await ObjStore.acall_actor_method( - ObjStore.get_env_servlet(env_servlet_name), + return await ObjStore.acall_env_servlet_method( + env_servlet_name, "arename_local", old_key, new_key, @@ -1007,8 +1008,8 @@ async def acall_for_env_servlet_name( stream_logs: bool = False, remote: bool = False, ): - return await ObjStore.acall_actor_method( - ObjStore.get_env_servlet(env_servlet_name), + return await ObjStore.acall_env_servlet_method( + env_servlet_name, "acall_local", key, method_name=method_name, @@ -1414,8 +1415,8 @@ async def aput_resource( resources=resource_config.get("compute", None), ) - return await ObjStore.acall_actor_method( - ObjStore.get_env_servlet(env_name), + return await ObjStore.acall_env_servlet_method( + env_name, "aput_resource_local", data=serialized_data, serialization=serialization, @@ -1510,8 +1511,8 @@ async def astatus(self): config_cluster.pop("creds", None) cluster_servlets = {} for env in await self.aget_all_initialized_env_servlet_names(): - resources_in_env_modified = await self.acall_actor_method( - self.get_env_servlet(env), "astatus_local" + resources_in_env_modified = await self.acall_env_servlet_method( + env, "astatus_local" ) cluster_servlets[env] = resources_in_env_modified config_cluster["envs"] = cluster_servlets