diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 7c6e747695..c25f184bbb 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -50,14 +50,7 @@ ) from .client_proto import ResponseHandler from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params -from .helpers import ( - ceil_timeout, - is_ip_address, - noop, - sentinel, - set_exception, - set_result, -) +from .helpers import ceil_timeout, is_ip_address, noop, sentinel from .locks import EventResultOrError from .resolver import DefaultResolver @@ -748,6 +741,35 @@ def expired(self, key: Tuple[str, int]) -> bool: return self._timestamps[key] + self._ttl < monotonic() +def _make_ssl_context(verified: bool) -> SSLContext: + """Create SSL context. + + This method is not async-friendly and should be called from a thread + because it will load certificates from disk and do other blocking I/O. + """ + if ssl is None: + # No ssl support + return None + if verified: + return ssl.create_default_context() + sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.options |= ssl.OP_NO_SSLv3 + sslcontext.check_hostname = False + sslcontext.verify_mode = ssl.CERT_NONE + sslcontext.options |= ssl.OP_NO_COMPRESSION + sslcontext.set_default_verify_paths() + return sslcontext + + +# The default SSLContext objects are created at import time +# since they do blocking I/O to load certificates from disk, +# and imports should always be done before the event loop starts +# or in a thread. +_SSL_CONTEXT_VERIFIED = _make_ssl_context(True) +_SSL_CONTEXT_UNVERIFIED = _make_ssl_context(False) + + class TCPConnector(BaseConnector): """TCP connector. @@ -778,7 +800,6 @@ class TCPConnector(BaseConnector): """ allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"}) - _made_ssl_context: Dict[bool, "asyncio.Future[SSLContext]"] = {} def __init__( self, @@ -982,25 +1003,7 @@ async def _create_connection( return proto - @staticmethod - def _make_ssl_context(verified: bool) -> SSLContext: - """Create SSL context. - - This method is not async-friendly and should be called from a thread - because it will load certificates from disk and do other blocking I/O. - """ - if verified: - return ssl.create_default_context() - sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - sslcontext.options |= ssl.OP_NO_SSLv2 - sslcontext.options |= ssl.OP_NO_SSLv3 - sslcontext.check_hostname = False - sslcontext.verify_mode = ssl.CERT_NONE - sslcontext.options |= ssl.OP_NO_COMPRESSION - sslcontext.set_default_verify_paths() - return sslcontext - - async def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: + def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: """Logic to get the correct SSL context 0. if req.ssl is false, return None @@ -1024,35 +1027,14 @@ async def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: return sslcontext if sslcontext is not True: # not verified or fingerprinted - return await self._make_or_get_ssl_context(False) + return _SSL_CONTEXT_UNVERIFIED sslcontext = self._ssl if isinstance(sslcontext, ssl.SSLContext): return sslcontext if sslcontext is not True: # not verified or fingerprinted - return await self._make_or_get_ssl_context(False) - return await self._make_or_get_ssl_context(True) - - async def _make_or_get_ssl_context(self, verified: bool) -> SSLContext: - """Create or get cached SSL context.""" - try: - return await self._made_ssl_context[verified] - except KeyError: - loop = self._loop - future = loop.create_future() - self._made_ssl_context[verified] = future - try: - result = await loop.run_in_executor( - None, self._make_ssl_context, verified - ) - # BaseException is used since we might get CancelledError - except BaseException as ex: - del self._made_ssl_context[verified] - set_exception(future, ex) - raise - else: - set_result(future, result) - return result + return _SSL_CONTEXT_UNVERIFIED + return _SSL_CONTEXT_VERIFIED def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]: ret = req.ssl @@ -1204,13 +1186,11 @@ async def _start_tls_connection( ) -> Tuple[asyncio.BaseTransport, ResponseHandler]: """Wrap the raw TCP transport with TLS.""" tls_proto = self._factory() # Create a brand new proto for TLS - - # Safety of the `cast()` call here is based on the fact that - # internally `_get_ssl_context()` only returns `None` when - # `req.is_ssl()` evaluates to `False` which is never gonna happen - # in this code path. Of course, it's rather fragile - # maintainability-wise but this is to be solved separately. - sslcontext = cast(ssl.SSLContext, await self._get_ssl_context(req)) + sslcontext = self._get_ssl_context(req) + if TYPE_CHECKING: + # _start_tls_connection is unreachable in the current code path + # if sslcontext is None. + assert sslcontext is not None try: async with ceil_timeout( @@ -1288,7 +1268,7 @@ async def _create_direct_connection( *, client_error: Type[Exception] = ClientConnectorError, ) -> Tuple[asyncio.Transport, ResponseHandler]: - sslcontext = await self._get_ssl_context(req) + sslcontext = self._get_ssl_context(req) fingerprint = self._get_fingerprint(req) host = req.url.raw_host diff --git a/tests/test_connector.py b/tests/test_connector.py index 0129f0cc33..9f9dbe66c2 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -1,5 +1,4 @@ # Tests of http client with custom Connector - import asyncio import gc import hashlib @@ -9,8 +8,9 @@ import sys import uuid from collections import deque +from concurrent import futures from contextlib import closing, suppress -from typing import Any, List, Optional, Type +from typing import Any, List, Literal, Optional from unittest import mock import pytest @@ -18,10 +18,21 @@ from yarl import URL import aiohttp -from aiohttp import client, web -from aiohttp.client import ClientRequest, ClientTimeout +from aiohttp import ( + ClientRequest, + ClientTimeout, + client, + connector as connector_module, + web, +) from aiohttp.client_reqrep import ConnectionKey -from aiohttp.connector import Connection, TCPConnector, _DNSCacheTable +from aiohttp.connector import ( + _SSL_CONTEXT_UNVERIFIED, + _SSL_CONTEXT_VERIFIED, + Connection, + TCPConnector, + _DNSCacheTable, +) from aiohttp.locks import EventResultOrError from aiohttp.test_utils import make_mocked_coro, unused_port from aiohttp.tracing import Trace @@ -1540,23 +1551,11 @@ async def test_tcp_connector_clear_dns_cache_bad_args(loop) -> None: conn.clear_dns_cache("localhost") -async def test_dont_recreate_ssl_context() -> None: - conn = aiohttp.TCPConnector() - ctx = await conn._make_or_get_ssl_context(True) - assert ctx is await conn._make_or_get_ssl_context(True) - - -async def test_dont_recreate_ssl_context2() -> None: - conn = aiohttp.TCPConnector() - ctx = await conn._make_or_get_ssl_context(False) - assert ctx is await conn._make_or_get_ssl_context(False) - - async def test___get_ssl_context1() -> None: conn = aiohttp.TCPConnector() req = mock.Mock() req.is_ssl.return_value = False - assert await conn._get_ssl_context(req) is None + assert conn._get_ssl_context(req) is None async def test___get_ssl_context2(loop) -> None: @@ -1565,7 +1564,7 @@ async def test___get_ssl_context2(loop) -> None: req = mock.Mock() req.is_ssl.return_value = True req.ssl = ctx - assert await conn._get_ssl_context(req) is ctx + assert conn._get_ssl_context(req) is ctx async def test___get_ssl_context3(loop) -> None: @@ -1574,7 +1573,7 @@ async def test___get_ssl_context3(loop) -> None: req = mock.Mock() req.is_ssl.return_value = True req.ssl = True - assert await conn._get_ssl_context(req) is ctx + assert conn._get_ssl_context(req) is ctx async def test___get_ssl_context4(loop) -> None: @@ -1583,9 +1582,7 @@ async def test___get_ssl_context4(loop) -> None: req = mock.Mock() req.is_ssl.return_value = True req.ssl = False - assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context( - False - ) + assert conn._get_ssl_context(req) is _SSL_CONTEXT_UNVERIFIED async def test___get_ssl_context5(loop) -> None: @@ -1594,9 +1591,7 @@ async def test___get_ssl_context5(loop) -> None: req = mock.Mock() req.is_ssl.return_value = True req.ssl = aiohttp.Fingerprint(hashlib.sha256(b"1").digest()) - assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context( - False - ) + assert conn._get_ssl_context(req) is _SSL_CONTEXT_UNVERIFIED async def test___get_ssl_context6() -> None: @@ -1604,7 +1599,7 @@ async def test___get_ssl_context6() -> None: req = mock.Mock() req.is_ssl.return_value = True req.ssl = True - assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context(True) + assert conn._get_ssl_context(req) is _SSL_CONTEXT_VERIFIED async def test_ssl_context_once() -> None: @@ -1616,31 +1611,9 @@ async def test_ssl_context_once() -> None: req = mock.Mock() req.is_ssl.return_value = True req.ssl = True - assert await conn1._get_ssl_context(req) is await conn1._make_or_get_ssl_context( - True - ) - assert await conn2._get_ssl_context(req) is await conn1._make_or_get_ssl_context( - True - ) - assert await conn3._get_ssl_context(req) is await conn1._make_or_get_ssl_context( - True - ) - assert conn1._made_ssl_context is conn2._made_ssl_context is conn3._made_ssl_context - assert True in conn1._made_ssl_context - - -@pytest.mark.parametrize("exception", [OSError, ssl.SSLError, asyncio.CancelledError]) -async def test_ssl_context_creation_raises(exception: Type[BaseException]) -> None: - """Test that we try again if SSLContext creation fails the first time.""" - conn = aiohttp.TCPConnector() - conn._made_ssl_context.clear() - - with mock.patch.object( - conn, "_make_ssl_context", side_effect=exception - ), pytest.raises(exception): - await conn._make_or_get_ssl_context(True) - - assert isinstance(await conn._make_or_get_ssl_context(True), ssl.SSLContext) + assert conn1._get_ssl_context(req) is _SSL_CONTEXT_VERIFIED + assert conn2._get_ssl_context(req) is _SSL_CONTEXT_VERIFIED + assert conn3._get_ssl_context(req) is _SSL_CONTEXT_VERIFIED async def test_close_twice(loop) -> None: @@ -2717,3 +2690,42 @@ async def allow_connection_and_add_dummy_waiter(): ) await connector.close() + + +def test_connector_multiple_event_loop() -> None: + """Test the connector with multiple event loops.""" + + async def async_connect() -> Literal[True]: + conn = aiohttp.TCPConnector() + loop = asyncio.get_running_loop() + req = ClientRequest("GET", URL("https://127.0.0.1"), loop=loop) + with suppress(aiohttp.ClientConnectorError): + with mock.patch.object( + conn._loop, + "create_connection", + autospec=True, + spec_set=True, + side_effect=ssl.CertificateError, + ): + await conn.connect(req, [], ClientTimeout()) + return True + + def test_connect() -> Literal[True]: + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(async_connect()) + finally: + loop.close() + + with futures.ThreadPoolExecutor() as executor: + res_list = [executor.submit(test_connect) for _ in range(2)] + raw_response_list = [res.result() for res in futures.as_completed(res_list)] + + assert raw_response_list == [True, True] + + +def test_default_ssl_context_creation_without_ssl() -> None: + """Verify _make_ssl_context does not raise when ssl is not available.""" + with mock.patch.object(connector_module, "ssl", None): + assert connector_module._make_ssl_context(False) is None + assert connector_module._make_ssl_context(True) is None diff --git a/tests/test_proxy.py b/tests/test_proxy.py index c5e98deb8a..4fa5e93209 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -12,6 +12,7 @@ import aiohttp from aiohttp.client_reqrep import ClientRequest, ClientResponse +from aiohttp.connector import _SSL_CONTEXT_VERIFIED from aiohttp.helpers import TimerNoop from aiohttp.test_utils import make_mocked_coro @@ -817,7 +818,7 @@ async def make_conn(): self.loop.start_tls.assert_called_with( mock.ANY, mock.ANY, - self.loop.run_until_complete(connector._make_or_get_ssl_context(True)), + _SSL_CONTEXT_VERIFIED, server_hostname="www.python.org", ssl_handshake_timeout=mock.ANY, )