Skip to content

Commit

Permalink
Properly run commands in env servlet (#603)
Browse files Browse the repository at this point in the history
  • Loading branch information
carolineechen committed Mar 19, 2024
1 parent d3b2010 commit 1907551
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 57 deletions.
2 changes: 1 addition & 1 deletion runhouse/resources/envs/conda_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def install(self, force=False):
pkg._install(self)

return (
self.run([f"{self._activate_cmd} && {self.setup_cmds.join(' && ')}"])
self._run_command([f"{self.setup_cmds.join(' && ')}"])
if self.setup_cmds
else None
)
Expand Down
26 changes: 11 additions & 15 deletions runhouse/resources/envs/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,23 +160,19 @@ def install(self, force=False):

logger.debug(f"Installing package: {str(pkg)}")
pkg._install(self)
return self.run(self.setup_cmds) if self.setup_cmds else None
if self.setup_cmds:
for cmd in self.setup_cmds:
self._run_command(cmd)

def run(self, cmds: Union[List[str], str]):
def _run_command(self, command: str, **kwargs):
"""Run command locally inside the environment"""
ret_codes = []
for cmd in cmds:
if self._run_cmd:
cmd = f"{self._run_cmd} {cmd}"
logging.info(f"Running: {cmd}")
use_shell = any(shell_feat in cmd for shell_feat in [">", "|", "&&", "||"])
if use_shell:
# Example: "echo '<TOKEN>' > ~/.rh/config.yaml"
retcode = run_with_logs(cmd, shell=True)
else:
retcode = run_with_logs(cmd, shell=False)
ret_codes.append(retcode)
return ret_codes
if self._run_cmd:
command = f"{self._run_cmd} {command}"
logging.info(f"Running command in {self.name}: {command}")
use_shell = any(
shell_feat in command for shell_feat in [">", "|", "&&", "||", "$"]
)
return run_with_logs(command, shell=use_shell, **kwargs)

def to(
self, system: Union[str, Cluster], path=None, mount=False, force_install=False
Expand Down
61 changes: 30 additions & 31 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,6 @@ def creds_values(self) -> Dict:

return self._creds.values

def _get_env_activate_cmd(self, env=None):
if env:
from runhouse.resources.envs import _get_env_from

return _get_env_from(env)._activate_cmd
return None

def save_config_to_cluster(self, node: str = None):
config = self.config(condensed=False)
json_config = f"{json.dumps(config)}"
Expand Down Expand Up @@ -352,10 +345,10 @@ def _sync_runhouse_to_cluster(self, _install_url=None, env=None):
_install_url = f"runhouse=={runhouse.__version__}"
rh_install_cmd = f"python3 -m pip install {_install_url}"

install_cmd = f"{env._run_cmd} {rh_install_cmd}" if env else rh_install_cmd

for node in self.ips:
status_codes = self.run([install_cmd], node=node, stream_logs=True)
status_codes = self.run(
[rh_install_cmd], node=node, env=env, stream_logs=True
)

if status_codes[0][0] != 0:
raise ValueError(
Expand Down Expand Up @@ -751,10 +744,7 @@ def restart_server(
+ f" --port {self.server_port}"
)

env_activate_cmd = self._get_env_activate_cmd(env)
cmd = f"{env_activate_cmd} && {cmd}" if env_activate_cmd else cmd

status_codes = self.run(commands=[cmd])
status_codes = self.run(commands=[cmd], env=env)
if not status_codes[0][0] == 0:
raise ValueError(f"Failed to restart server {self.name}.")

Expand Down Expand Up @@ -782,11 +772,9 @@ def stop_server(self, stop_ray: bool = True, env: Union[str, "Env"] = None):
stop_ray (bool): Whether to stop Ray. (Default: `True`)
env (str or Env, optional): Specified environment to stop the server on. (Default: ``None``)
"""
env_activate_cmd = self._get_env_activate_cmd(env)
cmd = CLI_STOP_CMD if stop_ray else f"{CLI_STOP_CMD} --no-stop-ray"
cmd = f"{env_activate_cmd} && {cmd}" if env_activate_cmd else cmd

status_codes = self.run([cmd], stream_logs=False)
status_codes = self.run([cmd], env=env, stream_logs=False)
assert status_codes[0][0] == 1

@contextlib.contextmanager
Expand Down Expand Up @@ -1099,15 +1087,31 @@ def run(
return res_list

# TODO [DG] suspend autostop while running
from runhouse.resources.provenance import run

cmd_prefix = ""
if env:
if isinstance(env, str):
from runhouse.resources.envs import Env
from runhouse.resources.envs import Env

if env and not port_forward and not node:
env_name = (
env
if isinstance(env, str)
else env.name
if isinstance(env, Env)
else "base_env"
)
return_codes = []
for command in commands:
ret_code = self.call(
env_name,
"_run_command",
command,
require_outputs=require_outputs,
stream_logs=stream_logs,
)
return_codes.append(ret_code)
return return_codes

env = Env.from_name(env)
cmd_prefix = env._run_cmd
env = _get_env_from(env)
cmd_prefix = env._run_cmd if isinstance(env, Env) else ""

if self.on_this_cluster():
return_codes = []
Expand Down Expand Up @@ -1137,6 +1141,8 @@ def run(
)

# Create and save the Run locally
from runhouse.resources.provenance import run

with run(name=run_name, cmds=commands, overwrite=True) as r:
return_codes = self._run_commands_with_ssh(
commands,
Expand Down Expand Up @@ -1247,14 +1253,7 @@ def run_python(
try to wrap the outer quote with double quotes (") and the inner quotes with a single quote (').
"""
# If no node provided, assume the commands are to be run on the head node
node = node or self.address
cmd_prefix = "python3 -c"
if env:
if isinstance(env, str):
from runhouse.resources.envs import Env

env = Env.from_name(env)
cmd_prefix = f"{env._run_cmd} {cmd_prefix}"
command_str = "; ".join(commands)
command_str_repr = (
repr(repr(command_str))[2:-2]
Expand Down
7 changes: 5 additions & 2 deletions runhouse/resources/hardware/sagemaker/sagemaker_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,11 @@ def _get_env_activate_cmd(self, env=None):
"""Prefix for commands run on the cluster. Ensure we are running all commands in the conda environment
and not the system default python."""
# TODO [JL] Can SageMaker handle this for us?
cmd = super()._get_env_activate_cmd(env)
return cmd or "source /opt/conda/bin/activate"
if env:
from runhouse.resources.envs import _get_env_from

return _get_env_from(env)._activate_cmd
return "source /opt/conda/bin/activate"

def _set_boto_session(self, profile_name: str = None):
self._boto_session = boto3.Session(
Expand Down
2 changes: 1 addition & 1 deletion runhouse/resources/packages/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _pip_install(install_cmd: str, env: Union[str, "Env"] = ""):
from runhouse.resources.envs.utils import _get_env_from

env = _get_env_from(env)
env.run([pip_cmd])
env._run_command(pip_cmd)
else:
cmd = f"{sys.executable} -m {pip_cmd}"
retcode = run_with_logs(cmd)
Expand Down
22 changes: 15 additions & 7 deletions runhouse/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,23 @@ def run_with_logs(cmd: Union[List[str], str], **kwargs) -> int:
Returns:
The returncode of the command.
"""
if isinstance(cmd, str) and not kwargs.get("shell", False):
cmd = shlex.split(cmd)
logging.info(f"Running command: {cmd} with kwargs: {kwargs}")
if isinstance(cmd, str):
cmd = shlex.split(cmd) if not kwargs.get("shell", False) else [cmd]
require_outputs = kwargs.pop("require_outputs", False)
stream_logs = kwargs.pop("stream_logs", True)

p = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kwargs
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, **kwargs
)
for line in p.stdout:
print(line.decode("utf-8").strip())
return p.wait()
stdout, stderr = p.communicate()

if stream_logs:
print(stdout)

if require_outputs:
return p.returncode, stdout, stderr

return p.returncode


def install_conda():
Expand Down
12 changes: 12 additions & 0 deletions tests/test_resources/test_envs/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,15 @@ def test_secrets_env(self, env, cluster):
assert get_env_var_cpu(var)

named_secret.delete()

@pytest.mark.level("local")
def test_env_run_cmd(self, env, cluster):
test_env_var = "ENV_VAR"
test_value = "env_val"
env.env_vars = {test_env_var: test_value}

env.to(cluster)
res = cluster.run(["echo $ENV_VAR"], env=env)

assert res[0][0] == 0 # returncode
assert "env_val" in res[0][1] # stdout

0 comments on commit 1907551

Please sign in to comment.