Skip to content

Commit

Permalink
Change ObjStore methods to call by env servlet name instead of actor.
Browse files Browse the repository at this point in the history
  • Loading branch information
rohinb2 committed Apr 15, 2024
1 parent 02715f9 commit 575e711
Showing 1 changed file with 25 additions and 24 deletions.
49 changes: 25 additions & 24 deletions runhouse/servers/obj_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,11 @@ def initialize(
##############################################
# Generic helpers
##############################################
@staticmethod
async def acall_env_servlet_method(servlet_name: str, method: str, *args, **kwargs):
env_servlet = ObjStore.get_env_servlet(servlet_name)
return await ObjStore.acall_actor_method(env_servlet, method, *args, **kwargs)

@staticmethod
async def acall_actor_method(
actor: ray.actor.ActorHandle, method: str, *args, **kwargs
Expand Down Expand Up @@ -539,9 +544,7 @@ async def aremove_env_servlet_name(self, env_servlet_name: str):
##############################################
@staticmethod
async def akeys_for_env_servlet_name(env_servlet_name: str) -> List[Any]:
return await ObjStore.acall_actor_method(
ObjStore.get_env_servlet(env_servlet_name), "akeys_local"
)
return await ObjStore.acall_env_servlet_method(env_servlet_name, "akeys_local")

def keys_for_env_servlet_name(self, env_servlet_name: str) -> List[Any]:
return sync_function(self.akeys_for_env_servlet_name)(env_servlet_name)
Expand All @@ -568,8 +571,8 @@ def keys(self) -> List[Any]:
async def aput_for_env_servlet_name(
env_servlet_name: str, key: Any, data: Any, serialization: Optional[str] = None
):
return await ObjStore.acall_actor_method(
ObjStore.get_env_servlet(env_servlet_name),
return await ObjStore.acall_env_servlet_method(
env_servlet_name,
"aput_local",
key,
data=data,
Expand Down Expand Up @@ -644,8 +647,8 @@ async def aget_from_env_servlet_name(
remote: bool = False,
):
logger.info(f"Getting {key} from servlet {env_servlet_name}")
return await ObjStore.acall_actor_method(
ObjStore.get_env_servlet(env_servlet_name),
return await ObjStore.acall_env_servlet_method(
env_servlet_name,
"aget_local",
key,
default=default,
Expand Down Expand Up @@ -757,8 +760,8 @@ def get(
##############################################
@staticmethod
async def acontains_for_env_servlet_name(env_servlet_name: str, key: Any):
return await ObjStore.acall_actor_method(
ObjStore.get_env_servlet(env_servlet_name), "acontains_local", key
return await ObjStore.acall_env_servlet_method(
env_servlet_name, "acontains_local", key
)

def contains_local(self, key: Any):
Expand Down Expand Up @@ -792,8 +795,8 @@ def contains(self, key: Any):
async def apop_from_env_servlet_name(
env_servlet_name: str, key: Any, serialization: Optional[str] = "pickle", *args
) -> Any:
return await ObjStore.acall_actor_method(
ObjStore.get_env_servlet(env_servlet_name),
return await ObjStore.acall_env_servlet_method(
env_servlet_name,
"apop_local",
key,
serialization,
Expand Down Expand Up @@ -862,8 +865,8 @@ def pop(self, key: Any, serialization: Optional[str] = "pickle", *args) -> Any:
##############################################
@staticmethod
async def adelete_for_env_servlet_name(env_servlet_name: str, key: Any):
return await ObjStore.acall_actor_method(
ObjStore.get_env_servlet(env_servlet_name), "adelete_local", key
return await ObjStore.acall_env_servlet_method(
env_servlet_name, "adelete_local", key
)

async def adelete_local(self, key: Any):
Expand Down Expand Up @@ -925,9 +928,7 @@ def delete(self, key: Union[Any, List[Any]]):
##############################################
@staticmethod
async def aclear_for_env_servlet_name(env_servlet_name: str):
return await ObjStore.acall_actor_method(
ObjStore.get_env_servlet(env_servlet_name), "aclear_local"
)
return await ObjStore.acall_env_servlet_method(env_servlet_name, "aclear_local")

async def aclear_local(self):
if self.has_local_storage:
Expand All @@ -953,8 +954,8 @@ def clear(self):
async def arename_for_env_servlet_name(
env_servlet_name: str, old_key: Any, new_key: Any
):
return await ObjStore.acall_actor_method(
ObjStore.get_env_servlet(env_servlet_name),
return await ObjStore.acall_env_servlet_method(
env_servlet_name,
"arename_local",
old_key,
new_key,
Expand Down Expand Up @@ -1007,8 +1008,8 @@ async def acall_for_env_servlet_name(
stream_logs: bool = False,
remote: bool = False,
):
return await ObjStore.acall_actor_method(
ObjStore.get_env_servlet(env_servlet_name),
return await ObjStore.acall_env_servlet_method(
env_servlet_name,
"acall_local",
key,
method_name=method_name,
Expand Down Expand Up @@ -1414,8 +1415,8 @@ async def aput_resource(
resources=resource_config.get("compute", None),
)

return await ObjStore.acall_actor_method(
ObjStore.get_env_servlet(env_name),
return await ObjStore.acall_env_servlet_method(
env_name,
"aput_resource_local",
data=serialized_data,
serialization=serialization,
Expand Down Expand Up @@ -1510,8 +1511,8 @@ async def astatus(self):
config_cluster.pop("creds", None)
cluster_servlets = {}
for env in await self.aget_all_initialized_env_servlet_names():
resources_in_env_modified = await self.acall_actor_method(
self.get_env_servlet(env), "astatus_local"
resources_in_env_modified = await self.acall_env_servlet_method(
env, "astatus_local"
)
cluster_servlets[env] = resources_in_env_modified
config_cluster["envs"] = cluster_servlets
Expand Down

0 comments on commit 575e711

Please sign in to comment.