Skip to content

Commit

Permalink
Make env_servlets cache an instance field instead of global. (#736)
Browse files Browse the repository at this point in the history
We don't need this to be global, it's only used in the object store. 

In general, minimizing the amount of global stuff we have referencing important objects is better to avoid weird memory issues.
  • Loading branch information
rohinb2 committed Apr 16, 2024
1 parent f677f28 commit 39536fd
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 55 deletions.
1 change: 0 additions & 1 deletion runhouse/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,3 @@ def clean_up_ssh_connections():
# Note: this initalizes a dummy global object. The obj_store must
# be properly initialized by a servlet via initialize.
obj_store = ObjStore()
env_servlets = {}
12 changes: 7 additions & 5 deletions runhouse/servers/cluster_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
from typing import Any, Dict, List, Optional, Set, Union

from runhouse.globals import configs, ObjStore, rns_client
from runhouse.globals import configs, obj_store, rns_client
from runhouse.resources.hardware import load_cluster_config_from_file
from runhouse.rns.utils.api import ResourceAccess
from runhouse.servers.http.auth import AuthCache
Expand Down Expand Up @@ -76,10 +76,11 @@ async def aset_cluster_config(self, cluster_config: Dict[str, Any]):
# Propagate the changes to all other process's obj_stores
await asyncio.gather(
*[
ObjStore.acall_actor_method(
ObjStore.get_env_servlet(env_servlet_name),
obj_store.acall_env_servlet_method(
env_servlet_name,
"aset_cluster_config",
cluster_config,
use_env_servlet_cache=False,
)
for env_servlet_name in await self.aget_all_initialized_env_servlet_names()
]
Expand All @@ -98,11 +99,12 @@ async def aset_cluster_config_value(self, key: str, value: Any):
# Propagate the changes to all other process's obj_stores
await asyncio.gather(
*[
ObjStore.acall_actor_method(
ObjStore.get_env_servlet(env_servlet_name),
obj_store.acall_env_servlet_method(
env_servlet_name,
"aset_cluster_config_value",
key,
value,
use_env_servlet_cache=False,
)
for env_servlet_name in await self.aget_all_initialized_env_servlet_names()
]
Expand Down
5 changes: 2 additions & 3 deletions runhouse/servers/http/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
)
from runhouse.servers.obj_store import (
ClusterServletSetupOption,
ObjStore,
ObjStoreError,
RaySetupOption,
)
Expand Down Expand Up @@ -199,7 +198,7 @@ async def _add_username_to_span(request: Request, call_next):
# TODO: We aren't sure _exactly_ where this is or isn't used.
# There are a few spots where we do `env_name or "base"`, and
# this allows that base env to be pre-initialized.
_ = ObjStore.get_env_servlet(
_ = obj_store.get_env_servlet(
env_name="base",
create=True,
runtime_env=runtime_env,
Expand Down Expand Up @@ -679,7 +678,7 @@ async def get_keys(request: Request, env_name: Optional[str] = None):
if not env_name:
output = await obj_store.akeys()
else:
output = await ObjStore.akeys_for_env_servlet_name(env_name)
output = await obj_store.akeys_for_env_servlet_name(env_name)

# Expicitly tell the client not to attempt to deserialize the output
return Response(
Expand Down
96 changes: 53 additions & 43 deletions runhouse/servers/obj_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(self):
self.imported_modules = {}
self.installed_envs = {} # TODO: consider deleting it?
self._kv_store: Dict[Any, Any] = None
self.env_servlet_cache = {}

# Ray defaults to setting OMP_NUM_THREADS to 1, which unexpectedly limit parallelism in user programs.
# We delete it by default, but if we find that the user explicitly set it to another value, we respect that.
Expand Down Expand Up @@ -261,9 +262,17 @@ def initialize(
##############################################
# Generic helpers
##############################################
@staticmethod
async def acall_env_servlet_method(servlet_name: str, method: str, *args, **kwargs):
env_servlet = ObjStore.get_env_servlet(servlet_name)
async def acall_env_servlet_method(
self,
servlet_name: str,
method: str,
*args,
use_env_servlet_cache: bool = True,
**kwargs,
):
env_servlet = self.get_env_servlet(
servlet_name, use_env_servlet_cache=use_env_servlet_cache
)
return await ObjStore.acall_actor_method(env_servlet, method, *args, **kwargs)

@staticmethod
Expand All @@ -281,25 +290,26 @@ def call_actor_method(actor: ray.actor.ActorHandle, method: str, *args, **kwargs

return ray.get(getattr(actor, method).remote(*args, **kwargs))

@staticmethod
def get_env_servlet(
self,
env_name: str,
create: bool = False,
raise_ex_if_not_found: bool = False,
resources: Optional[Dict[str, Any]] = None,
use_env_servlet_cache: bool = True,
**kwargs,
):
# Need to import these here to avoid circular imports
from runhouse.globals import env_servlets
from runhouse.servers.env_servlet import EnvServlet

if env_name in env_servlets:
return env_servlets[env_name]
if use_env_servlet_cache and env_name in self.env_servlet_cache:
return self.env_servlet_cache[env_name]

# It may not have been cached, but does exist
try:
existing_actor = ray.get_actor(env_name, namespace="runhouse")
env_servlets[env_name] = existing_actor
if use_env_servlet_cache:
self.env_servlet_cache[env_name] = existing_actor
return existing_actor
except ValueError:
# ValueError: Failed to look up actor with name ...
Expand Down Expand Up @@ -339,8 +349,8 @@ def get_env_servlet(

# Make sure env_servlet is actually initialized
# ray.get(new_env_actor.register_activity.remote())

env_servlets[env_name] = new_env_actor
if use_env_servlet_cache:
self.env_servlet_cache[env_name] = new_env_actor
return new_env_actor

else:
Expand Down Expand Up @@ -542,9 +552,8 @@ async def aremove_env_servlet_name(self, env_servlet_name: str):
##############################################
# KV Store: Keys
##############################################
@staticmethod
async def akeys_for_env_servlet_name(env_servlet_name: str) -> List[Any]:
return await ObjStore.acall_env_servlet_method(env_servlet_name, "akeys_local")
async def akeys_for_env_servlet_name(self, env_servlet_name: str) -> List[Any]:
return await self.acall_env_servlet_method(env_servlet_name, "akeys_local")

def keys_for_env_servlet_name(self, env_servlet_name: str) -> List[Any]:
return sync_function(self.akeys_for_env_servlet_name)(env_servlet_name)
Expand All @@ -567,11 +576,14 @@ def keys(self) -> List[Any]:
##############################################
# KV Store: Put
##############################################
@staticmethod
async def aput_for_env_servlet_name(
env_servlet_name: str, key: Any, data: Any, serialization: Optional[str] = None
self,
env_servlet_name: str,
key: Any,
data: Any,
serialization: Optional[str] = None,
):
return await ObjStore.acall_env_servlet_method(
return await self.acall_env_servlet_method(
env_servlet_name,
"aput_local",
key,
Expand Down Expand Up @@ -638,16 +650,16 @@ def put(
##############################################
# KV Store: Get
##############################################
@staticmethod
async def aget_from_env_servlet_name(
self,
env_servlet_name: str,
key: Any,
default: Optional[Any] = None,
serialization: Optional[str] = None,
remote: bool = False,
):
logger.info(f"Getting {key} from servlet {env_servlet_name}")
return await ObjStore.acall_env_servlet_method(
return await self.acall_env_servlet_method(
env_servlet_name,
"aget_local",
key,
Expand All @@ -656,15 +668,15 @@ async def aget_from_env_servlet_name(
remote=remote,
)

@staticmethod
def get_from_env_servlet_name(
self,
env_servlet_name: str,
key: Any,
default: Optional[Any] = None,
serialization: Optional[str] = None,
remote: bool = False,
):
return sync_function(ObjStore.aget_from_env_servlet_name)(
return sync_function(self.aget_from_env_servlet_name)(
env_servlet_name, key, default, serialization, remote
)

Expand Down Expand Up @@ -758,9 +770,8 @@ def get(
##############################################
# KV Store: Contains
##############################################
@staticmethod
async def acontains_for_env_servlet_name(env_servlet_name: str, key: Any):
return await ObjStore.acall_env_servlet_method(
async def acontains_for_env_servlet_name(self, env_servlet_name: str, key: Any):
return await self.acall_env_servlet_method(
env_servlet_name, "acontains_local", key
)

Expand Down Expand Up @@ -791,11 +802,14 @@ def contains(self, key: Any):
##############################################
# KV Store: Pop
##############################################
@staticmethod
async def apop_from_env_servlet_name(
env_servlet_name: str, key: Any, serialization: Optional[str] = "pickle", *args
self,
env_servlet_name: str,
key: Any,
serialization: Optional[str] = "pickle",
*args,
) -> Any:
return await ObjStore.acall_env_servlet_method(
return await self.acall_env_servlet_method(
env_servlet_name,
"apop_local",
key,
Expand Down Expand Up @@ -863,28 +877,26 @@ def pop(self, key: Any, serialization: Optional[str] = "pickle", *args) -> Any:
##############################################
# KV Store: Delete
##############################################
@staticmethod
async def adelete_for_env_servlet_name(env_servlet_name: str, key: Any):
return await ObjStore.acall_env_servlet_method(
async def adelete_for_env_servlet_name(self, env_servlet_name: str, key: Any):
return await self.acall_env_servlet_method(
env_servlet_name, "adelete_local", key
)

async def adelete_local(self, key: Any):
await self.apop_local(key)

async def adelete_env_contents(self, env_name: Any):
from runhouse.globals import env_servlets

# clear keys in the env servlet
deleted_keys = await self.akeys_for_env_servlet_name(env_name)
await self.aclear_for_env_servlet_name(env_name)

# delete the env servlet actor and remove its references
if env_name in env_servlets:
actor = env_servlets[env_name]
if env_name in self.env_servlet_cache:
actor = self.env_servlet_cache[env_name]
ray.kill(actor)

del env_servlets[env_name]
del self.env_servlet_cache[env_name]
await self.aremove_env_servlet_name(env_name)

return deleted_keys
Expand Down Expand Up @@ -926,9 +938,8 @@ def delete(self, key: Union[Any, List[Any]]):
##############################################
# KV Store: Clear
##############################################
@staticmethod
async def aclear_for_env_servlet_name(env_servlet_name: str):
return await ObjStore.acall_env_servlet_method(env_servlet_name, "aclear_local")
async def aclear_for_env_servlet_name(self, env_servlet_name: str):
return await self.acall_env_servlet_method(env_servlet_name, "aclear_local")

async def aclear_local(self):
if self.has_local_storage:
Expand All @@ -951,11 +962,10 @@ def clear(self):
##############################################
# KV Store: Rename
##############################################
@staticmethod
async def arename_for_env_servlet_name(
env_servlet_name: str, old_key: Any, new_key: Any
self, env_servlet_name: str, old_key: Any, new_key: Any
):
return await ObjStore.acall_env_servlet_method(
return await self.acall_env_servlet_method(
env_servlet_name,
"arename_local",
old_key,
Expand Down Expand Up @@ -998,8 +1008,8 @@ def rename(self, old_key: Any, new_key: Any):
##############################################
# KV Store: Call
##############################################
@staticmethod
async def acall_for_env_servlet_name(
self,
env_servlet_name: str,
key: Any,
method_name: str,
Expand All @@ -1009,7 +1019,7 @@ async def acall_for_env_servlet_name(
stream_logs: bool = False,
remote: bool = False,
):
return await ObjStore.acall_env_servlet_method(
return await self.acall_env_servlet_method(
env_servlet_name,
"acall_local",
key,
Expand Down Expand Up @@ -1409,14 +1419,14 @@ async def aput_resource(
else {}
)

_ = ObjStore.get_env_servlet(
_ = self.get_env_servlet(
env_name=env_name,
create=True,
runtime_env=runtime_env,
resources=resource_config.get("compute", None),
)

return await ObjStore.acall_env_servlet_method(
return await self.acall_env_servlet_method(
env_name,
"aput_resource_local",
data=serialized_data,
Expand Down
6 changes: 4 additions & 2 deletions tests/test_servers/test_server_obj_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from runhouse.servers.obj_store import ObjStore, ObjStoreError
from runhouse.servers.obj_store import ObjStoreError

from tests.utils import friend_account, get_ray_servlet_and_obj_store

Expand Down Expand Up @@ -341,7 +341,9 @@ def test_delete_env_servlet(self, obj_store):

# check that corresponding Ray actor is killed
with pytest.raises(ObjStoreError):
ObjStore.get_env_servlet(env_name=env_to_delete, raise_ex_if_not_found=True)
obj_store.get_env_servlet(
env_name=env_to_delete, raise_ex_if_not_found=True
)


@pytest.mark.servertest
Expand Down
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def get_ray_servlet_and_obj_store(env_name):
test_obj_store = ObjStore()
test_obj_store.initialize(env_name, setup_ray=RaySetupOption.GET_OR_FAIL)

servlet = ObjStore.get_env_servlet(
servlet = test_obj_store.get_env_servlet(
env_name=env_name,
create=True,
)
Expand Down

0 comments on commit 39536fd

Please sign in to comment.