diff --git a/RELEASE.md b/RELEASE.md index 2168d8731d..b64e696037 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -92,6 +92,7 @@ ## Bug Fixes and Other Changes +* GCP compute project in BigQuery Pusher executor can be specified. * New extra dependencies for convenience. - tfx[airflow] installs all Apache Airflow orchestrator dependencies. - tfx[kfp] installs all Kubeflow Pipelines orchestrator dependencies. diff --git a/tfx/extensions/google_cloud_big_query/pusher/executor.py b/tfx/extensions/google_cloud_big_query/pusher/executor.py index 223d375bc6..43cc8ec55f 100644 --- a/tfx/extensions/google_cloud_big_query/pusher/executor.py +++ b/tfx/extensions/google_cloud_big_query/pusher/executor.py @@ -37,6 +37,9 @@ _BQ_DATASET_ID_KEY = 'bq_dataset_id' _MODEL_NAME_KEY = 'model_name' +# Project where query will be executed +_COMPUTE_PROJECT_ID_KEY = 'compute_project_id' + # Keys for custom_config. _CUSTOM_CONFIG_KEY = 'custom_config' @@ -66,7 +69,15 @@ def Do(self, input_dict: Dict[Text, List[types.Artifact]], include the model in this push execution if the model was pushed. exec_properties: Mostly a passthrough input dict for tfx.components.Pusher.executor. custom_config.bigquery_serving_args is - consumed by this class. For the full set of parameters supported by + consumed by this class, including: + - bq_dataset_id: ID of the dataset you're creating or replacing + - model_name: name of the model you're creating or replacing + - project_id: GCP project where the model will be stored. It is also + the project where the query is executed unless a compute_project_id + is provided. + - compute_project_id: GCP project where the query is executed. If not + provided, the query is executed in project_id. + For the full set of parameters supported by Big Query ML, refer to https://cloud.google.com/bigquery-ml/ Returns: @@ -76,6 +87,23 @@ def Do(self, input_dict: Dict[Text, List[types.Artifact]], If bigquery_serving_args is not in exec_properties.custom_config. If pipeline_root is not 'gs://...' RuntimeError: if the Big Query job failed. + + Example usage: + from tfx.extensions.google_cloud_big_query.pusher import executor + + pusher = Pusher( + model=trainer.outputs['model'], + model_blessing=evaluator.outputs['blessing'], + custom_executor_spec=executor_spec.ExecutorClassSpec(executor.Executor), + custom_config={ + 'bigquery_serving_args': { + 'model_name': 'your_model_name', + 'project_id': 'your_gcp_storage_project', + 'bq_dataset_id': 'your_dataset_id', + 'compute_project_id': 'your_gcp_compute_project', + }, + }, + ) """ self._log_startup(input_dict, output_dict, exec_properties) model_push = artifact_utils.get_single_instance( @@ -122,15 +150,17 @@ def Do(self, input_dict: Dict[Text, List[types.Artifact]], default_query_job_config = bigquery.job.QueryJobConfig( labels=telemetry_utils.get_labels_dict()) # TODO(b/181368842) Add integration test for BQML Pusher + Managed Pipeline + project_id = ( + bigquery_serving_args.get(_COMPUTE_PROJECT_ID_KEY) or + bigquery_serving_args[_PROJECT_ID_KEY]) client = bigquery.Client( - default_query_job_config=default_query_job_config, - project=bigquery_serving_args[_PROJECT_ID_KEY]) + default_query_job_config=default_query_job_config, project=project_id) try: query_job = client.query(query) query_job.result() # Waits for the query to finish except Exception as e: - raise RuntimeError('BigQuery ML Push failed: {}'.format(e)) + raise RuntimeError('BigQuery ML Push failed: {}'.format(e)) from e logging.info('Successfully deployed model %s serving from %s', bq_model_uri, model_path) diff --git a/tfx/extensions/google_cloud_big_query/pusher/executor_test.py b/tfx/extensions/google_cloud_big_query/pusher/executor_test.py index 7e0b95edc3..4df4f09ec6 100644 --- a/tfx/extensions/google_cloud_big_query/pusher/executor_test.py +++ b/tfx/extensions/google_cloud_big_query/pusher/executor_test.py @@ -30,7 +30,7 @@ class ExecutorTest(tf.test.TestCase): def setUp(self): - super(ExecutorTest, self).setUp() + super().setUp() self._source_data_dir = os.path.join( os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), @@ -59,6 +59,7 @@ def setUp(self): 'model_name': 'model_name', 'project_id': 'project_id', 'bq_dataset_id': 'bq_dataset_id', + 'compute_project_id': 'compute_project_id', }, }, 'push_destination': None,