Skip to content

Commit

Permalink
[k8s] Attach custom metadata to objects created by SkyPilot (#3333)
Browse files Browse the repository at this point in the history
* Add support for custom metadata - untested

* Add support for custom metadata

* disallow name and namespace

* lint

* spaces
  • Loading branch information
romilbhardwaj committed Mar 22, 2024
1 parent 0b323bb commit 8f4f6f8
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 50 deletions.
12 changes: 12 additions & 0 deletions docs/source/reference/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,18 @@ Available fields and semantics:
# for details on deploying the NGINX ingress controller.
ports: loadbalancer
# Attach custom metadata to Kubernetes objects created by SkyPilot
#
# Uses the same schema as Kubernetes metadata object: https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.26/#objectmeta-v1-meta
#
# Since metadata is applied to all all objects created by SkyPilot,
# specifying 'name' and 'namespace' fields here is not allowed.
custom_metadata:
labels:
mylabel: myvalue
annotations:
myannotation: myvalue
# Additional fields to override the pod fields used by SkyPilot (optional)
#
# Any key:value pairs added here would get added to the pod spec used to
Expand Down
16 changes: 13 additions & 3 deletions sky/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,10 +415,20 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
public_key = f.read()
if not public_key.endswith('\n'):
public_key += '\n'
secret_metadata = k8s.client.V1ObjectMeta(name=secret_name,
labels={'parent': 'skypilot'})

# Generate metadata
secret_metadata = {
'name': secret_name,
'labels': {
'parent': 'skypilot'
}
}
custom_metadata = skypilot_config.get_nested(
('kubernetes', 'custom_metadata'), {})
kubernetes_utils.merge_dicts(custom_metadata, secret_metadata)

secret = k8s.client.V1Secret(
metadata=secret_metadata,
metadata=k8s.client.V1ObjectMeta(**secret_metadata),
string_data={secret_field_name: public_key})
if kubernetes_utils.check_secret_exists(secret_name, namespace):
logger.debug(f'Key {secret_name} exists in the cluster, patching it...')
Expand Down
1 change: 1 addition & 0 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,7 @@ def write_cluster_config(
# Add kubernetes config fields from ~/.sky/config
if isinstance(cloud, clouds.Kubernetes):
kubernetes_utils.combine_pod_config_fields(tmp_yaml_path)
kubernetes_utils.combine_metadata_fields(tmp_yaml_path)

# Restore the old yaml content for backward compatibility.
if os.path.exists(yaml_path) and keep_launch_fields_in_existing_config:
Expand Down
8 changes: 8 additions & 0 deletions sky/provision/kubernetes/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sky.adaptors import kubernetes
from sky.provision import common
from sky.provision.kubernetes import network_utils
from sky.provision.kubernetes import utils as kubernetes_utils
from sky.utils import kubernetes_enums
from sky.utils.resources_utils import port_ranges_to_set

Expand Down Expand Up @@ -46,6 +47,10 @@ def _open_ports_using_loadbalancer(
selector_key='skypilot-cluster',
selector_value=cluster_name_on_cloud,
)

# Update metadata from config
kubernetes_utils.merge_custom_metadata(content['service_spec']['metadata'])

network_utils.create_or_replace_namespaced_service(
namespace=provider_config.get('namespace', 'default'),
service_name=service_name,
Expand Down Expand Up @@ -93,12 +98,15 @@ def _open_ports_using_ingress(

# Create or update services based on the generated specs
for service_name, service_spec in content['services_spec'].items():
# Update metadata from config
kubernetes_utils.merge_custom_metadata(service_spec['metadata'])
network_utils.create_or_replace_namespaced_service(
namespace=provider_config.get('namespace', 'default'),
service_name=service_name,
service_spec=service_spec,
)

kubernetes_utils.merge_custom_metadata(content['ingress_spec']['metadata'])
# Create or update the single ingress for all services
network_utils.create_or_replace_namespaced_ingress(
namespace=provider_config.get('namespace', 'default'),
Expand Down
143 changes: 98 additions & 45 deletions sky/provision/kubernetes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,10 @@ def setup_ssh_jump_svc(ssh_jump_name: str, namespace: str,
# Fill in template - ssh_key_secret and ssh_jump_image are not required for
# the service spec, so we pass in empty strs.
content = fill_ssh_jump_template('', '', ssh_jump_name, service_type.value)

# Add custom metadata from config
merge_custom_metadata(content['service_spec']['metadata'])

# Create service
try:
kubernetes.core_api().create_namespaced_service(namespace,
Expand Down Expand Up @@ -885,6 +889,11 @@ def setup_ssh_jump_pod(ssh_jump_name: str, ssh_jump_image: str,
# required, so we pass in empty str.
content = fill_ssh_jump_template(ssh_key_secret, ssh_jump_image,
ssh_jump_name, '')

# Add custom metadata to all objects
for object_type in content.keys():
merge_custom_metadata(content[object_type]['metadata'])

# ServiceAccount
try:
kubernetes.core_api().create_namespaced_service_account(
Expand Down Expand Up @@ -1059,7 +1068,48 @@ def get_endpoint_debug_message() -> str:
debug_cmd=debug_cmd)


def combine_pod_config_fields(config_yaml_path: str) -> None:
def merge_dicts(source: Dict[Any, Any], destination: Dict[Any, Any]):
"""Merge two dictionaries into the destination dictionary.
Updates nested dictionaries instead of replacing them.
If a list is encountered, it will be appended to the destination list.
An exception is when the key is 'containers', in which case the
first container in the list will be fetched and merge_dict will be
called on it with the first container in the destination list.
"""
for key, value in source.items():
if isinstance(value, dict) and key in destination:
merge_dicts(value, destination[key])
elif isinstance(value, list) and key in destination:
assert isinstance(destination[key], list), \
f'Expected {key} to be a list, found {destination[key]}'
if key == 'containers':
# If the key is 'containers', we take the first and only
# container in the list and merge it.
assert len(value) == 1, \
f'Expected only one container, found {value}'
merge_dicts(value[0], destination[key][0])
elif key in ['volumes', 'volumeMounts']:
# If the key is 'volumes' or 'volumeMounts', we search for
# item with the same name and merge it.
for new_volume in value:
new_volume_name = new_volume.get('name')
if new_volume_name is not None:
destination_volume = next(
(v for v in destination[key]
if v.get('name') == new_volume_name), None)
if destination_volume is not None:
merge_dicts(new_volume, destination_volume)
else:
destination[key].append(new_volume)
else:
destination[key].extend(value)
else:
destination[key] = value


def combine_pod_config_fields(cluster_yaml_path: str) -> None:
"""Adds or updates fields in the YAML with fields from the ~/.sky/config's
kubernetes.pod_spec dict.
This can be used to add fields to the YAML that are not supported by
Expand Down Expand Up @@ -1098,60 +1148,63 @@ def combine_pod_config_fields(config_yaml_path: str) -> None:
- name: my-secret
```
"""

def _merge_dicts(source, destination):
"""Merge two dictionaries.
Updates nested dictionaries instead of replacing them.
If a list is encountered, it will be appended to the destination list.
An exception is when the key is 'containers', in which case the
first container in the list will be fetched and _merge_dict will be
called on it with the first container in the destination list.
"""
for key, value in source.items():
if isinstance(value, dict) and key in destination:
_merge_dicts(value, destination[key])
elif isinstance(value, list) and key in destination:
assert isinstance(destination[key], list), \
f'Expected {key} to be a list, found {destination[key]}'
if key == 'containers':
# If the key is 'containers', we take the first and only
# container in the list and merge it.
assert len(value) == 1, \
f'Expected only one container, found {value}'
_merge_dicts(value[0], destination[key][0])
elif key in ['volumes', 'volumeMounts']:
# If the key is 'volumes' or 'volumeMounts', we search for
# item with the same name and merge it.
for new_volume in value:
new_volume_name = new_volume.get('name')
if new_volume_name is not None:
destination_volume = next(
(v for v in destination[key]
if v.get('name') == new_volume_name), None)
if destination_volume is not None:
_merge_dicts(new_volume, destination_volume)
else:
destination[key].append(new_volume)
else:
destination[key].extend(value)
else:
destination[key] = value

with open(config_yaml_path, 'r', encoding='utf-8') as f:
with open(cluster_yaml_path, 'r', encoding='utf-8') as f:
yaml_content = f.read()
yaml_obj = yaml.safe_load(yaml_content)
kubernetes_config = skypilot_config.get_nested(('kubernetes', 'pod_config'),
{})

# Merge the kubernetes config into the YAML for both head and worker nodes.
_merge_dicts(
merge_dicts(
kubernetes_config,
yaml_obj['available_node_types']['ray_head_default']['node_config'])

# Write the updated YAML back to the file
common_utils.dump_yaml(config_yaml_path, yaml_obj)
common_utils.dump_yaml(cluster_yaml_path, yaml_obj)


def combine_metadata_fields(cluster_yaml_path: str) -> None:
"""Updates the metadata for all Kubernetes objects created by SkyPilot with
fields from the ~/.sky/config's kubernetes.custom_metadata dict.
Obeys the same add or update semantics as combine_pod_config_fields().
"""

with open(cluster_yaml_path, 'r', encoding='utf-8') as f:
yaml_content = f.read()
yaml_obj = yaml.safe_load(yaml_content)
custom_metadata = skypilot_config.get_nested(
('kubernetes', 'custom_metadata'), {})

# List of objects in the cluster YAML to be updated
combination_destinations = [
# Service accounts
yaml_obj['provider']['autoscaler_service_account']['metadata'],
yaml_obj['provider']['autoscaler_role']['metadata'],
yaml_obj['provider']['autoscaler_role_binding']['metadata'],
yaml_obj['provider']['autoscaler_service_account']['metadata'],
# Pod spec
yaml_obj['available_node_types']['ray_head_default']['node_config']
['metadata'],
# Services for pods
*[svc['metadata'] for svc in yaml_obj['provider']['services']]
]

for destination in combination_destinations:
merge_dicts(custom_metadata, destination)

# Write the updated YAML back to the file
common_utils.dump_yaml(cluster_yaml_path, yaml_obj)


def merge_custom_metadata(original_metadata: Dict[str, Any]) -> None:
"""Merges original metadata with custom_metadata from config
Merge is done in-place, so return is not required
"""
custom_metadata = skypilot_config.get_nested(
('kubernetes', 'custom_metadata'), {})
merge_dicts(custom_metadata, original_metadata)


def check_nvidia_runtime_class() -> bool:
Expand Down
8 changes: 6 additions & 2 deletions sky/templates/kubernetes-ssh-jump.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,15 @@ service_account:
kind: ServiceAccount
metadata:
name: sky-ssh-jump-sa
parent: skypilot
labels:
parent: skypilot
role:
kind: Role
apiVersion: rbac.authorization.k8s.io/v1
metadata:
name: sky-ssh-jump-role
labels:
parent: skypilot
rules:
- apiGroups: [""]
resources: ["pods", "pods/status", "pods/exec", "services"]
Expand All @@ -80,7 +83,8 @@ role_binding:
kind: RoleBinding
metadata:
name: sky-ssh-jump-rb
parent: skypilot
labels:
parent: skypilot
subjects:
- kind: ServiceAccount
name: sky-ssh-jump-sa
Expand Down
14 changes: 14 additions & 0 deletions sky/utils/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,20 @@ def get_config_schema():
'required': [],
# Allow arbitrary keys since validating pod spec is hard
'additionalProperties': True,
},
'custom_metadata': {
'type': 'object',
'required': [],
# Allow arbitrary keys since validating metadata is hard
'additionalProperties': True,
# Disallow 'name' and 'namespace' keys in this dict
'not': {
'anyOf': [{
'required': ['name']
}, {
'required': ['namespace']
}]
}
}
}
},
Expand Down

0 comments on commit 8f4f6f8

Please sign in to comment.