Skip to content

Commit

Permalink
[Core][AWS] Allow specification of Security Groups for resources. (#3501
Browse files Browse the repository at this point in the history
)

* feat: allow security group specification

* feat: refactor format of sg names

* refactor: update readme for security group name update

* fix: format

* refactor: use ClusterName data class

* fix: move warning

* fix: clean code

* fix: clean code

* fix: schema constant

* refactor: add sg test

* fix: pylint

* refactor: updates to use display name

* fix: bug in remote identity and update tests

* fix: formatting

* fix: remove

* fix: format

* fix: missing resources_utils ClusterName

* fix: tests

* fix: bug

* fix: clone_disk_from reference
  • Loading branch information
JGSweets committed Jul 11, 2024
1 parent 5e23f16 commit efe4625
Show file tree
Hide file tree
Showing 26 changed files with 405 additions and 140 deletions.
18 changes: 17 additions & 1 deletion docs/source/reference/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,27 @@ Available fields and semantics:
# Security group (optional).
#
# The name of the security group to use for all instances. If not specified,
# Security group name to use for AWS instances. If not specified,
# SkyPilot will use the default name for the security group: sky-sg-<hash>
# Note: please ensure the security group name specified exists in the
# regions the instances are going to be launched or the AWS account has the
# permission to create a security group.
#
# Some example use cases are shown below. All fields are optional.
# - <string>: apply the service account with the specified name to all instances.
# Example:
# security_group_name: my-security-group
# - <list of single-element dict>: A list of single-element dict mapping from the cluster name (pattern)
# to the security group name to use. The matching of the cluster name is done in the same order
# as the list.
# NOTE: If none of the wildcard expressions in the dict match the cluster name, SkyPilot will use the default
# security group name as mentioned above: sky-sg-<hash>
# To specify your default, use "*" as the wildcard expression.
# Example:
# security_group_name:
# - my-cluster-name: my-security-group-1
# - sky-serve-controller-*: my-security-group-2
# - "*": my-default-security-group
security_group_name: my-security-group
# Identity to use for AWS instances (optional).
Expand Down
21 changes: 14 additions & 7 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@
# we need to take this field from the new yaml.
('provider', 'tpu_node'),
('provider', 'security_group', 'GroupName'),
('available_node_types', 'ray.head.default', 'node_config',
'IamInstanceProfile'),
('available_node_types', 'ray.head.default', 'node_config', 'UserData'),
('available_node_types', 'ray.worker.default', 'node_config', 'UserData'),
]
Expand Down Expand Up @@ -793,8 +795,11 @@ def write_cluster_config(
# move the check out of this function, i.e. the caller should be responsible
# for the validation.
# TODO(tian): Move more cloud agnostic vars to resources.py.
resources_vars = to_provision.make_deploy_variables(cluster_name_on_cloud,
region, zones, dryrun)
resources_vars = to_provision.make_deploy_variables(
resources_utils.ClusterName(
cluster_name,
cluster_name_on_cloud,
), region, zones, dryrun)
config_dict = {}

specific_reservations = set(
Expand All @@ -803,11 +808,13 @@ def write_cluster_config(

assert cluster_name is not None
excluded_clouds = []
remote_identity = skypilot_config.get_nested(
(str(cloud).lower(), 'remote_identity'),
schemas.get_default_remote_identity(str(cloud).lower()))
if remote_identity is not None and not isinstance(remote_identity, str):
for profile in remote_identity:
remote_identity_config = skypilot_config.get_nested(
(str(cloud).lower(), 'remote_identity'), None)
remote_identity = schemas.get_default_remote_identity(str(cloud).lower())
if isinstance(remote_identity_config, str):
remote_identity = remote_identity_config
if isinstance(remote_identity_config, list):
for profile in remote_identity_config:
if fnmatch.fnmatchcase(cluster_name, list(profile.keys())[0]):
remote_identity = list(profile.values())[0]
break
Expand Down
16 changes: 9 additions & 7 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1566,8 +1566,8 @@ def _retry_zones(
to_provision.cloud,
region,
zones,
provisioner.ClusterName(cluster_name,
handle.cluster_name_on_cloud),
resources_utils.ClusterName(
cluster_name, handle.cluster_name_on_cloud),
num_nodes=num_nodes,
cluster_yaml=handle.cluster_yaml,
prev_cluster_ever_up=prev_cluster_ever_up,
Expand All @@ -1577,8 +1577,10 @@ def _retry_zones(
# caller.
resources_vars = (
to_provision.cloud.make_deploy_resources_variables(
to_provision, handle.cluster_name_on_cloud, region,
zones))
to_provision,
resources_utils.ClusterName(
cluster_name, handle.cluster_name_on_cloud),
region, zones))
config_dict['provision_record'] = provision_record
config_dict['resources_vars'] = resources_vars
config_dict['handle'] = handle
Expand Down Expand Up @@ -2898,8 +2900,8 @@ def _provision(
# 4. Starting ray cluster and skylet.
cluster_info = provisioner.post_provision_runtime_setup(
repr(handle.launched_resources.cloud),
provisioner.ClusterName(handle.cluster_name,
handle.cluster_name_on_cloud),
resources_utils.ClusterName(handle.cluster_name,
handle.cluster_name_on_cloud),
handle.cluster_yaml,
provision_record=provision_record,
custom_resource=resources_vars.get('custom_resources'),
Expand Down Expand Up @@ -3877,7 +3879,7 @@ def teardown_no_lock(self,

try:
provisioner.teardown_cluster(repr(cloud),
provisioner.ClusterName(
resources_utils.ClusterName(
cluster_name,
cluster_name_on_cloud),
terminate=terminate,
Expand Down
2 changes: 1 addition & 1 deletion sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3868,7 +3868,7 @@ def _generate_task_with_service(
env: List[Tuple[str, str]],
gpus: Optional[str],
instance_type: Optional[str],
ports: Tuple[str],
ports: Optional[Tuple[str]],
cpus: Optional[str],
memory: Optional[str],
disk_size: Optional[int],
Expand Down
64 changes: 41 additions & 23 deletions sky/clouds/aws.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Amazon Web Services."""
import enum
import fnmatch
import functools
import json
import os
Expand Down Expand Up @@ -370,12 +371,13 @@ def get_vcpus_mem_from_instance_type(
return service_catalog.get_vcpus_mem_from_instance_type(instance_type,
clouds='aws')

def make_deploy_resources_variables(self,
resources: 'resources_lib.Resources',
cluster_name_on_cloud: str,
region: 'clouds.Region',
zones: Optional[List['clouds.Zone']],
dryrun: bool = False) -> Dict[str, Any]:
def make_deploy_resources_variables(
self,
resources: 'resources_lib.Resources',
cluster_name: resources_utils.ClusterName,
region: 'clouds.Region',
zones: Optional[List['clouds.Zone']],
dryrun: bool = False) -> Dict[str, Any]:
del dryrun # unused
assert zones is not None, (region, zones)

Expand All @@ -397,18 +399,32 @@ def make_deploy_resources_variables(self,
image_id = self._get_image_id(image_id_to_use, region_name,
r.instance_type)

user_security_group = skypilot_config.get_nested(
user_security_group_config = skypilot_config.get_nested(
('aws', 'security_group_name'), None)
if resources.ports is not None:
# Already checked in Resources._try_validate_ports
assert user_security_group is None
security_group = USER_PORTS_SECURITY_GROUP_NAME.format(
cluster_name_on_cloud)
elif user_security_group is not None:
assert resources.ports is None
security_group = user_security_group
else:
user_security_group = None
if isinstance(user_security_group_config, str):
user_security_group = user_security_group_config
elif isinstance(user_security_group_config, list):
for profile in user_security_group_config:
if fnmatch.fnmatchcase(cluster_name.display_name,
list(profile.keys())[0]):
user_security_group = list(profile.values())[0]
break
security_group = user_security_group
if security_group is None:
security_group = DEFAULT_SECURITY_GROUP_NAME
if resources.ports is not None:
# Already checked in Resources._try_validate_ports
security_group = USER_PORTS_SECURITY_GROUP_NAME.format(
cluster_name.display_name)
elif resources.ports is not None:
with ux_utils.print_exception_no_traceback():
logger.warning(
f'Skip opening ports {resources.ports} for cluster {cluster_name!r}, '
'as `aws.security_group_name` in `~/.sky/config.yaml` is specified as '
f' {security_group!r}. Please make sure the specified security group '
'has requested ports setup; or, leave out `aws.security_group_name` '
'in `~/.sky/config.yaml`.')

return {
'instance_type': r.instance_type,
Expand Down Expand Up @@ -840,22 +856,24 @@ def query_status(cls, name: str, tag_filters: Dict[str, str],
assert False, 'This code path should not be used.'

@classmethod
def create_image_from_cluster(cls, cluster_name: str,
cluster_name_on_cloud: str,
def create_image_from_cluster(cls,
cluster_name: resources_utils.ClusterName,
region: Optional[str],
zone: Optional[str]) -> str:
assert region is not None, (cluster_name, cluster_name_on_cloud, region)
assert region is not None, (cluster_name.display_name,
cluster_name.name_on_cloud, region)
del zone # unused

image_name = f'skypilot-{cluster_name}-{int(time.time())}'
image_name = f'skypilot-{cluster_name.display_name}-{int(time.time())}'

status = provision_lib.query_instances('AWS', cluster_name_on_cloud,
status = provision_lib.query_instances('AWS',
cluster_name.name_on_cloud,
{'region': region})
instance_ids = list(status.keys())
if not instance_ids:
with ux_utils.print_exception_no_traceback():
raise RuntimeError(
f'Failed to find the source cluster {cluster_name!r} on '
f'Failed to find the source cluster {cluster_name.display_name!r} on '
'AWS.')

if len(instance_ids) != 1:
Expand All @@ -882,7 +900,7 @@ def create_image_from_cluster(cls, cluster_name: str,
stream_logs=True)

rich_utils.force_update_status(
f'Waiting for the source image {cluster_name!r} from {region} to be available on AWS.'
f'Waiting for the source image {cluster_name.display_name!r} from {region} to be available on AWS.'
)
# Wait for the image to be available
wait_image_cmd = (
Expand Down
15 changes: 8 additions & 7 deletions sky/clouds/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,12 +269,13 @@ def get_vcpus_mem_from_instance_type(
def get_zone_shell_cmd(cls) -> Optional[str]:
return None

def make_deploy_resources_variables(self,
resources: 'resources.Resources',
cluster_name_on_cloud: str,
region: 'clouds.Region',
zones: Optional[List['clouds.Zone']],
dryrun: bool = False) -> Dict[str, Any]:
def make_deploy_resources_variables(
self,
resources: 'resources.Resources',
cluster_name: resources_utils.ClusterName,
region: 'clouds.Region',
zones: Optional[List['clouds.Zone']],
dryrun: bool = False) -> Dict[str, Any]:
assert zones is None, ('Azure does not support zones', zones)

region_name = region.name
Expand Down Expand Up @@ -374,7 +375,7 @@ def _failover_disk_tier() -> Optional[resources_utils.DiskTier]:
'disk_tier': Azure._get_disk_type(_failover_disk_tier()),
'cloud_init_setup_commands': cloud_init_setup_commands,
'azure_subscription_id': self.get_project_id(dryrun),
'resource_group': f'{cluster_name_on_cloud}-{region_name}',
'resource_group': f'{cluster_name.name_on_cloud}-{region_name}',
}

def _get_feasible_launchable_resources(
Expand Down
6 changes: 3 additions & 3 deletions sky/clouds/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def is_same_cloud(self, other: 'Cloud') -> bool:
def make_deploy_resources_variables(
self,
resources: 'resources_lib.Resources',
cluster_name_on_cloud: str,
cluster_name: resources_utils.ClusterName,
region: 'Region',
zones: Optional[List['Zone']],
dryrun: bool = False,
Expand Down Expand Up @@ -726,8 +726,8 @@ def query_status(cls, name: str, tag_filters: Dict[str, str],
# cloud._cloud_unsupported_features().

@classmethod
def create_image_from_cluster(cls, cluster_name: str,
cluster_name_on_cloud: str,
def create_image_from_cluster(cls,
cluster_name: resources_utils.ClusterName,
region: Optional[str],
zone: Optional[str]) -> str:
"""Creates an image from the cluster.
Expand Down
4 changes: 2 additions & 2 deletions sky/clouds/cudo.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,12 @@ def get_zone_shell_cmd(cls) -> Optional[str]:
def make_deploy_resources_variables(
self,
resources: 'resources_lib.Resources',
cluster_name_on_cloud: str,
cluster_name: resources_utils.ClusterName,
region: 'clouds.Region',
zones: Optional[List['clouds.Zone']],
dryrun: bool = False,
) -> Dict[str, Optional[str]]:
del zones
del zones, cluster_name # unused
r = resources
acc_dict = self.get_accelerators_from_instance_type(r.instance_type)
if acc_dict is not None:
Expand Down
5 changes: 3 additions & 2 deletions sky/clouds/fluidstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sky import status_lib
from sky.clouds import service_catalog
from sky.provision.fluidstack import fluidstack_utils
from sky.utils import resources_utils
from sky.utils.resources_utils import DiskTier

_CREDENTIAL_FILES = [
Expand Down Expand Up @@ -174,7 +175,7 @@ def get_zone_shell_cmd(cls) -> Optional[str]:
def make_deploy_resources_variables(
self,
resources: 'resources_lib.Resources',
cluster_name_on_cloud: str,
cluster_name: resources_utils.ClusterName,
region: clouds.Region,
zones: Optional[List[clouds.Zone]],
dryrun: bool = False,
Expand All @@ -189,7 +190,7 @@ def make_deploy_resources_variables(
else:
custom_resources = None
cuda_installation_commands = """
sudo wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.1-1_all.deb -O /usr/local/cuda-keyring_1.1-1_all.deb;
sudo wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.1-1_all.deb -O /usr/local/cuda-keyring_1.1-1_all.deb;
sudo dpkg -i /usr/local/cuda-keyring_1.1-1_all.deb;
sudo apt-get update;
sudo apt-get -y install cuda-toolkit-12-3;
Expand Down
Loading

0 comments on commit efe4625

Please sign in to comment.