diff --git a/examples/sample_dotenv b/examples/sample_dotenv new file mode 100644 index 00000000000..ac56cc1c950 --- /dev/null +++ b/examples/sample_dotenv @@ -0,0 +1 @@ +TEST_ENV2="success" diff --git a/sky/cli.py b/sky/cli.py index 3136a16678a..1e27275d4f9 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -40,6 +40,7 @@ import click import colorama +import dotenv from rich import progress as rich_progress import yaml @@ -331,6 +332,16 @@ def _parse_env_var(env_var: str) -> Tuple[str, str]: return ret[0], ret[1] +def _merge_env_vars(env_dict: Optional[Dict[str, str]], + env_list: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + """Merges all values from env_list into env_dict.""" + if not env_dict: + return env_list + for (key, value) in env_list: + env_dict[key] = value + return list(env_dict.items()) + + _TASK_OPTIONS = [ click.option('--name', '-n', @@ -382,6 +393,15 @@ def _parse_env_var(env_var: str) -> Tuple[str, str]: default=None, help=('Custom image id for launching the instances. ' 'Passing "none" resets the config.')), + click.option('--env-file', + required=False, + type=dotenv.dotenv_values, + help="""\ + Path to a dotenv file with environment variables to set on the remote + node. + + If any values from ``--env-file`` conflict with values set by + ``--env``, the ``--env`` value will be preferred."""), click.option( '--env', required=False, @@ -1307,6 +1327,7 @@ def launch( num_nodes: Optional[int], use_spot: Optional[bool], image_id: Optional[str], + env_file: Optional[Dict[str, str]], env: List[Tuple[str, str]], disk_size: Optional[int], disk_tier: Optional[str], @@ -1326,6 +1347,7 @@ def launch( In both cases, the commands are run under the task's workdir (if specified) and they undergo job queue scheduling. """ + env = _merge_env_vars(env_file, env) backend_utils.check_cluster_name_not_reserved( cluster, operation_str='Launching tasks on it') if backend_name is None: @@ -1423,6 +1445,7 @@ def exec( num_nodes: Optional[int], use_spot: Optional[bool], image_id: Optional[str], + env_file: Optional[Dict[str, str]], env: List[Tuple[str, str]], ): # NOTE(dev): Keep the docstring consistent between the Python API and CLI. @@ -1483,6 +1506,7 @@ def exec( sky exec mycluster --env WANDB_API_KEY python train_gpu.py """ + env = _merge_env_vars(env_file, env) backend_utils.check_cluster_name_not_reserved( cluster, operation_str='Executing task on it') handle = global_user_state.get_handle_from_cluster_name(cluster) @@ -3459,6 +3483,7 @@ def spot_launch( use_spot: Optional[bool], image_id: Optional[str], spot_recovery: Optional[str], + env_file: Optional[Dict[str, str]], env: List[Tuple[str, str]], disk_size: Optional[int], disk_tier: Optional[str], @@ -3480,6 +3505,7 @@ def spot_launch( sky spot launch 'echo hello!' """ + env = _merge_env_vars(env_file, env) task_or_dag = _make_task_or_dag_from_entrypoint_with_overrides( entrypoint, name=name, @@ -3907,6 +3933,7 @@ def benchmark_launch( num_nodes: Optional[int], use_spot: Optional[bool], image_id: Optional[str], + env_file: Optional[Dict[str, str]], env: List[Tuple[str, str]], disk_size: Optional[int], disk_tier: Optional[str], @@ -3920,6 +3947,7 @@ def benchmark_launch( Alternatively, specify the benchmarking resources in your YAML (see doc), which allows benchmarking on many more resource fields. """ + env = _merge_env_vars(env_file, env) record = benchmark_state.get_benchmark_from_name(benchmark) if record is not None: raise click.BadParameter(f'Benchmark {benchmark} already exists. ' diff --git a/sky/setup_files/setup.py b/sky/setup_files/setup.py index 6a738bdad38..562910d94d2 100644 --- a/sky/setup_files/setup.py +++ b/sky/setup_files/setup.py @@ -86,6 +86,7 @@ def parse_readme(readme: str) -> str: # PrettyTable with version >=2.0.0 is required for the support of # `add_rows` method. 'PrettyTable>=2.0.0', + 'python-dotenv', # Lower version of ray will cause dependency conflict for # click/grpcio/protobuf. 'ray[default]>=2.2.0,<=2.4.0', diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 3c54ef26f41..ecb313872a7 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -2195,6 +2195,23 @@ def test_inline_env(generic_cloud: str): run_one_test(test) +# ---------- Testing env file ---------- +def test_inline_env_file(generic_cloud: str): + """Test env""" + name = _get_cluster_name() + test = Test( + 'test-inline-env-file', + [ + f'sky launch -c {name} -y --cloud {generic_cloud} --env TEST_ENV="hello world" -- "([[ ! -z \\"\$TEST_ENV\\" ]] && [[ ! -z \\"\$SKYPILOT_NODE_IPS\\" ]] && [[ ! -z \\"\$SKYPILOT_NODE_RANK\\" ]]) || exit 1"', + f'sky logs {name} 1 --status', + f'sky exec {name} --env-file examples/sample_dotenv "([[ ! -z \\"\$TEST_ENV2\\" ]] && [[ ! -z \\"\$SKYPILOT_NODE_IPS\\" ]] && [[ ! -z \\"\$SKYPILOT_NODE_RANK\\" ]]) || exit 1"', + f'sky logs {name} 2 --status', + ], + f'sky down -y {name}', + ) + run_one_test(test) + + # ---------- Testing custom image ---------- @pytest.mark.aws def test_custom_image():