From ad7a7e330b50635fe3e69ba53c094bfc5497bdad Mon Sep 17 00:00:00 2001 From: Zanie Adkins Date: Fri, 12 May 2023 07:54:06 -0500 Subject: [PATCH] Add check for `h2.connection.ConnectionState.CLOSED` in `AsyncHTTP2Connection.is_available` (#679) * Add check for `h2.connection.ConnectionState.CLOSED` in `AsyncHTTP2Connection.is_available` * Add sync implementation * Add test for closed connection * Regenerate sync tests with `unasync` * Use async with * Add anyio annotation --------- Co-authored-by: Tom Christie --- httpcore/_async/http2.py | 4 ++++ httpcore/_sync/http2.py | 4 ++++ tests/_async/test_http2.py | 34 ++++++++++++++++++++++++++++++++++ tests/_sync/test_http2.py | 34 ++++++++++++++++++++++++++++++++++ 4 files changed, 76 insertions(+) diff --git a/httpcore/_async/http2.py b/httpcore/_async/http2.py index 035cbcd3..fa8062ab 100644 --- a/httpcore/_async/http2.py +++ b/httpcore/_async/http2.py @@ -433,6 +433,10 @@ def is_available(self) -> bool: self._state != HTTPConnectionState.CLOSED and not self._connection_error and not self._used_all_stream_ids + and not ( + self._h2_state.state_machine.state + == h2.connection.ConnectionState.CLOSED + ) ) def has_expired(self) -> bool: diff --git a/httpcore/_sync/http2.py b/httpcore/_sync/http2.py index 35426be5..8e2f55e0 100644 --- a/httpcore/_sync/http2.py +++ b/httpcore/_sync/http2.py @@ -433,6 +433,10 @@ def is_available(self) -> bool: self._state != HTTPConnectionState.CLOSED and not self._connection_error and not self._used_all_stream_ids + and not ( + self._h2_state.state_machine.state + == h2.connection.ConnectionState.CLOSED + ) ) def has_expired(self) -> bool: diff --git a/tests/_async/test_http2.py b/tests/_async/test_http2.py index cad89c05..f995465f 100644 --- a/tests/_async/test_http2.py +++ b/tests/_async/test_http2.py @@ -52,6 +52,40 @@ async def test_http2_connection(): ) +@pytest.mark.anyio +async def test_http2_connection_closed(): + origin = Origin(b"https", b"example.com", 443) + stream = AsyncMockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + # Connection is closed after the first response + hyperframe.frame.GoAwayFrame(stream_id=0, error_code=0).serialize(), + ] + ) + async with AsyncHTTP2Connection( + origin=origin, stream=stream, keepalive_expiry=5.0 + ) as conn: + await conn.request("GET", "https://example.com/") + + with pytest.raises(RemoteProtocolError): + await conn.request("GET", "https://example.com/") + + assert not conn.is_available() + + @pytest.mark.anyio async def test_http2_connection_post_request(): origin = Origin(b"https", b"example.com", 443) diff --git a/tests/_sync/test_http2.py b/tests/_sync/test_http2.py index 1c959434..0adb96ef 100644 --- a/tests/_sync/test_http2.py +++ b/tests/_sync/test_http2.py @@ -53,6 +53,40 @@ def test_http2_connection(): +def test_http2_connection_closed(): + origin = Origin(b"https", b"example.com", 443) + stream = MockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + # Connection is closed after the first response + hyperframe.frame.GoAwayFrame(stream_id=0, error_code=0).serialize(), + ] + ) + with HTTP2Connection( + origin=origin, stream=stream, keepalive_expiry=5.0 + ) as conn: + conn.request("GET", "https://example.com/") + + with pytest.raises(RemoteProtocolError): + conn.request("GET", "https://example.com/") + + assert not conn.is_available() + + + def test_http2_connection_post_request(): origin = Origin(b"https", b"example.com", 443) stream = MockStream(