Skip to content

Commit

Permalink
Merge branch 'pub-main'
Browse files Browse the repository at this point in the history
  • Loading branch information
mattrasmus committed Jun 21, 2023
2 parents 34b8400 + fb83b0b commit 47fd162
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 25 deletions.
48 changes: 37 additions & 11 deletions redun/executors/gcp_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def __init__(
docker_config["include_aws_env"] = "False"
self._docker_executor = DockerExecutor(name + "_debug", scheduler, config=docker_config)

self.gcp_client = gcp_utils.get_gcp_client()
self.gcp_batch_client = gcp_utils.get_gcp_batch_client()
self.gcp_compute_client = gcp_utils.get_gcp_compute_client()

# Required config.
self.project = config["project"]
Expand Down Expand Up @@ -99,8 +100,6 @@ def __init__(
# Default task options.
self.default_task_options: Dict[str, Any] = {
"machine_type": config.get("machine_type", fallback="e2-standard-4"),
"vcpus": config.getint("vcpus", fallback=2),
"memory": config.getint("memory", fallback=16),
"task_count": config.getint("task_count", fallback=1),
"retries": config.getint("retries", fallback=2),
"priority": config.getint("priority", fallback=30),
Expand Down Expand Up @@ -128,7 +127,7 @@ def _is_debug_job(self, job: RedunJob) -> bool:

def gather_inflight_jobs(self) -> None:
batch_jobs = gcp_utils.list_jobs(
client=self.gcp_client, project_id=self.project, region=self.region
client=self.gcp_batch_client, project_id=self.project, region=self.region
)
inflight_job_statuses = [
JobStatus.State.SCHEDULED,
Expand Down Expand Up @@ -170,7 +169,7 @@ def gather_inflight_jobs(self) -> None:
eval_hashes = cast(str, eval_file.read("r")).splitlines()

batch_tasks = gcp_utils.list_tasks(
client=self.gcp_client, group_name=job.task_groups[0].name
client=self.gcp_batch_client, group_name=job.task_groups[0].name
)
for array_index, task in enumerate(batch_tasks):
# Skip job if it is not in one of the 'inflight' states
Expand Down Expand Up @@ -199,6 +198,31 @@ def _get_job_options(self, job: RedunJob, array_uuid: Optional[str] = None) -> d
REDUN_JOB_TYPE_LABEL_KEY: "script" if job.task.script else "container",
}

project = self.project
region = self.region
machine_type = task_options.get("machine_type")

# If either memory or vCPUS are not provided use the maximum available based on machine
# type otherwise use the min of what is available and what is requested
compute_machine_type = gcp_utils.get_compute_machine_type(
self.gcp_compute_client, project, region, machine_type
)

requested_memory = task_options.get("memory", 0)
# MiB to GB
max_memory = compute_machine_type.memory_mb / 1024
if requested_memory <= 0:
task_options["memory"] = max_memory
else:
task_options["memory"] = min(requested_memory, max_memory)

requested_vcpus = task_options.get("vcpus", 0)
max_vcpus = compute_machine_type.guest_cpus
if requested_vcpus <= 0:
task_options["vcpus"] = max_vcpus
else:
task_options["vcpus"] = min(requested_vcpus, max_vcpus)

# Merge labels if needed.
labels: Dict[str, str] = {
**self.default_task_options.get("labels", {}),
Expand Down Expand Up @@ -267,7 +291,9 @@ def _submit(self, job: RedunJob) -> None:
batch_task_name = self.preexisting_batch_tasks.pop(job.eval_hash)

job_dir = get_job_scratch_dir(self.gcs_scratch_prefix, job)
existing_task = gcp_utils.get_task(client=self.gcp_client, task_name=batch_task_name)
existing_task = gcp_utils.get_task(
client=self.gcp_batch_client, task_name=batch_task_name
)

if existing_task:
self.log(
Expand Down Expand Up @@ -358,7 +384,7 @@ def _submit_array_job(self, jobs: List[RedunJob]) -> str:
)

gcp_job = gcp_utils.batch_submit(
client=self.gcp_client,
client=self.gcp_batch_client,
job_name=f"{REDUN_ARRAY_JOB_PREFIX}{array_uuid}",
project=project,
region=region,
Expand Down Expand Up @@ -428,7 +454,7 @@ def _submit_single_job(self, job: RedunJob) -> None:
)

gcp_job = gcp_utils.batch_submit(
client=self.gcp_client,
client=self.gcp_batch_client,
job_name=f"{REDUN_JOB_PREFIX}{job.id}",
project=project,
region=region,
Expand Down Expand Up @@ -462,7 +488,7 @@ def _submit_single_job(self, job: RedunJob) -> None:
# GCP Batch takes script as a string and requires quoting of -c argument
script_command = ["bash", script_path]
gcp_job = gcp_utils.batch_submit(
client=self.gcp_client,
client=self.gcp_batch_client,
job_name=f"redun-{job.id}",
project=project,
region=region,
Expand Down Expand Up @@ -520,7 +546,7 @@ def _monitor(self) -> None:
assert self._scheduler

# Need new client for thread safety
gcp_client = gcp_utils.get_gcp_client()
gcp_batch_client = gcp_utils.get_gcp_batch_client()

try:
while self.is_running and (self.pending_batch_tasks or self.arrayer.num_pending):
Expand All @@ -539,7 +565,7 @@ def _monitor(self) -> None:
task_names = list(self.pending_batch_tasks.keys())
for name in task_names:
try:
task = gcp_utils.get_task(client=gcp_client, task_name=name)
task = gcp_utils.get_task(client=gcp_batch_client, task_name=name)
self._process_task_status(task)
except NotFound:
# Batch Job has not instantiated tasks yet so ignore this NotFound error
Expand Down
27 changes: 25 additions & 2 deletions redun/executors/gcp_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from enum import Enum
from functools import lru_cache
from typing import Dict, Iterable, List, Tuple, Union

from google.api_core import gapic_v1
from google.cloud import batch_v1
from google.cloud import batch_v1, compute_v1

from redun.version import version

Expand All @@ -20,7 +21,7 @@ class MinCPUPlatform(Enum):
EPYC_MILAN = "AMD Milan"


def get_gcp_client(
def get_gcp_batch_client(
sync: bool = True,
) -> Union[batch_v1.BatchServiceClient, batch_v1.BatchServiceAsyncClient]:
# TODO: Integrate redun version here later.
Expand Down Expand Up @@ -263,3 +264,25 @@ def get_task(client: batch_v1.BatchServiceClient, task_name: str) -> batch_v1.Ta
A Task object representing the specified task.
"""
return client.get_task(name=task_name)


def get_gcp_compute_client() -> compute_v1.MachineTypesClient:
return compute_v1.MachineTypesClient()


@lru_cache(maxsize=None)
def get_compute_machine_type(
client: compute_v1.MachineTypesClient, project: str, region: str, machine_type: str
) -> compute_v1.types.MachineType:
"""
Retrieve information about a GCP MachineType
Args:
project_id: project ID or project number of the Cloud project you want to use.
zone: name of the zone for the machine type.
machine_type: the machine type to get details about.
Returns:
A MachineType which provides information about a machines available vCPUs and memory.
"""

return client.get(project=project, zone=f"{region}-a", machine_type=machine_type)
45 changes: 34 additions & 11 deletions redun/tests/test_gcp_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from unittest.mock import Mock, patch

import boto3
from google.cloud import batch_v1
from google.cloud import batch_v1, compute_v1
from google.protobuf.json_format import MessageToDict # type: ignore

from redun import File, task
Expand Down Expand Up @@ -71,21 +71,25 @@ def task1(x: int) -> int:

@use_tempdir
@mock_s3
@patch("redun.executors.gcp_utils.get_gcp_client")
@patch("redun.file.get_filesystem_class", get_filesystem_class_mock)
@patch("redun.executors.gcp_utils.get_gcp_compute_client")
@patch("redun.executors.gcp_utils.get_compute_machine_type")
@patch("redun.executors.gcp_utils.get_gcp_batch_client")
@patch("redun.executors.gcp_utils.list_jobs")
@patch("redun.executors.gcp_utils.get_task")
def test_executor(
get_task_mock: Mock,
list_jobs_mock: Mock,
get_gcp_client_mock: Mock,
get_gcp_batch_client_mock: Mock,
get_compute_machine_type_mock: Mock,
get_gcp_compute_client_mock: Mock,
) -> None:
"""
GCPBatchExecutor should run jobs.
"""
scheduler = mock_scheduler()
executor = mock_executor(scheduler)
client = get_gcp_client_mock()
client = get_gcp_batch_client_mock()

# Prepare API mocks for job submission.
batch_job_id = "123"
Expand All @@ -96,6 +100,10 @@ def test_executor(
name=f"{batch_job_id}/tasks/0",
status=batch_v1.TaskStatus(state=batch_v1.TaskStatus.State.SUCCEEDED),
)
get_compute_machine_type_mock.return_value = compute_v1.types.MachineType(
memory_mb=16384,
guest_cpus=2,
)

# Create and submit a job.
expr = cast(TaskExpression[int], task1(10))
Expand Down Expand Up @@ -199,20 +207,24 @@ def test_executor(
@use_tempdir
@mock_s3
@patch("redun.file.get_filesystem_class", get_filesystem_class_mock)
@patch("redun.executors.gcp_utils.get_gcp_client")
@patch("redun.executors.gcp_utils.get_gcp_compute_client")
@patch("redun.executors.gcp_utils.get_compute_machine_type")
@patch("redun.executors.gcp_utils.get_gcp_batch_client")
@patch("redun.executors.gcp_utils.list_jobs")
@patch("redun.executors.gcp_utils.get_task")
def test_executor_array(
get_task_mock: Mock,
list_jobs_mock: Mock,
get_gcp_client_mock: Mock,
get_gcp_batch_client_mock: Mock,
get_compute_machine_type_mock: Mock,
get_gcp_compute_client_mock: Mock,
) -> None:
"""
GCPBatchExecutor should be able to submit array jobs.
"""
scheduler = mock_scheduler()
executor = mock_executor(scheduler)
client = get_gcp_client_mock()
client = get_gcp_batch_client_mock()

# Suppress inflight jobs check.
list_jobs_mock.return_value = []
Expand All @@ -239,6 +251,10 @@ def get_task(client, task_name):
)

get_task_mock.side_effect = get_task
get_compute_machine_type_mock.return_value = compute_v1.types.MachineType(
memory_mb=16384,
guest_cpus=2,
)

# Submit two jobs in order to trigger array submission.
# Create and submit a job.
Expand Down Expand Up @@ -329,20 +345,24 @@ def get_task(client, task_name):
@use_tempdir
@mock_s3
@patch("redun.file.get_filesystem_class", get_filesystem_class_mock)
@patch("redun.executors.gcp_utils.get_gcp_client")
@patch("redun.executors.gcp_utils.get_gcp_compute_client")
@patch("redun.executors.gcp_utils.get_compute_machine_type")
@patch("redun.executors.gcp_utils.get_gcp_batch_client")
@patch("redun.executors.gcp_utils.list_jobs")
@patch("redun.executors.gcp_utils.get_task")
def test_executor_script(
get_task_mock: Mock,
list_jobs_mock: Mock,
get_gcp_client_mock: Mock,
get_gcp_batch_client_mock: Mock,
get_compute_machine_type_mock: Mock,
get_gcp_compute_client_mock: Mock,
) -> None:
"""
GCPBatchExecutor should run script jobs.
"""
scheduler = mock_scheduler()
executor = mock_executor(scheduler)
client = get_gcp_client_mock()
client = get_gcp_batch_client_mock()

# Prepare API mocks for job submission.
batch_job_id = "123"
Expand All @@ -353,7 +373,10 @@ def test_executor_script(
name=f"{batch_job_id}/tasks/0",
status=batch_v1.TaskStatus(state=batch_v1.TaskStatus.State.SUCCEEDED),
)

get_compute_machine_type_mock.return_value = compute_v1.types.MachineType(
memory_mb=16384,
guest_cpus=2,
)
# Create and submit a script job.
# Simulate the call to script_task().
expr = script(
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ vcrpy==4.1.1
pygraphviz
kubernetes==22.6
google-cloud-batch==0.9.0
google-cloud-compute==1.11.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"glue": ["pandas", "pyarrow", "pyspark"],
"k8s": "kubernetes>=22.6",
"viz": "pygraphviz",
"google-batch": "google-cloud-batch>=0.2.0",
"google-batch": ["google-cloud-batch>=0.2.0", "google-cloud-compute>=1.11.0"],
}

if REQUIRE_POSTGRES:
Expand Down

0 comments on commit 47fd162

Please sign in to comment.