diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index a1d46d3efe..d251e0f62a 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -497,6 +497,7 @@ def connect( credentials=None, pool=None, user_agent=None, + client=None, ): """Creates a connection to a Google Cloud Spanner database. @@ -529,25 +530,31 @@ def connect( :param user_agent: (Optional) User agent to be used with this connection's requests. + :type client: Concrete subclass of + :class:`~google.cloud.spanner_v1.Client`. + :param client: (Optional) Custom user provided Client Object + :rtype: :class:`google.cloud.spanner_dbapi.connection.Connection` :returns: Connection object associated with the given Google Cloud Spanner resource. """ - - client_info = ClientInfo( - user_agent=user_agent or DEFAULT_USER_AGENT, - python_version=PY_VERSION, - client_library_version=spanner.__version__, - ) - - if isinstance(credentials, str): - client = spanner.Client.from_service_account_json( - credentials, project=project, client_info=client_info + if client is None: + client_info = ClientInfo( + user_agent=user_agent or DEFAULT_USER_AGENT, + python_version=PY_VERSION, + client_library_version=spanner.__version__, ) + if isinstance(credentials, str): + client = spanner.Client.from_service_account_json( + credentials, project=project, client_info=client_info + ) + else: + client = spanner.Client( + project=project, credentials=credentials, client_info=client_info + ) else: - client = spanner.Client( - project=project, credentials=credentials, client_info=client_info - ) + if project is not None and client.project != project: + raise ValueError("project in url does not match client object project") instance = client.instance(instance_id) conn = Connection(instance, instance.database(database_id, pool=pool)) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 090def3519..b077c1feba 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -18,6 +18,7 @@ import mock import unittest import warnings +import pytest PROJECT = "test-project" INSTANCE = "test-instance" @@ -915,7 +916,52 @@ def test_request_priority(self): sql, params, param_types=param_types, request_options=None ) + @mock.patch("google.cloud.spanner_v1.Client") + def test_custom_client_connection(self, mock_client): + from google.cloud.spanner_dbapi import connect + + client = _Client() + connection = connect("test-instance", "test-database", client=client) + self.assertTrue(connection.instance._client == client) + + @mock.patch("google.cloud.spanner_v1.Client") + def test_invalid_custom_client_connection(self, mock_client): + from google.cloud.spanner_dbapi import connect + + client = _Client() + with pytest.raises(ValueError): + connect( + "test-instance", + "test-database", + project="invalid_project", + client=client, + ) + def exit_ctx_func(self, exc_type, exc_value, traceback): """Context __exit__ method mock.""" pass + + +class _Client(object): + def __init__(self, project="project_id"): + self.project = project + self.project_name = "projects/" + self.project + + def instance(self, instance_id="instance_id"): + return _Instance(name=instance_id, client=self) + + +class _Instance(object): + def __init__(self, name="instance_id", client=None): + self.name = name + self._client = client + + def database(self, database_id="database_id", pool=None): + return _Database(database_id, pool) + + +class _Database(object): + def __init__(self, database_id="database_id", pool=None): + self.name = database_id + self.pool = pool