diff --git a/google/cloud/datastore_admin_v1/services/datastore_admin/async_client.py b/google/cloud/datastore_admin_v1/services/datastore_admin/async_client.py index b564e8c1..4b7b0c8d 100644 --- a/google/cloud/datastore_admin_v1/services/datastore_admin/async_client.py +++ b/google/cloud/datastore_admin_v1/services/datastore_admin/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -37,6 +38,7 @@ from google.auth import credentials as ga_credentials # type: ignore from google.oauth2 import service_account # type: ignore + try: OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault, None] except AttributeError: # pragma: NO COVER @@ -250,7 +252,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, DatastoreAdminTransport] = "grpc_asyncio", + transport: Optional[ + Union[str, DatastoreAdminTransport, Callable[..., DatastoreAdminTransport]] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -262,9 +266,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.DatastoreAdminTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,DatastoreAdminTransport,Callable[..., DatastoreAdminTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the DatastoreAdminTransport constructor. + If set to None, a transport is chosen automatically. client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the client. @@ -429,8 +435,8 @@ async def sample_export_entities(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [project_id, labels, entity_filter, output_url_prefix] ) @@ -440,7 +446,10 @@ async def sample_export_entities(): "the individual field arguments should be set." ) - request = datastore_admin.ExportEntitiesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, datastore_admin.ExportEntitiesRequest): + request = datastore_admin.ExportEntitiesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -456,11 +465,9 @@ async def sample_export_entities(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.export_entities, - default_timeout=60.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.export_entities + ] # Certain fields should be provided within the metadata header; # add these here. @@ -613,8 +620,8 @@ async def sample_import_entities(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([project_id, labels, input_url, entity_filter]) if request is not None and has_flattened_params: raise ValueError( @@ -622,7 +629,10 @@ async def sample_import_entities(): "the individual field arguments should be set." ) - request = datastore_admin.ImportEntitiesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, datastore_admin.ImportEntitiesRequest): + request = datastore_admin.ImportEntitiesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -638,11 +648,9 @@ async def sample_import_entities(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.import_entities, - default_timeout=60.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.import_entities + ] # Certain fields should be provided within the metadata header; # add these here. @@ -747,15 +755,16 @@ async def sample_create_index(): """ # Create or coerce a protobuf request object. - request = datastore_admin.CreateIndexRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, datastore_admin.CreateIndexRequest): + request = datastore_admin.CreateIndexRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_index, - default_timeout=60.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_index + ] # Certain fields should be provided within the metadata header; # add these here. @@ -859,15 +868,16 @@ async def sample_delete_index(): """ # Create or coerce a protobuf request object. - request = datastore_admin.DeleteIndexRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, datastore_admin.DeleteIndexRequest): + request = datastore_admin.DeleteIndexRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_index, - default_timeout=60.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_index + ] # Certain fields should be provided within the metadata header; # add these here. @@ -952,25 +962,16 @@ async def sample_get_index(): Datastore composite index definition. """ # Create or coerce a protobuf request object. - request = datastore_admin.GetIndexRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, datastore_admin.GetIndexRequest): + request = datastore_admin.GetIndexRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_index, - default_retry=retries.AsyncRetry( - initial=0.1, - maximum=60.0, - multiplier=1.3, - predicate=retries.if_exception_type( - core_exceptions.DeadlineExceeded, - core_exceptions.ServiceUnavailable, - ), - deadline=60.0, - ), - default_timeout=60.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_index + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1056,25 +1057,16 @@ async def sample_list_indexes(): """ # Create or coerce a protobuf request object. - request = datastore_admin.ListIndexesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, datastore_admin.ListIndexesRequest): + request = datastore_admin.ListIndexesRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_indexes, - default_retry=retries.AsyncRetry( - initial=0.1, - maximum=60.0, - multiplier=1.3, - predicate=retries.if_exception_type( - core_exceptions.DeadlineExceeded, - core_exceptions.ServiceUnavailable, - ), - deadline=60.0, - ), - default_timeout=60.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_indexes + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/datastore_admin_v1/services/datastore_admin/client.py b/google/cloud/datastore_admin_v1/services/datastore_admin/client.py index de174f58..e6f35ba3 100644 --- a/google/cloud/datastore_admin_v1/services/datastore_admin/client.py +++ b/google/cloud/datastore_admin_v1/services/datastore_admin/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -565,7 +566,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, DatastoreAdminTransport]] = None, + transport: Optional[ + Union[str, DatastoreAdminTransport, Callable[..., DatastoreAdminTransport]] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -577,9 +580,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, DatastoreAdminTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,DatastoreAdminTransport,Callable[..., DatastoreAdminTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the DatastoreAdminTransport constructor. + If set to None, a transport is chosen automatically. client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the client. @@ -688,8 +693,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[DatastoreAdminTransport], Callable[..., DatastoreAdminTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., DatastoreAdminTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -822,8 +834,8 @@ def sample_export_entities(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [project_id, labels, entity_filter, output_url_prefix] ) @@ -833,10 +845,8 @@ def sample_export_entities(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a datastore_admin.ExportEntitiesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, datastore_admin.ExportEntitiesRequest): request = datastore_admin.ExportEntitiesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1005,8 +1015,8 @@ def sample_import_entities(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([project_id, labels, input_url, entity_filter]) if request is not None and has_flattened_params: raise ValueError( @@ -1014,10 +1024,8 @@ def sample_import_entities(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a datastore_admin.ImportEntitiesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, datastore_admin.ImportEntitiesRequest): request = datastore_admin.ImportEntitiesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1138,10 +1146,8 @@ def sample_create_index(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a datastore_admin.CreateIndexRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, datastore_admin.CreateIndexRequest): request = datastore_admin.CreateIndexRequest(request) @@ -1251,10 +1257,8 @@ def sample_delete_index(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a datastore_admin.DeleteIndexRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, datastore_admin.DeleteIndexRequest): request = datastore_admin.DeleteIndexRequest(request) @@ -1345,10 +1349,8 @@ def sample_get_index(): Datastore composite index definition. """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a datastore_admin.GetIndexRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, datastore_admin.GetIndexRequest): request = datastore_admin.GetIndexRequest(request) @@ -1440,10 +1442,8 @@ def sample_list_indexes(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a datastore_admin.ListIndexesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, datastore_admin.ListIndexesRequest): request = datastore_admin.ListIndexesRequest(request) diff --git a/google/cloud/datastore_admin_v1/services/datastore_admin/transports/base.py b/google/cloud/datastore_admin_v1/services/datastore_admin/transports/base.py index bddab490..8c3a00f3 100644 --- a/google/cloud/datastore_admin_v1/services/datastore_admin/transports/base.py +++ b/google/cloud/datastore_admin_v1/services/datastore_admin/transports/base.py @@ -88,6 +88,8 @@ def __init__( # Save the scopes. self._scopes = scopes + if not hasattr(self, "_ignore_credentials"): + self._ignore_credentials: bool = False # If no credentials are provided, then determine the appropriate # defaults. @@ -100,7 +102,7 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: + elif credentials is None and not self._ignore_credentials: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id ) diff --git a/google/cloud/datastore_admin_v1/services/datastore_admin/transports/grpc.py b/google/cloud/datastore_admin_v1/services/datastore_admin/transports/grpc.py index 68867594..4d08d9c4 100644 --- a/google/cloud/datastore_admin_v1/services/datastore_admin/transports/grpc.py +++ b/google/cloud/datastore_admin_v1/services/datastore_admin/transports/grpc.py @@ -106,7 +106,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -126,14 +126,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -143,11 +146,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -174,9 +177,10 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. - credentials = False + credentials = None + self._ignore_credentials = True # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None @@ -215,7 +219,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/datastore_admin_v1/services/datastore_admin/transports/grpc_asyncio.py b/google/cloud/datastore_admin_v1/services/datastore_admin/transports/grpc_asyncio.py index 367a5ab6..7526fc5c 100644 --- a/google/cloud/datastore_admin_v1/services/datastore_admin/transports/grpc_asyncio.py +++ b/google/cloud/datastore_admin_v1/services/datastore_admin/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -121,7 +123,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -151,7 +152,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -171,15 +172,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -189,11 +193,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -220,9 +224,10 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. - credentials = False + credentials = None + self._ignore_credentials = True # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None @@ -260,7 +265,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -517,6 +524,61 @@ def list_indexes( ) return self._stubs["list_indexes"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.export_entities: gapic_v1.method_async.wrap_method( + self.export_entities, + default_timeout=60.0, + client_info=client_info, + ), + self.import_entities: gapic_v1.method_async.wrap_method( + self.import_entities, + default_timeout=60.0, + client_info=client_info, + ), + self.create_index: gapic_v1.method_async.wrap_method( + self.create_index, + default_timeout=60.0, + client_info=client_info, + ), + self.delete_index: gapic_v1.method_async.wrap_method( + self.delete_index, + default_timeout=60.0, + client_info=client_info, + ), + self.get_index: gapic_v1.method_async.wrap_method( + self.get_index, + default_retry=retries.AsyncRetry( + initial=0.1, + maximum=60.0, + multiplier=1.3, + predicate=retries.if_exception_type( + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + ), + deadline=60.0, + ), + default_timeout=60.0, + client_info=client_info, + ), + self.list_indexes: gapic_v1.method_async.wrap_method( + self.list_indexes, + default_retry=retries.AsyncRetry( + initial=0.1, + maximum=60.0, + multiplier=1.3, + predicate=retries.if_exception_type( + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + ), + deadline=60.0, + ), + default_timeout=60.0, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/datastore_v1/__init__.py b/google/cloud/datastore_v1/__init__.py index 8c9d09fe..b1855aff 100644 --- a/google/cloud/datastore_v1/__init__.py +++ b/google/cloud/datastore_v1/__init__.py @@ -33,6 +33,7 @@ from .types.datastore import LookupResponse from .types.datastore import Mutation from .types.datastore import MutationResult +from .types.datastore import PropertyMask from .types.datastore import ReadOptions from .types.datastore import ReserveIdsRequest from .types.datastore import ReserveIdsResponse @@ -98,6 +99,7 @@ "PlanSummary", "Projection", "PropertyFilter", + "PropertyMask", "PropertyOrder", "PropertyReference", "Query", diff --git a/google/cloud/datastore_v1/services/datastore/async_client.py b/google/cloud/datastore_v1/services/datastore/async_client.py index e911a362..d6c347f6 100644 --- a/google/cloud/datastore_v1/services/datastore/async_client.py +++ b/google/cloud/datastore_v1/services/datastore/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -37,6 +38,7 @@ from google.auth import credentials as ga_credentials # type: ignore from google.oauth2 import service_account # type: ignore + try: OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault, None] except AttributeError: # pragma: NO COVER @@ -197,7 +199,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, DatastoreTransport] = "grpc_asyncio", + transport: Optional[ + Union[str, DatastoreTransport, Callable[..., DatastoreTransport]] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -209,9 +213,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.DatastoreTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,DatastoreTransport,Callable[..., DatastoreTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the DatastoreTransport constructor. + If set to None, a transport is chosen automatically. client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the client. @@ -330,8 +336,8 @@ async def sample_lookup(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([project_id, read_options, keys]) if request is not None and has_flattened_params: raise ValueError( @@ -339,7 +345,10 @@ async def sample_lookup(): "the individual field arguments should be set." ) - request = datastore.LookupRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, datastore.LookupRequest): + request = datastore.LookupRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -352,21 +361,7 @@ async def sample_lookup(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.lookup, - default_retry=retries.AsyncRetry( - initial=0.1, - maximum=60.0, - multiplier=1.3, - predicate=retries.if_exception_type( - core_exceptions.DeadlineExceeded, - core_exceptions.ServiceUnavailable, - ), - deadline=60.0, - ), - default_timeout=60.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[self._client._transport.lookup] # Certain fields should be provided within the metadata header; # add these here. @@ -443,25 +438,16 @@ async def sample_run_query(): """ # Create or coerce a protobuf request object. - request = datastore.RunQueryRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, datastore.RunQueryRequest): + request = datastore.RunQueryRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.run_query, - default_retry=retries.AsyncRetry( - initial=0.1, - maximum=60.0, - multiplier=1.3, - predicate=retries.if_exception_type( - core_exceptions.DeadlineExceeded, - core_exceptions.ServiceUnavailable, - ), - deadline=60.0, - ), - default_timeout=60.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.run_query + ] # Certain fields should be provided within the metadata header; # add these here. @@ -538,25 +524,16 @@ async def sample_run_aggregation_query(): """ # Create or coerce a protobuf request object. - request = datastore.RunAggregationQueryRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, datastore.RunAggregationQueryRequest): + request = datastore.RunAggregationQueryRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.run_aggregation_query, - default_retry=retries.AsyncRetry( - initial=0.1, - maximum=60.0, - multiplier=1.3, - predicate=retries.if_exception_type( - core_exceptions.DeadlineExceeded, - core_exceptions.ServiceUnavailable, - ), - deadline=60.0, - ), - default_timeout=60.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.run_aggregation_query + ] # Certain fields should be provided within the metadata header; # add these here. @@ -641,8 +618,8 @@ async def sample_begin_transaction(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([project_id]) if request is not None and has_flattened_params: raise ValueError( @@ -650,7 +627,10 @@ async def sample_begin_transaction(): "the individual field arguments should be set." ) - request = datastore.BeginTransactionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, datastore.BeginTransactionRequest): + request = datastore.BeginTransactionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -659,11 +639,9 @@ async def sample_begin_transaction(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.begin_transaction, - default_timeout=60.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.begin_transaction + ] # Certain fields should be provided within the metadata header; # add these here. @@ -788,8 +766,8 @@ async def sample_commit(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([project_id, mode, transaction, mutations]) if request is not None and has_flattened_params: raise ValueError( @@ -797,7 +775,10 @@ async def sample_commit(): "the individual field arguments should be set." ) - request = datastore.CommitRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, datastore.CommitRequest): + request = datastore.CommitRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -812,11 +793,7 @@ async def sample_commit(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.commit, - default_timeout=60.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[self._client._transport.commit] # Certain fields should be provided within the metadata header; # add these here. @@ -912,8 +889,8 @@ async def sample_rollback(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([project_id, transaction]) if request is not None and has_flattened_params: raise ValueError( @@ -921,7 +898,10 @@ async def sample_rollback(): "the individual field arguments should be set." ) - request = datastore.RollbackRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, datastore.RollbackRequest): + request = datastore.RollbackRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -932,11 +912,7 @@ async def sample_rollback(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.rollback, - default_timeout=60.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[self._client._transport.rollback] # Certain fields should be provided within the metadata header; # add these here. @@ -1032,8 +1008,8 @@ async def sample_allocate_ids(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([project_id, keys]) if request is not None and has_flattened_params: raise ValueError( @@ -1041,7 +1017,10 @@ async def sample_allocate_ids(): "the individual field arguments should be set." ) - request = datastore.AllocateIdsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, datastore.AllocateIdsRequest): + request = datastore.AllocateIdsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1052,11 +1031,9 @@ async def sample_allocate_ids(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.allocate_ids, - default_timeout=60.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.allocate_ids + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1151,8 +1128,8 @@ async def sample_reserve_ids(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([project_id, keys]) if request is not None and has_flattened_params: raise ValueError( @@ -1160,7 +1137,10 @@ async def sample_reserve_ids(): "the individual field arguments should be set." ) - request = datastore.ReserveIdsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, datastore.ReserveIdsRequest): + request = datastore.ReserveIdsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1171,21 +1151,9 @@ async def sample_reserve_ids(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.reserve_ids, - default_retry=retries.AsyncRetry( - initial=0.1, - maximum=60.0, - multiplier=1.3, - predicate=retries.if_exception_type( - core_exceptions.DeadlineExceeded, - core_exceptions.ServiceUnavailable, - ), - deadline=60.0, - ), - default_timeout=60.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.reserve_ids + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/datastore_v1/services/datastore/client.py b/google/cloud/datastore_v1/services/datastore/client.py index 0a498175..6c3cb802 100644 --- a/google/cloud/datastore_v1/services/datastore/client.py +++ b/google/cloud/datastore_v1/services/datastore/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -516,7 +517,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, DatastoreTransport]] = None, + transport: Optional[ + Union[str, DatastoreTransport, Callable[..., DatastoreTransport]] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -528,9 +531,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, DatastoreTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,DatastoreTransport,Callable[..., DatastoreTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the DatastoreTransport constructor. + If set to None, a transport is chosen automatically. client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the client. @@ -636,8 +641,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[DatastoreTransport], Callable[..., DatastoreTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., DatastoreTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -724,8 +736,8 @@ def sample_lookup(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([project_id, read_options, keys]) if request is not None and has_flattened_params: raise ValueError( @@ -733,10 +745,8 @@ def sample_lookup(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a datastore.LookupRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, datastore.LookupRequest): request = datastore.LookupRequest(request) # If we have keyword arguments corresponding to fields on the @@ -832,10 +842,8 @@ def sample_run_query(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a datastore.RunQueryRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, datastore.RunQueryRequest): request = datastore.RunQueryRequest(request) @@ -923,10 +931,8 @@ def sample_run_aggregation_query(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a datastore.RunAggregationQueryRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, datastore.RunAggregationQueryRequest): request = datastore.RunAggregationQueryRequest(request) @@ -1022,8 +1028,8 @@ def sample_begin_transaction(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([project_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1031,10 +1037,8 @@ def sample_begin_transaction(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a datastore.BeginTransactionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, datastore.BeginTransactionRequest): request = datastore.BeginTransactionRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1174,8 +1178,8 @@ def sample_commit(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([project_id, mode, transaction, mutations]) if request is not None and has_flattened_params: raise ValueError( @@ -1183,10 +1187,8 @@ def sample_commit(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a datastore.CommitRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, datastore.CommitRequest): request = datastore.CommitRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1303,8 +1305,8 @@ def sample_rollback(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([project_id, transaction]) if request is not None and has_flattened_params: raise ValueError( @@ -1312,10 +1314,8 @@ def sample_rollback(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a datastore.RollbackRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, datastore.RollbackRequest): request = datastore.RollbackRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1428,8 +1428,8 @@ def sample_allocate_ids(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([project_id, keys]) if request is not None and has_flattened_params: raise ValueError( @@ -1437,10 +1437,8 @@ def sample_allocate_ids(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a datastore.AllocateIdsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, datastore.AllocateIdsRequest): request = datastore.AllocateIdsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1552,8 +1550,8 @@ def sample_reserve_ids(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([project_id, keys]) if request is not None and has_flattened_params: raise ValueError( @@ -1561,10 +1559,8 @@ def sample_reserve_ids(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a datastore.ReserveIdsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, datastore.ReserveIdsRequest): request = datastore.ReserveIdsRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/datastore_v1/services/datastore/transports/base.py b/google/cloud/datastore_v1/services/datastore/transports/base.py index 3c31a4a7..db08f5b4 100644 --- a/google/cloud/datastore_v1/services/datastore/transports/base.py +++ b/google/cloud/datastore_v1/services/datastore/transports/base.py @@ -86,6 +86,8 @@ def __init__( # Save the scopes. self._scopes = scopes + if not hasattr(self, "_ignore_credentials"): + self._ignore_credentials: bool = False # If no credentials are provided, then determine the appropriate # defaults. @@ -98,7 +100,7 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: + elif credentials is None and not self._ignore_credentials: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id ) diff --git a/google/cloud/datastore_v1/services/datastore/transports/grpc.py b/google/cloud/datastore_v1/services/datastore/transports/grpc.py index ebc16b21..620576e2 100644 --- a/google/cloud/datastore_v1/services/datastore/transports/grpc.py +++ b/google/cloud/datastore_v1/services/datastore/transports/grpc.py @@ -57,7 +57,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -77,14 +77,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -94,11 +97,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -124,9 +127,10 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. - credentials = False + credentials = None + self._ignore_credentials = True # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None @@ -165,7 +169,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/datastore_v1/services/datastore/transports/grpc_asyncio.py b/google/cloud/datastore_v1/services/datastore/transports/grpc_asyncio.py index 7b3997dd..b826d7c6 100644 --- a/google/cloud/datastore_v1/services/datastore/transports/grpc_asyncio.py +++ b/google/cloud/datastore_v1/services/datastore/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -72,7 +74,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -102,7 +103,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -122,15 +123,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -140,11 +144,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -170,9 +174,10 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. - credentials = False + credentials = None + self._ignore_credentials = True # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None @@ -210,7 +215,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -460,6 +467,91 @@ def reserve_ids( ) return self._stubs["reserve_ids"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.lookup: gapic_v1.method_async.wrap_method( + self.lookup, + default_retry=retries.AsyncRetry( + initial=0.1, + maximum=60.0, + multiplier=1.3, + predicate=retries.if_exception_type( + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + ), + deadline=60.0, + ), + default_timeout=60.0, + client_info=client_info, + ), + self.run_query: gapic_v1.method_async.wrap_method( + self.run_query, + default_retry=retries.AsyncRetry( + initial=0.1, + maximum=60.0, + multiplier=1.3, + predicate=retries.if_exception_type( + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + ), + deadline=60.0, + ), + default_timeout=60.0, + client_info=client_info, + ), + self.run_aggregation_query: gapic_v1.method_async.wrap_method( + self.run_aggregation_query, + default_retry=retries.AsyncRetry( + initial=0.1, + maximum=60.0, + multiplier=1.3, + predicate=retries.if_exception_type( + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + ), + deadline=60.0, + ), + default_timeout=60.0, + client_info=client_info, + ), + self.begin_transaction: gapic_v1.method_async.wrap_method( + self.begin_transaction, + default_timeout=60.0, + client_info=client_info, + ), + self.commit: gapic_v1.method_async.wrap_method( + self.commit, + default_timeout=60.0, + client_info=client_info, + ), + self.rollback: gapic_v1.method_async.wrap_method( + self.rollback, + default_timeout=60.0, + client_info=client_info, + ), + self.allocate_ids: gapic_v1.method_async.wrap_method( + self.allocate_ids, + default_timeout=60.0, + client_info=client_info, + ), + self.reserve_ids: gapic_v1.method_async.wrap_method( + self.reserve_ids, + default_retry=retries.AsyncRetry( + initial=0.1, + maximum=60.0, + multiplier=1.3, + predicate=retries.if_exception_type( + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + ), + deadline=60.0, + ), + default_timeout=60.0, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/datastore_v1/types/__init__.py b/google/cloud/datastore_v1/types/__init__.py index 6aa3d846..3ae809b4 100644 --- a/google/cloud/datastore_v1/types/__init__.py +++ b/google/cloud/datastore_v1/types/__init__.py @@ -28,6 +28,7 @@ LookupResponse, Mutation, MutationResult, + PropertyMask, ReadOptions, ReserveIdsRequest, ReserveIdsResponse, @@ -81,6 +82,7 @@ "LookupResponse", "Mutation", "MutationResult", + "PropertyMask", "ReadOptions", "ReserveIdsRequest", "ReserveIdsResponse", diff --git a/google/cloud/datastore_v1/types/datastore.py b/google/cloud/datastore_v1/types/datastore.py index ccea0458..11974c3d 100644 --- a/google/cloud/datastore_v1/types/datastore.py +++ b/google/cloud/datastore_v1/types/datastore.py @@ -47,6 +47,7 @@ "ReserveIdsResponse", "Mutation", "MutationResult", + "PropertyMask", "ReadOptions", "TransactionOptions", }, @@ -70,6 +71,15 @@ class LookupRequest(proto.Message): The options for this lookup request. keys (MutableSequence[google.cloud.datastore_v1.types.Key]): Required. Keys of entities to look up. + property_mask (google.cloud.datastore_v1.types.PropertyMask): + The properties to return. Defaults to returning all + properties. + + If this field is set and an entity has a property not + referenced in the mask, it will be absent from + [LookupResponse.found.entity.properties][]. + + The entity's key is always returned. """ project_id: str = proto.Field( @@ -90,6 +100,11 @@ class LookupRequest(proto.Message): number=3, message=entity.Key, ) + property_mask: "PropertyMask" = proto.Field( + proto.MESSAGE, + number=5, + message="PropertyMask", + ) class LookupResponse(proto.Message): @@ -186,6 +201,12 @@ class RunQueryRequest(proto.Message): non-aggregation query. This field is a member of `oneof`_ ``query_type``. + property_mask (google.cloud.datastore_v1.types.PropertyMask): + The properties to return. This field must not be set for a + projection query. + + See + [LookupRequest.property_mask][google.datastore.v1.LookupRequest.property_mask]. explain_options (google.cloud.datastore_v1.types.ExplainOptions): Optional. Explain options for the query. If set, additional query statistics will be @@ -223,6 +244,11 @@ class RunQueryRequest(proto.Message): oneof="query_type", message=gd_query.GqlQuery, ) + property_mask: "PropertyMask" = proto.Field( + proto.MESSAGE, + number=10, + message="PropertyMask", + ) explain_options: query_profile.ExplainOptions = proto.Field( proto.MESSAGE, number=12, @@ -770,6 +796,14 @@ class Mutation(proto.Message): mutation conflicts. This field is a member of `oneof`_ ``conflict_detection_strategy``. + property_mask (google.cloud.datastore_v1.types.PropertyMask): + The properties to write in this mutation. None of the + properties in the mask may have a reserved name, except for + ``__key__``. This field is ignored for ``delete``. + + If the entity already exists, only properties referenced in + the mask are updated, others are left untouched. Properties + referenced in the mask but not in the entity are deleted. """ insert: entity.Entity = proto.Field( @@ -807,6 +841,11 @@ class Mutation(proto.Message): oneof="conflict_detection_strategy", message=timestamp_pb2.Timestamp, ) + property_mask: "PropertyMask" = proto.Field( + proto.MESSAGE, + number=9, + message="PropertyMask", + ) class MutationResult(proto.Message): @@ -866,6 +905,33 @@ class MutationResult(proto.Message): ) +class PropertyMask(proto.Message): + r"""The set of arbitrarily nested property paths used to restrict + an operation to only a subset of properties in an entity. + + Attributes: + paths (MutableSequence[str]): + The paths to the properties covered by this mask. + + A path is a list of property names separated by dots + (``.``), for example ``foo.bar`` means the property ``bar`` + inside the entity property ``foo`` inside the entity + associated with this path. + + If a property name contains a dot ``.`` or a backslash + ``\``, then that name must be escaped. + + A path must not be empty, and may not reference a value + inside an [array + value][google.datastore.v1.Value.array_value]. + """ + + paths: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=1, + ) + + class ReadOptions(proto.Message): r"""The options shared by read requests. diff --git a/scripts/fixup_datastore_v1_keywords.py b/scripts/fixup_datastore_v1_keywords.py index f0406904..661d509b 100644 --- a/scripts/fixup_datastore_v1_keywords.py +++ b/scripts/fixup_datastore_v1_keywords.py @@ -42,11 +42,11 @@ class datastoreCallTransformer(cst.CSTTransformer): 'allocate_ids': ('project_id', 'keys', 'database_id', ), 'begin_transaction': ('project_id', 'database_id', 'transaction_options', ), 'commit': ('project_id', 'database_id', 'mode', 'transaction', 'single_use_transaction', 'mutations', ), - 'lookup': ('project_id', 'keys', 'database_id', 'read_options', ), + 'lookup': ('project_id', 'keys', 'database_id', 'read_options', 'property_mask', ), 'reserve_ids': ('project_id', 'keys', 'database_id', ), 'rollback': ('project_id', 'transaction', 'database_id', ), 'run_aggregation_query': ('project_id', 'database_id', 'partition_id', 'read_options', 'aggregation_query', 'gql_query', 'explain_options', ), - 'run_query': ('project_id', 'database_id', 'partition_id', 'read_options', 'query', 'gql_query', 'explain_options', ), + 'run_query': ('project_id', 'database_id', 'partition_id', 'read_options', 'query', 'gql_query', 'property_mask', 'explain_options', ), } def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode: diff --git a/tests/unit/gapic/datastore_admin_v1/test_datastore_admin.py b/tests/unit/gapic/datastore_admin_v1/test_datastore_admin.py index c08b309a..8e65052c 100644 --- a/tests/unit/gapic/datastore_admin_v1/test_datastore_admin.py +++ b/tests/unit/gapic/datastore_admin_v1/test_datastore_admin.py @@ -1176,6 +1176,9 @@ def test_export_entities_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.export_entities), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.export_entities() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1200,6 +1203,9 @@ def test_export_entities_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.export_entities), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.export_entities(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1209,6 +1215,45 @@ def test_export_entities_non_empty_request_with_auto_populated_field(): ) +def test_export_entities_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.export_entities in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.export_entities] = mock_rpc + request = {} + client.export_entities(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.export_entities(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_export_entities_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1230,6 +1275,51 @@ async def test_export_entities_empty_call_async(): assert args[0] == datastore_admin.ExportEntitiesRequest() +@pytest.mark.asyncio +async def test_export_entities_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatastoreAdminAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.export_entities + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_object = mock.AsyncMock() + client._client._transport._wrapped_methods[ + client._client._transport.export_entities + ] = mock_object + + request = {} + await client.export_entities(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.export_entities(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_export_entities_async( transport: str = "grpc_asyncio", request_type=datastore_admin.ExportEntitiesRequest @@ -1482,6 +1572,9 @@ def test_import_entities_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.import_entities), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.import_entities() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1506,6 +1599,9 @@ def test_import_entities_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.import_entities), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.import_entities(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1515,6 +1611,45 @@ def test_import_entities_non_empty_request_with_auto_populated_field(): ) +def test_import_entities_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.import_entities in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.import_entities] = mock_rpc + request = {} + client.import_entities(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.import_entities(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_import_entities_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1536,6 +1671,51 @@ async def test_import_entities_empty_call_async(): assert args[0] == datastore_admin.ImportEntitiesRequest() +@pytest.mark.asyncio +async def test_import_entities_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatastoreAdminAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.import_entities + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_object = mock.AsyncMock() + client._client._transport._wrapped_methods[ + client._client._transport.import_entities + ] = mock_object + + request = {} + await client.import_entities(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.import_entities(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_import_entities_async( transport: str = "grpc_asyncio", request_type=datastore_admin.ImportEntitiesRequest @@ -1788,6 +1968,9 @@ def test_create_index_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_index() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1811,6 +1994,9 @@ def test_create_index_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_index(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1819,6 +2005,45 @@ def test_create_index_non_empty_request_with_auto_populated_field(): ) +def test_create_index_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_index] = mock_rpc + request = {} + client.create_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_index_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1840,6 +2065,51 @@ async def test_create_index_empty_call_async(): assert args[0] == datastore_admin.CreateIndexRequest() +@pytest.mark.asyncio +async def test_create_index_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatastoreAdminAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_index + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_object = mock.AsyncMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_index + ] = mock_object + + request = {} + await client.create_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_index_async( transport: str = "grpc_asyncio", request_type=datastore_admin.CreateIndexRequest @@ -1980,6 +2250,9 @@ def test_delete_index_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_index() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2004,6 +2277,9 @@ def test_delete_index_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_index(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2013,6 +2289,45 @@ def test_delete_index_non_empty_request_with_auto_populated_field(): ) +def test_delete_index_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_index] = mock_rpc + request = {} + client.delete_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_index_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2034,6 +2349,51 @@ async def test_delete_index_empty_call_async(): assert args[0] == datastore_admin.DeleteIndexRequest() +@pytest.mark.asyncio +async def test_delete_index_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatastoreAdminAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_index + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_object = mock.AsyncMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_index + ] = mock_object + + request = {} + await client.delete_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_index_async( transport: str = "grpc_asyncio", request_type=datastore_admin.DeleteIndexRequest @@ -2187,6 +2547,9 @@ def test_get_index_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_index() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2211,6 +2574,9 @@ def test_get_index_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_index(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2220,6 +2586,41 @@ def test_get_index_non_empty_request_with_auto_populated_field(): ) +def test_get_index_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_index] = mock_rpc + request = {} + client.get_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_index_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2247,6 +2648,45 @@ async def test_get_index_empty_call_async(): assert args[0] == datastore_admin.GetIndexRequest() +@pytest.mark.asyncio +async def test_get_index_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatastoreAdminAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_index + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_object = mock.AsyncMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_index + ] = mock_object + + request = {} + await client.get_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_index_async( transport: str = "grpc_asyncio", request_type=datastore_admin.GetIndexRequest @@ -2401,6 +2841,9 @@ def test_list_indexes_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_indexes), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_indexes() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2426,6 +2869,9 @@ def test_list_indexes_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_indexes), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_indexes(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2436,6 +2882,41 @@ def test_list_indexes_non_empty_request_with_auto_populated_field(): ) +def test_list_indexes_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_indexes in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_indexes] = mock_rpc + request = {} + client.list_indexes(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_indexes(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_indexes_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2459,6 +2940,47 @@ async def test_list_indexes_empty_call_async(): assert args[0] == datastore_admin.ListIndexesRequest() +@pytest.mark.asyncio +async def test_list_indexes_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatastoreAdminAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_indexes + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_object = mock.AsyncMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_indexes + ] = mock_object + + request = {} + await client.list_indexes(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_indexes(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_indexes_async( transport: str = "grpc_asyncio", request_type=datastore_admin.ListIndexesRequest @@ -2596,13 +3118,13 @@ def test_list_indexes_pager(transport_name: str = "grpc"): RuntimeError, ) - metadata = () - metadata = tuple(metadata) + ( + expected_metadata = () + expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("project_id", ""),)), ) pager = client.list_indexes(request={}) - assert pager._metadata == metadata + assert pager._metadata == expected_metadata results = list(pager) assert len(results) == 6 @@ -2784,6 +3306,46 @@ def test_export_entities_rest(request_type): assert response.operation.name == "operations/spam" +def test_export_entities_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.export_entities in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.export_entities] = mock_rpc + + request = {} + client.export_entities(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.export_entities(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_export_entities_rest_required_fields( request_type=datastore_admin.ExportEntitiesRequest, ): @@ -3060,6 +3622,46 @@ def test_import_entities_rest(request_type): assert response.operation.name == "operations/spam" +def test_import_entities_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.import_entities in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.import_entities] = mock_rpc + + request = {} + client.import_entities(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.import_entities(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_import_entities_rest_required_fields( request_type=datastore_admin.ImportEntitiesRequest, ): @@ -3411,6 +4013,46 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_index_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_index] = mock_rpc + + request = {} + client.create_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.parametrize("null_interceptor", [True, False]) def test_create_index_rest_interceptors(null_interceptor): transport = transports.DatastoreAdminRestTransport( @@ -3535,6 +4177,46 @@ def test_delete_index_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_index_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_index] = mock_rpc + + request = {} + client.delete_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.parametrize("null_interceptor", [True, False]) def test_delete_index_rest_interceptors(null_interceptor): transport = transports.DatastoreAdminRestTransport( @@ -3672,6 +4354,42 @@ def test_get_index_rest(request_type): assert response.state == index.Index.State.CREATING +def test_get_index_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_index] = mock_rpc + + request = {} + client.get_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.parametrize("null_interceptor", [True, False]) def test_get_index_rest_interceptors(null_interceptor): transport = transports.DatastoreAdminRestTransport( @@ -3797,6 +4515,42 @@ def test_list_indexes_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_indexes_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_indexes in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_indexes] = mock_rpc + + request = {} + client.list_indexes(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_indexes(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.parametrize("null_interceptor", [True, False]) def test_list_indexes_rest_interceptors(null_interceptor): transport = transports.DatastoreAdminRestTransport( diff --git a/tests/unit/gapic/datastore_v1/test_datastore.py b/tests/unit/gapic/datastore_v1/test_datastore.py index 203d9c3a..73f3d837 100644 --- a/tests/unit/gapic/datastore_v1/test_datastore.py +++ b/tests/unit/gapic/datastore_v1/test_datastore.py @@ -1131,6 +1131,9 @@ def test_lookup_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.lookup), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.lookup() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1155,6 +1158,9 @@ def test_lookup_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.lookup), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.lookup(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1164,6 +1170,41 @@ def test_lookup_non_empty_request_with_auto_populated_field(): ) +def test_lookup_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.lookup in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.lookup] = mock_rpc + request = {} + client.lookup(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.lookup(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_lookup_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1187,6 +1228,45 @@ async def test_lookup_empty_call_async(): assert args[0] == datastore.LookupRequest() +@pytest.mark.asyncio +async def test_lookup_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatastoreAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.lookup + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_object = mock.AsyncMock() + client._client._transport._wrapped_methods[ + client._client._transport.lookup + ] = mock_object + + request = {} + await client.lookup(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.lookup(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_lookup_async( transport: str = "grpc_asyncio", request_type=datastore.LookupRequest @@ -1447,6 +1527,9 @@ def test_run_query_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.run_query), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.run_query() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1471,6 +1554,9 @@ def test_run_query_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.run_query), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.run_query(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1480,6 +1566,41 @@ def test_run_query_non_empty_request_with_auto_populated_field(): ) +def test_run_query_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.run_query in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.run_query] = mock_rpc + request = {} + client.run_query(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.run_query(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_run_query_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1503,6 +1624,45 @@ async def test_run_query_empty_call_async(): assert args[0] == datastore.RunQueryRequest() +@pytest.mark.asyncio +async def test_run_query_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatastoreAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.run_query + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_object = mock.AsyncMock() + client._client._transport._wrapped_methods[ + client._client._transport.run_query + ] = mock_object + + request = {} + await client.run_query(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.run_query(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_run_query_async( transport: str = "grpc_asyncio", request_type=datastore.RunQueryRequest @@ -1633,6 +1793,9 @@ def test_run_aggregation_query_empty_call(): with mock.patch.object( type(client.transport.run_aggregation_query), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.run_aggregation_query() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1659,6 +1822,9 @@ def test_run_aggregation_query_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.run_aggregation_query), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.run_aggregation_query(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1668,6 +1834,46 @@ def test_run_aggregation_query_non_empty_request_with_auto_populated_field(): ) +def test_run_aggregation_query_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.run_aggregation_query + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.run_aggregation_query + ] = mock_rpc + request = {} + client.run_aggregation_query(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.run_aggregation_query(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_run_aggregation_query_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1693,6 +1899,47 @@ async def test_run_aggregation_query_empty_call_async(): assert args[0] == datastore.RunAggregationQueryRequest() +@pytest.mark.asyncio +async def test_run_aggregation_query_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatastoreAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.run_aggregation_query + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_object = mock.AsyncMock() + client._client._transport._wrapped_methods[ + client._client._transport.run_aggregation_query + ] = mock_object + + request = {} + await client.run_aggregation_query(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.run_aggregation_query(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_run_aggregation_query_async( transport: str = "grpc_asyncio", request_type=datastore.RunAggregationQueryRequest @@ -1829,6 +2076,9 @@ def test_begin_transaction_empty_call(): with mock.patch.object( type(client.transport.begin_transaction), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.begin_transaction() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1855,6 +2105,9 @@ def test_begin_transaction_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.begin_transaction), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.begin_transaction(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1864,6 +2117,43 @@ def test_begin_transaction_non_empty_request_with_auto_populated_field(): ) +def test_begin_transaction_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.begin_transaction in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.begin_transaction + ] = mock_rpc + request = {} + client.begin_transaction(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.begin_transaction(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_begin_transaction_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1889,6 +2179,47 @@ async def test_begin_transaction_empty_call_async(): assert args[0] == datastore.BeginTransactionRequest() +@pytest.mark.asyncio +async def test_begin_transaction_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatastoreAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.begin_transaction + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_object = mock.AsyncMock() + client._client._transport._wrapped_methods[ + client._client._transport.begin_transaction + ] = mock_object + + request = {} + await client.begin_transaction(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.begin_transaction(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_begin_transaction_async( transport: str = "grpc_asyncio", request_type=datastore.BeginTransactionRequest @@ -2107,6 +2438,9 @@ def test_commit_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.commit), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.commit() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2131,6 +2465,9 @@ def test_commit_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.commit), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.commit(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2140,6 +2477,41 @@ def test_commit_non_empty_request_with_auto_populated_field(): ) +def test_commit_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.commit in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.commit] = mock_rpc + request = {} + client.commit(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.commit(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_commit_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2163,6 +2535,45 @@ async def test_commit_empty_call_async(): assert args[0] == datastore.CommitRequest() +@pytest.mark.asyncio +async def test_commit_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatastoreAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.commit + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_object = mock.AsyncMock() + client._client._transport._wrapped_methods[ + client._client._transport.commit + ] = mock_object + + request = {} + await client.commit(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.commit(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_commit_async( transport: str = "grpc_asyncio", request_type=datastore.CommitRequest @@ -2450,6 +2861,9 @@ def test_rollback_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.rollback), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.rollback() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2474,6 +2888,9 @@ def test_rollback_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.rollback), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.rollback(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2483,6 +2900,41 @@ def test_rollback_non_empty_request_with_auto_populated_field(): ) +def test_rollback_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.rollback in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.rollback] = mock_rpc + request = {} + client.rollback(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.rollback(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_rollback_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2504,6 +2956,45 @@ async def test_rollback_empty_call_async(): assert args[0] == datastore.RollbackRequest() +@pytest.mark.asyncio +async def test_rollback_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatastoreAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.rollback + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_object = mock.AsyncMock() + client._client._transport._wrapped_methods[ + client._client._transport.rollback + ] = mock_object + + request = {} + await client.rollback(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.rollback(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_rollback_async( transport: str = "grpc_asyncio", request_type=datastore.RollbackRequest @@ -2716,6 +3207,9 @@ def test_allocate_ids_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.allocate_ids), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.allocate_ids() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2740,6 +3234,9 @@ def test_allocate_ids_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.allocate_ids), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.allocate_ids(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2749,6 +3246,41 @@ def test_allocate_ids_non_empty_request_with_auto_populated_field(): ) +def test_allocate_ids_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.allocate_ids in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.allocate_ids] = mock_rpc + request = {} + client.allocate_ids(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.allocate_ids(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_allocate_ids_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2770,6 +3302,47 @@ async def test_allocate_ids_empty_call_async(): assert args[0] == datastore.AllocateIdsRequest() +@pytest.mark.asyncio +async def test_allocate_ids_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatastoreAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.allocate_ids + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_object = mock.AsyncMock() + client._client._transport._wrapped_methods[ + client._client._transport.allocate_ids + ] = mock_object + + request = {} + await client.allocate_ids(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.allocate_ids(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_allocate_ids_async( transport: str = "grpc_asyncio", request_type=datastore.AllocateIdsRequest @@ -3002,6 +3575,9 @@ def test_reserve_ids_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.reserve_ids), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.reserve_ids() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3026,6 +3602,9 @@ def test_reserve_ids_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.reserve_ids), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.reserve_ids(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3035,6 +3614,41 @@ def test_reserve_ids_non_empty_request_with_auto_populated_field(): ) +def test_reserve_ids_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.reserve_ids in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.reserve_ids] = mock_rpc + request = {} + client.reserve_ids(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.reserve_ids(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_reserve_ids_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3056,6 +3670,47 @@ async def test_reserve_ids_empty_call_async(): assert args[0] == datastore.ReserveIdsRequest() +@pytest.mark.asyncio +async def test_reserve_ids_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatastoreAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.reserve_ids + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_object = mock.AsyncMock() + client._client._transport._wrapped_methods[ + client._client._transport.reserve_ids + ] = mock_object + + request = {} + await client.reserve_ids(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.reserve_ids(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_reserve_ids_async( transport: str = "grpc_asyncio", request_type=datastore.ReserveIdsRequest @@ -3285,6 +3940,42 @@ def test_lookup_rest(request_type): assert response.transaction == b"transaction_blob" +def test_lookup_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.lookup in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.lookup] = mock_rpc + + request = {} + client.lookup(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.lookup(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_lookup_rest_required_fields(request_type=datastore.LookupRequest): transport_class = transports.DatastoreRestTransport @@ -3569,6 +4260,42 @@ def test_run_query_rest(request_type): assert response.transaction == b"transaction_blob" +def test_run_query_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.run_query in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.run_query] = mock_rpc + + request = {} + client.run_query(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.run_query(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_run_query_rest_required_fields(request_type=datastore.RunQueryRequest): transport_class = transports.DatastoreRestTransport @@ -3774,6 +4501,47 @@ def test_run_aggregation_query_rest(request_type): assert response.transaction == b"transaction_blob" +def test_run_aggregation_query_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.run_aggregation_query + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.run_aggregation_query + ] = mock_rpc + + request = {} + client.run_aggregation_query(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.run_aggregation_query(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_run_aggregation_query_rest_required_fields( request_type=datastore.RunAggregationQueryRequest, ): @@ -3983,6 +4751,44 @@ def test_begin_transaction_rest(request_type): assert response.transaction == b"transaction_blob" +def test_begin_transaction_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.begin_transaction in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.begin_transaction + ] = mock_rpc + + request = {} + client.begin_transaction(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.begin_transaction(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_begin_transaction_rest_required_fields( request_type=datastore.BeginTransactionRequest, ): @@ -4248,6 +5054,42 @@ def test_commit_rest(request_type): assert response.index_updates == 1389 +def test_commit_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.commit in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.commit] = mock_rpc + + request = {} + client.commit(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.commit(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_commit_rest_required_fields(request_type=datastore.CommitRequest): transport_class = transports.DatastoreRestTransport @@ -4530,6 +5372,42 @@ def test_rollback_rest(request_type): assert isinstance(response, datastore.RollbackResponse) +def test_rollback_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.rollback in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.rollback] = mock_rpc + + request = {} + client.rollback(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.rollback(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_rollback_rest_required_fields(request_type=datastore.RollbackRequest): transport_class = transports.DatastoreRestTransport @@ -4801,6 +5679,42 @@ def test_allocate_ids_rest(request_type): assert isinstance(response, datastore.AllocateIdsResponse) +def test_allocate_ids_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.allocate_ids in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.allocate_ids] = mock_rpc + + request = {} + client.allocate_ids(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.allocate_ids(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_allocate_ids_rest_required_fields(request_type=datastore.AllocateIdsRequest): transport_class = transports.DatastoreRestTransport @@ -5076,6 +5990,42 @@ def test_reserve_ids_rest(request_type): assert isinstance(response, datastore.ReserveIdsResponse) +def test_reserve_ids_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.reserve_ids in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.reserve_ids] = mock_rpc + + request = {} + client.reserve_ids(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.reserve_ids(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_reserve_ids_rest_required_fields(request_type=datastore.ReserveIdsRequest): transport_class = transports.DatastoreRestTransport