diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e5e15cb..f1d3251a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## Unreleased - Add support for HTTPS proxies. Currently only available for async. (#745) +- Handle `sni_hostname` extension with SOCKS proxy. (#774) - Change the type of `Extensions` from `Mapping[Str, Any]` to `MutableMapping[Str, Any]`. (#762) - Handle HTTP/1.1 half-closed connections gracefully. (#641) - Drop Python 3.7 support. (#727) diff --git a/httpcore/_async/socks_proxy.py b/httpcore/_async/socks_proxy.py index f12cb373..08a065d6 100644 --- a/httpcore/_async/socks_proxy.py +++ b/httpcore/_async/socks_proxy.py @@ -216,6 +216,7 @@ def __init__( async def handle_async_request(self, request: Request) -> Response: timeouts = request.extensions.get("timeout", {}) + sni_hostname = request.extensions.get("sni_hostname", None) timeout = timeouts.get("connect", None) async with self._connect_lock: @@ -258,7 +259,8 @@ async def handle_async_request(self, request: Request) -> Response: kwargs = { "ssl_context": ssl_context, - "server_hostname": self._remote_origin.host.decode("ascii"), + "server_hostname": sni_hostname + or self._remote_origin.host.decode("ascii"), "timeout": timeout, } async with Trace("start_tls", logger, request, kwargs) as trace: diff --git a/httpcore/_sync/socks_proxy.py b/httpcore/_sync/socks_proxy.py index 407351d0..502e4d7f 100644 --- a/httpcore/_sync/socks_proxy.py +++ b/httpcore/_sync/socks_proxy.py @@ -216,6 +216,7 @@ def __init__( def handle_request(self, request: Request) -> Response: timeouts = request.extensions.get("timeout", {}) + sni_hostname = request.extensions.get("sni_hostname", None) timeout = timeouts.get("connect", None) with self._connect_lock: @@ -258,7 +259,8 @@ def handle_request(self, request: Request) -> Response: kwargs = { "ssl_context": ssl_context, - "server_hostname": self._remote_origin.host.decode("ascii"), + "server_hostname": sni_hostname + or self._remote_origin.host.decode("ascii"), "timeout": timeout, } with Trace("start_tls", logger, request, kwargs) as trace: