Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Key Vault] Add local-only mode to CryptographyClient #16565

Merged
merged 18 commits into from
Mar 10, 2021
Merged
113 changes: 99 additions & 14 deletions sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
from ._key_validity import raise_if_time_invalid
from ._providers import get_local_cryptography_provider, NoLocalCryptography
from .. import KeyOperation
from .._models import KeyVaultKey
from .._models import JsonWebKey, KeyVaultKey
from .._shared import KeyVaultClientBase, parse_key_vault_id

if TYPE_CHECKING:
# pylint:disable=unused-import
# pylint:disable=unused-import,ungrouped-imports
from datetime import datetime
from typing import Any, Optional, Union
from azure.core.credentials import TokenCredential
from . import EncryptionAlgorithm, KeyWrapAlgorithm, SignatureAlgorithm
from .._shared import KeyVaultResourceId

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -98,33 +100,80 @@ class CryptographyClient(KeyVaultClientBase):

def __init__(self, key, credential, **kwargs):
# type: (Union[KeyVaultKey, str], TokenCredential, **Any) -> None
self._jwk = kwargs.pop("_jwk", False)
self._not_before = None # type: Optional[datetime]
self._expires_on = None # type: Optional[datetime]
self._key_id = None # type: Optional[KeyVaultResourceId]

if isinstance(key, KeyVaultKey):
self._key = key
self._key = key.key
self._key_id = parse_key_vault_id(key.id)
if key.properties._attributes: # pylint:disable=protected-access
mccoyp marked this conversation as resolved.
Show resolved Hide resolved
self._not_before = key.properties.not_before
self._expires_on = key.properties.expires_on
elif isinstance(key, six.string_types):
self._key = None
self._key_id = parse_key_vault_id(key)
self._keys_get_forbidden = None # type: Optional[bool]
elif self._jwk:
self._key = key
else:
raise ValueError("'key' must be a KeyVaultKey instance or a key ID string including a version")

if not self._key_id.version:
if not (self._jwk or self._key_id.version):
raise ValueError("'key' must include a version")

self._local_provider = NoLocalCryptography()
self._initialized = False
if self._jwk:
try:
self._local_provider = get_local_cryptography_provider(self._key)
self._initialized = True
except Exception as ex: # pylint:disable=broad-except
six.raise_from(ValueError("The provided jwk is not valid for local cryptography"), ex)
else:
self._local_provider = NoLocalCryptography()
self._initialized = False

super(CryptographyClient, self).__init__(vault_url=self._key_id.vault_url, credential=credential, **kwargs)
self._vault_url = None if self._jwk else self._key_id.vault_url
super(CryptographyClient, self).__init__(
vault_url=self._vault_url or "vault_url", credential=credential, **kwargs
)

@property
def key_id(self):
# type: () -> str
# type: () -> Optional[str]
"""The full identifier of the client's key.

:rtype: str
This property may be None when a client is constructed with :func:`from_jwk`.

:rtype: str or None
"""
return self._key_id.source_id
if not self._jwk:
return self._key_id.source_id
return self._key.kid

@property
def vault_url(self):
# type: () -> Optional[str]
"""The base vault URL of the client's key.

This property may be None when a client is constructed with :func:`from_jwk`.

:rtype: str or None
"""
return self._vault_url

@classmethod
def from_jwk(cls, jwk):
# type: (Union[JsonWebKey, dict]) -> CryptographyClient
"""Creates a client that can only perform cryptographic operations locally.

:param jwk: the key's cryptographic material, as a JsonWebKey or dictionary.
:type jwk: JsonWebKey or dict
:rtype: CryptographyClient
"""
if not isinstance(jwk, JsonWebKey):
jwk = JsonWebKey(**jwk)
return cls(jwk, object(), _jwk=True)

@distributed_trace
def _initialize(self, **kwargs):
Expand All @@ -138,7 +187,7 @@ def _initialize(self, **kwargs):
key_bundle = self._client.get_key(
self._key_id.vault_url, self._key_id.name, self._key_id.version, **kwargs
)
self._key = KeyVaultKey._from_key_bundle(key_bundle) # pylint:disable=protected-access
self._key = KeyVaultKey._from_key_bundle(key_bundle).key # pylint:disable=protected-access
except HttpResponseError as ex:
# if we got a 403, we don't have keys/get permission and won't try to get the key again
# (other errors may be transient)
Expand Down Expand Up @@ -181,11 +230,17 @@ def encrypt(self, algorithm, plaintext, **kwargs):
self._initialize(**kwargs)

if self._local_provider.supports(KeyOperation.encrypt, algorithm):
raise_if_time_invalid(self._key)
raise_if_time_invalid(self._not_before, self._expires_on)
try:
return self._local_provider.encrypt(algorithm, plaintext)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local encrypt operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "encrypt" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.encrypt(
vault_base_url=self._key_id.vault_url,
Expand Down Expand Up @@ -240,6 +295,12 @@ def decrypt(self, algorithm, ciphertext, **kwargs):
return self._local_provider.decrypt(algorithm, ciphertext)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local decrypt operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "decrypt" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.decrypt(
vault_base_url=self._key_id.vault_url,
Expand Down Expand Up @@ -272,11 +333,17 @@ def wrap_key(self, algorithm, key, **kwargs):
"""
self._initialize(**kwargs)
if self._local_provider.supports(KeyOperation.wrap_key, algorithm):
raise_if_time_invalid(self._key)
raise_if_time_invalid(self._not_before, self._expires_on)
try:
return self._local_provider.wrap_key(algorithm, key)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local wrap operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "wrapKey" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.wrap_key(
vault_base_url=self._key_id.vault_url,
Expand Down Expand Up @@ -311,6 +378,12 @@ def unwrap_key(self, algorithm, encrypted_key, **kwargs):
return self._local_provider.unwrap_key(algorithm, encrypted_key)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local unwrap operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "unwrapKey" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.unwrap_key(
vault_base_url=self._key_id.vault_url,
Expand Down Expand Up @@ -340,11 +413,17 @@ def sign(self, algorithm, digest, **kwargs):
"""
self._initialize(**kwargs)
if self._local_provider.supports(KeyOperation.sign, algorithm):
raise_if_time_invalid(self._key)
raise_if_time_invalid(self._not_before, self._expires_on)
try:
return self._local_provider.sign(algorithm, digest)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local sign operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "sign" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.sign(
vault_base_url=self._key_id.vault_url,
Expand Down Expand Up @@ -381,6 +460,12 @@ def verify(self, algorithm, digest, signature, **kwargs):
return self._local_provider.verify(algorithm, digest, signature)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local verify operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "verify" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.verify(
vault_base_url=self._key_id.vault_url,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import
from .. import KeyVaultKey
from typing import Optional


class _UTC_TZ(tzinfo):
Expand All @@ -28,20 +28,12 @@ def dst(self, dt):
_UTC = _UTC_TZ()


def raise_if_time_invalid(key):
# type: (KeyVaultKey) -> None
try:
nbf = key.properties.not_before
exp = key.properties.expires_on
except AttributeError:
# we consider the key valid because a user must have deliberately created it
# (if it came from Key Vault, it would have those attributes)
return

def raise_if_time_invalid(not_before, expires_on):
# type: (Optional[datetime], Optional[datetime]) -> None
now = datetime.now(_UTC)
if (nbf and exp) and not nbf <= now <= exp:
raise ValueError("This client's key is useable only between {} and {} (UTC)".format(nbf, exp))
if nbf and nbf > now:
raise ValueError("This client's key is not useable until {} (UTC)".format(nbf))
if exp and exp <= now:
raise ValueError("This client's key expired at {} (UTC)".format(exp))
if (not_before and expires_on) and not not_before <= now <= expires_on:
raise ValueError("This client's key is useable only between {} and {} (UTC)".format(not_before, expires_on))
if not_before and not_before > now:
raise ValueError("This client's key is not useable until {} (UTC)".format(not_before))
if expires_on and expires_on <= now:
raise ValueError("This client's key expires_onired at {} (UTC)".format(expires_on))
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@
from ... import KeyType

if TYPE_CHECKING:
from ... import KeyVaultKey
from ... import JsonWebKey


def get_local_cryptography_provider(key):
# type: (KeyVaultKey) -> LocalCryptographyProvider
if key.key_type in (KeyType.ec, KeyType.ec_hsm):
# type: (JsonWebKey) -> LocalCryptographyProvider
if key.kty in (KeyType.ec, KeyType.ec_hsm):
return EllipticCurveCryptographyProvider(key)
if key.key_type in (KeyType.rsa, KeyType.rsa_hsm):
if key.kty in (KeyType.rsa, KeyType.rsa_hsm):
return RsaCryptographyProvider(key)
if key.key_type in (KeyType.oct, KeyType.oct_hsm):
if key.kty in (KeyType.oct, KeyType.oct_hsm):
return SymmetricCryptographyProvider(key)

raise ValueError('Unsupported key type "{}"'.format(key.key_type))
raise ValueError('Unsupported key type "{}"'.format(key.kty))


class NoLocalCryptography(LocalCryptographyProvider):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@
# pylint:disable=unused-import
from .local_provider import Algorithm
from .._internal import Key
from ... import KeyVaultKey
from ... import JsonWebKey

_PRIVATE_KEY_OPERATIONS = frozenset((KeyOperation.decrypt, KeyOperation.sign, KeyOperation.unwrap_key))


class EllipticCurveCryptographyProvider(LocalCryptographyProvider):
def _get_internal_key(self, key):
# type: (KeyVaultKey) -> Key
if key.key_type not in (KeyType.ec, KeyType.ec_hsm):
# type: (JsonWebKey) -> Key
if key.kty not in (KeyType.ec, KeyType.ec_hsm):
raise ValueError('"key" must be an EC or EC-HSM key')
return EllipticCurveKey.from_jwk(key.key)
return EllipticCurveKey.from_jwk(key)

def supports(self, operation, algorithm):
# type: (KeyOperation, Algorithm) -> bool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,24 @@

if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Union
from typing import Any, Optional, Union
from .._internal.key import Key
from .. import EncryptionAlgorithm, KeyWrapAlgorithm, SignatureAlgorithm
from ... import KeyVaultKey
from ... import JsonWebKey

Algorithm = Union[EncryptionAlgorithm, KeyWrapAlgorithm, SignatureAlgorithm]


class LocalCryptographyProvider(ABC):
def __init__(self, key):
# type: (KeyVaultKey) -> None
self._allowed_ops = frozenset(key.key_operations)
# type: (JsonWebKey, **Any) -> None
self._allowed_ops = frozenset(key.key_ops or [])
self._internal_key = self._get_internal_key(key)
self._key = key

@abc.abstractmethod
def _get_internal_key(self, key):
# type: (KeyVaultKey) -> Key
# type: (JsonWebKey) -> Key
pass

@abc.abstractmethod
Expand All @@ -44,12 +44,12 @@ def supports(self, operation, algorithm):

@property
def key_id(self):
# type: () -> str
# type: () -> Optional[str]
"""The full identifier of the provider's key.

:rtype: str
:rtype: str or None
"""
return self._key.id
return self._key.kid

def _raise_if_unsupported(self, operation, algorithm):
# type: (KeyOperation, Algorithm) -> None
Expand All @@ -64,34 +64,34 @@ def encrypt(self, algorithm, plaintext):
# type: (EncryptionAlgorithm, bytes) -> EncryptResult
self._raise_if_unsupported(KeyOperation.encrypt, algorithm)
ciphertext = self._internal_key.encrypt(plaintext, algorithm=algorithm.value)
return EncryptResult(key_id=self._key.id, algorithm=algorithm, ciphertext=ciphertext)
return EncryptResult(key_id=self._key.kid, algorithm=algorithm, ciphertext=ciphertext)

def decrypt(self, algorithm, ciphertext):
# type: (EncryptionAlgorithm, bytes) -> DecryptResult
self._raise_if_unsupported(KeyOperation.decrypt, algorithm)
plaintext = self._internal_key.decrypt(ciphertext, iv=None, algorithm=algorithm.value)
return DecryptResult(key_id=self._key.id, algorithm=algorithm, plaintext=plaintext)
return DecryptResult(key_id=self._key.kid, algorithm=algorithm, plaintext=plaintext)

def wrap_key(self, algorithm, key):
# type: (KeyWrapAlgorithm, bytes) -> WrapResult
self._raise_if_unsupported(KeyOperation.wrap_key, algorithm)
encrypted_key = self._internal_key.wrap_key(key, algorithm=algorithm.value)
return WrapResult(key_id=self._key.id, algorithm=algorithm, encrypted_key=encrypted_key)
return WrapResult(key_id=self._key.kid, algorithm=algorithm, encrypted_key=encrypted_key)

def unwrap_key(self, algorithm, encrypted_key):
# type: (KeyWrapAlgorithm, bytes) -> UnwrapResult
self._raise_if_unsupported(KeyOperation.unwrap_key, algorithm)
unwrapped_key = self._internal_key.unwrap_key(encrypted_key, algorithm=algorithm.value)
return UnwrapResult(key_id=self._key.id, algorithm=algorithm, key=unwrapped_key)
return UnwrapResult(key_id=self._key.kid, algorithm=algorithm, key=unwrapped_key)

def sign(self, algorithm, digest):
# type: (SignatureAlgorithm, bytes) -> SignResult
self._raise_if_unsupported(KeyOperation.sign, algorithm)
signature = self._internal_key.sign(digest, algorithm=algorithm.value)
return SignResult(key_id=self._key.id, algorithm=algorithm, signature=signature)
return SignResult(key_id=self._key.kid, algorithm=algorithm, signature=signature)

def verify(self, algorithm, digest, signature):
# type: (SignatureAlgorithm, bytes, bytes) -> VerifyResult
self._raise_if_unsupported(KeyOperation.verify, algorithm)
is_valid = self._internal_key.verify(digest, signature, algorithm=algorithm.value)
return VerifyResult(key_id=self._key.id, algorithm=algorithm, is_valid=is_valid)
return VerifyResult(key_id=self._key.kid, algorithm=algorithm, is_valid=is_valid)
Loading