Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adds support for X509 workload credential type #1541

Merged
merged 11 commits into from
Jul 2, 2024
18 changes: 16 additions & 2 deletions google/auth/external_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import copy
from dataclasses import dataclass
import datetime
import functools
import io
import json
import re
Expand All @@ -40,6 +41,7 @@
from google.auth import exceptions
from google.auth import impersonated_credentials
from google.auth import metrics
from google.auth.transport.requests import _MutualTlsAdapter
aeitzman marked this conversation as resolved.
Show resolved Hide resolved
from google.oauth2 import sts
from google.oauth2 import utils

Expand Down Expand Up @@ -393,12 +395,18 @@ def get_project_id(self, request):
@_helpers.copy_docstring(credentials.Credentials)
def refresh(self, request):
scopes = self._scopes if self._scopes is not None else self._default_scopes
auth_request = request
aeitzman marked this conversation as resolved.
Show resolved Hide resolved

# if mtls is required, wrap the incoming request in a partial to set the cert.
aeitzman marked this conversation as resolved.
Show resolved Hide resolved
if self._should_add_mtls():
print("mtls yeah")
aeitzman marked this conversation as resolved.
Show resolved Hide resolved
auth_request = functools.partial(request, cert=self._get_mtls_cert())

if self._should_initialize_impersonated_credentials():
self._impersonated_credentials = self._initialize_impersonated_credentials()

if self._impersonated_credentials:
self._impersonated_credentials.refresh(request)
self._impersonated_credentials.refresh(auth_request)
self.token = self._impersonated_credentials.token
self.expiry = self._impersonated_credentials.expiry
else:
Expand All @@ -414,7 +422,7 @@ def refresh(self, request):
)
}
response_data = self._sts_client.exchange_token(
request=request,
request=auth_request,
grant_type=_STS_GRANT_TYPE,
subject_token=self.retrieve_subject_token(request),
subject_token_type=self._subject_token_type,
Expand Down Expand Up @@ -523,6 +531,12 @@ def _create_default_metrics_options(self):

return metrics_options

def _should_add_mtls(self):
aeitzman marked this conversation as resolved.
Show resolved Hide resolved
return False

def _get_mtls_cert(self):
aeitzman marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError("_get_mtls_cert must be implemented.")
aeitzman marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def from_info(cls, info, **kwargs):
"""Creates a Credentials instance from parsed external account info.
Expand Down
131 changes: 101 additions & 30 deletions google/auth/identity_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from google.auth import _helpers
from google.auth import exceptions
from google.auth import external_account
from google.auth.transport import _mtls_helper


class SubjectTokenSupplier(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -141,6 +142,14 @@ def get_subject_token(self, context, request):
)


class _X509Supplier(SubjectTokenSupplier):
""" Internal implementation of subject token supplier for X509 workload credentials, always returns an empty string."""
aeitzman marked this conversation as resolved.
Show resolved Hide resolved

@_helpers.copy_docstring(SubjectTokenSupplier)
def get_subject_token(self, context, request):
return ""


def _parse_token_data(token_content, format_type="text", subject_token_field_name=None):
if format_type == "text":
token = token_content.content
Expand Down Expand Up @@ -247,6 +256,7 @@ def __init__(
self._subject_token_supplier = subject_token_supplier
self._credential_source_file = None
self._credential_source_url = None
self._credential_source_certificate = None
else:
if not isinstance(credential_source, Mapping):
self._credential_source_executable = None
Expand All @@ -255,76 +265,93 @@ def __init__(
)
self._credential_source_file = credential_source.get("file")
self._credential_source_url = credential_source.get("url")
clundin25 marked this conversation as resolved.
Show resolved Hide resolved
self._credential_source_headers = credential_source.get("headers")
credential_source_format = credential_source.get("format", {})
# Get credential_source format type. When not provided, this
# defaults to text.
self._credential_source_format_type = (
credential_source_format.get("type") or "text"
)
self._credential_source_certificate = credential_source.get("certificate")

# environment_id is only supported in AWS or dedicated future external
# account credentials.
if "environment_id" in credential_source:
raise exceptions.MalformedError(
"Invalid Identity Pool credential_source field 'environment_id'"
)
if self._credential_source_format_type not in ["text", "json"]:
raise exceptions.MalformedError(
"Invalid credential_source format '{}'".format(
self._credential_source_format_type

# check that only one of file, url, or certificate are provided.
if (
sum(
map(
bool,
[
self._credential_source_file,
self._credential_source_url,
self._credential_source_certificate,
],
)
)
aeitzman marked this conversation as resolved.
Show resolved Hide resolved
# For JSON types, get the required subject_token field name.
if self._credential_source_format_type == "json":
self._credential_source_field_name = credential_source_format.get(
"subject_token_field_name"
)
if self._credential_source_field_name is None:
raise exceptions.MalformedError(
"Missing subject_token_field_name for JSON credential_source format"
)
else:
self._credential_source_field_name = None

if self._credential_source_file and self._credential_source_url:
> 1
):
raise exceptions.MalformedError(
"Ambiguous credential_source. 'file' is mutually exclusive with 'url'."
"Ambiguous credential_source. 'file', 'url', and 'certificate' are mutually exclusive.."
)
if not self._credential_source_file and not self._credential_source_url:
if (
not self._credential_source_file
and not self._credential_source_url
and not self._credential_source_certificate
):
raise exceptions.MalformedError(
"Missing credential_source. A 'file' or 'url' must be provided."
"Missing credential_source. A 'file', 'url', or 'certificate' must be provided."
)

if self._credential_source_certificate:
self._validate_certificate_credential_source()
else:
self._validate_file_url_credential_source(credential_source)

if self._credential_source_file:
self._subject_token_supplier = _FileSupplier(
self._credential_source_file,
self._credential_source_format_type,
self._credential_source_field_name,
)
else:
elif self._credential_source_url:
self._subject_token_supplier = _UrlSupplier(
self._credential_source_url,
self._credential_source_format_type,
self._credential_source_field_name,
self._credential_source_headers,
)
else:
aeitzman marked this conversation as resolved.
Show resolved Hide resolved
self._subject_token_supplier = _X509Supplier()

@_helpers.copy_docstring(external_account.Credentials)
def retrieve_subject_token(self, request):
return self._subject_token_supplier.get_subject_token(
self._supplier_context, request
)

def _get_mtls_cert(self):
if self._credential_source_certificate == None:
raise exceptions.RefreshError(
'The credential is not configured to use mtls requests. The credential should include a "certificate" section in the credential source.'
)
else:
return _mtls_helper._get_workload_cert_and_key_paths(
self._certificate_config_location
)

def _should_add_mtls(self):
return self._credential_source_certificate is not None

def _create_default_metrics_options(self):
metrics_options = super(Credentials, self)._create_default_metrics_options()
# Check that credential source is a dict before checking for file vs url. This check needs to be done
# Check that credential source is a dict before checking for credential type. This check needs to be done
# here because the external_account credential constructor needs to pass the metrics options to the
# impersonated credential object before the identity_pool credentials are validated.
if isinstance(self._credential_source, Mapping):
if self._credential_source.get("file"):
metrics_options["source"] = "file"
else:
elif self._credential_source.get("url"):
metrics_options["source"] = "url"
else:
metrics_options["source"] = "x509"
else:
metrics_options["source"] = "programmatic"
return metrics_options
Expand All @@ -339,6 +366,50 @@ def _constructor_args(self):
args.update({"subject_token_supplier": self._subject_token_supplier})
return args

def _validate_certificate_credential_source(self):
aeitzman marked this conversation as resolved.
Show resolved Hide resolved
self._certificate_config_location = self._credential_source_certificate.get(
"certificate_config_location"
)
use_default = self._credential_source_certificate.get(
"use_default_certificate_config"
)
if self._certificate_config_location:
if use_default:
raise exceptions.MalformedError(
"Invalid certificate configuration, certificate_config_location cannot be specified when use_default_certificate_config = true."
)
else:
if not use_default:
raise exceptions.MalformedError(
"Invalid certificate configuration, use_default_certificate_config should be true if no certificate_config_location is provided."
)
aeitzman marked this conversation as resolved.
Show resolved Hide resolved

def _validate_file_url_credential_source(self, credential_source):
aeitzman marked this conversation as resolved.
Show resolved Hide resolved
self._credential_source_headers = credential_source.get("headers")
credential_source_format = credential_source.get("format", {})
# Get credential_source format type. When not provided, this
# defaults to text.
self._credential_source_format_type = (
credential_source_format.get("type") or "text"
)
if self._credential_source_format_type not in ["text", "json"]:
raise exceptions.MalformedError(
"Invalid credential_source format '{}'".format(
self._credential_source_format_type
)
)
# For JSON types, get the required subject_token field name.
if self._credential_source_format_type == "json":
self._credential_source_field_name = credential_source_format.get(
"subject_token_field_name"
)
if self._credential_source_field_name is None:
raise exceptions.MalformedError(
"Missing subject_token_field_name for JSON credential_source format"
)
else:
self._credential_source_field_name = None

@classmethod
def from_info(cls, info, **kwargs):
"""Creates an Identity Pool Credentials instance from parsed external account info.
Expand Down
75 changes: 43 additions & 32 deletions google/auth/transport/_mtls_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,50 @@ def _get_workload_cert_and_key(certificate_config_path=None):
google.auth.exceptions.ClientCertError: if problems occurs when retrieving
the certificate or key information.
"""
absolute_path = _get_cert_config_path(certificate_config_path)

cert_path, key_path = _get_workload_cert_and_key_paths(certificate_config_path)

if cert_path is None and key_path is None:
return None, None

return _read_cert_and_key_files(cert_path, key_path)


def _get_cert_config_path(certificate_config_path=None):
"""Gets the certificate configuration full path using the following order of precedence:
aeitzman marked this conversation as resolved.
Show resolved Hide resolved

1: Explicit override, if set
2: Environment variable, if set
3: Well-known location

Returns "None" if the selected config file does not exist.

Args:
certificate_config_path (string): The certificate config path. If provided, the well known
location and environment variable will be ignored.

Returns:
The absolute path of the certificate config file, and None if the file does not exist.
"""

if certificate_config_path is None:
env_path = environ.get(_CERTIFICATE_CONFIGURATION_ENV, None)
if env_path is not None and env_path != "":
certificate_config_path = env_path
else:
certificate_config_path = _CERTIFICATE_CONFIGURATION_DEFAULT_PATH

certificate_config_path = path.expanduser(certificate_config_path)
if not path.exists(certificate_config_path):
return None
return certificate_config_path


def _get_workload_cert_and_key_paths(config_path):
absolute_path = _get_cert_config_path(config_path)
if absolute_path is None:
return None, None

data = _load_json_file(absolute_path)

if "cert_configs" not in data:
Expand Down Expand Up @@ -142,37 +183,7 @@ def _get_workload_cert_and_key(certificate_config_path=None):
)
key_path = workload["key_path"]

return _read_cert_and_key_files(cert_path, key_path)


def _get_cert_config_path(certificate_config_path=None):
"""Gets the certificate configuration full path using the following order of precedence:

1: Explicit override, if set
2: Environment variable, if set
3: Well-known location

Returns "None" if the selected config file does not exist.

Args:
certificate_config_path (string): The certificate config path. If provided, the well known
location and environment variable will be ignored.

Returns:
The absolute path of the certificate config file, and None if the file does not exist.
"""

if certificate_config_path is None:
env_path = environ.get(_CERTIFICATE_CONFIGURATION_ENV, None)
if env_path is not None and env_path != "":
certificate_config_path = env_path
else:
certificate_config_path = _CERTIFICATE_CONFIGURATION_DEFAULT_PATH

certificate_config_path = path.expanduser(certificate_config_path)
if not path.exists(certificate_config_path):
return None
return certificate_config_path
return cert_path, key_path


def _read_cert_and_key_files(cert_path, key_path):
Expand Down
Loading