Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Add Azure ML Compute Instance Support #3905

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions sky/adaptors/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ def get_client(name: str,
from azure.mgmt import authorization
return authorization.AuthorizationManagementClient(
credential, subscription_id)
elif name == 'ml':
rg = kwargs.pop('resource_group', None)
ws = kwargs.pop('workspace_name', None)
assert rg is not None, ('Must provide resource_group keyword '
'arguments for ML client.')
assert ws is not None, ('Must provide workspace keyword '
'arguments for ML client.')
from azure.ai.ml import MLClient
return MLClient(credential, subscription_id, rg, ws)
elif name == 'graph':
import msgraph
return msgraph.GraphServiceClient(credential)
Expand Down Expand Up @@ -459,6 +468,30 @@ def create_security_rule(**kwargs):
return models.SecurityRule(**kwargs)


@common.load_lazy_modules(modules=_LAZY_MODULES)
def create_az_ml_workspace(**kwargs):
from azure.ai.ml import entities
return entities.Workspace(**kwargs)


@common.load_lazy_modules(modules=_LAZY_MODULES)
def create_az_ml_compute_instance(**kwargs):
from azure.ai.ml import entities
return entities.ComputeInstance(**kwargs)


@common.load_lazy_modules(modules=_LAZY_MODULES)
def create_az_ml_compute_instance_ssh_settings(**kwargs):
from azure.ai.ml import entities
return entities.ComputeInstanceSshSettings(**kwargs)


@common.load_lazy_modules(modules=_LAZY_MODULES)
def create_az_ml_network_settings(**kwargs):
from azure.ai.ml import entities
return entities.NetworkSettings(**kwargs)


@common.load_lazy_modules(modules=_LAZY_MODULES)
def deployment_mode():
"""Azure deployment mode."""
Expand Down
30 changes: 30 additions & 0 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from sky import resources as resources_lib
from sky import serve as serve_lib
from sky import sky_logging
from sky import skypilot_config
from sky import status_lib
from sky import task as task_lib
from sky.backends import backend_utils
Expand Down Expand Up @@ -2616,6 +2617,21 @@ def check_resources_fit_cluster(
# was handled by ResourceHandle._update_cluster_region.
assert launched_resources.region is not None, handle

# Check whether Azure cluster uses Azure ML API
if launched_resources.cloud.is_same_cloud(clouds.Azure()):
task_use_az_ml = skypilot_config.get_nested(('azure', 'use_az_ml'),
False)
cluster_use_az_ml = launched_resources.use_az_ml
if cluster_use_az_ml != task_use_az_ml:
task_str = 'uses' if task_use_az_ml else 'does not use'
cluster_str = 'uses' if cluster_use_az_ml else 'does not use'
with ux_utils.print_exception_no_traceback():
raise exceptions.ResourcesMismatchError(
f'Task requirements {task_str} Azure ML API, but the '
f'specified cluster {cluster_name} {cluster_str} it. '
f'Please set azure.use_az_ml to {cluster_use_az_ml} in '
'~/.sky/config.yaml.')

mismatch_str = (f'To fix: specify a new cluster name, or down the '
f'existing cluster first: sky down {cluster_name}')
valid_resource = None
Expand Down Expand Up @@ -3510,6 +3526,20 @@ def _teardown(self,
else:
raise

if handle.launched_resources.cloud.is_same_cloud(clouds.Azure()):
task_use_az_ml = skypilot_config.get_nested(('azure', 'use_az_ml'),
False)
cluster_use_az_ml = handle.launched_resources.use_az_ml
if cluster_use_az_ml != task_use_az_ml:
task_str = 'uses' if task_use_az_ml else 'does not use'
cluster_str = 'uses' if cluster_use_az_ml else 'does not use'
with ux_utils.print_exception_no_traceback():
raise exceptions.ResourcesMismatchError(
f'Current setup {task_str} Azure ML API, but the '
f'specified cluster {cluster_name} to terminate '
f'{cluster_str} it. Please set azure.use_az_ml '
f'to {cluster_use_az_ml} in ~/.sky/config.yaml.')

lock_path = os.path.expanduser(
backend_utils.CLUSTER_STATUS_LOCK_PATH.format(cluster_name))

Expand Down
7 changes: 7 additions & 0 deletions sky/clouds/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sky import clouds
from sky import exceptions
from sky import sky_logging
from sky import skypilot_config
from sky.adaptors import azure
from sky.clouds import service_catalog
from sky.utils import common_utils
Expand Down Expand Up @@ -58,6 +59,10 @@ class Azure(clouds.Cloud):
# names, so the limit is 64 - 4 - 7 - 10 = 43.
# Reference: https://azure.github.io/PSRule.Rules.Azure/en/rules/Azure.ResourceGroup.Name/ # pylint: disable=line-too-long
_MAX_CLUSTER_NAME_LEN_LIMIT = 42
# Azure ML has a 24 character limit for instance names. Here we reserve 4
# characters for the multi-node suffix '-n{c}'. Please reference to
# sky/provision/azure/instance.py::_create_instances for more details.
_AZ_ML_MAX_NAME_LEN_LIMIT = 20
_BEST_DISK_TIER = resources_utils.DiskTier.HIGH
_DEFAULT_DISK_TIER = resources_utils.DiskTier.MEDIUM
# Azure does not support high disk and ultra disk tier.
Expand Down Expand Up @@ -85,6 +90,8 @@ def _unsupported_features_for_resources(

@classmethod
def max_cluster_name_length(cls) -> int:
if skypilot_config.get_nested(('azure', 'use_az_ml'), False):
return cls._AZ_ML_MAX_NAME_LEN_LIMIT
return cls._MAX_CLUSTER_NAME_LEN_LIMIT

def instance_type_to_hourly_cost(self,
Expand Down
16 changes: 13 additions & 3 deletions sky/clouds/service_catalog/azure_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Dict, List, Optional, Tuple

from sky import clouds as cloud_lib
from sky import skypilot_config
from sky.clouds import Azure
from sky.clouds.service_catalog import common
from sky.utils import resources_utils
Expand All @@ -18,7 +19,14 @@
# user is using the latest catalog.
_PULL_FREQUENCY_HOURS = 7

_df = common.read_catalog('azure/vms.csv',
_USE_AZ_ML = skypilot_config.get_nested(('azure', 'use_az_ml'), False)

if _USE_AZ_ML:
catalog_file = 'azure/az_ml_vms.csv'
else:
catalog_file = 'azure/vms.csv'

_df = common.read_catalog(catalog_file,
pull_frequency_hours=_PULL_FREQUENCY_HOURS)

_image_df = common.read_catalog('azure/images.csv',
Expand All @@ -29,11 +37,13 @@
# The latest general-purpose instance family as of Mar. 2023.
# CPU: Intel Ice Lake 8370C.
# Memory: 4 GiB RAM per 1 vCPU;
'Ds_v5',
# Azure ML only supports up to Ds_v3.
'Ds_v5' if not _USE_AZ_ML else 'Ds_v3',
# The latest memory-optimized instance family as of Mar. 2023.
# CPU: Intel Ice Lake 8370C.
# Memory: 8 GiB RAM per 1 vCPU.
'Es_v5',
# Azure ML only supports up to Es_v3.
'Es_v5' if not _USE_AZ_ML else 'Es_v3',
# The latest compute-optimized instance family as of Mar 2023.
# CPU: Intel Ice Lake 8370C, Cascade Lake 8272CL, or Skylake 8168.
# Memory: 2 GiB RAM per 1 vCPU.
Expand Down
29 changes: 20 additions & 9 deletions sky/clouds/service_catalog/data_fetchers/fetch_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,8 @@ def get_additional_columns(row):
return df_ret


if __name__ == '__main__':
parser = argparse.ArgumentParser()
def get_arg_parser(description: str) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=description)
group = parser.add_mutually_exclusive_group()
group.add_argument('--all-regions',
action='store_true',
Expand All @@ -327,23 +327,34 @@ def get_additional_columns(row):
'running in github action, as the multiprocessing '
'does not work well with the azure client due '
'to ssl issues.')
args = parser.parse_args()
return parser

SINGLE_THREADED = args.single_threaded

if args.regions:
region_filter = set(args.regions) - EXCLUDED_REGIONS
elif args.all_regions:
def get_region_filter(all_regions: bool, regions: Optional[List[str]],
exclude: Optional[List[str]]) -> Set[str]:
if regions:
region_filter = set(regions) - EXCLUDED_REGIONS
elif all_regions:
region_filter = set(get_regions()) - EXCLUDED_REGIONS
else:
region_filter = US_REGIONS
region_filter = region_filter - set(
args.exclude) if args.exclude else region_filter
exclude) if exclude is not None else region_filter

if not region_filter:
raise ValueError('No regions to fetch. Please check your arguments.')

instance_df = get_all_regions_instance_types_df(region_filter)
return region_filter


if __name__ == '__main__':
az_parser = get_arg_parser('Fetch Azure pricing data.')
args = az_parser.parse_args()

SINGLE_THREADED = args.single_threaded

instance_df = get_all_regions_instance_types_df(
get_region_filter(args.all_regions, args.regions, args.exclude))
os.makedirs('azure', exist_ok=True)
instance_df.to_csv('azure/vms.csv', index=False)
print('Azure Service Catalog saved to azure/vms.csv')
120 changes: 120 additions & 0 deletions sky/clouds/service_catalog/data_fetchers/fetch_azure_ml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""A script to fetch Azure ML pricing data.

Requires running fetch_azure.py first to get the pricing data.
"""
from multiprocessing import pool as mp_pool
import os
import typing
from typing import Dict, Set

from azure.ai import ml
from azure.ai.ml import entities

from sky.adaptors import azure
from sky.adaptors import common as adaptors_common
from sky.clouds.service_catalog import common
from sky.clouds.service_catalog.data_fetchers import fetch_azure

if typing.TYPE_CHECKING:
import pandas as pd
else:
pd = adaptors_common.LazyImport('pandas')

SUBSCRIPTION_ID = azure.get_subscription_id()

SINGLE_THREADED = False

az_df = common.read_catalog('azure/vms.csv')


def init_ml_client(region: str) -> ml.MLClient:
resource_client = azure.get_client('resource', SUBSCRIPTION_ID)
resource_group_name = f'az-ml-fetcher-{region}'
workspace_name = f'az-ml-fetcher-{region}-ws'
resource_client.resource_groups.create_or_update(resource_group_name,
{'location': region})
ml_client: ml.MLClient = azure.get_client(
'ml',
SUBSCRIPTION_ID,
resource_group=resource_group_name,
workspace_name=workspace_name)
try:
ml_client.workspaces.get(workspace_name)
except azure.exceptions().ResourceNotFoundError:
print(f'Creating workspace {workspace_name} in {region}')
ws = ml_client.workspaces.begin_create(
entities.Workspace(name=workspace_name, location=region)).result()
print(f'Created workspace {ws.name} in {ws.location}.')
return ml_client


def get_supported_instance_type(region: str) -> Dict[str, bool]:
ml_client = init_ml_client(region)
supported_instance_types = {}
for sz in ml_client.compute.list_sizes():
if sz.supported_compute_types is None:
continue
if 'ComputeInstance' not in sz.supported_compute_types:
continue
supported_instance_types[sz.name] = sz.low_priority_capable
return supported_instance_types


def get_instance_type_df(region: str) -> 'pd.DataFrame':
supported_instance_type = get_supported_instance_type(region)
df_filtered = az_df[az_df['Region'] == region].copy()
df_filtered = df_filtered[df_filtered['InstanceType'].isin(
supported_instance_type.keys())]

def _get_spot_price(row):
ins_type = row['InstanceType']
assert ins_type in supported_instance_type, (
f'Instance type {ins_type} not in supported_instance_type')
if supported_instance_type[ins_type]:
return row['SpotPrice']
return None

df_filtered['SpotPrice'] = df_filtered.apply(_get_spot_price, axis=1)

supported_set = set(supported_instance_type.keys())
df_set = set(az_df[az_df['Region'] == region]['InstanceType'])
missing_instance_types = supported_set - df_set
missing_str = ', '.join(missing_instance_types)
if missing_instance_types:
print(f'Missing instance types for {region}: {missing_str}')
else:
print(f'All supported instance types for {region} are in the catalog.')

return df_filtered


def get_all_regions_instance_types_df(region_set: Set[str]) -> 'pd.DataFrame':
if SINGLE_THREADED:
dfs = [get_instance_type_df(region) for region in region_set]
else:
with mp_pool.Pool() as pool:
dfs_result = pool.map_async(get_instance_type_df, region_set)
dfs = dfs_result.get()
df = pd.concat(dfs, ignore_index=True)
df = df.sort_values(by='InstanceType').reset_index(drop=True)
return df


if __name__ == '__main__':
az_ml_parser = fetch_azure.get_arg_parser('Fetch Azure ML pricing data.')
# TODO(tian): Support cleanup after fetching the data.
az_ml_parser.add_argument(
'--cleanup',
action='store_true',
help='Cleanup the resource group and workspace after '
'fetching the data.')
args = az_ml_parser.parse_args()

SINGLE_THREADED = args.single_threaded

instance_df = get_all_regions_instance_types_df(
fetch_azure.get_region_filter(args.all_regions, args.regions,
args.exclude))
os.makedirs('azure', exist_ok=True)
instance_df.to_csv('azure/az_ml_vms.csv', index=False)
print('Azure ML Service Catalog saved to azure/az_ml_vms.csv')
Loading
Loading