Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR #3460: Specify GCP compute project in BigQuery Pusher executor #3701

Merged
merged 6 commits into from
May 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@

## Bug Fixes and Other Changes

* GCP compute project in BigQuery Pusher executor can be specified.
* New extra dependencies for convenience.
- tfx[airflow] installs all Apache Airflow orchestrator dependencies.
- tfx[kfp] installs all Kubeflow Pipelines orchestrator dependencies.
Expand Down
38 changes: 34 additions & 4 deletions tfx/extensions/google_cloud_big_query/pusher/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
_BQ_DATASET_ID_KEY = 'bq_dataset_id'
_MODEL_NAME_KEY = 'model_name'

# Project where query will be executed
_COMPUTE_PROJECT_ID_KEY = 'compute_project_id'

# Keys for custom_config.
_CUSTOM_CONFIG_KEY = 'custom_config'

Expand Down Expand Up @@ -66,7 +69,15 @@ def Do(self, input_dict: Dict[Text, List[types.Artifact]],
include the model in this push execution if the model was pushed.
exec_properties: Mostly a passthrough input dict for
tfx.components.Pusher.executor. custom_config.bigquery_serving_args is
consumed by this class. For the full set of parameters supported by
consumed by this class, including:
- bq_dataset_id: ID of the dataset you're creating or replacing
- model_name: name of the model you're creating or replacing
- project_id: GCP project where the model will be stored. It is also
the project where the query is executed unless a compute_project_id
is provided.
- compute_project_id: GCP project where the query is executed. If not
provided, the query is executed in project_id.
For the full set of parameters supported by
Big Query ML, refer to https://cloud.google.com/bigquery-ml/

Returns:
Expand All @@ -76,6 +87,23 @@ def Do(self, input_dict: Dict[Text, List[types.Artifact]],
If bigquery_serving_args is not in exec_properties.custom_config.
If pipeline_root is not 'gs://...'
RuntimeError: if the Big Query job failed.

Example usage:
from tfx.extensions.google_cloud_big_query.pusher import executor

pusher = Pusher(
model=trainer.outputs['model'],
model_blessing=evaluator.outputs['blessing'],
custom_executor_spec=executor_spec.ExecutorClassSpec(executor.Executor),
custom_config={
'bigquery_serving_args': {
'model_name': 'your_model_name',
'project_id': 'your_gcp_storage_project',
'bq_dataset_id': 'your_dataset_id',
'compute_project_id': 'your_gcp_compute_project',
},
},
)
"""
self._log_startup(input_dict, output_dict, exec_properties)
model_push = artifact_utils.get_single_instance(
Expand Down Expand Up @@ -122,15 +150,17 @@ def Do(self, input_dict: Dict[Text, List[types.Artifact]],
default_query_job_config = bigquery.job.QueryJobConfig(
labels=telemetry_utils.get_labels_dict())
# TODO(b/181368842) Add integration test for BQML Pusher + Managed Pipeline
project_id = (
bigquery_serving_args.get(_COMPUTE_PROJECT_ID_KEY) or
bigquery_serving_args[_PROJECT_ID_KEY])
client = bigquery.Client(
default_query_job_config=default_query_job_config,
project=bigquery_serving_args[_PROJECT_ID_KEY])
default_query_job_config=default_query_job_config, project=project_id)

try:
query_job = client.query(query)
query_job.result() # Waits for the query to finish
except Exception as e:
raise RuntimeError('BigQuery ML Push failed: {}'.format(e))
raise RuntimeError('BigQuery ML Push failed: {}'.format(e)) from e

logging.info('Successfully deployed model %s serving from %s', bq_model_uri,
model_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
class ExecutorTest(tf.test.TestCase):

def setUp(self):
super(ExecutorTest, self).setUp()
super().setUp()
self._source_data_dir = os.path.join(
os.path.dirname(
os.path.dirname(os.path.dirname(os.path.dirname(__file__)))),
Expand Down Expand Up @@ -59,6 +59,7 @@ def setUp(self):
'model_name': 'model_name',
'project_id': 'project_id',
'bq_dataset_id': 'bq_dataset_id',
'compute_project_id': 'compute_project_id',
},
},
'push_destination': None,
Expand Down