Skip to content

Commit

Permalink
Add support for env files. (#2296)
Browse files Browse the repository at this point in the history
* Add support for env files.  Env args will override any env file values and be passed on to override any yaml values

* Review comments

* Adding an env-file smoke test

* Fix pylint failure
  • Loading branch information
fozziethebeat committed Jul 24, 2023
1 parent 4fa6378 commit 24a5622
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 0 deletions.
1 change: 1 addition & 0 deletions examples/sample_dotenv
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
TEST_ENV2="success"
28 changes: 28 additions & 0 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

import click
import colorama
import dotenv
from rich import progress as rich_progress
import yaml

Expand Down Expand Up @@ -331,6 +332,16 @@ def _parse_env_var(env_var: str) -> Tuple[str, str]:
return ret[0], ret[1]


def _merge_env_vars(env_dict: Optional[Dict[str, str]],
env_list: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
"""Merges all values from env_list into env_dict."""
if not env_dict:
return env_list
for (key, value) in env_list:
env_dict[key] = value
return list(env_dict.items())


_TASK_OPTIONS = [
click.option('--name',
'-n',
Expand Down Expand Up @@ -382,6 +393,15 @@ def _parse_env_var(env_var: str) -> Tuple[str, str]:
default=None,
help=('Custom image id for launching the instances. '
'Passing "none" resets the config.')),
click.option('--env-file',
required=False,
type=dotenv.dotenv_values,
help="""\
Path to a dotenv file with environment variables to set on the remote
node.
If any values from ``--env-file`` conflict with values set by
``--env``, the ``--env`` value will be preferred."""),
click.option(
'--env',
required=False,
Expand Down Expand Up @@ -1307,6 +1327,7 @@ def launch(
num_nodes: Optional[int],
use_spot: Optional[bool],
image_id: Optional[str],
env_file: Optional[Dict[str, str]],
env: List[Tuple[str, str]],
disk_size: Optional[int],
disk_tier: Optional[str],
Expand All @@ -1326,6 +1347,7 @@ def launch(
In both cases, the commands are run under the task's workdir (if specified)
and they undergo job queue scheduling.
"""
env = _merge_env_vars(env_file, env)
backend_utils.check_cluster_name_not_reserved(
cluster, operation_str='Launching tasks on it')
if backend_name is None:
Expand Down Expand Up @@ -1423,6 +1445,7 @@ def exec(
num_nodes: Optional[int],
use_spot: Optional[bool],
image_id: Optional[str],
env_file: Optional[Dict[str, str]],
env: List[Tuple[str, str]],
):
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
Expand Down Expand Up @@ -1483,6 +1506,7 @@ def exec(
sky exec mycluster --env WANDB_API_KEY python train_gpu.py
"""
env = _merge_env_vars(env_file, env)
backend_utils.check_cluster_name_not_reserved(
cluster, operation_str='Executing task on it')
handle = global_user_state.get_handle_from_cluster_name(cluster)
Expand Down Expand Up @@ -3459,6 +3483,7 @@ def spot_launch(
use_spot: Optional[bool],
image_id: Optional[str],
spot_recovery: Optional[str],
env_file: Optional[Dict[str, str]],
env: List[Tuple[str, str]],
disk_size: Optional[int],
disk_tier: Optional[str],
Expand All @@ -3480,6 +3505,7 @@ def spot_launch(
sky spot launch 'echo hello!'
"""
env = _merge_env_vars(env_file, env)
task_or_dag = _make_task_or_dag_from_entrypoint_with_overrides(
entrypoint,
name=name,
Expand Down Expand Up @@ -3907,6 +3933,7 @@ def benchmark_launch(
num_nodes: Optional[int],
use_spot: Optional[bool],
image_id: Optional[str],
env_file: Optional[Dict[str, str]],
env: List[Tuple[str, str]],
disk_size: Optional[int],
disk_tier: Optional[str],
Expand All @@ -3920,6 +3947,7 @@ def benchmark_launch(
Alternatively, specify the benchmarking resources in your YAML (see doc),
which allows benchmarking on many more resource fields.
"""
env = _merge_env_vars(env_file, env)
record = benchmark_state.get_benchmark_from_name(benchmark)
if record is not None:
raise click.BadParameter(f'Benchmark {benchmark} already exists. '
Expand Down
1 change: 1 addition & 0 deletions sky/setup_files/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def parse_readme(readme: str) -> str:
# PrettyTable with version >=2.0.0 is required for the support of
# `add_rows` method.
'PrettyTable>=2.0.0',
'python-dotenv',
# Lower version of ray will cause dependency conflict for
# click/grpcio/protobuf.
'ray[default]>=2.2.0,<=2.4.0',
Expand Down
17 changes: 17 additions & 0 deletions tests/test_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -2195,6 +2195,23 @@ def test_inline_env(generic_cloud: str):
run_one_test(test)


# ---------- Testing env file ----------
def test_inline_env_file(generic_cloud: str):
"""Test env"""
name = _get_cluster_name()
test = Test(
'test-inline-env-file',
[
f'sky launch -c {name} -y --cloud {generic_cloud} --env TEST_ENV="hello world" -- "([[ ! -z \\"\$TEST_ENV\\" ]] && [[ ! -z \\"\$SKYPILOT_NODE_IPS\\" ]] && [[ ! -z \\"\$SKYPILOT_NODE_RANK\\" ]]) || exit 1"',
f'sky logs {name} 1 --status',
f'sky exec {name} --env-file examples/sample_dotenv "([[ ! -z \\"\$TEST_ENV2\\" ]] && [[ ! -z \\"\$SKYPILOT_NODE_IPS\\" ]] && [[ ! -z \\"\$SKYPILOT_NODE_RANK\\" ]]) || exit 1"',
f'sky logs {name} 2 --status',
],
f'sky down -y {name}',
)
run_one_test(test)


# ---------- Testing custom image ----------
@pytest.mark.aws
def test_custom_image():
Expand Down

0 comments on commit 24a5622

Please sign in to comment.