diff --git a/sdk/servicebus/azure-servicebus/CHANGELOG.md b/sdk/servicebus/azure-servicebus/CHANGELOG.md index 3556002dbdab..122d0f541f31 100644 --- a/sdk/servicebus/azure-servicebus/CHANGELOG.md +++ b/sdk/servicebus/azure-servicebus/CHANGELOG.md @@ -2,6 +2,9 @@ ## 7.0.0b6 (Unreleased) +**Breaking Changes** + +* `ServiceBusClient.close()` now closes spawned senders and receivers. ## 7.0.0b5 (2020-08-10) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_client.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_client.py index 28f253bdfdf2..be27075ea2de 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_client.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_client.py @@ -2,11 +2,12 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from typing import Any, TYPE_CHECKING +from typing import Any, List, TYPE_CHECKING +import logging import uamqp -from ._base_handler import _parse_conn_str, ServiceBusSharedKeyCredential +from ._base_handler import _parse_conn_str, ServiceBusSharedKeyCredential, BaseHandler from ._servicebus_sender import ServiceBusSender from ._servicebus_receiver import ServiceBusReceiver from ._servicebus_session_receiver import ServiceBusSessionReceiver @@ -16,6 +17,8 @@ if TYPE_CHECKING: from azure.core.credentials import TokenCredential +_LOGGER = logging.getLogger(__name__) + class ServiceBusClient(object): """The ServiceBusClient class defines a high level interface for @@ -69,6 +72,7 @@ def __init__( self._auth_uri = "{}/{}".format(self._auth_uri, self._entity_name) # Internal flag for switching whether to apply connection sharing, pending fix in uamqp library self._connection_sharing = False + self._handlers = [] # type: List[BaseHandler] def __enter__(self): if self._connection_sharing: @@ -89,10 +93,22 @@ def _create_uamqp_connection(self): def close(self): # type: () -> None """ - Close down the ServiceBus client and the underlying connection. + Close down the ServiceBus client. + All spawned senders, receivers and underlying connection will be shutdown. :return: None """ + for handler in self._handlers: + try: + handler.close() + except Exception as exception: # pylint: disable=broad-except + _LOGGER.error( + "Client has met an exception when closing the handler: %r. Exception: %r.", + handler._container_id, # pylint: disable=protected-access + exception, + ) + del self._handlers[:] + if self._connection_sharing and self._connection: self._connection.destroy() @@ -157,7 +173,7 @@ def get_queue_sender(self, queue_name, **kwargs): """ # pylint: disable=protected-access - return ServiceBusSender( + handler = ServiceBusSender( fully_qualified_namespace=self.fully_qualified_namespace, queue_name=queue_name, credential=self._credential, @@ -168,6 +184,8 @@ def get_queue_sender(self, queue_name, **kwargs): user_agent=self._config.user_agent, **kwargs ) + self._handlers.append(handler) + return handler def get_queue_receiver(self, queue_name, **kwargs): # type: (str, Any) -> ServiceBusReceiver @@ -205,7 +223,7 @@ def get_queue_receiver(self, queue_name, **kwargs): """ # pylint: disable=protected-access - return ServiceBusReceiver( + handler = ServiceBusReceiver( fully_qualified_namespace=self.fully_qualified_namespace, queue_name=queue_name, credential=self._credential, @@ -216,6 +234,8 @@ def get_queue_receiver(self, queue_name, **kwargs): user_agent=self._config.user_agent, **kwargs ) + self._handlers.append(handler) + return handler def get_queue_deadletter_receiver(self, queue_name, **kwargs): # type: (str, Any) -> ServiceBusReceiver @@ -265,7 +285,7 @@ def get_queue_deadletter_receiver(self, queue_name, **kwargs): queue_name=queue_name, transfer_deadletter=kwargs.get('transfer_deadletter', False) ) - return ServiceBusReceiver( + handler = ServiceBusReceiver( fully_qualified_namespace=self.fully_qualified_namespace, entity_name=entity_name, credential=self._credential, @@ -277,6 +297,8 @@ def get_queue_deadletter_receiver(self, queue_name, **kwargs): user_agent=self._config.user_agent, **kwargs ) + self._handlers.append(handler) + return handler def get_topic_sender(self, topic_name, **kwargs): # type: (str, Any) -> ServiceBusSender @@ -300,7 +322,7 @@ def get_topic_sender(self, topic_name, **kwargs): :caption: Create a new instance of the ServiceBusSender from ServiceBusClient. """ - return ServiceBusSender( + handler = ServiceBusSender( fully_qualified_namespace=self.fully_qualified_namespace, topic_name=topic_name, credential=self._credential, @@ -311,6 +333,8 @@ def get_topic_sender(self, topic_name, **kwargs): user_agent=self._config.user_agent, **kwargs ) + self._handlers.append(handler) + return handler def get_subscription_receiver(self, topic_name, subscription_name, **kwargs): # type: (str, str, Any) -> ServiceBusReceiver @@ -353,7 +377,7 @@ def get_subscription_receiver(self, topic_name, subscription_name, **kwargs): """ # pylint: disable=protected-access - return ServiceBusReceiver( + handler = ServiceBusReceiver( fully_qualified_namespace=self.fully_qualified_namespace, topic_name=topic_name, subscription_name=subscription_name, @@ -365,6 +389,8 @@ def get_subscription_receiver(self, topic_name, subscription_name, **kwargs): user_agent=self._config.user_agent, **kwargs ) + self._handlers.append(handler) + return handler def get_subscription_deadletter_receiver(self, topic_name, subscription_name, **kwargs): # type: (str, str, Any) -> ServiceBusReceiver @@ -416,7 +442,7 @@ def get_subscription_deadletter_receiver(self, topic_name, subscription_name, ** subscription_name=subscription_name, transfer_deadletter=kwargs.get('transfer_deadletter', False) ) - return ServiceBusReceiver( + handler = ServiceBusReceiver( fully_qualified_namespace=self.fully_qualified_namespace, entity_name=entity_name, credential=self._credential, @@ -428,6 +454,8 @@ def get_subscription_deadletter_receiver(self, topic_name, subscription_name, ** user_agent=self._config.user_agent, **kwargs ) + self._handlers.append(handler) + return handler def get_subscription_session_receiver(self, topic_name, subscription_name, session_id=None, **kwargs): # type: (str, str, str, Any) -> ServiceBusSessionReceiver @@ -473,7 +501,7 @@ def get_subscription_session_receiver(self, topic_name, subscription_name, sessi """ # pylint: disable=protected-access - return ServiceBusSessionReceiver( + handler = ServiceBusSessionReceiver( fully_qualified_namespace=self.fully_qualified_namespace, topic_name=topic_name, subscription_name=subscription_name, @@ -486,6 +514,8 @@ def get_subscription_session_receiver(self, topic_name, subscription_name, sessi user_agent=self._config.user_agent, **kwargs ) + self._handlers.append(handler) + return handler def get_queue_session_receiver(self, queue_name, session_id=None, **kwargs): # type: (str, str, Any) -> ServiceBusSessionReceiver @@ -526,7 +556,7 @@ def get_queue_session_receiver(self, queue_name, session_id=None, **kwargs): """ # pylint: disable=protected-access - return ServiceBusSessionReceiver( + handler = ServiceBusSessionReceiver( fully_qualified_namespace=self.fully_qualified_namespace, queue_name=queue_name, credential=self._credential, @@ -538,3 +568,5 @@ def get_queue_session_receiver(self, queue_name, session_id=None, **kwargs): user_agent=self._config.user_agent, **kwargs ) + self._handlers.append(handler) + return handler diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_client_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_client_async.py index 67cd7bff710c..a6827a8ae91a 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_client_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_client_async.py @@ -2,12 +2,13 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from typing import Any, TYPE_CHECKING, Union +from typing import Any, List, TYPE_CHECKING +import logging import uamqp from .._base_handler import _parse_conn_str -from ._base_handler_async import ServiceBusSharedKeyCredential +from ._base_handler_async import ServiceBusSharedKeyCredential, BaseHandler from ._servicebus_sender_async import ServiceBusSender from ._servicebus_receiver_async import ServiceBusReceiver from ._servicebus_session_receiver_async import ServiceBusSessionReceiver @@ -18,6 +19,8 @@ if TYPE_CHECKING: from azure.core.credentials import TokenCredential +_LOGGER = logging.getLogger(__name__) + class ServiceBusClient(object): """The ServiceBusClient class defines a high level interface for @@ -71,6 +74,7 @@ def __init__( self._auth_uri = "{}/{}".format(self._auth_uri, self._entity_name) # Internal flag for switching whether to apply connection sharing, pending fix in uamqp library self._connection_sharing = False + self._handlers = [] # type: List[BaseHandler] async def __aenter__(self): if self._connection_sharing: @@ -133,9 +137,21 @@ async def close(self): # type: () -> None """ Close down the ServiceBus client. + All spawned senders, receivers and underlying connection will be shutdown. :return: None """ + for handler in self._handlers: + try: + await handler.close() + except Exception as exception: # pylint: disable=broad-except + _LOGGER.error( + "Client has met an exception when closing the handler: %r. Exception: %r.", + handler._container_id, # pylint: disable=protected-access + exception, + ) + del self._handlers[:] + if self._connection_sharing and self._connection: await self._connection.destroy_async() @@ -159,7 +175,7 @@ def get_queue_sender(self, queue_name, **kwargs): """ # pylint: disable=protected-access - return ServiceBusSender( + handler = ServiceBusSender( fully_qualified_namespace=self.fully_qualified_namespace, queue_name=queue_name, credential=self._credential, @@ -170,6 +186,8 @@ def get_queue_sender(self, queue_name, **kwargs): user_agent=self._config.user_agent, **kwargs ) + self._handlers.append(handler) + return handler def get_queue_receiver(self, queue_name, **kwargs): # type: (str, Any) -> ServiceBusReceiver @@ -206,7 +224,7 @@ def get_queue_receiver(self, queue_name, **kwargs): """ # pylint: disable=protected-access - return ServiceBusReceiver( + handler = ServiceBusReceiver( fully_qualified_namespace=self.fully_qualified_namespace, queue_name=queue_name, credential=self._credential, @@ -217,6 +235,8 @@ def get_queue_receiver(self, queue_name, **kwargs): user_agent=self._config.user_agent, **kwargs ) + self._handlers.append(handler) + return handler def get_queue_deadletter_receiver(self, queue_name, **kwargs): # type: (str, Any) -> ServiceBusReceiver @@ -266,7 +286,7 @@ def get_queue_deadletter_receiver(self, queue_name, **kwargs): queue_name=queue_name, transfer_deadletter=kwargs.get('transfer_deadletter', False) ) - return ServiceBusReceiver( + handler = ServiceBusReceiver( fully_qualified_namespace=self.fully_qualified_namespace, entity_name=entity_name, credential=self._credential, @@ -278,6 +298,8 @@ def get_queue_deadletter_receiver(self, queue_name, **kwargs): user_agent=self._config.user_agent, **kwargs ) + self._handlers.append(handler) + return handler def get_topic_sender(self, topic_name, **kwargs): # type: (str, Any) -> ServiceBusSender @@ -301,7 +323,7 @@ def get_topic_sender(self, topic_name, **kwargs): :caption: Create a new instance of the ServiceBusSender from ServiceBusClient. """ - return ServiceBusSender( + handler = ServiceBusSender( fully_qualified_namespace=self.fully_qualified_namespace, topic_name=topic_name, credential=self._credential, @@ -312,6 +334,8 @@ def get_topic_sender(self, topic_name, **kwargs): user_agent=self._config.user_agent, **kwargs ) + self._handlers.append(handler) + return handler def get_subscription_receiver(self, topic_name, subscription_name, **kwargs): # type: (str, str, Any) -> ServiceBusReceiver @@ -354,7 +378,7 @@ def get_subscription_receiver(self, topic_name, subscription_name, **kwargs): """ # pylint: disable=protected-access - return ServiceBusReceiver( + handler = ServiceBusReceiver( fully_qualified_namespace=self.fully_qualified_namespace, topic_name=topic_name, subscription_name=subscription_name, @@ -366,6 +390,8 @@ def get_subscription_receiver(self, topic_name, subscription_name, **kwargs): user_agent=self._config.user_agent, **kwargs ) + self._handlers.append(handler) + return handler def get_subscription_deadletter_receiver(self, topic_name, subscription_name, **kwargs): # type: (str, str, Any) -> ServiceBusReceiver @@ -417,7 +443,7 @@ def get_subscription_deadletter_receiver(self, topic_name, subscription_name, ** subscription_name=subscription_name, transfer_deadletter=kwargs.get('transfer_deadletter', False) ) - return ServiceBusReceiver( + handler = ServiceBusReceiver( fully_qualified_namespace=self.fully_qualified_namespace, entity_name=entity_name, credential=self._credential, @@ -429,6 +455,8 @@ def get_subscription_deadletter_receiver(self, topic_name, subscription_name, ** user_agent=self._config.user_agent, **kwargs ) + self._handlers.append(handler) + return handler def get_subscription_session_receiver(self, topic_name, subscription_name, session_id=None, **kwargs): # type: (str, str, str, Any) -> ServiceBusSessionReceiver @@ -474,7 +502,7 @@ def get_subscription_session_receiver(self, topic_name, subscription_name, sessi """ # pylint: disable=protected-access - return ServiceBusSessionReceiver( + handler = ServiceBusSessionReceiver( fully_qualified_namespace=self.fully_qualified_namespace, topic_name=topic_name, subscription_name=subscription_name, @@ -487,6 +515,8 @@ def get_subscription_session_receiver(self, topic_name, subscription_name, sessi user_agent=self._config.user_agent, **kwargs ) + self._handlers.append(handler) + return handler def get_queue_session_receiver(self, queue_name, session_id=None, **kwargs): # type: (str, str, Any) -> ServiceBusSessionReceiver @@ -526,7 +556,7 @@ def get_queue_session_receiver(self, queue_name, session_id=None, **kwargs): """ # pylint: disable=protected-access - return ServiceBusSessionReceiver( + handler = ServiceBusSessionReceiver( fully_qualified_namespace=self.fully_qualified_namespace, queue_name=queue_name, credential=self._credential, @@ -538,3 +568,5 @@ def get_queue_session_receiver(self, queue_name, session_id=None, **kwargs): user_agent=self._config.user_agent, **kwargs ) + self._handlers.append(handler) + return handler diff --git a/sdk/servicebus/azure-servicebus/tests/async_tests/test_sb_client_async.py b/sdk/servicebus/azure-servicebus/tests/async_tests/test_sb_client_async.py new file mode 100644 index 000000000000..351d1fe51a99 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/tests/async_tests/test_sb_client_async.py @@ -0,0 +1,60 @@ +#-------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#-------------------------------------------------------------------------- + + +import logging +import pytest + +from azure.servicebus.aio import ServiceBusClient +from devtools_testutils import AzureMgmtTestCase, CachedResourceGroupPreparer +from servicebus_preparer import CachedServiceBusNamespacePreparer, CachedServiceBusQueuePreparer +from utilities import get_logger + +_logger = get_logger(logging.DEBUG) + + +class ServiceBusClientAsyncTests(AzureMgmtTestCase): + @pytest.mark.liveTest + @pytest.mark.live_test_only + @CachedResourceGroupPreparer() + @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') + @CachedServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) + async def test_async_sb_client_close_spawned_handlers(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + client = ServiceBusClient.from_connection_string(servicebus_namespace_connection_string) + + await client.close() + + # context manager + async with client: + assert len(client._handlers) == 0 + sender = client.get_queue_sender(servicebus_queue.name) + receiver = client.get_queue_receiver(servicebus_queue.name) + await sender._open() + await receiver._open() + + assert sender._handler and sender._running + assert receiver._handler and receiver._running + assert len(client._handlers) == 2 + + assert not sender._handler and not sender._running + assert not receiver._handler and not receiver._running + assert len(client._handlers) == 0 + + # close operation + sender = client.get_queue_sender(servicebus_queue.name) + receiver = client.get_queue_receiver(servicebus_queue.name) + await sender._open() + await receiver._open() + + assert sender._handler and sender._running + assert receiver._handler and receiver._running + assert len(client._handlers) == 2 + + await client.close() + + assert not sender._handler and not sender._running + assert not receiver._handler and not receiver._running + assert len(client._handlers) == 0 diff --git a/sdk/servicebus/azure-servicebus/tests/test_sb_client.py b/sdk/servicebus/azure-servicebus/tests/test_sb_client.py index 8f69482f0c4f..f9cf5fb0e306 100644 --- a/sdk/servicebus/azure-servicebus/tests/test_sb_client.py +++ b/sdk/servicebus/azure-servicebus/tests/test_sb_client.py @@ -28,7 +28,8 @@ ServiceBusTopicPreparer, ServiceBusQueuePreparer, ServiceBusNamespaceAuthorizationRulePreparer, - ServiceBusQueueAuthorizationRulePreparer + ServiceBusQueueAuthorizationRulePreparer, + CachedServiceBusQueuePreparer ) class ServiceBusClientTests(AzureMgmtTestCase): @@ -126,3 +127,45 @@ def test_sb_client_incorrect_queue_conn_str(self, servicebus_queue_authorization with pytest.raises(ServiceBusError): with client.get_queue_sender(wrong_queue.name) as sender: sender.send_messages(Message("test")) + + @pytest.mark.liveTest + @pytest.mark.live_test_only + @CachedResourceGroupPreparer() + @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') + @CachedServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) + def test_sb_client_close_spawned_handlers(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + client = ServiceBusClient.from_connection_string(servicebus_namespace_connection_string) + + client.close() + + # context manager + with client: + assert len(client._handlers) == 0 + sender = client.get_queue_sender(servicebus_queue.name) + receiver = client.get_queue_receiver(servicebus_queue.name) + sender._open() + receiver._open() + + assert sender._handler and sender._running + assert receiver._handler and receiver._running + assert len(client._handlers) == 2 + + assert not sender._handler and not sender._running + assert not receiver._handler and not receiver._running + assert len(client._handlers) == 0 + + # close operation + sender = client.get_queue_sender(servicebus_queue.name) + receiver = client.get_queue_receiver(servicebus_queue.name) + sender._open() + receiver._open() + + assert sender._handler and sender._running + assert receiver._handler and receiver._running + assert len(client._handlers) == 2 + + client.close() + + assert not sender._handler and not sender._running + assert not receiver._handler and not receiver._running + assert len(client._handlers) == 0