diff --git a/tests/_async/test_http2.py b/tests/_async/test_http2.py index 95def84c..c719dbed 100644 --- a/tests/_async/test_http2.py +++ b/tests/_async/test_http2.py @@ -294,3 +294,55 @@ async def test_http2_request_to_incorrect_origin(): async with AsyncHTTP2Connection(origin=origin, stream=stream) as conn: with pytest.raises(RuntimeError): await conn.request("GET", "https://other.com/") + + +@pytest.mark.anyio +async def test_http2_remote_max_streams_update(): + """ + If the remote server updates the maximum concurrent streams value, we should + be adjusting how many streams we will allow. + """ + origin = Origin(b"https", b"example.com", 443) + stream = AsyncMockStream( + [ + hyperframe.frame.SettingsFrame( + settings={hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 1000} + ).serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame(stream_id=1, data=b"Hello, world!").serialize(), + hyperframe.frame.SettingsFrame( + settings={hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 50} + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world...again!", flags=["END_STREAM"] + ).serialize(), + ] + ) + async with AsyncHTTP2Connection(origin=origin, stream=stream) as conn: + async with conn.stream("GET", "https://example.com/") as response: + i = 0 + async for chunk in response.aiter_stream(): + if i == 0: + assert chunk == b"Hello, world!" + assert conn._h2_state.remote_settings.max_concurrent_streams == 1000 + assert conn._max_streams == min( + conn._h2_state.remote_settings.max_concurrent_streams, + conn._h2_state.local_settings.max_concurrent_streams, + ) + elif i == 1: + assert chunk == b"Hello, world...again!" + assert conn._h2_state.remote_settings.max_concurrent_streams == 50 + assert conn._max_streams == min( + conn._h2_state.remote_settings.max_concurrent_streams, + conn._h2_state.local_settings.max_concurrent_streams, + ) + i += 1 diff --git a/tests/_sync/test_http2.py b/tests/_sync/test_http2.py index 08044419..e17dfa9e 100644 --- a/tests/_sync/test_http2.py +++ b/tests/_sync/test_http2.py @@ -294,3 +294,55 @@ def test_http2_request_to_incorrect_origin(): with HTTP2Connection(origin=origin, stream=stream) as conn: with pytest.raises(RuntimeError): conn.request("GET", "https://other.com/") + + + +def test_http2_remote_max_streams_update(): + """ + If the remote server updates the maximum concurrent streams value, we should + be adjusting how many streams we will allow. + """ + origin = Origin(b"https", b"example.com", 443) + stream = MockStream( + [ + hyperframe.frame.SettingsFrame( + settings={hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 1000} + ).serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame(stream_id=1, data=b"Hello, world!").serialize(), + hyperframe.frame.SettingsFrame( + settings={hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 50} + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world...again!", flags=["END_STREAM"] + ).serialize(), + ] + ) + with HTTP2Connection(origin=origin, stream=stream) as conn: + with conn.stream("GET", "https://example.com/") as response: + i = 0 + for chunk in response.iter_stream(): + if i == 0: + assert chunk == b"Hello, world!" + assert conn._h2_state.remote_settings.max_concurrent_streams == 1000 + assert conn._max_streams == min( + conn._h2_state.remote_settings.max_concurrent_streams, + conn._h2_state.local_settings.max_concurrent_streams, + ) + elif i == 1: + assert chunk == b"Hello, world...again!" + assert conn._h2_state.remote_settings.max_concurrent_streams == 50 + assert conn._max_streams == min( + conn._h2_state.remote_settings.max_concurrent_streams, + conn._h2_state.local_settings.max_concurrent_streams, + ) + i += 1