diff --git a/google/auth/transport/_custom_tls_signer.py b/google/auth/transport/_custom_tls_signer.py index 57a563d03..9279158d4 100644 --- a/google/auth/transport/_custom_tls_signer.py +++ b/google/auth/transport/_custom_tls_signer.py @@ -46,10 +46,17 @@ # Cast SSL_CTX* to void* -def _cast_ssl_ctx_to_void_p(ssl_ctx): +def _cast_ssl_ctx_to_void_p_pyopenssl(ssl_ctx): return ctypes.cast(int(cffi.FFI().cast("intptr_t", ssl_ctx)), ctypes.c_void_p) +# Cast SSL_CTX* to void* +def _cast_ssl_ctx_to_void_p_stdlib(context): + return ctypes.c_void_p.from_address( + id(context) + ctypes.sizeof(ctypes.c_void_p) * 2 + ) + + # Load offload library and set up the function types. def load_offload_lib(offload_lib_path): _LOGGER.debug("loading offload library from %s", offload_lib_path) @@ -249,10 +256,15 @@ def set_up_custom_key(self): self._signer_lib, self._enterprise_cert_file_path ) - def attach_to_ssl_context(self, ctx): + def should_use_provider(self): if self._provider_lib: + return True + return False + + def attach_to_ssl_context(self, ctx): + if self.should_use_provider(): if not self._provider_lib.ECP_attach_to_ctx( - _cast_ssl_ctx_to_void_p(ctx._ctx._context), + _cast_ssl_ctx_to_void_p_stdlib(ctx), self._enterprise_cert_file_path.encode("ascii"), ): raise exceptions.MutualTLSChannelError( @@ -262,7 +274,7 @@ def attach_to_ssl_context(self, ctx): if not self._offload_lib.ConfigureSslContext( self._sign_callback, ctypes.c_char_p(self._cert), - _cast_ssl_ctx_to_void_p(ctx._ctx._context), + _cast_ssl_ctx_to_void_p_pyopenssl(ctx._ctx._context), ): raise exceptions.MutualTLSChannelError( "failed to configure ECP Offload SSL context" diff --git a/google/auth/transport/requests.py b/google/auth/transport/requests.py index aa1611322..63a2b4596 100644 --- a/google/auth/transport/requests.py +++ b/google/auth/transport/requests.py @@ -262,19 +262,16 @@ class _MutualTlsOffloadAdapter(requests.adapters.HTTPAdapter): def __init__(self, enterprise_cert_file_path): import certifi - import urllib3.contrib.pyopenssl - from google.auth.transport import _custom_tls_signer - # Call inject_into_urllib3 to activate certificate checking. See the - # following links for more info: - # (1) doc: https://github.com/urllib3/urllib3/blob/cb9ebf8aac5d75f64c8551820d760b72b619beff/src/urllib3/contrib/pyopenssl.py#L31-L32 - # (2) mTLS example: https://github.com/urllib3/urllib3/issues/474#issuecomment-253168415 - urllib3.contrib.pyopenssl.inject_into_urllib3() - self.signer = _custom_tls_signer.CustomTlsSigner(enterprise_cert_file_path) self.signer.load_libraries() + if not self.signer.should_use_provider(): + import urllib3.contrib.pyopenssl + + urllib3.contrib.pyopenssl.inject_into_urllib3() + poolmanager = create_urllib3_context() poolmanager.load_verify_locations(cafile=certifi.where()) self.signer.attach_to_ssl_context(poolmanager) diff --git a/system_tests/secrets.tar.enc b/system_tests/secrets.tar.enc index 883ab7749..104243c2c 100644 Binary files a/system_tests/secrets.tar.enc and b/system_tests/secrets.tar.enc differ diff --git a/tests/transport/test__custom_tls_signer.py b/tests/transport/test__custom_tls_signer.py index d2907bad2..3a33c2c02 100644 --- a/tests/transport/test__custom_tls_signer.py +++ b/tests/transport/test__custom_tls_signer.py @@ -195,6 +195,7 @@ def test_custom_tls_signer(): get_cert.assert_called_once() get_sign_callback.assert_called_once() offload_lib.ConfigureSslContext.assert_called_once() + assert not signer_object.should_use_provider() assert signer_object._enterprise_cert_file_path == ENTERPRISE_CERT_FILE assert signer_object._offload_lib == offload_lib assert signer_object._signer_lib == signer_lib @@ -216,6 +217,7 @@ def test_custom_tls_signer_provider(): signer_object.load_libraries() signer_object.attach_to_ssl_context(mock.MagicMock()) + assert signer_object.should_use_provider() assert signer_object._enterprise_cert_file_path == ENTERPRISE_CERT_FILE_PROVIDER assert signer_object._provider_lib == provider_lib load_provider_lib.assert_called_with("/path/to/provider/lib") diff --git a/tests/transport/test_requests.py b/tests/transport/test_requests.py index aadc1ddbf..0da3e36d9 100644 --- a/tests/transport/test_requests.py +++ b/tests/transport/test_requests.py @@ -568,3 +568,38 @@ def test_success( adapter.proxy_manager_for() mock_proxy_manager_for.assert_called_with(ssl_context=adapter._ctx_proxymanager) + + @mock.patch.object(requests.adapters.HTTPAdapter, "init_poolmanager") + @mock.patch.object(requests.adapters.HTTPAdapter, "proxy_manager_for") + @mock.patch.object( + google.auth.transport._custom_tls_signer.CustomTlsSigner, "should_use_provider" + ) + @mock.patch.object( + google.auth.transport._custom_tls_signer.CustomTlsSigner, "load_libraries" + ) + @mock.patch.object( + google.auth.transport._custom_tls_signer.CustomTlsSigner, + "attach_to_ssl_context", + ) + def test_success_should_use_provider( + self, + mock_attach_to_ssl_context, + mock_load_libraries, + mock_should_use_provider, + mock_proxy_manager_for, + mock_init_poolmanager, + ): + enterprise_cert_file_path = "/path/to/enterprise/cert/json" + adapter = google.auth.transport.requests._MutualTlsOffloadAdapter( + enterprise_cert_file_path + ) + + mock_should_use_provider.side_effect = True + mock_load_libraries.assert_called_once() + assert mock_attach_to_ssl_context.call_count == 2 + + adapter.init_poolmanager() + mock_init_poolmanager.assert_called_with(ssl_context=adapter._ctx_poolmanager) + + adapter.proxy_manager_for() + mock_proxy_manager_for.assert_called_with(ssl_context=adapter._ctx_proxymanager)