diff --git a/google/auth/external_account.py b/google/auth/external_account.py index 3943de2a3..df0511f25 100644 --- a/google/auth/external_account.py +++ b/google/auth/external_account.py @@ -31,6 +31,7 @@ import copy from dataclasses import dataclass import datetime +import functools import io import json import re @@ -394,6 +395,12 @@ def get_project_id(self, request): def refresh(self, request): scopes = self._scopes if self._scopes is not None else self._default_scopes + # Inject client certificate into request. + if self._mtls_required(): + request = functools.partial( + request, cert=self._get_mtls_cert_and_key_paths() + ) + if self._should_initialize_impersonated_credentials(): self._impersonated_credentials = self._initialize_impersonated_credentials() @@ -523,6 +530,33 @@ def _create_default_metrics_options(self): return metrics_options + def _mtls_required(self): + """Returns a boolean representing whether the current credential is configured + for mTLS and should add a certificate to the outgoing calls to the sts and service + account impersonation endpoint. + + Returns: + bool: True if the credential is configured for mTLS, False if it is not. + """ + return False + + def _get_mtls_cert_and_key_paths(self): + """Gets the file locations for a certificate and private key file + to be used for configuring mTLS for the sts and service account + impersonation calls. Currently only expected to return a value when using + X509 workload identity federation. + + Returns: + Tuple[str, str]: The cert and key file locations as strings in a tuple. + + Raises: + NotImplementedError: When the current credential is not configured for + mTLS. + """ + raise NotImplementedError( + "_get_mtls_cert_and_key_location must be implemented." + ) + @classmethod def from_info(cls, info, **kwargs): """Creates a Credentials instance from parsed external account info. diff --git a/google/auth/identity_pool.py b/google/auth/identity_pool.py index 1c97885a4..47f9a5571 100644 --- a/google/auth/identity_pool.py +++ b/google/auth/identity_pool.py @@ -48,6 +48,7 @@ from google.auth import _helpers from google.auth import exceptions from google.auth import external_account +from google.auth.transport import _mtls_helper class SubjectTokenSupplier(metaclass=abc.ABCMeta): @@ -141,6 +142,14 @@ def get_subject_token(self, context, request): ) +class _X509Supplier(SubjectTokenSupplier): + """Internal supplier for X509 workload credentials. This class is used internally and always returns an empty string as the subject token.""" + + @_helpers.copy_docstring(SubjectTokenSupplier) + def get_subject_token(self, context, request): + return "" + + def _parse_token_data(token_content, format_type="text", subject_token_field_name=None): if format_type == "text": token = token_content.content @@ -247,6 +256,7 @@ def __init__( self._subject_token_supplier = subject_token_supplier self._credential_source_file = None self._credential_source_url = None + self._credential_source_certificate = None else: if not isinstance(credential_source, Mapping): self._credential_source_executable = None @@ -255,45 +265,22 @@ def __init__( ) self._credential_source_file = credential_source.get("file") self._credential_source_url = credential_source.get("url") - self._credential_source_headers = credential_source.get("headers") - credential_source_format = credential_source.get("format", {}) - # Get credential_source format type. When not provided, this - # defaults to text. - self._credential_source_format_type = ( - credential_source_format.get("type") or "text" - ) + self._credential_source_certificate = credential_source.get("certificate") + # environment_id is only supported in AWS or dedicated future external # account credentials. if "environment_id" in credential_source: raise exceptions.MalformedError( "Invalid Identity Pool credential_source field 'environment_id'" ) - if self._credential_source_format_type not in ["text", "json"]: - raise exceptions.MalformedError( - "Invalid credential_source format '{}'".format( - self._credential_source_format_type - ) - ) - # For JSON types, get the required subject_token field name. - if self._credential_source_format_type == "json": - self._credential_source_field_name = credential_source_format.get( - "subject_token_field_name" - ) - if self._credential_source_field_name is None: - raise exceptions.MalformedError( - "Missing subject_token_field_name for JSON credential_source format" - ) - else: - self._credential_source_field_name = None - if self._credential_source_file and self._credential_source_url: - raise exceptions.MalformedError( - "Ambiguous credential_source. 'file' is mutually exclusive with 'url'." - ) - if not self._credential_source_file and not self._credential_source_url: - raise exceptions.MalformedError( - "Missing credential_source. A 'file' or 'url' must be provided." - ) + # check that only one of file, url, or certificate are provided. + self._validate_single_source() + + if self._credential_source_certificate: + self._validate_certificate_config() + else: + self._validate_file_or_url_config(credential_source) if self._credential_source_file: self._subject_token_supplier = _FileSupplier( @@ -301,13 +288,15 @@ def __init__( self._credential_source_format_type, self._credential_source_field_name, ) - else: + elif self._credential_source_url: self._subject_token_supplier = _UrlSupplier( self._credential_source_url, self._credential_source_format_type, self._credential_source_field_name, self._credential_source_headers, ) + else: # self._credential_source_certificate + self._subject_token_supplier = _X509Supplier() @_helpers.copy_docstring(external_account.Credentials) def retrieve_subject_token(self, request): @@ -315,16 +304,31 @@ def retrieve_subject_token(self, request): self._supplier_context, request ) + def _get_mtls_cert_and_key_paths(self): + if self._credential_source_certificate is None: + raise exceptions.RefreshError( + 'The credential is not configured to use mtls requests. The credential should include a "certificate" section in the credential source.' + ) + else: + return _mtls_helper._get_workload_cert_and_key_paths( + self._certificate_config_location + ) + + def _mtls_required(self): + return self._credential_source_certificate is not None + def _create_default_metrics_options(self): metrics_options = super(Credentials, self)._create_default_metrics_options() - # Check that credential source is a dict before checking for file vs url. This check needs to be done + # Check that credential source is a dict before checking for credential type. This check needs to be done # here because the external_account credential constructor needs to pass the metrics options to the # impersonated credential object before the identity_pool credentials are validated. if isinstance(self._credential_source, Mapping): if self._credential_source.get("file"): metrics_options["source"] = "file" - else: + elif self._credential_source.get("url"): metrics_options["source"] = "url" + else: + metrics_options["source"] = "x509" else: metrics_options["source"] = "programmatic" return metrics_options @@ -339,6 +343,67 @@ def _constructor_args(self): args.update({"subject_token_supplier": self._subject_token_supplier}) return args + def _validate_certificate_config(self): + self._certificate_config_location = self._credential_source_certificate.get( + "certificate_config_location" + ) + use_default = self._credential_source_certificate.get( + "use_default_certificate_config" + ) + if self._certificate_config_location and use_default: + raise exceptions.MalformedError( + "Invalid certificate configuration, certificate_config_location cannot be specified when use_default_certificate_config = true." + ) + if not self._certificate_config_location and not use_default: + raise exceptions.MalformedError( + "Invalid certificate configuration, use_default_certificate_config should be true if no certificate_config_location is provided." + ) + + def _validate_file_or_url_config(self, credential_source): + self._credential_source_headers = credential_source.get("headers") + credential_source_format = credential_source.get("format", {}) + # Get credential_source format type. When not provided, this + # defaults to text. + self._credential_source_format_type = ( + credential_source_format.get("type") or "text" + ) + if self._credential_source_format_type not in ["text", "json"]: + raise exceptions.MalformedError( + "Invalid credential_source format '{}'".format( + self._credential_source_format_type + ) + ) + # For JSON types, get the required subject_token field name. + if self._credential_source_format_type == "json": + self._credential_source_field_name = credential_source_format.get( + "subject_token_field_name" + ) + if self._credential_source_field_name is None: + raise exceptions.MalformedError( + "Missing subject_token_field_name for JSON credential_source format" + ) + else: + self._credential_source_field_name = None + + def _validate_single_source(self): + credential_sources = [ + self._credential_source_file, + self._credential_source_url, + self._credential_source_certificate, + ] + valid_credential_sources = list( + filter(lambda source: source is not None, credential_sources) + ) + + if len(valid_credential_sources) > 1: + raise exceptions.MalformedError( + "Ambiguous credential_source. 'file', 'url', and 'certificate' are mutually exclusive.." + ) + if len(valid_credential_sources) != 1: + raise exceptions.MalformedError( + "Missing credential_source. A 'file', 'url', or 'certificate' must be provided." + ) + @classmethod def from_info(cls, info, **kwargs): """Creates an Identity Pool Credentials instance from parsed external account info. diff --git a/google/auth/transport/_mtls_helper.py b/google/auth/transport/_mtls_helper.py index e95b953a1..6299e2bde 100644 --- a/google/auth/transport/_mtls_helper.py +++ b/google/auth/transport/_mtls_helper.py @@ -105,9 +105,50 @@ def _get_workload_cert_and_key(certificate_config_path=None): google.auth.exceptions.ClientCertError: if problems occurs when retrieving the certificate or key information. """ - absolute_path = _get_cert_config_path(certificate_config_path) + + cert_path, key_path = _get_workload_cert_and_key_paths(certificate_config_path) + + if cert_path is None and key_path is None: + return None, None + + return _read_cert_and_key_files(cert_path, key_path) + + +def _get_cert_config_path(certificate_config_path=None): + """Get the certificate configuration path based on the following order: + + 1: Explicit override, if set + 2: Environment variable, if set + 3: Well-known location + + Returns "None" if the selected config file does not exist. + + Args: + certificate_config_path (string): The certificate config path. If provided, the well known + location and environment variable will be ignored. + + Returns: + The absolute path of the certificate config file, and None if the file does not exist. + """ + + if certificate_config_path is None: + env_path = environ.get(_CERTIFICATE_CONFIGURATION_ENV, None) + if env_path is not None and env_path != "": + certificate_config_path = env_path + else: + certificate_config_path = _CERTIFICATE_CONFIGURATION_DEFAULT_PATH + + certificate_config_path = path.expanduser(certificate_config_path) + if not path.exists(certificate_config_path): + return None + return certificate_config_path + + +def _get_workload_cert_and_key_paths(config_path): + absolute_path = _get_cert_config_path(config_path) if absolute_path is None: return None, None + data = _load_json_file(absolute_path) if "cert_configs" not in data: @@ -142,37 +183,7 @@ def _get_workload_cert_and_key(certificate_config_path=None): ) key_path = workload["key_path"] - return _read_cert_and_key_files(cert_path, key_path) - - -def _get_cert_config_path(certificate_config_path=None): - """Gets the certificate configuration full path using the following order of precedence: - - 1: Explicit override, if set - 2: Environment variable, if set - 3: Well-known location - - Returns "None" if the selected config file does not exist. - - Args: - certificate_config_path (string): The certificate config path. If provided, the well known - location and environment variable will be ignored. - - Returns: - The absolute path of the certificate config file, and None if the file does not exist. - """ - - if certificate_config_path is None: - env_path = environ.get(_CERTIFICATE_CONFIGURATION_ENV, None) - if env_path is not None and env_path != "": - certificate_config_path = env_path - else: - certificate_config_path = _CERTIFICATE_CONFIGURATION_DEFAULT_PATH - - certificate_config_path = path.expanduser(certificate_config_path) - if not path.exists(certificate_config_path): - return None - return certificate_config_path + return cert_path, key_path def _read_cert_and_key_files(cert_path, key_path): diff --git a/system_tests/secrets.tar.enc b/system_tests/secrets.tar.enc index 254396d33..94db780f1 100644 Binary files a/system_tests/secrets.tar.enc and b/system_tests/secrets.tar.enc differ diff --git a/tests/test_external_account.py b/tests/test_external_account.py index c458b21b6..3c372e629 100644 --- a/tests/test_external_account.py +++ b/tests/test_external_account.py @@ -235,10 +235,16 @@ def make_mock_request( return request @classmethod - def assert_token_request_kwargs(cls, request_kwargs, headers, request_data): + def assert_token_request_kwargs( + cls, request_kwargs, headers, request_data, cert=None + ): assert request_kwargs["url"] == cls.TOKEN_URL assert request_kwargs["method"] == "POST" assert request_kwargs["headers"] == headers + if cert is not None: + assert request_kwargs["cert"] == cert + else: + assert "cert" not in request_kwargs assert request_kwargs["body"] is not None body_tuples = urllib.parse.parse_qsl(request_kwargs["body"]) for (k, v) in body_tuples: @@ -246,10 +252,16 @@ def assert_token_request_kwargs(cls, request_kwargs, headers, request_data): assert len(body_tuples) == len(request_data.keys()) @classmethod - def assert_impersonation_request_kwargs(cls, request_kwargs, headers, request_data): + def assert_impersonation_request_kwargs( + cls, request_kwargs, headers, request_data, cert=None + ): assert request_kwargs["url"] == cls.SERVICE_ACCOUNT_IMPERSONATION_URL assert request_kwargs["method"] == "POST" assert request_kwargs["headers"] == headers + if cert is not None: + assert request_kwargs["cert"] == cert + else: + assert "cert" not in request_kwargs assert request_kwargs["body"] is not None body_json = json.loads(request_kwargs["body"].decode("utf-8")) assert body_json == request_data @@ -665,6 +677,56 @@ def test_refresh_without_client_auth_success( assert not credentials.expired assert credentials.token == response["access_token"] + @mock.patch( + "google.auth.metrics.python_and_auth_lib_version", + return_value=LANG_LIBRARY_METRICS_HEADER_VALUE, + ) + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) + @mock.patch( + "google.auth.external_account.Credentials._mtls_required", return_value=True + ) + @mock.patch( + "google.auth.external_account.Credentials._get_mtls_cert_and_key_paths", + return_value=("path/to/cert.pem", "path/to/key.pem"), + ) + def test_refresh_with_mtls( + self, + mock_get_mtls_cert_and_key_paths, + mock_mtls_required, + unused_utcnow, + mock_auth_lib_value, + ): + response = self.SUCCESS_RESPONSE.copy() + # Test custom expiration to confirm expiry is set correctly. + response["expires_in"] = 2800 + expected_expiry = datetime.datetime.min + datetime.timedelta( + seconds=response["expires_in"] + ) + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "x-goog-api-client": "gl-python/3.7 auth/1.1 google-byoid-sdk sa-impersonation/false config-lifetime/false", + } + request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": self.AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "subject_token": "subject_token_0", + "subject_token_type": self.SUBJECT_TOKEN_TYPE, + } + request = self.make_mock_request(status=http_client.OK, data=response) + credentials = self.make_credentials() + + credentials.refresh(request) + + expected_cert_path = ("path/to/cert.pem", "path/to/key.pem") + self.assert_token_request_kwargs( + request.call_args[1], headers, request_data, expected_cert_path + ) + assert credentials.valid + assert credentials.expiry == expected_expiry + assert not credentials.expired + assert credentials.token == response["access_token"] + @mock.patch( "google.auth.metrics.python_and_auth_lib_version", return_value=LANG_LIBRARY_METRICS_HEADER_VALUE, @@ -869,6 +931,101 @@ def test_refresh_impersonation_without_client_auth_success( assert not credentials.expired assert credentials.token == impersonation_response["accessToken"] + @mock.patch( + "google.auth.metrics.token_request_access_token_impersonate", + return_value=IMPERSONATE_ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, + ) + @mock.patch( + "google.auth.metrics.python_and_auth_lib_version", + return_value=LANG_LIBRARY_METRICS_HEADER_VALUE, + ) + @mock.patch( + "google.auth.external_account.Credentials._mtls_required", return_value=True + ) + @mock.patch( + "google.auth.external_account.Credentials._get_mtls_cert_and_key_paths", + return_value=("path/to/cert.pem", "path/to/key.pem"), + ) + def test_refresh_impersonation_with_mtls_success( + self, + mock_get_mtls_cert_and_key_paths, + mock_mtls_required, + mock_metrics_header_value, + mock_auth_lib_value, + ): + # Simulate service account access token expires in 2800 seconds. + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=2800) + ).isoformat("T") + "Z" + expected_expiry = datetime.datetime.strptime(expire_time, "%Y-%m-%dT%H:%M:%SZ") + # STS token exchange request/response. + token_response = self.SUCCESS_RESPONSE.copy() + token_headers = { + "Content-Type": "application/x-www-form-urlencoded", + "x-goog-api-client": "gl-python/3.7 auth/1.1 google-byoid-sdk sa-impersonation/true config-lifetime/false", + } + token_request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": self.AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "subject_token": "subject_token_0", + "subject_token_type": self.SUBJECT_TOKEN_TYPE, + "scope": "https://www.googleapis.com/auth/iam", + } + # Service account impersonation request/response. + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + impersonation_headers = { + "Content-Type": "application/json", + "authorization": "Bearer {}".format(token_response["access_token"]), + "x-goog-api-client": IMPERSONATE_ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, + "x-allowed-locations": "0x0", + } + impersonation_request_data = { + "delegates": None, + "scope": self.SCOPES, + "lifetime": "3600s", + } + # Initialize mock request to handle token exchange and service account + # impersonation request. + request = self.make_mock_request( + status=http_client.OK, + data=token_response, + impersonation_status=http_client.OK, + impersonation_data=impersonation_response, + ) + # Initialize credentials with service account impersonation. + credentials = self.make_credentials( + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=self.SCOPES, + ) + + credentials.refresh(request) + + # Only 2 requests should be processed. + assert len(request.call_args_list) == 2 + # Verify token exchange request parameters. + expected_cert_paths = ("path/to/cert.pem", "path/to/key.pem") + self.assert_token_request_kwargs( + request.call_args_list[0][1], + token_headers, + token_request_data, + expected_cert_paths, + ) + # Verify service account impersonation request parameters. + self.assert_impersonation_request_kwargs( + request.call_args_list[1][1], + impersonation_headers, + impersonation_request_data, + expected_cert_paths, + ) + assert credentials.valid + assert credentials.expiry == expected_expiry + assert not credentials.expired + assert credentials.token == impersonation_response["accessToken"] + @mock.patch( "google.auth.metrics.token_request_access_token_impersonate", return_value=IMPERSONATE_ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, diff --git a/tests/test_identity_pool.py b/tests/test_identity_pool.py index a11b1e70f..b2b0063ec 100644 --- a/tests/test_identity_pool.py +++ b/tests/test_identity_pool.py @@ -179,6 +179,12 @@ class TestCredentials(object): "url": CREDENTIAL_URL, "format": {"type": "json", "subject_token_field_name": "access_token"}, } + CREDENTIAL_SOURCE_CERTIFICATE = { + "certificate": {"use_default_certificate_config": "true"} + } + CREDENTIAL_SOURCE_CERTIFICATE_NOT_DEFAULT = { + "certificate": {"certificate_config_location": "path/to/config"} + } SUCCESS_RESPONSE = { "access_token": "ACCESS_TOKEN", "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", @@ -677,6 +683,40 @@ def test_constructor_invalid_options_url_and_file(self): assert excinfo.match(r"Ambiguous credential_source") + def test_constructor_invalid_options_url_and_certificate(self): + credential_source = { + "url": self.CREDENTIAL_URL, + "certificate": {"certificate": {"use_default_certificate_config": True}}, + } + + with pytest.raises(ValueError) as excinfo: + self.make_credentials(credential_source=credential_source) + + assert excinfo.match(r"Ambiguous credential_source") + + def test_constructor_invalid_options_file_and_certificate(self): + credential_source = { + "file": SUBJECT_TOKEN_TEXT_FILE, + "certificate": {"certificate": {"use_default_certificate": True}}, + } + + with pytest.raises(ValueError) as excinfo: + self.make_credentials(credential_source=credential_source) + + assert excinfo.match(r"Ambiguous credential_source") + + def test_constructor_invalid_options_url_file_and_certificate(self): + credential_source = { + "file": SUBJECT_TOKEN_TEXT_FILE, + "url": self.CREDENTIAL_URL, + "certificate": {"certificate": {"use_default_certificate": True}}, + } + + with pytest.raises(ValueError) as excinfo: + self.make_credentials(credential_source=credential_source) + + assert excinfo.match(r"Ambiguous credential_source") + def test_constructor_invalid_options_environment_id(self): credential_source = {"url": self.CREDENTIAL_URL, "environment_id": "aws1"} @@ -716,7 +756,7 @@ def test_constructor_invalid_both_credential_source_and_supplier(self): ) def test_constructor_invalid_credential_source_format_type(self): - credential_source = {"format": {"type": "xml"}} + credential_source = {"file": "test.txt", "format": {"type": "xml"}} with pytest.raises(ValueError) as excinfo: self.make_credentials(credential_source=credential_source) @@ -724,7 +764,7 @@ def test_constructor_invalid_credential_source_format_type(self): assert excinfo.match(r"Invalid credential_source format 'xml'") def test_constructor_missing_subject_token_field_name(self): - credential_source = {"format": {"type": "json"}} + credential_source = {"file": "test.txt", "format": {"type": "json"}} with pytest.raises(ValueError) as excinfo: self.make_credentials(credential_source=credential_source) @@ -733,6 +773,27 @@ def test_constructor_missing_subject_token_field_name(self): r"Missing subject_token_field_name for JSON credential_source format" ) + def test_constructor_default_and_file_location_certificate(self): + credential_source = { + "certificate": { + "use_default_certificate_config": True, + "certificate_config_location": "test", + } + } + + with pytest.raises(ValueError) as excinfo: + self.make_credentials(credential_source=credential_source) + + assert excinfo.match(r"Invalid certificate configuration") + + def test_constructor_no_default_or_file_location_certificate(self): + credential_source = {"certificate": {"use_default_certificate_config": False}} + + with pytest.raises(ValueError) as excinfo: + self.make_credentials(credential_source=credential_source) + + assert excinfo.match(r"Invalid certificate configuration") + def test_info_with_workforce_pool_user_project(self): credentials = self.make_credentials( audience=WORKFORCE_AUDIENCE, @@ -782,6 +843,36 @@ def test_info_with_url_credential_source(self): "universe_domain": DEFAULT_UNIVERSE_DOMAIN, } + def test_info_with_certificate_credential_source(self): + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_CERTIFICATE.copy() + ) + + assert credentials.info == { + "type": "external_account", + "audience": AUDIENCE, + "subject_token_type": SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "token_info_url": TOKEN_INFO_URL, + "credential_source": self.CREDENTIAL_SOURCE_CERTIFICATE, + "universe_domain": DEFAULT_UNIVERSE_DOMAIN, + } + + def test_info_with_non_default_certificate_credential_source(self): + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_CERTIFICATE_NOT_DEFAULT.copy() + ) + + assert credentials.info == { + "type": "external_account", + "audience": AUDIENCE, + "subject_token_type": SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "token_info_url": TOKEN_INFO_URL, + "credential_source": self.CREDENTIAL_SOURCE_CERTIFICATE_NOT_DEFAULT, + "universe_domain": DEFAULT_UNIVERSE_DOMAIN, + } + def test_info_with_default_token_url(self): credentials = identity_pool.Credentials( audience=AUDIENCE, @@ -845,6 +936,15 @@ def test_retrieve_subject_token_json_file(self): assert subject_token == JSON_FILE_SUBJECT_TOKEN + def test_retrieve_subject_token_certificate(self): + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_CERTIFICATE + ) + + subject_token = credentials.retrieve_subject_token(None) + + assert subject_token == "" + def test_retrieve_subject_token_json_file_invalid_field_name(self): credential_source = { "file": SUBJECT_TOKEN_JSON_FILE, @@ -1485,3 +1585,28 @@ def test_refresh_success_supplier_without_impersonation_url(self): scopes=SCOPES, default_scopes=None, ) + + @mock.patch( + "google.auth.transport._mtls_helper._get_workload_cert_and_key_paths", + return_value=("cert", "key"), + ) + def test_get_mtls_certs(self, mock_get_workload_cert_and_key_paths): + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_CERTIFICATE.copy() + ) + + cert, key = credentials._get_mtls_cert_and_key_paths() + assert cert == "cert" + assert key == "key" + + def test_get_mtls_certs_invalid(self): + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_TEXT.copy() + ) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials._get_mtls_cert_and_key_paths() + + assert excinfo.match( + 'The credential is not configured to use mtls requests. The credential should include a "certificate" section in the credential source.' + )