Skip to content

Commit

Permalink
Cleanup studies after integration test
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 334680911
  • Loading branch information
chongyouquan authored and Tensorflow Cloud maintainers committed Sep 30, 2020
1 parent bf71178 commit da3b051
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 8 deletions.
4 changes: 3 additions & 1 deletion src/python/tensorflow_cloud/tuner/optimizer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def delete_study(self, study_name: Text = None) -> None:
.format(study_name))
tf.get_logger().info("DeleteStudy failed.")
raise
tf.get_logger().info("Study deleted: {}.".format(study_name))

def _obtain_long_running_operation(self, resp):
"""Obtain the long-running operation."""
Expand Down Expand Up @@ -467,8 +468,9 @@ def _get_study(
study_should_exist: Indicates whether it should be assumed that the
study with the given study_id exists.
"""
tf.get_logger().info("Study already exists. Load existing study...")
study_name = "{}/studies/{}".format(study_parent, study_id)
tf.get_logger().info(
"Study already exists: {}.\nLoad existing study...".format(study_name))
num_tries = 0
while True:
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import tensorflow as tf
from tensorflow import keras
from tensorflow_cloud import CloudTuner
from tensorflow_cloud.tuner import optimizer_client

# If input dataset is created outside tuner.search(),
# it requires eager execution even in TF 1.x.
Expand Down Expand Up @@ -123,6 +124,10 @@ def _dist_search_fn_wrapper(args):

class _CloudTunerIntegrationTestBase(tf.test.TestCase):

def setUp(self):
super(_CloudTunerIntegrationTestBase, self).setUp()
self._study_id = None

def _assert_output(self, fn, regex_str):
stdout = io.StringIO()
with contextlib.redirect_stdout(stdout):
Expand All @@ -136,6 +141,13 @@ def _assert_results_summary(self, fn):

def tearDown(self):
super(_CloudTunerIntegrationTestBase, self).tearDown()

# Delete the study used in the test, if present
if self._study_id:
service = optimizer_client.create_or_load_study(
_PROJECT_ID, _REGION, self._study_id, None)
service.delete_study()

tf.keras.backend.clear_session()


Expand All @@ -149,6 +161,8 @@ def setUp(self):
def testCloudTunerHyperparameters(self):
"""Test case to configure Tuner with HyperParameters object."""
study_id = "{}_hyperparameters".format(_STUDY_ID_BASE)
self._study_id = study_id

tuner = CloudTuner(
_build_model,
project_id=_PROJECT_ID,
Expand Down Expand Up @@ -196,6 +210,8 @@ def testCloudTunerDatasets(self):
)

study_id = "{}_dataset".format(_STUDY_ID_BASE)
self._study_id = study_id

tuner = CloudTuner(
_build_model,
project_id=_PROJECT_ID,
Expand Down Expand Up @@ -253,6 +269,8 @@ def testCloudTunerStudyConfig(self):
}

study_id = "{}_study_config".format(_STUDY_ID_BASE)
self._study_id = study_id

tuner = CloudTuner(
_build_model,
project_id=_PROJECT_ID,
Expand Down Expand Up @@ -286,6 +304,7 @@ class CloudTunerInDistributedIntegrationTest(_CloudTunerIntegrationTestBase):
def testCloudTunerInProcessDistributedTuning(self):
"""Test case to simulate multiple parallel tuning workers."""
study_id = "{}_dist".format(_STUDY_ID_BASE)
self._study_id = study_id

with multiprocessing.Pool(processes=_NUM_PARALLEL_TRIALS) as pool:
results = pool.map(
Expand All @@ -300,6 +319,7 @@ def testCloudTunerInProcessDistributedTuning(self):

def testCloudTunerAIPlatformTrainingDistributedTuning(self):
"""Test case of parallel tuning using CAIP Training as flock manager."""
# TODO(b/169697464): Implement test for tuning with CAIP Training
study_id = "{}_caip_dist".format(_STUDY_ID_BASE)
del study_id

Expand Down
11 changes: 5 additions & 6 deletions src/python/tensorflow_cloud/tuner/tests/unit/tuner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def setUp(self):
self._region = "us-central1"
self._project_id = "project-a"
self._trial_parent = "projects/{}/locations/{}/studies/{}".format(
self._project_id, self._region,
"CloudTuner_study_{}".format(self._study_id)
self._project_id, self._region, self._study_id
)
self._container_uri = "test_container_uri",
hps = hp_module.HyperParameters()
Expand Down Expand Up @@ -138,23 +137,23 @@ def test_tuner_initialization_with_hparams(self):
(self.mock_optimizer_client_module.create_or_load_study
.assert_called_with(self._project_id,
self._region,
"CloudTuner_study_{}".format(self._study_id),
self._study_id,
self._study_config))

def test_tuner_initialization_with_study_config(self):
self.tuner = self._tuner(None, None, self._study_config)
(self.mock_optimizer_client_module.create_or_load_study
.assert_called_with(self._project_id,
self._region,
"CloudTuner_study_{}".format(self._study_id),
self._study_id,
self._study_config))

def test_remote_tuner_initialization_with_study_config(self):
self._remote_tuner(None, None, self._study_config)
(self.mock_optimizer_client_module.create_or_load_study
.assert_called_with(self._project_id,
self._region,
"CloudTuner_study_{}".format(self._study_id),
self._study_id,
self._study_config))

def test_tuner_initialization_neither_hparam_nor_study_config(self):
Expand All @@ -178,7 +177,7 @@ def test_tuner_initialization_with_study_config_and_max_trials(self):
(self.mock_optimizer_client_module.create_or_load_study
.assert_called_with(self._project_id,
self._region,
"CloudTuner_study_{}".format(self._study_id),
self._study_id,
self._study_config))

def test_create_trial_initially(self):
Expand Down
2 changes: 1 addition & 1 deletion src/python/tensorflow_cloud/tuner/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
self.max_trials = max_trials

if study_id:
self.study_id = "CloudTuner_study_{}".format(study_id)
self.study_id = study_id
else:
self.study_id = "CloudTuner_study_{}".format(
datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
Expand Down

0 comments on commit da3b051

Please sign in to comment.