Skip to content

Commit

Permalink
Add integration tests for Distributing Cloud Tuner.
Browse files Browse the repository at this point in the history
Fix a few issues:
- Duplicate epoch reports to Oracle
- Excessive cash discovery Info/Error log

PiperOrigin-RevId: 339757281
  • Loading branch information
SinaChavoshi authored and Tensorflow Cloud maintainers committed Oct 29, 2020
1 parent 651ce41 commit a79009e
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Lint as: python3
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Integration tests for Distributing Cloud Tuner."""

import contextlib
import io
import os
import re
import kerastuner
import tensorflow as tf
from tensorflow import keras
from tensorflow_cloud.tuner import optimizer_client
from tensorflow_cloud.tuner.tuner import DistributingCloudTuner

# If input dataset is created outside tuner.search(),
# it requires eager execution even in TF 1.x.
if tf.version.VERSION.split(".")[0] == "1":
tf.compat.v1.enable_eager_execution()

# The project id to use to run tests.
_PROJECT_ID = os.environ["PROJECT_ID"]

# The GCP region in which the end-to-end test is run.
_REGION = os.environ["REGION"]

# Study ID for testing
_STUDY_ID_BASE = "dct_{}".format((os.environ["BUILD_ID"]).replace("-", "_"))

# The base docker image to use for the remote environment.
_DOCKER_IMAGE = os.environ["DOCKER_IMAGE"]

# The staging bucket to use to copy the model and data for the remote run.
_REMOTE_DIR = os.path.join("gs://", os.environ["TEST_BUCKET"], _STUDY_ID_BASE)

# The search space for hyperparameters
_HPS = kerastuner.engine.hyperparameters.HyperParameters()
_HPS.Float("learning_rate", min_value=1e-4, max_value=1e-2, sampling="log")
_HPS.Int("num_layers", 2, 10)


def _load_data(dir_path=None):
"""Loads and prepares data."""

mnist_file_path = None
if dir_path:
mnist_file_path = os.path.join(dir_path, "mnist.npz")

(x, y), (val_x, val_y) = keras.datasets.mnist.load_data(mnist_file_path)
x = x.astype("float32") / 255.0
val_x = val_x.astype("float32") / 255.0

return ((x[:10000], y[:10000]), (val_x, val_y))


def _build_model(hparams):
# Note that CloudTuner does not support adding hyperparameters in
# the model building function. Instead, the search space is configured
# by passing a hyperparameters argument when instantiating (constructing)
# the tuner.
model = keras.Sequential()
model.add(keras.layers.Flatten(input_shape=(28, 28)))

# Build the model with number of layers from the hyperparameters
for _ in range(hparams.get("num_layers")):
model.add(keras.layers.Dense(units=64, activation="relu"))
model.add(keras.layers.Dense(10, activation="softmax"))

# Compile the model with learning rate from the hyperparameters
model.compile(
optimizer=keras.optimizers.Adam(lr=hparams.get("learning_rate")),
loss="sparse_categorical_crossentropy",
metrics=["acc"],
)
return model


class _DistributingCloudTunerIntegrationTestBase(tf.test.TestCase):

def setUp(self):
super(_DistributingCloudTunerIntegrationTestBase, self).setUp()
self._study_id = None

def _assert_output(self, fn, regex_str):
stdout = io.StringIO()
with contextlib.redirect_stdout(stdout):
fn()
output = stdout.getvalue()
self.assertRegex(output, re.compile(regex_str, re.DOTALL))

def _assert_results_summary(self, fn):
self._assert_output(
fn, ".*Results summary.*Trial summary.*Hyperparameters.*")

def _delete_dir(self, path) -> None:
"""Deletes a directory if exists."""
if tf.io.gfile.isdir(path):
tf.io.gfile.rmtree(path)

def tearDown(self):
super(_DistributingCloudTunerIntegrationTestBase, self).tearDown()

# Delete the study used in the test, if present
if self._study_id:
service = optimizer_client.create_or_load_study(
_PROJECT_ID, _REGION, self._study_id, None)
service.delete_study()

tf.keras.backend.clear_session()

# Delete log files, saved_models and other training assets
self._delete_dir(_REMOTE_DIR)


class DistributingCloudTunerIntegrationTest(
_DistributingCloudTunerIntegrationTestBase):

def setUp(self):
super(DistributingCloudTunerIntegrationTest, self).setUp()
(self._x, self._y), (self._val_x, self._val_y) = _load_data(
self.get_temp_dir())

def testCloudTunerHyperparameters(self):
"""Test case to configure Distributing Tuner with HyperParameters."""
study_id = "{}_hyperparameters".format(_STUDY_ID_BASE)
self._study_id = study_id

tuner = DistributingCloudTuner(
_build_model,
project_id=_PROJECT_ID,
region=_REGION,
objective="acc",
hyperparameters=_HPS,
max_trials=2,
study_id=study_id,
directory=_REMOTE_DIR,
container_uri=_DOCKER_IMAGE
)

tuner.search(
x=self._x,
y=self._y,
epochs=2,
validation_data=(self._val_x, self._val_y),
)

self._assert_results_summary(tuner.results_summary)

if __name__ == "__main__":
tf.test.main()
18 changes: 11 additions & 7 deletions src/python/tensorflow_cloud/tuner/tests/unit/tuner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,8 +539,9 @@ def test_add_model_checkpoint_callback(self, mock_super_tuner):
auto_spec=True)
@mock.patch.object(tf_utils, "get_tensorboard_log_watcher_from_path",
auto_spec=True)
@mock.patch.object(tf.io.gfile, "makedirs", auto_spec=True)
def test_remote_run_trial_with_successful_job(
self, mock_log_watcher, mock_is_running, mock_super_tuner,
self, mock_tf_io, mock_log_watcher, mock_is_running, mock_super_tuner,
mock_job_status, mock_cloud_fit):
remote_tuner = self._remote_tuner(
None, None, self._study_config, max_trials=10)
Expand Down Expand Up @@ -573,21 +574,23 @@ def test_remote_run_trial_with_successful_job(
image_uri=self._container_uri,
job_id=self._job_id)

log_path = remote_tuner._get_tensorboard_log_dir(
self._test_trial.trial_id)
log_path = os.path.join(remote_tuner._get_tensorboard_log_dir(
self._test_trial.trial_id), "train")
mock_log_watcher.assert_called_with(log_path)
self.assertEqual(
2, remote_tuner._get_remote_training_metrics.call_count)
mock_tf_io.assert_called_with(log_path)

@mock.patch.object(cloud_fit_client, "cloud_fit", auto_spec=True)
@mock.patch.object(google_api_client,
"wait_for_api_training_job_completion", auto_spec=True)
@mock.patch.object(super_tuner.Tuner, "__init__", auto_spec=True)
@mock.patch.object(google_api_client, "is_api_training_job_running",
auto_spec=True)
@mock.patch.object(tf.io.gfile, "makedirs", auto_spec=True)
def test_remote_run_trial_with_failed_job(
self, mock_is_running, mock_super_tuner,
mock_job_status, mock_cloud_fit):
self, mock_tf_io, mock_is_running, mock_super_tuner, mock_job_status,
mock_cloud_fit):

remote_tuner = self._remote_tuner(
None, None, self._study_config, max_trials=10)
Expand All @@ -609,8 +612,9 @@ def test_remote_run_trial_with_failed_job(
@mock.patch.object(super_tuner.Tuner, "__init__", auto_spec=True)
@mock.patch.object(google_api_client, "is_api_training_job_running",
auto_spec=True)
@mock.patch.object(tf.io.gfile, "makedirs", auto_spec=True)
def test_remote_run_trial_with_oracle_canceling_job(
self, mock_is_running, mock_super_tuner,
self, mock_tf_io, mock_is_running, mock_super_tuner,
mock_job_status, mock_cloud_fit, mock_stop_job):

remote_tuner = self._remote_tuner(
Expand Down Expand Up @@ -656,7 +660,7 @@ def test_get_remote_training_metrics(self, mock_super_tuner):
log_reader = tf_utils.get_tensorboard_log_watcher_from_path(log_dir)
results = remote_tuner._get_remote_training_metrics(log_reader, {})

self.assertLen(results.completed_epoch_metrics, 3)
self.assertLen(results.completed_epoch_metrics, 2)
self.assertIn("accuracy", results.completed_epoch_metrics[0])
self.assertIn("loss", results.completed_epoch_metrics[0])
self.assertEqual(
Expand Down
37 changes: 27 additions & 10 deletions src/python/tensorflow_cloud/tuner/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
# metrics from remote training Tensorboard logs during training with:
# - 'completed_epoch_metrics'- a list of epoch metrics for completed epochs.
# - 'partial_epoch_metrics' - Any incomplete epoch metrics for the last epoch.
# If training has completed this will contain metrics for the final epoch of
# training.

_TrainingMetrics = collections.namedtuple("_TrainingMetrics", [
"completed_epoch_metrics", "partial_epoch_metrics"])

Expand Down Expand Up @@ -568,7 +571,11 @@ def run_trial(self, trial, *fit_args, **fit_kwargs):

# Create an instance of tensorboard DirectoryWatcher to retrieve the
# logs for this trial run
log_path = self._get_tensorboard_log_dir(trial.trial_id)
log_path = os.path.join(
self._get_tensorboard_log_dir(trial.trial_id), "train")

# Tensorboard log watcher expects the path to exist
tf.io.gfile.makedirs(log_path)

# TODO(b/170687807) Switch from using "{}".format() to f-string
tf.get_logger().info(
Expand All @@ -590,11 +597,12 @@ def run_trial(self, trial, *fit_args, **fit_kwargs):

for epoch_metrics in training_metrics.completed_epoch_metrics:
# TODO(b/169197272) Validate metrics contain oracle objective
trial.status = self.oracle.update_trial(
trial_id=trial.trial_id,
metrics=epoch_metrics,
step=epoch)
epoch += 1
if epoch_metrics:
trial.status = self.oracle.update_trial(
trial_id=trial.trial_id,
metrics=epoch_metrics,
step=epoch)
epoch += 1

if trial.status == "STOPPED":
google_api_client.stop_aip_training_job(
Expand All @@ -617,11 +625,19 @@ def run_trial(self, trial, *fit_args, **fit_kwargs):
for epoch_metrics in training_metrics.completed_epoch_metrics:
# TODO(b/169197272) Validate metrics contain oracle objective
# TODO(b/170907612) Support submit partial results to Oracle
if epoch_metrics:
self.oracle.update_trial(
trial_id=trial.trial_id,
metrics=epoch_metrics,
step=epoch)
epoch += 1

# submit final epoch metrics
if training_metrics.partial_epoch_metrics:
self.oracle.update_trial(
trial_id=trial.trial_id,
metrics=epoch_metrics,
metrics=training_metrics.partial_epoch_metrics,
step=epoch)
epoch += 1

def _get_job_spec_from_config(self, job_id: Text) -> Dict[Text, Any]:
"""Creates a request dictionary for the CAIP training service.
Expand Down Expand Up @@ -680,7 +696,9 @@ def _get_remote_training_metrics(
- 'completed_epoch_metrics'- a list of epoch metrics for completed
epochs.
- 'partial_epoch_metrics' - Any incomplete epoch metrics for the
last epoch.
last epoch. Once training completes, the final epoch metrics
will be stored here, this is not included in
completed_epoch_metrics.
"""
completed_epoch_metrics = []
for event in log_reader.Load():
Expand All @@ -699,7 +717,6 @@ def _get_remote_training_metrics(
# the unrelated Objectives.
partial_epoch_metrics[metric] = tf.make_ndarray(
event.summary.value[0].tensor)
completed_epoch_metrics.append(partial_epoch_metrics)
return _TrainingMetrics(completed_epoch_metrics, partial_epoch_metrics)

def load_model(self, trial):
Expand Down
12 changes: 9 additions & 3 deletions src/python/tensorflow_cloud/utils/google_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def wait_for_api_training_job_completion(job_id: Text, project_id: Text)->bool:
"""
# Wait for AIP Training job to finish
job_name = "projects/{}/jobs/{}".format(project_id, job_id)
api_client = discovery.build("ml", "v1")
# Disable cache_discovery to remove excessive info logs see:
# https://github.com/googleapis/google-api-python-client/issues/299
api_client = discovery.build("ml", "v1", cache_discovery=False)

request = api_client.projects().jobs().get(name=job_name)

Expand Down Expand Up @@ -94,7 +96,9 @@ def is_api_training_job_running(job_id: Text, project_id: Text)->bool:
cancelled.
"""
job_name = "projects/{}/jobs/{}".format(project_id, job_id)
api_client = discovery.build("ml", "v1")
# Disable cache_discovery to remove excessive info logs see:
# https://github.com/googleapis/google-api-python-client/issues/299
api_client = discovery.build("ml", "v1", cache_discovery=False)

logging.info("Retrieving status for job %s.", job_name)

Expand All @@ -112,7 +116,9 @@ def stop_aip_training_job(job_id: Text, project_id: Text):
project_id: Project under which the AIP Training job is running.
"""
job_name = "projects/{}/jobs/{}".format(project_id, job_id)
api_client = discovery.build("ml", "v1")
# Disable cache_discovery to remove excessive info logs see:
# https://github.com/googleapis/google-api-python-client/issues/299
api_client = discovery.build("ml", "v1", cache_discovery=False)

logging.info("Canceling the job %s.", job_name)

Expand Down

0 comments on commit a79009e

Please sign in to comment.