diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index ac723f35fc2..3faf75acf8d 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -53,7 +53,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install ".[all]" + pip install -e ".[all]" pip install pytest pytest-xdist pytest-env>=0.6 memory-profiler==0.61.0 - name: Run tests with pytest diff --git a/docs/source/cloud-setup/policy.rst b/docs/source/cloud-setup/policy.rst new file mode 100644 index 00000000000..0d3e3444372 --- /dev/null +++ b/docs/source/cloud-setup/policy.rst @@ -0,0 +1,195 @@ +.. _advanced-policy-config: + +Admin Policy Enforcement +======================== + + +SkyPilot provides an **admin policy** mechanism that admins can use to enforce certain policies on users' SkyPilot usage. An admin policy applies +custom validation and mutation logic to a user's tasks and SkyPilot config. + +Example usage: + +- :ref:`kubernetes-labels-policy` +- :ref:`disable-public-ip-policy` +- :ref:`use-spot-for-gpu-policy` +- :ref:`enforce-autostop-policy` + + +To implement and use an admin policy: + +- Admins writes a simple Python package with a policy class that implements SkyPilot's ``sky.AdminPolicy`` interface; +- Admins distributes this package to users; +- Users simply set the ``admin_policy`` field in the SkyPilot config file ``~/.sky/config.yaml`` for the policy to go into effect. + + +Overview +-------- + + + +User-Side +~~~~~~~~~~ + +To apply the policy, a user needs to set the ``admin_policy`` field in the SkyPilot config +``~/.sky/config.yaml`` to the path of the Python package that implements the policy. +For example: + +.. code-block:: yaml + + admin_policy: mypackage.subpackage.MyPolicy + + +.. hint:: + + SkyPilot loads the policy from the given package in the same Python environment. + You can test the existence of the policy by running: + + .. code-block:: bash + + python -c "from mypackage.subpackage import MyPolicy" + + +Admin-Side +~~~~~~~~~~ + +An admin can distribute the Python package to users with a pre-defined policy. The +policy should implement the ``sky.AdminPolicy`` `interface `_: + + +.. literalinclude:: ../../../sky/admin_policy.py + :language: python + :pyobject: AdminPolicy + :caption: `AdminPolicy Interface `_ + + +Your custom admin policy should look like this: + +.. code-block:: python + + import sky + + class MyPolicy(sky.AdminPolicy): + @classmethod + def validate_and_mutate(cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: + # Logic for validate and modify user requests. + ... + return sky.MutatedUserRequest(user_request.task, + user_request.skypilot_config) + + +``UserRequest`` and ``MutatedUserRequest`` are defined as follows (see `source code `_ for more details): + + +.. literalinclude:: ../../../sky/admin_policy.py + :language: python + :pyobject: UserRequest + :caption: `UserRequest Class `_ + +.. literalinclude:: ../../../sky/admin_policy.py + :language: python + :pyobject: MutatedUserRequest + :caption: `MutatedUserRequest Class `_ + + +In other words, an ``AdminPolicy`` can mutate any fields of a user request, including +the :ref:`task ` and the :ref:`global skypilot config `, +giving admins a lot of flexibility to control user's SkyPilot usage. + +An ``AdminPolicy`` can be used to both validate and mutate user requests. If +a request should be rejected, the policy should raise an exception. + + +The ``sky.Config`` and ``sky.RequestOptions`` classes are defined as follows: + +.. literalinclude:: ../../../sky/skypilot_config.py + :language: python + :pyobject: Config + :caption: `Config Class `_ + + +.. literalinclude:: ../../../sky/admin_policy.py + :language: python + :pyobject: RequestOptions + :caption: `RequestOptions Class `_ + + +Example Policies +---------------- + +We have provided a few example policies in `examples/admin_policy/example_policy `_. You can test these policies by installing the example policy package in your Python environment. + +.. code-block:: bash + + git clone https://github.com/skypilot-org/skypilot.git + cd skypilot + pip install examples/admin_policy/example_policy + +Reject All +~~~~~~~~~~ + +.. literalinclude:: ../../../examples/admin_policy/example_policy/example_policy/skypilot_policy.py + :language: python + :pyobject: RejectAllPolicy + :caption: `RejectAllPolicy `_ + +.. literalinclude:: ../../../examples/admin_policy/reject_all.yaml + :language: yaml + :caption: `Config YAML for using RejectAllPolicy `_ + +.. _kubernetes-labels-policy: + +Add Labels for all Tasks on Kubernetes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. literalinclude:: ../../../examples/admin_policy/example_policy/example_policy/skypilot_policy.py + :language: python + :pyobject: AddLabelsPolicy + :caption: `AddLabelsPolicy `_ + +.. literalinclude:: ../../../examples/admin_policy/add_labels.yaml + :language: yaml + :caption: `Config YAML for using AddLabelsPolicy `_ + + +.. _disable-public-ip-policy: + +Always Disable Public IP for AWS Tasks +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. literalinclude:: ../../../examples/admin_policy/example_policy/example_policy/skypilot_policy.py + :language: python + :pyobject: DisablePublicIpPolicy + :caption: `DisablePublicIpPolicy `_ + +.. literalinclude:: ../../../examples/admin_policy/disable_public_ip.yaml + :language: yaml + :caption: `Config YAML for using DisablePublicIpPolicy `_ + +.. _use-spot-for-gpu-policy: + +Use Spot for all GPU Tasks +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. +.. literalinclude:: ../../../examples/admin_policy/example_policy/example_policy/skypilot_policy.py + :language: python + :pyobject: UseSpotForGpuPolicy + :caption: `UseSpotForGpuPolicy `_ + +.. literalinclude:: ../../../examples/admin_policy/use_spot_for_gpu.yaml + :language: yaml + :caption: `Config YAML for using UseSpotForGpuPolicy `_ + +.. _enforce-autostop-policy: + +Enforce Autostop for all Tasks +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. literalinclude:: ../../../examples/admin_policy/example_policy/example_policy/skypilot_policy.py + :language: python + :pyobject: EnforceAutostopPolicy + :caption: `EnforceAutostopPolicy `_ + +.. literalinclude:: ../../../examples/admin_policy/enforce_autostop.yaml + :language: yaml + :caption: `Config YAML for using EnforceAutostopPolicy `_ diff --git a/docs/source/docs/index.rst b/docs/source/docs/index.rst index eeef2386337..00a645a3834 100644 --- a/docs/source/docs/index.rst +++ b/docs/source/docs/index.rst @@ -201,7 +201,8 @@ Read the research: ../cloud-setup/cloud-permissions/index ../cloud-setup/cloud-auth ../cloud-setup/quota - + ../cloud-setup/policy + .. toctree:: :hidden: :maxdepth: 1 diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index 6c2fe2569a6..ebe8db6751f 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -87,6 +87,17 @@ Available fields and semantics: # Default: false. disable_ecc: false + # Admin policy to be applied to all tasks. (optional). + # + # The policy class to be applied to all tasks, which can be used to validate + # and mutate user requests. + # + # This is useful for enforcing certain policies on all tasks, e.g., + # add custom labels; enforce certain resource limits; etc. + # + # The policy class should implement the sky.AdminPolicy interface. + admin_policy: my_package.SkyPilotPolicyV1 + # Advanced AWS configurations (optional). # Apply to all new instances but not existing ones. aws: diff --git a/examples/admin_policy/add_labels.yaml b/examples/admin_policy/add_labels.yaml new file mode 100644 index 00000000000..113b3b78044 --- /dev/null +++ b/examples/admin_policy/add_labels.yaml @@ -0,0 +1 @@ +admin_policy: example_policy.AddLabelsPolicy diff --git a/examples/admin_policy/disable_public_ip.yaml b/examples/admin_policy/disable_public_ip.yaml new file mode 100644 index 00000000000..cef910cbdaf --- /dev/null +++ b/examples/admin_policy/disable_public_ip.yaml @@ -0,0 +1 @@ +admin_policy: example_policy.DisablePublicIpPolicy diff --git a/examples/admin_policy/enforce_autostop.yaml b/examples/admin_policy/enforce_autostop.yaml new file mode 100644 index 00000000000..f0194fb994e --- /dev/null +++ b/examples/admin_policy/enforce_autostop.yaml @@ -0,0 +1 @@ +admin_policy: example_policy.EnforceAutostopPolicy diff --git a/examples/admin_policy/example_policy/example_policy/__init__.py b/examples/admin_policy/example_policy/example_policy/__init__.py new file mode 100644 index 00000000000..12ca4e952e2 --- /dev/null +++ b/examples/admin_policy/example_policy/example_policy/__init__.py @@ -0,0 +1,6 @@ +"""Example admin policy moduleĀ and prebuilt policies.""" +from example_policy.skypilot_policy import AddLabelsPolicy +from example_policy.skypilot_policy import DisablePublicIpPolicy +from example_policy.skypilot_policy import EnforceAutostopPolicy +from example_policy.skypilot_policy import RejectAllPolicy +from example_policy.skypilot_policy import UseSpotForGpuPolicy diff --git a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py new file mode 100644 index 00000000000..dc4e4b873fb --- /dev/null +++ b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py @@ -0,0 +1,121 @@ +"""Example prebuilt admin policies.""" +import sky + + +class RejectAllPolicy(sky.AdminPolicy): + """Example policy: rejects all user requests.""" + + @classmethod + def validate_and_mutate( + cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: + """Rejects all user requests.""" + raise RuntimeError('Reject all policy') + + +class AddLabelsPolicy(sky.AdminPolicy): + """Example policy: adds a kubernetes label for skypilot_config.""" + + @classmethod + def validate_and_mutate( + cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: + config = user_request.skypilot_config + labels = config.get_nested(('kubernetes', 'custom_metadata', 'labels'), + {}) + labels['app'] = 'skypilot' + config.set_nested(('kubernetes', 'custom_metadata', 'labels'), labels) + return sky.MutatedUserRequest(user_request.task, config) + + +class DisablePublicIpPolicy(sky.AdminPolicy): + """Example policy: disables public IP for all AWS tasks.""" + + @classmethod + def validate_and_mutate( + cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: + config = user_request.skypilot_config + config.set_nested(('aws', 'use_internal_ip'), True) + if config.get_nested(('aws', 'vpc_name'), None) is None: + # If no VPC name is specified, it is likely a mistake. We should + # reject the request + raise RuntimeError('VPC name should be set. Check organization ' + 'wiki for more information.') + return sky.MutatedUserRequest(user_request.task, config) + + +class UseSpotForGpuPolicy(sky.AdminPolicy): + """Example policy: use spot instances for all GPU tasks.""" + + @classmethod + def validate_and_mutate( + cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: + """Sets use_spot to True for all GPU tasks.""" + task = user_request.task + new_resources = [] + for r in task.resources: + if r.accelerators: + new_resources.append(r.copy(use_spot=True)) + else: + new_resources.append(r) + + task.set_resources(type(task.resources)(new_resources)) + + return sky.MutatedUserRequest( + task=task, skypilot_config=user_request.skypilot_config) + + +class EnforceAutostopPolicy(sky.AdminPolicy): + """Example policy: enforce autostop for all tasks.""" + + @classmethod + def validate_and_mutate( + cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: + """Enforces autostop for all tasks. + + Note that with this policy enforced, users can still change the autostop + setting for an existing cluster by using `sky autostop`. + + Since we refresh the cluster status with `sky.status` whenever this + policy is applied, we should expect a few seconds latency when a user + run a request. + """ + request_options = user_request.request_options + + # Request options is None when a task is executed with `jobs launch` or + # `sky serve up`. + if request_options is None: + return sky.MutatedUserRequest( + task=user_request.task, + skypilot_config=user_request.skypilot_config) + + # Get the cluster record to operate on. + cluster_name = request_options.cluster_name + cluster_records = [] + if cluster_name is not None: + cluster_records = sky.status(cluster_name, refresh=True) + + # Check if the user request should specify autostop settings. + need_autostop = False + if not cluster_records: + # Cluster does not exist + need_autostop = True + elif cluster_records[0]['status'] == sky.ClusterStatus.STOPPED: + # Cluster is stopped + need_autostop = True + elif cluster_records[0]['autostop'] < 0: + # Cluster is running but autostop is not set + need_autostop = True + + # Check if the user request is setting autostop settings. + is_setting_autostop = False + idle_minutes_to_autostop = request_options.idle_minutes_to_autostop + is_setting_autostop = (idle_minutes_to_autostop is not None and + idle_minutes_to_autostop >= 0) + + # If the cluster requires autostop but the user request is not setting + # autostop settings, raise an error. + if need_autostop and not is_setting_autostop: + raise RuntimeError('Autostop/down must be set for all clusters.') + + return sky.MutatedUserRequest( + task=user_request.task, + skypilot_config=user_request.skypilot_config) diff --git a/examples/admin_policy/example_policy/pyproject.toml b/examples/admin_policy/example_policy/pyproject.toml new file mode 100644 index 00000000000..b4aa56be4b2 --- /dev/null +++ b/examples/admin_policy/example_policy/pyproject.toml @@ -0,0 +1,7 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "example_policy" +version = "0.0.1" diff --git a/examples/admin_policy/reject_all.yaml b/examples/admin_policy/reject_all.yaml new file mode 100644 index 00000000000..fe6632089d9 --- /dev/null +++ b/examples/admin_policy/reject_all.yaml @@ -0,0 +1 @@ +admin_policy: example_policy.RejectAllPolicy diff --git a/examples/admin_policy/task.yaml b/examples/admin_policy/task.yaml new file mode 100644 index 00000000000..065b4cbfb11 --- /dev/null +++ b/examples/admin_policy/task.yaml @@ -0,0 +1,12 @@ +resources: + cloud: aws + cpus: 2 + labels: + other_labels: test + + +setup: | + echo "setup" + +run: | + echo "run" diff --git a/examples/admin_policy/use_spot_for_gpu.yaml b/examples/admin_policy/use_spot_for_gpu.yaml new file mode 100644 index 00000000000..45f257017a4 --- /dev/null +++ b/examples/admin_policy/use_spot_for_gpu.yaml @@ -0,0 +1 @@ +admin_policy: example_policy.UseSpotForGpuPolicy diff --git a/sky/__init__.py b/sky/__init__.py index a077fb8966a..37b5a1caf08 100644 --- a/sky/__init__.py +++ b/sky/__init__.py @@ -82,6 +82,9 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): from sky import backends from sky import benchmark from sky import clouds +from sky.admin_policy import AdminPolicy +from sky.admin_policy import MutatedUserRequest +from sky.admin_policy import UserRequest from sky.clouds.service_catalog import list_accelerators from sky.core import autostop from sky.core import cancel @@ -112,6 +115,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): from sky.optimizer import OptimizeTarget from sky.resources import Resources from sky.skylet.job_lib import JobStatus +from sky.skypilot_config import Config from sky.status_lib import ClusterStatus from sky.task import Task @@ -185,4 +189,9 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): # core APIs Storage Management 'storage_ls', 'storage_delete', + # Admin Policy + 'UserRequest', + 'MutatedUserRequest', + 'AdminPolicy', + 'Config', ] diff --git a/sky/admin_policy.py b/sky/admin_policy.py new file mode 100644 index 00000000000..304285d04b7 --- /dev/null +++ b/sky/admin_policy.py @@ -0,0 +1,101 @@ +"""Interface for admin-defined policy for user requests.""" +import abc +import dataclasses +import typing +from typing import Optional + +if typing.TYPE_CHECKING: + import sky + + +@dataclasses.dataclass +class RequestOptions: + """Request options for admin policy. + + Args: + cluster_name: Name of the cluster to create/reuse. It is None if not + specified by the user. + idle_minutes_to_autostop: Autostop setting requested by a user. The + cluster will be set to autostop after this many minutes of idleness. + down: If true, use autodown rather than autostop. + dryrun: Is the request a dryrun? + """ + cluster_name: Optional[str] + idle_minutes_to_autostop: Optional[int] + down: bool + dryrun: bool + + +@dataclasses.dataclass +class UserRequest: + """A user request. + + A "user request" is defined as a `sky launch / exec` command or its API + equivalent. + + `sky jobs launch / serve up` involves multiple launch requests, including + the launch of controller and clusters for a job (which can have multiple + tasks if it is a pipeline) or service replicas. Each launch is a separate + request. + + This class wraps the underlying task, the global skypilot config used to run + a task, and the request options. + + Args: + task: User specified task. + skypilot_config: Global skypilot config to be used in this request. + request_options: Request options. It is None for jobs and services. + """ + task: 'sky.Task' + skypilot_config: 'sky.Config' + request_options: Optional['RequestOptions'] = None + + +@dataclasses.dataclass +class MutatedUserRequest: + task: 'sky.Task' + skypilot_config: 'sky.Config' + + +# pylint: disable=line-too-long +class AdminPolicy: + """Abstract interface of an admin-defined policy for all user requests. + + Admins can implement a subclass of AdminPolicy with the following signature: + + import sky + + class SkyPilotPolicyV1(sky.AdminPolicy): + def validate_and_mutate(user_request: UserRequest) -> MutatedUserRequest: + ... + return MutatedUserRequest(task=..., skypilot_config=...) + + The policy can mutate both task and skypilot_config. Admins then distribute + a simple module that contains this implementation, installable in a way + that it can be imported by users from the same Python environment where + SkyPilot is running. + + Users can register a subclass of AdminPolicy in the SkyPilot config file + under the key 'admin_policy', e.g. + + admin_policy: my_package.SkyPilotPolicyV1 + """ + + @classmethod + @abc.abstractmethod + def validate_and_mutate(cls, + user_request: UserRequest) -> MutatedUserRequest: + """Validates and mutates the user request and returns mutated request. + + Args: + user_request: The user request to validate and mutate. + UserRequest contains (sky.Task, sky.Config) + + Returns: + MutatedUserRequest: The mutated user request. + + Raises: + Exception to throw if the user request failed the validation. + """ + raise NotImplementedError( + 'Your policy must implement validate_and_mutate') diff --git a/sky/dag.py b/sky/dag.py index d1904eb9fcc..4af5adc76b5 100644 --- a/sky/dag.py +++ b/sky/dag.py @@ -1,8 +1,12 @@ """DAGs: user applications to be run.""" import pprint import threading +import typing from typing import List, Optional +if typing.TYPE_CHECKING: + from sky import task + class Dag: """Dag: a user application, represented as a DAG of Tasks. @@ -13,37 +17,37 @@ class Dag: >>> task = sky.Task(...) """ - def __init__(self): - self.tasks = [] + def __init__(self) -> None: + self.tasks: List['task.Task'] = [] import networkx as nx # pylint: disable=import-outside-toplevel self.graph = nx.DiGraph() - self.name = None + self.name: Optional[str] = None - def add(self, task): + def add(self, task: 'task.Task') -> None: self.graph.add_node(task) self.tasks.append(task) - def remove(self, task): + def remove(self, task: 'task.Task') -> None: self.tasks.remove(task) self.graph.remove_node(task) - def add_edge(self, op1, op2): + def add_edge(self, op1: 'task.Task', op2: 'task.Task') -> None: assert op1 in self.graph.nodes assert op2 in self.graph.nodes self.graph.add_edge(op1, op2) - def __len__(self): + def __len__(self) -> int: return len(self.tasks) - def __enter__(self): + def __enter__(self) -> 'Dag': push_dag(self) return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback) -> None: pop_dag() - def __repr__(self): + def __repr__(self) -> str: pformat = pprint.pformat(self.tasks) return f'DAG:\n{pformat}' @@ -70,15 +74,15 @@ def is_chain(self) -> bool: class _DagContext(threading.local): """A thread-local stack of Dags.""" - _current_dag = None + _current_dag: Optional[Dag] = None _previous_dags: List[Dag] = [] - def push_dag(self, dag): + def push_dag(self, dag: Dag): if self._current_dag is not None: self._previous_dags.append(self._current_dag) self._current_dag = dag - def pop_dag(self): + def pop_dag(self) -> Optional[Dag]: old_dag = self._current_dag if self._previous_dags: self._current_dag = self._previous_dags.pop() diff --git a/sky/exceptions.py b/sky/exceptions.py index 15f3ea3f34e..04c50ad4e08 100644 --- a/sky/exceptions.py +++ b/sky/exceptions.py @@ -286,3 +286,8 @@ class ServeUserTerminatedError(Exception): class PortDoesNotExistError(Exception): """Raised when the port does not exist.""" + + +class UserRequestRejectedByPolicy(Exception): + """Raised when a user request is rejected by an admin policy.""" + pass diff --git a/sky/execution.py b/sky/execution.py index 1f6bd09f9c3..792ca5fffc0 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -9,6 +9,7 @@ import colorama import sky +from sky import admin_policy from sky import backends from sky import clouds from sky import global_user_state @@ -16,6 +17,7 @@ from sky import sky_logging from sky.backends import backend_utils from sky.usage import usage_lib +from sky.utils import admin_policy_utils from sky.utils import controller_utils from sky.utils import dag_utils from sky.utils import env_options @@ -158,7 +160,16 @@ def _execute( handle: Optional[backends.ResourceHandle]; the handle to the cluster. None if dryrun. """ + dag = dag_utils.convert_entrypoint_to_dag(entrypoint) + dag, _ = admin_policy_utils.apply( + dag, + request_options=admin_policy.RequestOptions( + cluster_name=cluster_name, + idle_minutes_to_autostop=idle_minutes_to_autostop, + down=down, + dryrun=dryrun, + )) assert len(dag) == 1, f'We support 1 task for now. {dag}' task = dag.tasks[0] @@ -170,9 +181,8 @@ def _execute( cluster_exists = False if cluster_name is not None: - existing_handle = global_user_state.get_handle_from_cluster_name( - cluster_name) - cluster_exists = existing_handle is not None + cluster_record = global_user_state.get_cluster_from_name(cluster_name) + cluster_exists = cluster_record is not None # TODO(woosuk): If the cluster exists, print a warning that # `cpus` and `memory` are not used as a job scheduling constraint, # unlike `gpus`. diff --git a/sky/jobs/controller.py b/sky/jobs/controller.py index 39c89d2784b..f3cd81576e2 100644 --- a/sky/jobs/controller.py +++ b/sky/jobs/controller.py @@ -64,6 +64,7 @@ def __init__(self, job_id: int, dag_yaml: str, if len(self._dag.tasks) <= 1: task_name = self._dag_name else: + assert task.name is not None, task task_name = task.name # This is guaranteed by the spot_launch API, where we fill in # the task.name with @@ -447,6 +448,7 @@ def _cleanup(job_id: int, dag_yaml: str): # controller, we should keep it in sync with JobsController.__init__() dag, _ = _get_dag_and_name(dag_yaml) for task in dag.tasks: + assert task.name is not None, task cluster_name = managed_job_utils.generate_managed_job_cluster_name( task.name, job_id) recovery_strategy.terminate_cluster(cluster_name) diff --git a/sky/jobs/core.py b/sky/jobs/core.py index 561d47f4b25..c4f59f65eca 100644 --- a/sky/jobs/core.py +++ b/sky/jobs/core.py @@ -18,6 +18,7 @@ from sky.jobs import utils as managed_job_utils from sky.skylet import constants as skylet_constants from sky.usage import usage_lib +from sky.utils import admin_policy_utils from sky.utils import common_utils from sky.utils import controller_utils from sky.utils import dag_utils @@ -54,6 +55,8 @@ def launch( dag_uuid = str(uuid.uuid4().hex[:4]) dag = dag_utils.convert_entrypoint_to_dag(entrypoint) + dag, mutated_user_config = admin_policy_utils.apply( + dag, use_mutated_config_in_current_request=False) if not dag.is_chain(): with ux_utils.print_exception_no_traceback(): raise ValueError('Only single-task or chain DAG is ' @@ -103,6 +106,7 @@ def launch( **controller_utils.shared_controller_vars_to_fill( controller_utils.Controllers.JOBS_CONTROLLER, remote_user_config_path=remote_user_config_path, + local_user_config=mutated_user_config, ), } diff --git a/sky/serve/core.py b/sky/serve/core.py index 4f15413cf7f..2bb6e1384ee 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -17,6 +17,7 @@ from sky.serve import serve_utils from sky.skylet import constants from sky.usage import usage_lib +from sky.utils import admin_policy_utils from sky.utils import common_utils from sky.utils import controller_utils from sky.utils import resources_utils @@ -124,6 +125,10 @@ def up( _validate_service_task(task) + dag, mutated_user_config = admin_policy_utils.apply( + task, use_mutated_config_in_current_request=False) + task = dag.tasks[0] + controller_utils.maybe_translate_local_file_mounts_and_sync_up(task, path='serve') @@ -158,6 +163,7 @@ def up( **controller_utils.shared_controller_vars_to_fill( controller=controller_utils.Controllers.SKY_SERVE_CONTROLLER, remote_user_config_path=remote_config_yaml_path, + local_user_config=mutated_user_config, ), } common_utils.fill_template(serve_constants.CONTROLLER_TEMPLATE, diff --git a/sky/skypilot_config.py b/sky/skypilot_config.py index 52e1d0ae3d9..aae62afc616 100644 --- a/sky/skypilot_config.py +++ b/sky/skypilot_config.py @@ -61,6 +61,8 @@ from sky.utils import schemas from sky.utils import ux_utils +logger = sky_logging.init_logger(__name__) + # The config path is discovered in this order: # # (1) (Used internally) If env var {ENV_VAR_SKYPILOT_CONFIG} exists, use its @@ -78,11 +80,57 @@ # Path to the local config file. CONFIG_PATH = '~/.sky/config.yaml' -logger = sky_logging.init_logger(__name__) + +class Config(Dict[str, Any]): + """SkyPilot config that supports setting/getting values with nested keys.""" + + def get_nested(self, + keys: Tuple[str, ...], + default_value: Any, + override_configs: Optional[Dict[str, Any]] = None) -> Any: + """Gets a nested key. + + If any key is not found, or any intermediate key does not point to a + dict value, returns 'default_value'. + + Args: + keys: A tuple of strings representing the nested keys. + default_value: The default value to return if the key is not found. + override_configs: A dict of override configs with the same schema as + the config file, but only containing the keys to override. + + Returns: + The value of the nested key, or 'default_value' if not found. + """ + config = copy.deepcopy(self) + if override_configs is not None: + config = _recursive_update(config, override_configs) + return _get_nested(config, keys, default_value) + + def set_nested(self, keys: Tuple[str, ...], value: Any) -> None: + """In-place sets a nested key to value. + + Like get_nested(), if any key is not found, this will not raise an + error. + """ + override = {} + for i, key in enumerate(reversed(keys)): + if i == 0: + override = {key: value} + else: + override = {key: override} + _recursive_update(self, override) + + @classmethod + def from_dict(cls, config: Optional[Dict[str, Any]]) -> 'Config': + if config is None: + return cls() + return cls(**config) + # The loaded config. -_dict: Optional[Dict[str, Any]] = None -_loaded_config_path = None +_dict = Config() +_loaded_config_path: Optional[str] = None def _get_nested(configs: Optional[Dict[str, Any]], keys: Iterable[str], @@ -131,17 +179,11 @@ def get_nested(keys: Tuple[str, ...], ), (f'Override configs must not be provided when keys {keys} is not within ' 'constants.OVERRIDEABLE_CONFIG_KEYS: ' f'{constants.OVERRIDEABLE_CONFIG_KEYS}') - config: Dict[str, Any] = {} - if _dict is not None: - config = copy.deepcopy(_dict) - if override_configs is None: - override_configs = {} - config = _recursive_update(config, override_configs) - return _get_nested(config, keys, default_value) + return _dict.get_nested(keys, default_value, override_configs) -def _recursive_update(base_config: Dict[str, Any], - override_config: Dict[str, Any]) -> Dict[str, Any]: +def _recursive_update(base_config: Config, + override_config: Dict[str, Any]) -> Config: """Recursively updates base configuration with override configuration""" for key, value in override_config.items(): if (isinstance(value, dict) and key in base_config and @@ -157,22 +199,14 @@ def set_nested(keys: Tuple[str, ...], value: Any) -> Dict[str, Any]: Like get_nested(), if any key is not found, this will not raise an error. """ - _check_loaded_or_die() - assert _dict is not None - override = {} - for i, key in enumerate(reversed(keys)): - if i == 0: - override = {key: value} - else: - override = {key: override} - return _recursive_update(copy.deepcopy(_dict), override) + copied_dict = copy.deepcopy(_dict) + copied_dict.set_nested(keys, value) + return dict(**copied_dict) -def to_dict() -> Dict[str, Any]: +def to_dict() -> Config: """Returns a deep-copied version of the current config.""" - if _dict is not None: - return copy.deepcopy(_dict) - return {} + return copy.deepcopy(_dict) def _try_load_config() -> None: @@ -192,13 +226,14 @@ def _try_load_config() -> None: config_path = os.path.expanduser(config_path) if os.path.exists(config_path): logger.debug(f'Using config path: {config_path}') - _loaded_config_path = config_path try: - _dict = common_utils.read_yaml(config_path) + config = common_utils.read_yaml(config_path) + _dict = Config.from_dict(config) + _loaded_config_path = config_path logger.debug(f'Config loaded:\n{pprint.pformat(_dict)}') except yaml.YAMLError as e: logger.error(f'Error in loading config file ({config_path}):', e) - if _dict is not None: + if _dict: common_utils.validate_schema( _dict, schemas.get_config_schema(), @@ -219,14 +254,6 @@ def loaded_config_path() -> Optional[str]: _try_load_config() -def _check_loaded_or_die(): - """Checks loaded() is true; otherwise raises RuntimeError.""" - if _dict is None: - raise RuntimeError( - f'No user configs loaded. Check {CONFIG_PATH} exists and ' - 'can be loaded.') - - def loaded() -> bool: """Returns if the user configurations are loaded.""" - return _dict is not None + return bool(_dict) diff --git a/sky/templates/jobs-controller.yaml.j2 b/sky/templates/jobs-controller.yaml.j2 index 51083e84a59..45cdb5141d4 100644 --- a/sky/templates/jobs-controller.yaml.j2 +++ b/sky/templates/jobs-controller.yaml.j2 @@ -4,7 +4,9 @@ name: {{dag_name}} file_mounts: {{remote_user_yaml_path}}: {{user_yaml_path}} - {{remote_user_config_path}}: skypilot:local_skypilot_config_path + {%- if local_user_config_path is not none %} + {{remote_user_config_path}}: {{local_user_config_path}} + {%- endif %} {%- for remote_catalog_path, local_catalog_path in modified_catalogs.items() %} {{remote_catalog_path}}: {{local_catalog_path}} {%- endfor %} diff --git a/sky/templates/sky-serve-controller.yaml.j2 b/sky/templates/sky-serve-controller.yaml.j2 index a20c2d680aa..507a6e3a325 100644 --- a/sky/templates/sky-serve-controller.yaml.j2 +++ b/sky/templates/sky-serve-controller.yaml.j2 @@ -23,7 +23,9 @@ setup: | file_mounts: {{remote_task_yaml_path}}: {{local_task_yaml_path}} - {{remote_user_config_path}}: skypilot:local_skypilot_config_path + {%- if local_user_config_path is not none %} + {{remote_user_config_path}}: {{local_user_config_path}} + {%- endif %} {%- for remote_catalog_path, local_catalog_path in modified_catalogs.items() %} {{remote_catalog_path}}: {{local_catalog_path}} {%- endfor %} diff --git a/sky/utils/admin_policy_utils.py b/sky/utils/admin_policy_utils.py new file mode 100644 index 00000000000..09db2fc4be8 --- /dev/null +++ b/sky/utils/admin_policy_utils.py @@ -0,0 +1,145 @@ +"""Admin policy utils.""" +import copy +import importlib +import os +import tempfile +from typing import Optional, Tuple, Union + +import colorama + +from sky import admin_policy +from sky import dag as dag_lib +from sky import exceptions +from sky import sky_logging +from sky import skypilot_config +from sky import task as task_lib +from sky.utils import common_utils +from sky.utils import ux_utils + +logger = sky_logging.init_logger(__name__) + + +def _get_policy_cls( + policy: Optional[str]) -> Optional[admin_policy.AdminPolicy]: + """Gets admin-defined policy.""" + if policy is None: + return None + try: + module_path, class_name = policy.rsplit('.', 1) + module = importlib.import_module(module_path) + except ImportError as e: + with ux_utils.print_exception_no_traceback(): + raise ImportError( + f'Failed to import policy module: {policy}. ' + 'Please check if the module is installed in your Python ' + 'environment.') from e + + try: + policy_cls = getattr(module, class_name) + except AttributeError as e: + with ux_utils.print_exception_no_traceback(): + raise AttributeError( + f'Could not find {class_name} class in module {module_path}. ' + 'Please check with your policy admin for details.') from e + + # Check if the module implements the AdminPolicy interface. + if not issubclass(policy_cls, admin_policy.AdminPolicy): + with ux_utils.print_exception_no_traceback(): + raise ValueError( + f'Policy class {policy!r} does not implement the AdminPolicy ' + 'interface. Please check with your policy admin for details.') + return policy_cls + + +def apply( + entrypoint: Union['dag_lib.Dag', 'task_lib.Task'], + use_mutated_config_in_current_request: bool = True, + request_options: Optional[admin_policy.RequestOptions] = None, +) -> Tuple['dag_lib.Dag', skypilot_config.Config]: + """Applies an admin policy (if registered) to a DAG or a task. + + It mutates a Dag by applying any registered admin policy and also + potentially updates (controlled by `use_mutated_config_in_current_request`) + the global SkyPilot config if there is any changes made by the policy. + + Args: + dag: The dag to be mutated by the policy. + use_mutated_config_in_current_request: Whether to use the mutated + config in the current request. + request_options: Additional options user passed for the current request. + + Returns: + - The new copy of dag after applying the policy + - The new copy of skypilot config after applying the policy. + """ + if isinstance(entrypoint, task_lib.Task): + dag = dag_lib.Dag() + dag.add(entrypoint) + else: + dag = entrypoint + + policy = skypilot_config.get_nested(('admin_policy',), None) + policy_cls = _get_policy_cls(policy) + if policy_cls is None: + return dag, skypilot_config.to_dict() + + logger.info(f'Applying policy: {policy}') + original_config = skypilot_config.to_dict() + config = copy.deepcopy(original_config) + mutated_dag = dag_lib.Dag() + mutated_dag.name = dag.name + + mutated_config = None + for task in dag.tasks: + user_request = admin_policy.UserRequest(task, config, request_options) + try: + mutated_user_request = policy_cls.validate_and_mutate(user_request) + except Exception as e: # pylint: disable=broad-except + with ux_utils.print_exception_no_traceback(): + raise exceptions.UserRequestRejectedByPolicy( + f'{colorama.Fore.RED}User request rejected by policy ' + f'{policy!r}{colorama.Fore.RESET}: ' + f'{common_utils.format_exception(e, use_bracket=True)}' + ) from e + if mutated_config is None: + mutated_config = mutated_user_request.skypilot_config + else: + if mutated_config != mutated_user_request.skypilot_config: + # In the case of a pipeline of tasks, the mutated config + # generated should remain the same for all tasks for now for + # simplicity. + # TODO(zhwu): We should support per-task mutated config or + # allowing overriding required global config in task YAML. + with ux_utils.print_exception_no_traceback(): + raise exceptions.UserRequestRejectedByPolicy( + 'All tasks must have the same SkyPilot config after ' + 'applying the policy. Please check with your policy ' + 'admin for details.') + mutated_dag.add(mutated_user_request.task) + assert mutated_config is not None, dag + + # Update the new_dag's graph with the old dag's graph + for u, v in dag.graph.edges: + u_idx = dag.tasks.index(u) + v_idx = dag.tasks.index(v) + mutated_dag.graph.add_edge(mutated_dag.tasks[u_idx], + mutated_dag.tasks[v_idx]) + + if (use_mutated_config_in_current_request and + original_config != mutated_config): + with tempfile.NamedTemporaryFile( + delete=False, + mode='w', + prefix='policy-mutated-skypilot-config-', + suffix='.yaml') as temp_file: + + common_utils.dump_yaml(temp_file.name, dict(**mutated_config)) + os.environ[skypilot_config.ENV_VAR_SKYPILOT_CONFIG] = temp_file.name + logger.debug(f'Updated SkyPilot config: {temp_file.name}') + # TODO(zhwu): This is not a clean way to update the SkyPilot config, + # because we are resetting the global context for a single DAG, + # which is conceptually weird. + importlib.reload(skypilot_config) + + logger.debug(f'Mutated user request: {mutated_user_request}') + return mutated_dag, mutated_config diff --git a/sky/utils/common_utils.py b/sky/utils/common_utils.py index a9227fb4c20..dffe784cc33 100644 --- a/sky/utils/common_utils.py +++ b/sky/utils/common_utils.py @@ -300,7 +300,7 @@ def user_and_hostname_hash() -> str: return f'{getpass.getuser()}-{hostname_hash}' -def read_yaml(path) -> Dict[str, Any]: +def read_yaml(path: str) -> Dict[str, Any]: with open(path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) return config @@ -316,12 +316,13 @@ def read_yaml_all(path: str) -> List[Dict[str, Any]]: return configs -def dump_yaml(path, config) -> None: +def dump_yaml(path: str, config: Union[List[Dict[str, Any]], + Dict[str, Any]]) -> None: with open(path, 'w', encoding='utf-8') as f: f.write(dump_yaml_str(config)) -def dump_yaml_str(config): +def dump_yaml_str(config: Union[List[Dict[str, Any]], Dict[str, Any]]) -> str: # https://github.com/yaml/pyyaml/issues/127 class LineBreakDumper(yaml.SafeDumper): @@ -331,9 +332,9 @@ def write_line_break(self, data=None): super().write_line_break() if isinstance(config, list): - dump_func = yaml.dump_all + dump_func = yaml.dump_all # type: ignore else: - dump_func = yaml.dump + dump_func = yaml.dump # type: ignore return dump_func(config, Dumper=LineBreakDumper, sort_keys=False, diff --git a/sky/utils/controller_utils.py b/sky/utils/controller_utils.py index 866aaf1ee1a..118f9a2b718 100644 --- a/sky/utils/controller_utils.py +++ b/sky/utils/controller_utils.py @@ -44,8 +44,12 @@ '{controller_type}.controller.resources is a valid resources spec. ' 'Details:\n {err}') -# The placeholder for the local skypilot config path in file mounts. -LOCAL_SKYPILOT_CONFIG_PATH_PLACEHOLDER = 'skypilot:local_skypilot_config_path' +# The suffix for local skypilot config path for a job/service in file mounts +# that tells the controller logic to update the config with specific settings, +# e.g., removing the ssh_proxy_command when a job/service is launched in a same +# cloud as controller. +_LOCAL_SKYPILOT_CONFIG_PATH_SUFFIX = ( + '__skypilot:local_skypilot_config_path.yaml') @dataclasses.dataclass @@ -350,8 +354,21 @@ def download_and_stream_latest_job_log( def shared_controller_vars_to_fill( - controller: Controllers, - remote_user_config_path: str) -> Dict[str, str]: + controller: Controllers, remote_user_config_path: str, + local_user_config: Dict[str, Any]) -> Dict[str, str]: + if not local_user_config: + local_user_config_path = None + else: + # Remove admin_policy from local_user_config so that it is not applied + # again on the controller. This is required since admin_policy is not + # installed on the controller. + local_user_config.pop('admin_policy', None) + with tempfile.NamedTemporaryFile( + delete=False, + suffix=_LOCAL_SKYPILOT_CONFIG_PATH_SUFFIX) as temp_file: + common_utils.dump_yaml(temp_file.name, dict(**local_user_config)) + local_user_config_path = temp_file.name + vars_to_fill: Dict[str, Any] = { 'cloud_dependencies_installation_commands': _get_cloud_dependencies_installation_commands(controller), @@ -360,6 +377,7 @@ def shared_controller_vars_to_fill( # accessed. 'sky_activate_python_env': constants.ACTIVATE_SKY_REMOTE_PYTHON_ENV, 'sky_python_cmd': constants.SKY_PYTHON_CMD, + 'local_user_config_path': local_user_config_path, } env_vars: Dict[str, str] = { env.value: '1' for env in env_options.Options if env.get() @@ -481,7 +499,8 @@ def get_controller_resources( def _setup_proxy_command_on_controller( - controller_launched_cloud: 'clouds.Cloud') -> Dict[str, Any]: + controller_launched_cloud: 'clouds.Cloud', + user_config: Dict[str, Any]) -> skypilot_config.Config: """Sets up proxy command on the controller. This function should be called on the controller (remote cluster), which @@ -515,21 +534,20 @@ def _setup_proxy_command_on_controller( # (or name). It may not be a sufficient check (as it's always # possible that peering is not set up), but it may catch some # obvious errors. + config = skypilot_config.Config.from_dict(user_config) proxy_command_key = (str(controller_launched_cloud).lower(), 'ssh_proxy_command') - ssh_proxy_command = skypilot_config.get_nested(proxy_command_key, None) - config_dict = skypilot_config.to_dict() + ssh_proxy_command = config.get_nested(proxy_command_key, None) if isinstance(ssh_proxy_command, str): - config_dict = skypilot_config.set_nested(proxy_command_key, None) + config.set_nested(proxy_command_key, None) elif isinstance(ssh_proxy_command, dict): # Instead of removing the key, we set the value to empty string # so that the controller will only try the regions specified by # the keys. ssh_proxy_command = {k: None for k in ssh_proxy_command} - config_dict = skypilot_config.set_nested(proxy_command_key, - ssh_proxy_command) + config.set_nested(proxy_command_key, ssh_proxy_command) - return config_dict + return config def replace_skypilot_config_path_in_file_mounts( @@ -543,25 +561,20 @@ def replace_skypilot_config_path_in_file_mounts( if file_mounts is None: return replaced = False - to_replace = True - with tempfile.NamedTemporaryFile('w', delete=False) as f: - if skypilot_config.loaded(): - new_skypilot_config = _setup_proxy_command_on_controller(cloud) - common_utils.dump_yaml(f.name, new_skypilot_config) - to_replace = True - else: - # Empty config. Remove the placeholder below. - to_replace = False - for remote_path, local_path in list(file_mounts.items()): - if local_path == LOCAL_SKYPILOT_CONFIG_PATH_PLACEHOLDER: - if to_replace: - file_mounts[remote_path] = f.name - replaced = True - else: - del file_mounts[remote_path] + for remote_path, local_path in list(file_mounts.items()): + if local_path is None: + del file_mounts[remote_path] + continue + if local_path.endswith(_LOCAL_SKYPILOT_CONFIG_PATH_SUFFIX): + with tempfile.NamedTemporaryFile('w', delete=False) as f: + user_config = common_utils.read_yaml(local_path) + config = _setup_proxy_command_on_controller(cloud, user_config) + common_utils.dump_yaml(f.name, dict(**config)) + file_mounts[remote_path] = f.name + replaced = True if replaced: - logger.debug(f'Replaced {LOCAL_SKYPILOT_CONFIG_PATH_PLACEHOLDER} with ' - f'the real path in file mounts: {file_mounts}') + logger.debug(f'Replaced {_LOCAL_SKYPILOT_CONFIG_PATH_SUFFIX} ' + f'with the real path in file mounts: {file_mounts}') def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', diff --git a/sky/utils/dag_utils.py b/sky/utils/dag_utils.py index 7a4fe90e7fb..e6b491c3168 100644 --- a/sky/utils/dag_utils.py +++ b/sky/utils/dag_utils.py @@ -36,30 +36,33 @@ def convert_entrypoint_to_dag(entrypoint: Any) -> 'dag_lib.Dag': - """Convert the entrypoint to a sky.Dag. + """Converts the entrypoint to a sky.Dag and applies the policy. Raises TypeError if 'entrypoint' is not a 'sky.Task' or 'sky.Dag'. """ # Not suppressing stacktrace: when calling this via API user may want to # see their own program in the stacktrace. Our CLI impl would not trigger # these errors. + converted_dag: 'dag_lib.Dag' if isinstance(entrypoint, str): with ux_utils.print_exception_no_traceback(): raise TypeError(_ENTRYPOINT_STRING_AS_DAG_MESSAGE) elif isinstance(entrypoint, dag_lib.Dag): - return copy.deepcopy(entrypoint) + converted_dag = copy.deepcopy(entrypoint) elif isinstance(entrypoint, task_lib.Task): entrypoint = copy.deepcopy(entrypoint) with dag_lib.Dag() as dag: dag.add(entrypoint) dag.name = entrypoint.name - return dag + converted_dag = dag else: with ux_utils.print_exception_no_traceback(): raise TypeError( 'Expected a sky.Task or sky.Dag but received argument of type: ' f'{type(entrypoint)}') + return converted_dag + def load_chain_dag_from_yaml( path: str, diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 01dc14f617c..a50c400b805 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -848,6 +848,13 @@ def get_config_schema(): }, } + admin_policy_schema = { + 'type': 'string', + # Check regex to be a valid python module path + 'pattern': (r'^[a-zA-Z_][a-zA-Z0-9_]*' + r'(\.[a-zA-Z_][a-zA-Z0-9_]*)+$'), + } + allowed_clouds = { # A list of cloud names that are allowed to be used 'type': 'array', @@ -905,6 +912,7 @@ def get_config_schema(): 'spot': controller_resources_schema, 'serve': controller_resources_schema, 'allowed_clouds': allowed_clouds, + 'admin_policy': admin_policy_schema, 'docker': docker_configs, 'nvidia_gpus': gpu_configs, **cloud_configs, diff --git a/tests/test_config.py b/tests/test_config.py index 0cae5f9befb..5789214dc61 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,4 +1,5 @@ import copy +import importlib import pathlib import textwrap @@ -21,19 +22,19 @@ def _reload_config() -> None: - skypilot_config._dict = None + skypilot_config._dict = skypilot_config.Config() + skypilot_config._loaded_config_path = None skypilot_config._try_load_config() def _check_empty_config() -> None: """Check that the config is empty.""" - assert not skypilot_config.loaded() + assert not skypilot_config.loaded(), (skypilot_config._dict, + skypilot_config._loaded_config_path) assert skypilot_config.get_nested( ('aws', 'ssh_proxy_command'), None) is None assert skypilot_config.get_nested(('aws', 'ssh_proxy_command'), 'default') == 'default' - with pytest.raises(RuntimeError): - skypilot_config.set_nested(('aws', 'ssh_proxy_command'), 'value') def _create_config_file(config_file_path: pathlib.Path) -> None: @@ -98,6 +99,22 @@ def _create_task_yaml_file(task_file_path: pathlib.Path) -> None: """)) +def test_nested_config(monkeypatch) -> None: + """Test that the nested config works.""" + config = skypilot_config.Config() + config.set_nested(('aws', 'ssh_proxy_command'), 'value') + assert config == {'aws': {'ssh_proxy_command': 'value'}} + + assert config.get_nested(('admin_policy',), 'default') == 'default' + config.set_nested(('aws', 'use_internal_ips'), True) + assert config == { + 'aws': { + 'ssh_proxy_command': 'value', + 'use_internal_ips': True + } + } + + def test_no_config(monkeypatch) -> None: """Test that the config is not loaded if the config file does not exist.""" monkeypatch.setattr(skypilot_config, 'CONFIG_PATH', '/tmp/does_not_exist') diff --git a/tests/unit_tests/test_admin_policy.py b/tests/unit_tests/test_admin_policy.py new file mode 100644 index 00000000000..96b666493d3 --- /dev/null +++ b/tests/unit_tests/test_admin_policy.py @@ -0,0 +1,172 @@ +import importlib +import os +import sys +from typing import Optional, Tuple +from unittest import mock + +import pytest + +import sky +from sky import exceptions +from sky import sky_logging +from sky import skypilot_config +from sky.utils import admin_policy_utils + +logger = sky_logging.init_logger(__name__) + +POLICY_PATH = os.path.join(os.path.dirname(os.path.dirname(sky.__file__)), + 'examples', 'admin_policy') + + +@pytest.fixture +def add_example_policy_paths(): + # Add to path to be able to import + sys.path.append(os.path.join(POLICY_PATH, 'example_policy')) + + +@pytest.fixture +def task(): + return sky.Task.from_yaml(os.path.join(POLICY_PATH, 'task.yaml')) + + +def _load_task_and_apply_policy( + task: sky.Task, + config_path: str, + idle_minutes_to_autostop: Optional[int] = None, +) -> Tuple[sky.Dag, skypilot_config.Config]: + os.environ['SKYPILOT_CONFIG'] = config_path + importlib.reload(skypilot_config) + return admin_policy_utils.apply( + task, + request_options=sky.admin_policy.RequestOptions( + cluster_name='test', + idle_minutes_to_autostop=idle_minutes_to_autostop, + down=False, + dryrun=False, + )) + + +def test_use_spot_for_all_gpus_policy(add_example_policy_paths, task): + dag, _ = _load_task_and_apply_policy( + task, os.path.join(POLICY_PATH, 'use_spot_for_gpu.yaml')) + assert not any(r.use_spot for r in dag.tasks[0].resources), ( + 'use_spot should be False as GPU is not specified') + + task.set_resources([ + sky.Resources(cloud='gcp', accelerators={'A100': 1}), + sky.Resources(accelerators={'L4': 1}) + ]) + dag, _ = _load_task_and_apply_policy( + task, os.path.join(POLICY_PATH, 'use_spot_for_gpu.yaml')) + assert all( + r.use_spot for r in dag.tasks[0].resources), 'use_spot should be True' + + task.set_resources([ + sky.Resources(accelerators={'A100': 1}), + sky.Resources(accelerators={'L4': 1}, use_spot=True), + sky.Resources(cpus='2+'), + ]) + dag, _ = _load_task_and_apply_policy( + task, os.path.join(POLICY_PATH, 'use_spot_for_gpu.yaml')) + for r in dag.tasks[0].resources: + if r.accelerators: + assert r.use_spot, 'use_spot should be True' + else: + assert not r.use_spot, 'use_spot should be False' + + +def test_add_labels_policy(add_example_policy_paths, task): + dag, _ = _load_task_and_apply_policy( + task, os.path.join(POLICY_PATH, 'add_labels.yaml')) + assert 'app' in skypilot_config.get_nested( + ('kubernetes', 'custom_metadata', 'labels'), + {}), ('label should be set') + + +def test_reject_all_policy(add_example_policy_paths, task): + with pytest.raises(exceptions.UserRequestRejectedByPolicy, + match='Reject all policy'): + _load_task_and_apply_policy( + task, os.path.join(POLICY_PATH, 'reject_all.yaml')) + + +def test_enforce_autostop_policy(add_example_policy_paths, task): + + def _gen_cluster_record(status: sky.ClusterStatus, autostop: int) -> dict: + return { + 'name': 'test', + 'status': status, + 'autostop': autostop, + } + + # Cluster does not exist + with mock.patch('sky.status', return_value=[]): + _load_task_and_apply_policy(task, + os.path.join(POLICY_PATH, + 'enforce_autostop.yaml'), + idle_minutes_to_autostop=10) + + with pytest.raises(exceptions.UserRequestRejectedByPolicy, + match='Autostop/down must be set'): + _load_task_and_apply_policy(task, + os.path.join(POLICY_PATH, + 'enforce_autostop.yaml'), + idle_minutes_to_autostop=None) + + # Cluster is stopped + with mock.patch( + 'sky.status', + return_value=[_gen_cluster_record(sky.ClusterStatus.STOPPED, 10)]): + _load_task_and_apply_policy(task, + os.path.join(POLICY_PATH, + 'enforce_autostop.yaml'), + idle_minutes_to_autostop=10) + with pytest.raises(exceptions.UserRequestRejectedByPolicy, + match='Autostop/down must be set'): + _load_task_and_apply_policy(task, + os.path.join(POLICY_PATH, + 'enforce_autostop.yaml'), + idle_minutes_to_autostop=None) + + # Cluster is running but autostop is not set + with mock.patch( + 'sky.status', + return_value=[_gen_cluster_record(sky.ClusterStatus.UP, -1)]): + _load_task_and_apply_policy(task, + os.path.join(POLICY_PATH, + 'enforce_autostop.yaml'), + idle_minutes_to_autostop=10) + with pytest.raises(exceptions.UserRequestRejectedByPolicy, + match='Autostop/down must be set'): + _load_task_and_apply_policy(task, + os.path.join(POLICY_PATH, + 'enforce_autostop.yaml'), + idle_minutes_to_autostop=None) + + # Cluster is init but autostop is not set + with mock.patch( + 'sky.status', + return_value=[_gen_cluster_record(sky.ClusterStatus.INIT, -1)]): + _load_task_and_apply_policy(task, + os.path.join(POLICY_PATH, + 'enforce_autostop.yaml'), + idle_minutes_to_autostop=10) + with pytest.raises(exceptions.UserRequestRejectedByPolicy, + match='Autostop/down must be set'): + _load_task_and_apply_policy(task, + os.path.join(POLICY_PATH, + 'enforce_autostop.yaml'), + idle_minutes_to_autostop=None) + + # Cluster is running and autostop is set + with mock.patch( + 'sky.status', + return_value=[_gen_cluster_record(sky.ClusterStatus.UP, 10)]): + _load_task_and_apply_policy(task, + os.path.join(POLICY_PATH, + 'enforce_autostop.yaml'), + idle_minutes_to_autostop=10) + _load_task_and_apply_policy(task, + os.path.join(POLICY_PATH, + 'enforce_autostop.yaml'), + idle_minutes_to_autostop=None) diff --git a/tests/unit_tests/test_backend_utils.py b/tests/unit_tests/test_backend_utils.py index cb1b83f1999..5da4410abb9 100644 --- a/tests/unit_tests/test_backend_utils.py +++ b/tests/unit_tests/test_backend_utils.py @@ -1,34 +1,31 @@ +import os import pathlib -from typing import Dict -from unittest.mock import Mock -from unittest.mock import patch - -import pytest +from unittest import mock from sky import clouds from sky import skypilot_config from sky.backends import backend_utils from sky.resources import Resources -from sky.resources import resources_utils -@patch.object(skypilot_config, 'CONFIG_PATH', - './tests/test_yamls/test_aws_config.yaml') -@patch.object(skypilot_config, '_dict', None) -@patch.object(skypilot_config, '_loaded_config_path', None) -@patch('sky.clouds.service_catalog.instance_type_exists', return_value=True) -@patch('sky.clouds.service_catalog.get_accelerators_from_instance_type', - return_value={'fake-acc': 2}) -@patch('sky.clouds.service_catalog.get_image_id_from_tag', - return_value='fake-image') -@patch.object(clouds.aws, 'DEFAULT_SECURITY_GROUP_NAME', 'fake-default-sg') -@patch('sky.check.get_cloud_credential_file_mounts', - return_value='~/.aws/credentials') -@patch('sky.backends.backend_utils._get_yaml_path_from_cluster_name', - return_value='/tmp/fake/path') -@patch('sky.utils.common_utils.fill_template') +# Set env var to test config file. +@mock.patch.object(skypilot_config, '_dict', None) +@mock.patch.object(skypilot_config, '_loaded_config_path', None) +@mock.patch('sky.clouds.service_catalog.instance_type_exists', + return_value=True) +@mock.patch('sky.clouds.service_catalog.get_accelerators_from_instance_type', + return_value={'fake-acc': 2}) +@mock.patch('sky.clouds.service_catalog.get_image_id_from_tag', + return_value='fake-image') +@mock.patch.object(clouds.aws, 'DEFAULT_SECURITY_GROUP_NAME', 'fake-default-sg') +@mock.patch('sky.check.get_cloud_credential_file_mounts', + return_value='~/.aws/credentials') +@mock.patch('sky.backends.backend_utils._get_yaml_path_from_cluster_name', + return_value='/tmp/fake/path') +@mock.patch('sky.utils.common_utils.fill_template') def test_write_cluster_config_w_remote_identity(mock_fill_template, *mocks) -> None: + os.environ['SKYPILOT_CONFIG'] = './tests/test_yamls/test_aws_config.yaml' skypilot_config._try_load_config() cloud = clouds.AWS() diff --git a/tests/unit_tests/test_common_utils.py b/tests/unit_tests/test_common_utils.py index f38e14069e5..38c31263baa 100644 --- a/tests/unit_tests/test_common_utils.py +++ b/tests/unit_tests/test_common_utils.py @@ -1,4 +1,4 @@ -from unittest.mock import patch +from unittest import mock import pytest @@ -33,18 +33,18 @@ def test_check_when_none(self): class TestMakeClusterNameOnCloud: - @patch('sky.utils.common_utils.get_user_hash') + @mock.patch('sky.utils.common_utils.get_user_hash') def test_make(self, mock_get_user_hash): mock_get_user_hash.return_value = MOCKED_USER_HASH assert "lora-ab12" == common_utils.make_cluster_name_on_cloud("lora") - @patch('sky.utils.common_utils.get_user_hash') + @mock.patch('sky.utils.common_utils.get_user_hash') def test_make_with_hyphen(self, mock_get_user_hash): mock_get_user_hash.return_value = MOCKED_USER_HASH assert "seed-1-ab12" == common_utils.make_cluster_name_on_cloud( "seed-1") - @patch('sky.utils.common_utils.get_user_hash') + @mock.patch('sky.utils.common_utils.get_user_hash') def test_make_with_characters_to_transform(self, mock_get_user_hash): mock_get_user_hash.return_value = MOCKED_USER_HASH assert "cuda-11-8-ab12" == common_utils.make_cluster_name_on_cloud( diff --git a/tests/unit_tests/test_resources.py b/tests/unit_tests/test_resources.py index 70da0532e9b..01b83132a1b 100644 --- a/tests/unit_tests/test_resources.py +++ b/tests/unit_tests/test_resources.py @@ -1,6 +1,7 @@ +import importlib +import os from typing import Dict -from unittest.mock import Mock -from unittest.mock import patch +from unittest import mock import pytest @@ -23,12 +24,12 @@ def test_get_reservations_available_resources(): - mock = Mock() - r = Resources(cloud=mock, instance_type="instance_type") + mock_cloud = mock.Mock() + r = Resources(cloud=mock_cloud, instance_type="instance_type") r._region = "region" r._zone = "zone" r.get_reservations_available_resources() - mock.get_reservations_available_resources.assert_called_once_with( + mock_cloud.get_reservations_available_resources.assert_called_once_with( "instance_type", "region", "zone", set()) @@ -91,18 +92,16 @@ def test_kubernetes_labels_resources(): _run_label_test(allowed_labels, invalid_labels, cloud) -@patch.object(skypilot_config, 'CONFIG_PATH', - './tests/test_yamls/test_aws_config.yaml') -@patch.object(skypilot_config, '_dict', None) -@patch.object(skypilot_config, '_loaded_config_path', None) -@patch('sky.clouds.service_catalog.instance_type_exists', return_value=True) -@patch('sky.clouds.service_catalog.get_accelerators_from_instance_type', - return_value={'fake-acc': 2}) -@patch('sky.clouds.service_catalog.get_image_id_from_tag', - return_value='fake-image') -@patch.object(clouds.aws, 'DEFAULT_SECURITY_GROUP_NAME', 'fake-default-sg') +@mock.patch('sky.clouds.service_catalog.instance_type_exists', + return_value=True) +@mock.patch('sky.clouds.service_catalog.get_accelerators_from_instance_type', + return_value={'fake-acc': 2}) +@mock.patch('sky.clouds.service_catalog.get_image_id_from_tag', + return_value='fake-image') +@mock.patch.object(clouds.aws, 'DEFAULT_SECURITY_GROUP_NAME', 'fake-default-sg') def test_aws_make_deploy_variables(*mocks) -> None: - skypilot_config._try_load_config() + os.environ['SKYPILOT_CONFIG'] = './tests/test_yamls/test_aws_config.yaml' + importlib.reload(skypilot_config) cloud = clouds.AWS() cluster_name = resources_utils.ClusterName(display_name='display',