Skip to content

Commit

Permalink
Make env_servlets cache an instance field instead of global.
Browse files Browse the repository at this point in the history
  • Loading branch information
rohinb2 committed Apr 16, 2024
1 parent 908fa1a commit 109a475
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 55 deletions.
1 change: 0 additions & 1 deletion runhouse/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,3 @@ def clean_up_ssh_connections():
# Note: this initalizes a dummy global object. The obj_store must
# be properly initialized by a servlet via initialize.
obj_store = ObjStore()
env_servlets = {}
12 changes: 7 additions & 5 deletions runhouse/servers/cluster_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
from typing import Any, Dict, List, Optional, Set, Union

from runhouse.globals import configs, ObjStore, rns_client
from runhouse.globals import configs, obj_store, rns_client
from runhouse.resources.hardware import load_cluster_config_from_file
from runhouse.rns.utils.api import ResourceAccess
from runhouse.servers.http.auth import AuthCache
Expand Down Expand Up @@ -76,10 +76,11 @@ async def aset_cluster_config(self, cluster_config: Dict[str, Any]):
# Propagate the changes to all other process's obj_stores
await asyncio.gather(
*[
ObjStore.acall_actor_method(
ObjStore.get_env_servlet(env_servlet_name),
obj_store.acall_env_servlet_method(
env_servlet_name,
"aset_cluster_config",
cluster_config,
use_env_servlet_cache=False,
)
for env_servlet_name in await self.aget_all_initialized_env_servlet_names()
]
Expand All @@ -98,11 +99,12 @@ async def aset_cluster_config_value(self, key: str, value: Any):
# Propagate the changes to all other process's obj_stores
await asyncio.gather(
*[
ObjStore.acall_actor_method(
ObjStore.get_env_servlet(env_servlet_name),
obj_store.acall_env_servlet_method(
env_servlet_name,
"aset_cluster_config_value",
key,
value,
use_env_servlet_cache=False,
)
for env_servlet_name in await self.aget_all_initialized_env_servlet_names()
]
Expand Down
5 changes: 2 additions & 3 deletions runhouse/servers/http/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
)
from runhouse.servers.obj_store import (
ClusterServletSetupOption,
ObjStore,
ObjStoreError,
RaySetupOption,
)
Expand Down Expand Up @@ -199,7 +198,7 @@ async def _add_username_to_span(request: Request, call_next):
# TODO: We aren't sure _exactly_ where this is or isn't used.
# There are a few spots where we do `env_name or "base"`, and
# this allows that base env to be pre-initialized.
_ = ObjStore.get_env_servlet(
_ = obj_store.get_env_servlet(
env_name="base",
create=True,
runtime_env=runtime_env,
Expand Down Expand Up @@ -679,7 +678,7 @@ async def get_keys(request: Request, env_name: Optional[str] = None):
if not env_name:
output = await obj_store.akeys()
else:
output = await ObjStore.akeys_for_env_servlet_name(env_name)
output = await obj_store.akeys_for_env_servlet_name(env_name)

# Expicitly tell the client not to attempt to deserialize the output
return Response(
Expand Down
94 changes: 51 additions & 43 deletions runhouse/servers/obj_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(self):
self.imported_modules = {}
self.installed_envs = {} # TODO: consider deleting it?
self._kv_store: Dict[Any, Any] = None
self.env_servlet_cache = {}

# Ray defaults to setting OMP_NUM_THREADS to 1, which unexpectedly limit parallelism in user programs.
# We delete it by default, but if we find that the user explicitly set it to another value, we respect that.
Expand Down Expand Up @@ -261,9 +262,15 @@ def initialize(
##############################################
# Generic helpers
##############################################
@staticmethod
async def acall_env_servlet_method(servlet_name: str, method: str, *args, **kwargs):
env_servlet = ObjStore.get_env_servlet(servlet_name)
async def acall_env_servlet_method(
self,
servlet_name: str,
method: str,
*args,
use_env_servlet_cache: bool = True,
**kwargs,
):
env_servlet = self.get_env_servlet(servlet_name, use_env_servlet_cache)
return await ObjStore.acall_actor_method(env_servlet, method, *args, **kwargs)

@staticmethod
Expand All @@ -281,25 +288,26 @@ def call_actor_method(actor: ray.actor.ActorHandle, method: str, *args, **kwargs

return ray.get(getattr(actor, method).remote(*args, **kwargs))

@staticmethod
def get_env_servlet(
self,
env_name: str,
create: bool = False,
raise_ex_if_not_found: bool = False,
resources: Optional[Dict[str, Any]] = None,
use_env_servlet_cache: bool = True,
**kwargs,
):
# Need to import these here to avoid circular imports
from runhouse.globals import env_servlets
from runhouse.servers.env_servlet import EnvServlet

if env_name in env_servlets:
return env_servlets[env_name]
if use_env_servlet_cache and env_name in self.env_servlet_cache:
return self.env_servlet_cache[env_name]

# It may not have been cached, but does exist
try:
existing_actor = ray.get_actor(env_name, namespace="runhouse")
env_servlets[env_name] = existing_actor
if use_env_servlet_cache:
self.env_servlet_cache[env_name] = existing_actor
return existing_actor
except ValueError:
# ValueError: Failed to look up actor with name ...
Expand Down Expand Up @@ -339,8 +347,8 @@ def get_env_servlet(

# Make sure env_servlet is actually initialized
# ray.get(new_env_actor.register_activity.remote())

env_servlets[env_name] = new_env_actor
if use_env_servlet_cache:
self.env_servlet_cache[env_name] = new_env_actor
return new_env_actor

else:
Expand Down Expand Up @@ -542,9 +550,8 @@ async def aremove_env_servlet_name(self, env_servlet_name: str):
##############################################
# KV Store: Keys
##############################################
@staticmethod
async def akeys_for_env_servlet_name(env_servlet_name: str) -> List[Any]:
return await ObjStore.acall_env_servlet_method(env_servlet_name, "akeys_local")
async def akeys_for_env_servlet_name(self, env_servlet_name: str) -> List[Any]:
return await self.acall_env_servlet_method(env_servlet_name, "akeys_local")

def keys_for_env_servlet_name(self, env_servlet_name: str) -> List[Any]:
return sync_function(self.akeys_for_env_servlet_name)(env_servlet_name)
Expand All @@ -567,11 +574,14 @@ def keys(self) -> List[Any]:
##############################################
# KV Store: Put
##############################################
@staticmethod
async def aput_for_env_servlet_name(
env_servlet_name: str, key: Any, data: Any, serialization: Optional[str] = None
self,
env_servlet_name: str,
key: Any,
data: Any,
serialization: Optional[str] = None,
):
return await ObjStore.acall_env_servlet_method(
return await self.acall_env_servlet_method(
env_servlet_name,
"aput_local",
key,
Expand Down Expand Up @@ -638,16 +648,16 @@ def put(
##############################################
# KV Store: Get
##############################################
@staticmethod
async def aget_from_env_servlet_name(
self,
env_servlet_name: str,
key: Any,
default: Optional[Any] = None,
serialization: Optional[str] = None,
remote: bool = False,
):
logger.info(f"Getting {key} from servlet {env_servlet_name}")
return await ObjStore.acall_env_servlet_method(
return await self.acall_env_servlet_method(
env_servlet_name,
"aget_local",
key,
Expand All @@ -656,15 +666,15 @@ async def aget_from_env_servlet_name(
remote=remote,
)

@staticmethod
def get_from_env_servlet_name(
self,
env_servlet_name: str,
key: Any,
default: Optional[Any] = None,
serialization: Optional[str] = None,
remote: bool = False,
):
return sync_function(ObjStore.aget_from_env_servlet_name)(
return sync_function(self.aget_from_env_servlet_name)(
env_servlet_name, key, default, serialization, remote
)

Expand Down Expand Up @@ -758,9 +768,8 @@ def get(
##############################################
# KV Store: Contains
##############################################
@staticmethod
async def acontains_for_env_servlet_name(env_servlet_name: str, key: Any):
return await ObjStore.acall_env_servlet_method(
async def acontains_for_env_servlet_name(self, env_servlet_name: str, key: Any):
return await self.acall_env_servlet_method(
env_servlet_name, "acontains_local", key
)

Expand Down Expand Up @@ -791,11 +800,14 @@ def contains(self, key: Any):
##############################################
# KV Store: Pop
##############################################
@staticmethod
async def apop_from_env_servlet_name(
env_servlet_name: str, key: Any, serialization: Optional[str] = "pickle", *args
self,
env_servlet_name: str,
key: Any,
serialization: Optional[str] = "pickle",
*args,
) -> Any:
return await ObjStore.acall_env_servlet_method(
return await self.acall_env_servlet_method(
env_servlet_name,
"apop_local",
key,
Expand Down Expand Up @@ -863,28 +875,26 @@ def pop(self, key: Any, serialization: Optional[str] = "pickle", *args) -> Any:
##############################################
# KV Store: Delete
##############################################
@staticmethod
async def adelete_for_env_servlet_name(env_servlet_name: str, key: Any):
return await ObjStore.acall_env_servlet_method(
async def adelete_for_env_servlet_name(self, env_servlet_name: str, key: Any):
return await self.acall_env_servlet_method(
env_servlet_name, "adelete_local", key
)

async def adelete_local(self, key: Any):
await self.apop_local(key)

async def adelete_env_contents(self, env_name: Any):
from runhouse.globals import env_servlets

# clear keys in the env servlet
deleted_keys = await self.akeys_for_env_servlet_name(env_name)
await self.aclear_for_env_servlet_name(env_name)

# delete the env servlet actor and remove its references
if env_name in env_servlets:
actor = env_servlets[env_name]
if env_name in self.env_servlet_cache:
actor = self.env_servlet_cache[env_name]
ray.kill(actor)

del env_servlets[env_name]
del self.env_servlet_cache[env_name]
await self.aremove_env_servlet_name(env_name)

return deleted_keys
Expand Down Expand Up @@ -926,9 +936,8 @@ def delete(self, key: Union[Any, List[Any]]):
##############################################
# KV Store: Clear
##############################################
@staticmethod
async def aclear_for_env_servlet_name(env_servlet_name: str):
return await ObjStore.acall_env_servlet_method(env_servlet_name, "aclear_local")
async def aclear_for_env_servlet_name(self, env_servlet_name: str):
return await self.acall_env_servlet_method(env_servlet_name, "aclear_local")

async def aclear_local(self):
if self.has_local_storage:
Expand All @@ -951,11 +960,10 @@ def clear(self):
##############################################
# KV Store: Rename
##############################################
@staticmethod
async def arename_for_env_servlet_name(
env_servlet_name: str, old_key: Any, new_key: Any
self, env_servlet_name: str, old_key: Any, new_key: Any
):
return await ObjStore.acall_env_servlet_method(
return await self.acall_env_servlet_method(
env_servlet_name,
"arename_local",
old_key,
Expand Down Expand Up @@ -998,8 +1006,8 @@ def rename(self, old_key: Any, new_key: Any):
##############################################
# KV Store: Call
##############################################
@staticmethod
async def acall_for_env_servlet_name(
self,
env_servlet_name: str,
key: Any,
method_name: str,
Expand All @@ -1009,7 +1017,7 @@ async def acall_for_env_servlet_name(
stream_logs: bool = False,
remote: bool = False,
):
return await ObjStore.acall_env_servlet_method(
return await self.acall_env_servlet_method(
env_servlet_name,
"acall_local",
key,
Expand Down Expand Up @@ -1409,14 +1417,14 @@ async def aput_resource(
else {}
)

_ = ObjStore.get_env_servlet(
_ = self.get_env_servlet(
env_name=env_name,
create=True,
runtime_env=runtime_env,
resources=resource_config.get("compute", None),
)

return await ObjStore.acall_env_servlet_method(
return await self.acall_env_servlet_method(
env_name,
"aput_resource_local",
data=serialized_data,
Expand Down
6 changes: 4 additions & 2 deletions tests/test_servers/test_server_obj_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from runhouse.servers.obj_store import ObjStore, ObjStoreError
from runhouse.servers.obj_store import ObjStoreError

from tests.utils import friend_account, get_ray_servlet_and_obj_store

Expand Down Expand Up @@ -341,7 +341,9 @@ def test_delete_env_servlet(self, obj_store):

# check that corresponding Ray actor is killed
with pytest.raises(ObjStoreError):
ObjStore.get_env_servlet(env_name=env_to_delete, raise_ex_if_not_found=True)
obj_store.get_env_servlet(
env_name=env_to_delete, raise_ex_if_not_found=True
)


@pytest.mark.servertest
Expand Down
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def get_ray_servlet_and_obj_store(env_name):
test_obj_store = ObjStore()
test_obj_store.initialize(env_name, setup_ray=RaySetupOption.GET_OR_FAIL)

servlet = ObjStore.get_env_servlet(
servlet = test_obj_store.get_env_servlet(
env_name=env_name,
create=True,
)
Expand Down

0 comments on commit 109a475

Please sign in to comment.