From accae7b3db6ff603327aae243b1868d91bf886b0 Mon Sep 17 00:00:00 2001 From: MtkN1 <51289448+MtkN1@users.noreply.github.com> Date: Wed, 21 Feb 2024 01:06:20 +0900 Subject: [PATCH] Fix support for connection Upgrade and CONNECT when some data in the stream has been read. (#882) * Add a starting point for the work * Add draft tests * Support connection `Upgrade` and `CONNECT`. * Update CHANGELOG.md * Remove private state assertions * Add Async prefix * Update CHANGELOG.md Co-authored-by: Tom Christie * Update tests/_async/test_http11.py Co-authored-by: T-256 <132141463+T-256@users.noreply.github.com> --------- Co-authored-by: Tom Christie Co-authored-by: T-256 <132141463+T-256@users.noreply.github.com> Co-authored-by: Tom Christie --- CHANGELOG.md | 4 +++ httpcore/_async/http11.py | 50 +++++++++++++++++++++++++++++++++--- httpcore/_sync/http11.py | 50 +++++++++++++++++++++++++++++++++--- tests/_async/test_http11.py | 51 +++++++++++++++++++++++++++++++++++++ tests/_sync/test_http11.py | 51 +++++++++++++++++++++++++++++++++++++ 5 files changed, 200 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f3fe41e8..4c66d6a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## Unreleased + +- Fix support for connection Upgrade and CONNECT when some data in the stream has been read. (#882) + ## 1.0.3 (February 13th, 2024) - Fix support for async cancellations. (#880) diff --git a/httpcore/_async/http11.py b/httpcore/_async/http11.py index a5eb4808..0493a923 100644 --- a/httpcore/_async/http11.py +++ b/httpcore/_async/http11.py @@ -1,8 +1,10 @@ import enum import logging +import ssl import time from types import TracebackType from typing import ( + Any, AsyncIterable, AsyncIterator, List, @@ -107,6 +109,7 @@ async def handle_async_request(self, request: Request) -> Response: status, reason_phrase, headers, + trailing_data, ) = await self._receive_response_headers(**kwargs) trace.return_value = ( http_version, @@ -115,6 +118,14 @@ async def handle_async_request(self, request: Request) -> Response: headers, ) + network_stream = self._network_stream + + # CONNECT or Upgrade request + if (status == 101) or ( + (request.method == b"CONNECT") and (200 <= status < 300) + ): + network_stream = AsyncHTTP11UpgradeStream(network_stream, trailing_data) + return Response( status=status, headers=headers, @@ -122,7 +133,7 @@ async def handle_async_request(self, request: Request) -> Response: extensions={ "http_version": http_version, "reason_phrase": reason_phrase, - "network_stream": self._network_stream, + "network_stream": network_stream, }, ) except BaseException as exc: @@ -167,7 +178,7 @@ async def _send_event( async def _receive_response_headers( self, request: Request - ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]: + ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], bytes]: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("read", None) @@ -187,7 +198,9 @@ async def _receive_response_headers( # raw header casing, rather than the enforced lowercase headers. headers = event.headers.raw_items() - return http_version, event.status_code, event.reason, headers + trailing_data, _ = self._h11_state.trailing_data + + return http_version, event.status_code, event.reason, headers, trailing_data async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes]: timeouts = request.extensions.get("timeout", {}) @@ -340,3 +353,34 @@ async def aclose(self) -> None: self._closed = True async with Trace("response_closed", logger, self._request): await self._connection._response_closed() + + +class AsyncHTTP11UpgradeStream(AsyncNetworkStream): + def __init__(self, stream: AsyncNetworkStream, leading_data: bytes) -> None: + self._stream = stream + self._leading_data = leading_data + + async def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes: + if self._leading_data: + buffer = self._leading_data[:max_bytes] + self._leading_data = self._leading_data[max_bytes:] + return buffer + else: + return await self._stream.read(max_bytes, timeout) + + async def write(self, buffer: bytes, timeout: Optional[float] = None) -> None: + await self._stream.write(buffer, timeout) + + async def aclose(self) -> None: + await self._stream.aclose() + + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: Optional[str] = None, + timeout: Optional[float] = None, + ) -> AsyncNetworkStream: + return await self._stream.start_tls(ssl_context, server_hostname, timeout) + + def get_extra_info(self, info: str) -> Any: + return self._stream.get_extra_info(info) diff --git a/httpcore/_sync/http11.py b/httpcore/_sync/http11.py index e108f88b..a74ff8e8 100644 --- a/httpcore/_sync/http11.py +++ b/httpcore/_sync/http11.py @@ -1,8 +1,10 @@ import enum import logging +import ssl import time from types import TracebackType from typing import ( + Any, Iterable, Iterator, List, @@ -107,6 +109,7 @@ def handle_request(self, request: Request) -> Response: status, reason_phrase, headers, + trailing_data, ) = self._receive_response_headers(**kwargs) trace.return_value = ( http_version, @@ -115,6 +118,14 @@ def handle_request(self, request: Request) -> Response: headers, ) + network_stream = self._network_stream + + # CONNECT or Upgrade request + if (status == 101) or ( + (request.method == b"CONNECT") and (200 <= status < 300) + ): + network_stream = HTTP11UpgradeStream(network_stream, trailing_data) + return Response( status=status, headers=headers, @@ -122,7 +133,7 @@ def handle_request(self, request: Request) -> Response: extensions={ "http_version": http_version, "reason_phrase": reason_phrase, - "network_stream": self._network_stream, + "network_stream": network_stream, }, ) except BaseException as exc: @@ -167,7 +178,7 @@ def _send_event( def _receive_response_headers( self, request: Request - ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]: + ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], bytes]: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("read", None) @@ -187,7 +198,9 @@ def _receive_response_headers( # raw header casing, rather than the enforced lowercase headers. headers = event.headers.raw_items() - return http_version, event.status_code, event.reason, headers + trailing_data, _ = self._h11_state.trailing_data + + return http_version, event.status_code, event.reason, headers, trailing_data def _receive_response_body(self, request: Request) -> Iterator[bytes]: timeouts = request.extensions.get("timeout", {}) @@ -340,3 +353,34 @@ def close(self) -> None: self._closed = True with Trace("response_closed", logger, self._request): self._connection._response_closed() + + +class HTTP11UpgradeStream(NetworkStream): + def __init__(self, stream: NetworkStream, leading_data: bytes) -> None: + self._stream = stream + self._leading_data = leading_data + + def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes: + if self._leading_data: + buffer = self._leading_data[:max_bytes] + self._leading_data = self._leading_data[max_bytes:] + return buffer + else: + return self._stream.read(max_bytes, timeout) + + def write(self, buffer: bytes, timeout: Optional[float] = None) -> None: + self._stream.write(buffer, timeout) + + def close(self) -> None: + self._stream.close() + + def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: Optional[str] = None, + timeout: Optional[float] = None, + ) -> NetworkStream: + return self._stream.start_tls(ssl_context, server_hostname, timeout) + + def get_extra_info(self, info: str) -> Any: + return self._stream.get_extra_info(info) diff --git a/tests/_async/test_http11.py b/tests/_async/test_http11.py index 489d68b4..94f2febf 100644 --- a/tests/_async/test_http11.py +++ b/tests/_async/test_http11.py @@ -269,6 +269,57 @@ async def test_http11_upgrade_connection(): assert content == b"..." +@pytest.mark.anyio +async def test_http11_upgrade_with_trailing_data(): + """ + HTTP "101 Switching Protocols" indicates an upgraded connection. + + In `CONNECT` and `Upgrade:` requests, we need to handover the trailing data + in the h11.Connection object. + + https://h11.readthedocs.io/en/latest/api.html#switching-protocols + """ + origin = httpcore.Origin(b"wss", b"example.com", 443) + stream = httpcore.AsyncMockStream( + # The first element of this mock network stream buffer simulates networking + # in which response headers and data are received at once. + # This means that "foobar" becomes trailing data. + [ + ( + b"HTTP/1.1 101 Switching Protocols\r\n" + b"Connection: upgrade\r\n" + b"Upgrade: custom\r\n" + b"\r\n" + b"foobar" + ), + b"baz", + ] + ) + async with httpcore.AsyncHTTP11Connection( + origin=origin, stream=stream, keepalive_expiry=5.0 + ) as conn: + async with conn.stream( + "GET", + "wss://example.com/", + headers={"Connection": "upgrade", "Upgrade": "custom"}, + ) as response: + assert response.status == 101 + network_stream = response.extensions["network_stream"] + + content = await network_stream.read(max_bytes=3) + assert content == b"foo" + content = await network_stream.read(max_bytes=3) + assert content == b"bar" + content = await network_stream.read(max_bytes=3) + assert content == b"baz" + + # Lazy tests for AsyncHTTP11UpgradeStream + await network_stream.write(b"spam") + invalid = network_stream.get_extra_info("invalid") + assert invalid is None + await network_stream.aclose() + + @pytest.mark.anyio async def test_http11_early_hints(): """ diff --git a/tests/_sync/test_http11.py b/tests/_sync/test_http11.py index dcd80e84..f2fa28f4 100644 --- a/tests/_sync/test_http11.py +++ b/tests/_sync/test_http11.py @@ -270,6 +270,57 @@ def test_http11_upgrade_connection(): +def test_http11_upgrade_with_trailing_data(): + """ + HTTP "101 Switching Protocols" indicates an upgraded connection. + + In `CONNECT` and `Upgrade:` requests, we need to handover the trailing data + in the h11.Connection object. + + https://h11.readthedocs.io/en/latest/api.html#switching-protocols + """ + origin = httpcore.Origin(b"wss", b"example.com", 443) + stream = httpcore.MockStream( + # The first element of this mock network stream buffer simulates networking + # in which response headers and data are received at once. + # This means that "foobar" becomes trailing data. + [ + ( + b"HTTP/1.1 101 Switching Protocols\r\n" + b"Connection: upgrade\r\n" + b"Upgrade: custom\r\n" + b"\r\n" + b"foobar" + ), + b"baz", + ] + ) + with httpcore.HTTP11Connection( + origin=origin, stream=stream, keepalive_expiry=5.0 + ) as conn: + with conn.stream( + "GET", + "wss://example.com/", + headers={"Connection": "upgrade", "Upgrade": "custom"}, + ) as response: + assert response.status == 101 + network_stream = response.extensions["network_stream"] + + content = network_stream.read(max_bytes=3) + assert content == b"foo" + content = network_stream.read(max_bytes=3) + assert content == b"bar" + content = network_stream.read(max_bytes=3) + assert content == b"baz" + + # Lazy tests for HTTP11UpgradeStream + network_stream.write(b"spam") + invalid = network_stream.get_extra_info("invalid") + assert invalid is None + network_stream.close() + + + def test_http11_early_hints(): """ HTTP "103 Early Hints" is an interim response.