diff --git a/prototype/README.md b/prototype/README.md index 7dabc0969ab..da04d54aa18 100644 --- a/prototype/README.md +++ b/prototype/README.md @@ -60,7 +60,7 @@ ray attach config/gcp.yml ray down config/gcp.yml ``` -**Azure**. Install the Azure CLI (`pip install azure-cli`) then login using `az login`. Set the subscription to use from the command line (`az account set -s `) or by modifying the provider section of the Azure template (`config/azure.yml.j2`). Ray Autoscaler does not work with the latest version of `azure-cli`. Hotfix: `pip install azure-cli-core==2.22.0` (this will make Ray work but at the cost of making the `az` CLI tool unusable). +**Azure**. Install the Azure CLI (`pip install azure-cli==2.22.0`) then login using `az login`. Set the subscription to use from the command line (`az account set -s `). Ray Autoscaler does not work with the latest version of `azure-cli` as of 1.9.1, hence the fixed Azure version. ## Open issues diff --git a/prototype/examples/resnet_distributed_tf_app.py b/prototype/examples/resnet_distributed_tf_app.py index 13379f5772f..c524eb0cd4f 100644 --- a/prototype/examples/resnet_distributed_tf_app.py +++ b/prototype/examples/resnet_distributed_tf_app.py @@ -9,9 +9,12 @@ with sky.Dag() as dag: # The working directory contains all code and will be synced to remote. workdir = '~/Downloads/tpu' - subprocess.run(f'cd {workdir} && git checkout 222cc86', - shell=True, - check=True) + subprocess.run( + 'cd ~/Downloads; ' + '(git clone https://github.com/concretevitamin/tpu || true); ' + f'cd {workdir} && git checkout 9459fee', + shell=True, + check=True) docker_image = None # 'rayproject/ray-ml:latest-gpu' diff --git a/prototype/setup.py b/prototype/setup.py index 0d08793651a..701d6d8e4e6 100644 --- a/prototype/setup.py +++ b/prototype/setup.py @@ -24,7 +24,9 @@ extras_require = { 'aws': ['awscli==1.22.17', 'boto3'], - 'azure': ['azure-cli'], + # ray <= 1.9.1 requires an older version of azure-cli. We can get rid of + # this version requirement once ray 1.10 is adopted as our local version. + 'azure': ['azure-cli==2.22.0'], 'gcp': ['google-api-python-client', 'google-cloud-storage'], } diff --git a/prototype/sky/__init__.py b/prototype/sky/__init__.py index efc83abdd2a..f706878765d 100644 --- a/prototype/sky/__init__.py +++ b/prototype/sky/__init__.py @@ -9,7 +9,6 @@ from sky.execution import launch, exec # pylint: disable=redefined-builtin from sky.resources import Resources from sky.task import Task -from sky.registry import fill_in_launchable_resources from sky.optimizer import Optimizer, OptimizeTarget from sky.data import Storage, StorageType @@ -34,7 +33,6 @@ 'backends', 'launch', 'exec', - 'fill_in_launchable_resources', 'list_accelerators', '__root_dir__', 'Storage', diff --git a/prototype/sky/backends/backend_utils.py b/prototype/sky/backends/backend_utils.py index fd798cd4114..7a3d151be6a 100644 --- a/prototype/sky/backends/backend_utils.py +++ b/prototype/sky/backends/backend_utils.py @@ -237,6 +237,10 @@ def add_cluster( host_name = ip logger.warning(f'Using {ip} to identify host instead.') break + else: + config = ['\n'] + with open(config_path, 'w') as f: + f.writelines(config) codegen = cls._get_generated_config(sky_autogen_comment, host_name, ip, username, key_path) @@ -252,7 +256,7 @@ def add_cluster( f.write('\n') else: with open(config_path, 'a') as f: - if not config[-1].endswith('\n'): + if len(config) > 0 and not config[-1].endswith('\n'): # Add trailing newline if it doesn't exist. f.write('\n') f.write('\n') @@ -791,9 +795,9 @@ def generate_cluster_name(): return f'sky-{uuid.uuid4().hex[:4]}-{getpass.getuser()}' -def get_backend_from_handle(handle: backends.Backend.ResourceHandle): - """ - Get a backend object from a handle. +def get_backend_from_handle( + handle: backends.Backend.ResourceHandle) -> backends.Backend: + """Gets a Backend object corresponding to a handle. Inspects handle type to infer the backend used for the resource. """ diff --git a/prototype/sky/cli.py b/prototype/sky/cli.py index 8a66fe49d10..7e6759d0ad3 100644 --- a/prototype/sky/cli.py +++ b/prototype/sky/cli.py @@ -42,6 +42,7 @@ import sky from sky import backends from sky import global_user_state +from sky import init as sky_init from sky import sky_logging from sky import clouds from sky.backends import backend as backend_lib @@ -79,7 +80,7 @@ def _truncate_long_string(s: str, max_length: int = 50) -> str: return s splits = s.split(' ') if len(splits[0]) > max_length: - return splits[0][:max_length] + '...' + return splits[0][:max_length] + '...' # Use '…'? # Truncate on word boundary. i = 0 total = 0 @@ -279,6 +280,7 @@ def _create_and_ssh_into_node( run='', ) task.set_resources(resources) + task.update_file_mounts(sky_init.get_cloud_credential_file_mounts()) backend = backend if backend is not None else backends.CloudVmRayBackend() handle = global_user_state.get_handle_from_cluster_name(cluster_name) @@ -1056,7 +1058,8 @@ def _terminate_or_stop_clusters(names: Tuple[str], apply_to_all: Optional[bool], name = record['name'] handle = record['handle'] backend = backend_utils.get_backend_from_handle(handle) - if handle.launched_resources.use_spot and not terminate: + if (isinstance(backend, backends.CloudVmRayBackend) and + handle.launched_resources.use_spot and not terminate): # TODO(suquark): enable GCP+spot to be stopped in the future. click.secho( f'Stopping cluster {name}... skipped, because spot instances ' @@ -1293,6 +1296,16 @@ def tpunode(cluster: str, port_forward: Optional[List[int]], @cli.command() +def init(): + """Determines a set of clouds that Sky will use. + + It checks access credentials for AWS, Azure and GCP. Sky tasks will only + run in clouds that you have access to. After configuring access for a + cloud, rerun `sky init` to reflect the changes. + """ + sky_init.init() + + @click.argument('gpu_name', required=False) @click.option('--all', '-a', diff --git a/prototype/sky/cloud_stores.py b/prototype/sky/cloud_stores.py index 9c28d09244c..cbfd619c519 100644 --- a/prototype/sky/cloud_stores.py +++ b/prototype/sky/cloud_stores.py @@ -6,8 +6,6 @@ TODO: * Better interface. * Better implementation (e.g., fsspec, smart_open, using each cloud's SDK). - The full-blown impl should handle authentication so each user's private - datasets can be accessed. """ import subprocess import urllib.parse diff --git a/prototype/sky/clouds/__init__.py b/prototype/sky/clouds/__init__.py index 93465f43562..8dc40a54bc1 100644 --- a/prototype/sky/clouds/__init__.py +++ b/prototype/sky/clouds/__init__.py @@ -9,10 +9,12 @@ __all__ = [ 'AWS', 'Azure', + 'CLOUD_REGISTRY', 'Cloud', 'GCP', 'Region', 'Zone', + 'from_str', ] CLOUD_REGISTRY = { @@ -20,3 +22,7 @@ 'gcp': GCP(), 'azure': Azure(), } + + +def from_str(name: str) -> 'Cloud': + return CLOUD_REGISTRY[name.lower()] diff --git a/prototype/sky/clouds/aws.py b/prototype/sky/clouds/aws.py index 3753bf7ba23..4efe6c1e2dd 100644 --- a/prototype/sky/clouds/aws.py +++ b/prototype/sky/clouds/aws.py @@ -1,6 +1,8 @@ """Amazon Web Services.""" import copy import json +import os +import subprocess from typing import Dict, Iterator, List, Optional, Tuple, TYPE_CHECKING from sky import clouds @@ -11,6 +13,15 @@ from sky import resources as resources_lib +def _run_output(cmd): + proc = subprocess.run(cmd, + shell=True, + check=True, + stderr=subprocess.PIPE, + stdout=subprocess.PIPE) + return proc.stdout.decode('ascii') + + class AWS(clouds.Cloud): """Amazon Web Services.""" @@ -187,3 +198,46 @@ def _make(instance_type): if instance_type is None: return [] return _make(instance_type) + + def check_credentials(self) -> Tuple[bool, Optional[str]]: + """Checks if the user has access credentials to this cloud.""" + help_str = ( + '\n For more info: ' + 'https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-quickstart.html' # pylint: disable=line-too-long + ) + # This file is required because it will be synced to remote VMs for + # `aws` to access private storage buckets. + # `aws configure list` does not guarantee this file exists. + if not os.path.isfile(os.path.expanduser('~/.aws/credentials')): + return (False, + '~/.aws/credentials does not exist. Run `aws configure`.' + + help_str) + try: + output = _run_output('aws configure list') + except subprocess.CalledProcessError: + return False, 'AWS CLI not installed properly.' + # Configured correctly, the AWS output should look like this: + # ... + # access_key ******************** shared-credentials-file + # secret_key ******************** shared-credentials-file + # ... + # Otherwise, one or both keys will show as ''. + lines = output.split('\n') + if len(lines) < 2: + return False, 'AWS CLI output invalid.' + access_key_ok = False + secret_key_ok = False + for line in lines[2:]: + line = line.lstrip() + if line.startswith('access_key'): + if '' not in line: + access_key_ok = True + elif line.startswith('secret_key'): + if '' not in line: + secret_key_ok = True + if access_key_ok and secret_key_ok: + return True, None + return False, 'AWS credentials not set. Run `aws configure`.' + help_str + + def get_credential_file_mounts(self) -> Dict[str, str]: + return {'~/.aws': '~/.aws'} diff --git a/prototype/sky/clouds/azure.py b/prototype/sky/clouds/azure.py index e3335fe6d70..2520d4e493a 100644 --- a/prototype/sky/clouds/azure.py +++ b/prototype/sky/clouds/azure.py @@ -1,12 +1,23 @@ """Azure.""" import copy import json +import os +import subprocess from typing import Dict, Iterator, List, Optional, Tuple from sky import clouds from sky.clouds.service_catalog import azure_catalog +def _run_output(cmd): + proc = subprocess.run(cmd, + shell=True, + check=True, + stderr=subprocess.PIPE, + stdout=subprocess.PIPE) + return proc.stdout.decode('ascii') + + class Azure(clouds.Cloud): """Azure.""" @@ -147,3 +158,32 @@ def _make(instance_type): if instance_type is None: return [] return _make(instance_type) + + def check_credentials(self) -> Tuple[bool, Optional[str]]: + """Checks if the user has access credentials to this cloud.""" + help_str = ( + '\n For more info: ' + 'https://docs.microsoft.com/en-us/cli/azure/get-started-with-azure-cli' # pylint: disable=line-too-long + ) + # This file is required because it will be synced to remote VMs for + # `az` to access private storage buckets. + # `az account show` does not guarantee this file exists. + if not os.path.isfile(os.path.expanduser('~/.azure/accessTokens.json')): + return ( + False, + '~/.azure/accessTokens.json does not exist. Run `az login`.' + + help_str) + try: + output = _run_output('az account show --output=json') + except subprocess.CalledProcessError: + return False, 'Azure CLI returned error.' + # If Azure is properly logged in, this will return something like: + # {"id": ..., "user": ...} + # and if not, it will return: + # Please run 'az login' to setup account. + if output.startswith('{'): + return True, None + return False, 'Azure credentials not set. Run `az login`.' + help_str + + def get_credential_file_mounts(self) -> Dict[str, str]: + return {'~/.azure': '~/.azure'} diff --git a/prototype/sky/clouds/cloud.py b/prototype/sky/clouds/cloud.py index 74a379ee5da..67a389c8a75 100644 --- a/prototype/sky/clouds/cloud.py +++ b/prototype/sky/clouds/cloud.py @@ -125,3 +125,17 @@ def get_feasible_launchable_resources(self, resources): Launchable resources require a cloud and an instance type be assigned. """ raise NotImplementedError + + def check_credentials(self) -> Tuple[bool, Optional[str]]: + """Checks if the user has access credentials to this cloud. + + Returns a boolean of whether the user can access this cloud, and a + string describing the reason if the user cannot access. + """ + raise NotImplementedError + + def get_credential_file_mounts(self) -> Dict[str, str]: + """Returns the files necessary to access this cloud. + + Returns a dictionary that will be added to a task's file mounts.""" + raise NotImplementedError diff --git a/prototype/sky/clouds/gcp.py b/prototype/sky/clouds/gcp.py index 3626781d0d2..64c26b77120 100644 --- a/prototype/sky/clouds/gcp.py +++ b/prototype/sky/clouds/gcp.py @@ -1,8 +1,11 @@ """Google Cloud Platform.""" import copy import json +import os from typing import Dict, Iterator, List, Optional, Tuple +from google import auth + from sky import clouds from sky.clouds.service_catalog import gcp_catalog @@ -217,3 +220,37 @@ def get_accelerators_from_instance_type( # GCP handles accelerators separately from regular instance types, # hence return none here. return None + + def check_credentials(self) -> Tuple[bool, Optional[str]]: + """Checks if the user has access credentials to this cloud.""" + try: + # These files are required because they will be synced to remote + # VMs for `gsutil` to access private storage buckets. + # `auth.default()` does not guarantee these files exist. + for file in [ + '~/.config/gcloud/access_tokens.db', + '~/.config/gcloud/credentials.db' + ]: + assert os.path.isfile(os.path.expanduser(file)) + # Calling `auth.default()` ensures the GCP client library works, + # which is used by Ray Autoscaler to launch VMs. + auth.default() + except (AssertionError, auth.exceptions.DefaultCredentialsError): + # See also: https://stackoverflow.com/a/53307505/1165051 + return False, ( + 'GCP credentials not set. Run the following commands:\n ' + # This authenticates the CLI to make `gsutil` work: + '$ gcloud auth login\n ' + '$ gcloud config set project \n ' + # These two commands setup the client library to make + # Ray Autoscaler work: + '$ gcloud auth application-default login\n ' + '$ gcloud auth application-default set-quota-project ' + '\n ' + 'For more info: ' + 'https://googleapis.dev/python/google-api-core/latest/auth.html' + ) + return True, None + + def get_credential_file_mounts(self) -> Dict[str, str]: + return {'~/.config/gcloud': '~/.config/gcloud'} diff --git a/prototype/sky/execution.py b/prototype/sky/execution.py index 43a82034573..dec12dcd504 100644 --- a/prototype/sky/execution.py +++ b/prototype/sky/execution.py @@ -19,6 +19,7 @@ import sky from sky import backends +from sky import init from sky import global_user_state from sky import sky_logging from sky import optimizer @@ -84,16 +85,25 @@ def _execute(dag: sky.Dag, cluster_name) cluster_exists = existing_handle is not None + backend = backend if backend is not None else backends.CloudVmRayBackend() + if not cluster_exists and (stages is None or Stage.OPTIMIZE in stages): if task.best_resources is None: # TODO: fix this for the situation where number of requested # accelerators is not an integer. - dag = sky.optimize(dag, minimize=optimize_target) + if isinstance(backend, backends.CloudVmRayBackend): + # TODO: adding this check because docker backend on a + # no-credential machine should not enter optimize(), which + # would directly error out ('No cloud is enabled...'). Fix by + # moving sky init checks out of optimize()? + dag = sky.optimize(dag, minimize=optimize_target) task = dag.tasks[0] # Keep: dag may have been deep-copied. - backend = backend if backend is not None else backends.CloudVmRayBackend() backend.register_info(dag=dag, optimize_target=optimize_target) + # FIXME: test on some node where the mounts do not exist. + task.update_file_mounts(init.get_cloud_credential_file_mounts()) + if task.storage_mounts is not None: # Optimizer should eventually choose where to store bucket task.add_storage_mounts() diff --git a/prototype/sky/global_user_state.py b/prototype/sky/global_user_state.py index e63e29dda9c..de1378315f2 100644 --- a/prototype/sky/global_user_state.py +++ b/prototype/sky/global_user_state.py @@ -7,6 +7,7 @@ interact with a cluster. """ import enum +import json import os import pathlib import pickle @@ -16,6 +17,9 @@ from typing import Any, Dict, List, Optional from sky import backends +from sky import clouds + +_ENABLED_CLOUDS_KEY = 'enabled_clouds' _DB_PATH = os.path.expanduser('~/.sky/state.db') os.makedirs(pathlib.Path(_DB_PATH).parents[0], exist_ok=True) @@ -23,17 +27,16 @@ _CONN = sqlite3.connect(_DB_PATH) _CURSOR = _CONN.cursor() -try: - _CURSOR.execute('select * from clusters limit 0') -except sqlite3.OperationalError: - # Tables do not exist, create them. - _CURSOR.execute("""\ - CREATE TABLE clusters ( - name TEXT PRIMARY KEY, - launched_at INTEGER, - handle BLOB, - last_use TEXT, - status TEXT)""") +_CURSOR.execute("""\ + CREATE TABLE IF NOT EXISTS clusters ( + name TEXT PRIMARY KEY, + lauched_at INTEGER, + handle BLOB, + last_use TEXT, + status TEXT)""") +_CURSOR.execute("""\ + CREATE TABLE IF NOT EXISTS config ( + key TEXT PRIMARY KEY, value TEXT)""") _CONN.commit() @@ -167,3 +170,19 @@ def get_clusters() -> List[Dict[str, Any]]: 'status': ClusterStatus[status], }) return records + + +def get_enabled_clouds() -> List[clouds.Cloud]: + rows = _CURSOR.execute('SELECT value FROM config WHERE key = ?', + (_ENABLED_CLOUDS_KEY,)) + ret = [] + for (value,) in rows: + ret = json.loads(value) + break + return [clouds.from_str(cloud) for cloud in ret] + + +def set_enabled_clouds(enabled_clouds: List[str]) -> None: + _CURSOR.execute('INSERT OR REPLACE INTO config VALUES (?, ?)', + (_ENABLED_CLOUDS_KEY, json.dumps(enabled_clouds))) + _CONN.commit() diff --git a/prototype/sky/init.py b/prototype/sky/init.py new file mode 100644 index 00000000000..5f7390c1ce6 --- /dev/null +++ b/prototype/sky/init.py @@ -0,0 +1,53 @@ +"""Sky Initialization: check cloud credentials and enable clouds.""" +from typing import Dict + +import click + +from sky import clouds +from sky import global_user_state + + +def init(quiet: bool = False) -> None: + echo = (lambda *_args, **_kwargs: None) if quiet else click.echo + echo('Checking credentials to enable clouds for Sky.') + + enabled_clouds = [] + for cloud in clouds.CLOUD_REGISTRY.values(): + echo(f' Checking {cloud}...', nl=False) + ok, reason = cloud.check_credentials() + echo('\r', nl=False) + status_msg = 'enabled' if ok else 'disabled' + status_color = 'green' if ok else 'red' + echo(' ' + + click.style(f'{cloud}: {status_msg}', fg=status_color, bold=True) + + ' ' * 10) + if ok: + enabled_clouds.append(str(cloud)) + else: + echo(f' Reason: {reason}') + + if len(enabled_clouds) == 0: + click.echo( + click.style( + 'No cloud is enabled. Sky will not be able to run any task. ' + 'Run `sky init` for more info.', + fg='red', + bold=True)) + raise SystemExit() + else: + echo('\nSky will use only the enabled clouds to run tasks. ' + 'To change this, configure cloud credentials, ' + 'and run ' + click.style('sky init', bold=True) + '.') + + global_user_state.set_enabled_clouds(enabled_clouds) + + +def get_cloud_credential_file_mounts() -> Dict[str, str]: + """Returns the files necessary to access all enabled clouds. + + Returns a dictionary that will be added to a task's file mounts.""" + enabled_clouds = global_user_state.get_enabled_clouds() + ret = {} + for cloud in enabled_clouds: + ret.update(cloud.get_credential_file_mounts()) + return ret diff --git a/prototype/sky/optimizer.py b/prototype/sky/optimizer.py index e17b8c6dd85..b83429159ce 100644 --- a/prototype/sky/optimizer.py +++ b/prototype/sky/optimizer.py @@ -2,24 +2,26 @@ import collections import enum import pprint -from typing import List, Optional +from typing import Dict, List, Optional import networkx as nx import numpy as np import tabulate -import sky from sky import clouds from sky import dag as dag_lib from sky import exceptions -from sky import sky_logging +from sky import global_user_state +from sky import init from sky import resources as resources_lib -from sky import task +from sky import sky_logging +from sky import task as task_lib logger = sky_logging.init_logger(__name__) Dag = dag_lib.Dag Resources = resources_lib.Resources +Task = task_lib.Task _DUMMY_SOURCE_NAME = 'sky-dummy-source' _DUMMY_SINK_NAME = 'sky-dummy-sink' @@ -112,7 +114,7 @@ def _add_dummy_source_sink_nodes(dag: Dag): zero_outdegree_nodes.append(node) def make_dummy(name): - dummy = task.Task(name) + dummy = Task(name) dummy.set_resources({DummyResources(DummyCloud(), None)}) dummy.set_time_estimator(lambda _: 0) return dummy @@ -137,8 +139,8 @@ def _remove_dummy_source_sink_nodes(dag: Dag): return dag @staticmethod - def _egress_cost_or_time(minimize_cost: bool, parent: task.Task, - parent_resources: Resources, node: task.Task, + def _egress_cost_or_time(minimize_cost: bool, parent: Task, + parent_resources: Resources, node: Task, resources: Resources): """Computes the egress cost or time depending on 'minimize_cost'.""" if isinstance(parent_resources.cloud, DummyCloud): @@ -190,7 +192,7 @@ def _optimize_cost( if node_i < len(topo_order) - 1: # Convert partial resource labels to launchable resources. launchable_resources = \ - sky.registry.fill_in_launchable_resources( + _fill_in_launchable_resources( node, blocked_launchable_resources ) @@ -205,8 +207,8 @@ def _optimize_cost( for orig_resources, launchable_list in launchable_resources.items(): if not launchable_list: raise exceptions.ResourcesUnavailableError( - f'No launchable resource found for task\n{node}; ' - f'To fix: relax its Resources() requirements.') + f'No launchable resource found for task {node}. ' + 'To fix: relax its resource requirements.') if num_resources == 1 and node.time_estimator_func is None: logger.info('Defaulting estimated time to 1 hr. ' 'Call Task.set_time_estimator() to override.') @@ -348,3 +350,62 @@ def get_cost(self, seconds): class DummyCloud(clouds.Cloud): """A dummy Cloud that has zero egress cost from/to.""" pass + + +def _cloud_in_list(cloud: clouds.Cloud, lst: List[clouds.Cloud]) -> bool: + return any(cloud.is_same_cloud(c) for c in lst) + + +def _filter_out_blocked_launchable_resources( + launchable_resources: List[Resources], + blocked_launchable_resources: List[Resources]): + """Whether the resources are blocked.""" + available_resources = [] + for resources in launchable_resources: + for blocked_resources in blocked_launchable_resources: + if resources.is_launchable_fuzzy_equal(blocked_resources): + break + else: # non-blokced launchable resources. (no break) + available_resources.append(resources) + return available_resources + + +def _fill_in_launchable_resources( + task: Task, + blocked_launchable_resources: Optional[List[Resources]], + try_fix_with_sky_init: bool = True, +) -> Dict[Resources, List[Resources]]: + enabled_clouds = global_user_state.get_enabled_clouds() + if len(enabled_clouds) == 0 and try_fix_with_sky_init: + init.init(quiet=True) + return _fill_in_launchable_resources(task, blocked_launchable_resources, + False) + launchable = collections.defaultdict(list) + if blocked_launchable_resources is None: + blocked_launchable_resources = [] + for resources in task.get_resources(): + if resources.cloud is not None and not _cloud_in_list( + resources.cloud, enabled_clouds): + if try_fix_with_sky_init: + init.init(quiet=True) + return _fill_in_launchable_resources( + task, blocked_launchable_resources, False) + raise exceptions.ResourcesUnavailableError( + f'Task {task} requires {resources.cloud} which is not ' + 'enabled. Run `sky init` to enable access to it, ' + 'or change the cloud requirement.') + elif resources.is_launchable(): + launchable[resources] = [resources] + elif resources.cloud is not None: + launchable[ + resources] = resources.cloud.get_feasible_launchable_resources( + resources) + else: + for cloud in enabled_clouds: + feasible_resources = cloud.get_feasible_launchable_resources( + resources) + launchable[resources].extend(feasible_resources) + launchable[resources] = _filter_out_blocked_launchable_resources( + launchable[resources], blocked_launchable_resources) + + return launchable diff --git a/prototype/sky/task.py b/prototype/sky/task.py index 43f0ccd43fc..6b09c94acb3 100644 --- a/prototype/sky/task.py +++ b/prototype/sky/task.py @@ -128,9 +128,6 @@ def __init__( # Filled in by the optimizer. If None, this Task is not planned. self.best_resources = None - # Block some of the clouds. - self.blocked_clouds = set() - # Semantics. if num_nodes is not None and num_nodes > 1 and isinstance( self.run, str): @@ -348,7 +345,6 @@ def add_storage_mounts(self) -> None: storage_type = storage_plans[store] if storage_type is storage_lib.StorageType.S3: # TODO: allow for Storage mounting of different clouds - self.update_file_mounts({'~/.aws': '~/.aws'}) self.update_file_mounts({ mnt_path: 's3://' + store.name, }) @@ -365,6 +361,7 @@ def add_storage_mounts(self) -> None: mnt_path: 'gs://' + store.name, }) elif storage_type is storage_lib.StorageType.AZURE: + # TODO when Azure Blob is done: sync ~/.azure assert False, 'TODO: Azure Blob not mountable yet' else: raise ValueError(f'Storage Type {storage_type} \ @@ -426,11 +423,6 @@ def update_file_mounts(self, file_mounts: Dict[str, str]): # For validation logic: return self.set_file_mounts(self.file_mounts) - def set_blocked_clouds(self, blocked_clouds: Set[clouds.Cloud]): - """Sets the clouds that this task should not run on.""" - self.blocked_clouds = blocked_clouds - return self - def get_local_to_remote_file_mounts(self) -> Optional[Dict[str, str]]: """Returns file mounts of the form (dst=VM path, src=local path). diff --git a/prototype/tests/test_global_user_state.py b/prototype/tests/test_global_user_state.py new file mode 100644 index 00000000000..4b20f30d08f --- /dev/null +++ b/prototype/tests/test_global_user_state.py @@ -0,0 +1,10 @@ +import pytest +import sys + +import sky + + +@pytest.mark.skipif(sys.platform != 'linux', reason='Only test in CI.') +def test_enabled_clouds_empty(): + # In test environment, no cloud should be enabled. + assert sky.global_user_state.get_enabled_clouds() == [] diff --git a/prototype/tests/test_optimizer_dryruns.py b/prototype/tests/test_optimizer_dryruns.py index 38c64438306..83a570b8b49 100644 --- a/prototype/tests/test_optimizer_dryruns.py +++ b/prototype/tests/test_optimizer_dryruns.py @@ -1,62 +1,105 @@ import pytest import sky - - -def _test_resources(resources): +from sky import clouds +from sky import exceptions + + +# Monkey-patching is required because in the test environment, no cloud is +# enabled. The optimizer checks the environment to find enabled clouds, and +# only generates plans within these clouds. The tests assume that all three +# clouds are enabled, so we monkeypatch the `sky.global_user_state` module +# to return all three clouds. We also monkeypatch `sky.init.init` so that +# when the optimizer tries calling it to update enabled_clouds, it does not +# raise SystemExit. +def _test_resources(monkeypatch, resources, enabled_clouds=None): + if enabled_clouds is None: + enabled_clouds = list(clouds.CLOUD_REGISTRY.values()) + monkeypatch.setattr( + 'sky.global_user_state.get_enabled_clouds', + lambda: enabled_clouds, + ) + monkeypatch.setattr('sky.init.init', lambda *_args, **_kwargs: None) with sky.Dag() as dag: - train = sky.Task('train') - train.set_resources({resources}) + task = sky.Task('test_task') + task.set_resources({resources}) sky.launch(dag, dryrun=True) + assert True -def test_resources_aws(): - _test_resources(sky.Resources(sky.AWS(), 'p3.2xlarge')) +def test_resources_aws(monkeypatch): + _test_resources(monkeypatch, sky.Resources(clouds.AWS(), 'p3.2xlarge')) -def test_resources_azure(): - _test_resources(sky.Resources(sky.Azure(), 'Standard_NC24s_v3')) +def test_resources_azure(monkeypatch): + _test_resources(monkeypatch, + sky.Resources(clouds.Azure(), 'Standard_NC24s_v3')) -def test_resources_gcp(): - _test_resources(sky.Resources(sky.GCP(), 'n1-standard-16')) +def test_resources_gcp(monkeypatch): + _test_resources(monkeypatch, sky.Resources(clouds.GCP(), 'n1-standard-16')) -def test_partial_k80(): - _test_resources(sky.Resources(accelerators='K80')) +def test_partial_k80(monkeypatch): + _test_resources(monkeypatch, sky.Resources(accelerators='K80')) -def test_partial_m60(): - _test_resources(sky.Resources(accelerators='M60')) +def test_partial_m60(monkeypatch): + _test_resources(monkeypatch, sky.Resources(accelerators='M60')) -def test_partial_p100(): - _test_resources(sky.Resources(accelerators='P100')) +def test_partial_p100(monkeypatch): + _test_resources(monkeypatch, sky.Resources(accelerators='P100')) -def test_partial_t4(): - _test_resources(sky.Resources(accelerators='T4')) - _test_resources(sky.Resources(accelerators={'T4': 8}, use_spot=True)) +def test_partial_t4(monkeypatch): + _test_resources(monkeypatch, sky.Resources(accelerators='T4')) + _test_resources(monkeypatch, + sky.Resources(accelerators={'T4': 8}, use_spot=True)) -def test_partial_tpu(): - _test_resources(sky.Resources(accelerators='tpu-v3-8')) +def test_partial_tpu(monkeypatch): + _test_resources(monkeypatch, sky.Resources(accelerators='tpu-v3-8')) -def test_invalid_cloud_tpu(): +def test_partial_v100(monkeypatch): + _test_resources(monkeypatch, sky.Resources(sky.AWS(), accelerators='V100')) + _test_resources( + monkeypatch, sky.Resources(sky.AWS(), + accelerators='V100', + use_spot=True)) + _test_resources(monkeypatch, + sky.Resources(sky.AWS(), accelerators={'V100': 8})) + + +def test_invalid_cloud_tpu(monkeypatch): with pytest.raises(AssertionError) as e: - _test_resources(sky.Resources(cloud=sky.AWS(), accelerators='tpu-v3-8')) + _test_resources(monkeypatch, + sky.Resources(cloud=sky.AWS(), accelerators='tpu-v3-8')) assert 'Cloud must be GCP' in str(e.value) -def test_partial_v100(): - _test_resources(sky.Resources(sky.AWS(), accelerators='V100')) - _test_resources(sky.Resources(sky.AWS(), accelerators='V100', - use_spot=True)) - _test_resources(sky.Resources(sky.AWS(), accelerators={'V100': 8})) +def test_clouds_not_enabled(monkeypatch): + with pytest.raises(exceptions.ResourcesUnavailableError): + _test_resources(monkeypatch, + sky.Resources(clouds.AWS()), + enabled_clouds=[ + clouds.Azure(), + clouds.GCP(), + ]) + + with pytest.raises(exceptions.ResourcesUnavailableError): + _test_resources(monkeypatch, + sky.Resources(clouds.Azure()), + enabled_clouds=[clouds.AWS()]) + + with pytest.raises(exceptions.ResourcesUnavailableError): + _test_resources(monkeypatch, + sky.Resources(clouds.GCP()), + enabled_clouds=[clouds.AWS()]) -def test_instance_type_mistmatches_accelerators(): +def test_instance_type_mistmatches_accelerators(monkeypatch): bad_instance_and_accs = [ # Actual: V100 ('p3.2xlarge', 'K80'), @@ -66,23 +109,27 @@ def test_instance_type_mistmatches_accelerators(): for instance, acc in bad_instance_and_accs: with pytest.raises(ValueError) as e: _test_resources( + monkeypatch, sky.Resources(sky.AWS(), instance_type=instance, accelerators=acc)) assert 'Infeasible resource demands found' in str(e.value) -def test_instance_type_matches_accelerators(): +def test_instance_type_matches_accelerators(monkeypatch): _test_resources( + monkeypatch, sky.Resources(sky.AWS(), instance_type='p3.2xlarge', accelerators='V100')) _test_resources( + monkeypatch, sky.Resources(sky.GCP(), instance_type='n1-standard-2', accelerators='V100')) # Partial use: Instance has 8 V100s, while the task needs 1 of them. _test_resources( + monkeypatch, sky.Resources(sky.AWS(), instance_type='p3.16xlarge', accelerators={'V100': 1}))