From ef743ecb2243786c0573b9fe726290878359689b Mon Sep 17 00:00:00 2001 From: 4JX <4JXcYvmyu3Hz8fV@protonmail.com> Date: Mon, 30 Oct 2023 16:35:36 +0100 Subject: [PATCH] Add a setter for header_table_size (#638) --- src/client.rs | 33 +++++++++++++++++++++++++ src/codec/framed_read.rs | 6 +++++ src/codec/mod.rs | 5 ++++ src/frame/settings.rs | 2 -- src/proto/settings.rs | 4 +++ tests/h2-support/src/frames.rs | 5 ++++ tests/h2-tests/tests/client_request.rs | 34 ++++++++++++++++++++++++++ 7 files changed, 87 insertions(+), 2 deletions(-) diff --git a/src/client.rs b/src/client.rs index b329121a..35cfc141 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1072,6 +1072,39 @@ impl Builder { self } + /// Sets the header table size. + /// + /// This setting informs the peer of the maximum size of the header compression + /// table used to encode header blocks, in octets. The encoder may select any value + /// equal to or less than the header table size specified by the sender. + /// + /// The default value is 4,096. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::{AsyncRead, AsyncWrite}; + /// # use h2::client::*; + /// # use bytes::Bytes; + /// # + /// # async fn doc(my_io: T) + /// # -> Result<((SendRequest, Connection)), h2::Error> + /// # { + /// // `client_fut` is a future representing the completion of the HTTP/2 + /// // handshake. + /// let client_fut = Builder::new() + /// .header_table_size(1_000_000) + /// .handshake(my_io); + /// # client_fut.await + /// # } + /// # + /// # pub fn main() {} + /// ``` + pub fn header_table_size(&mut self, size: u32) -> &mut Self { + self.settings.set_header_table_size(Some(size)); + self + } + /// Sets the first stream ID to something other than 1. #[cfg(feature = "unstable")] pub fn initial_stream_id(&mut self, stream_id: u32) -> &mut Self { diff --git a/src/codec/framed_read.rs b/src/codec/framed_read.rs index a874d773..3b0030d9 100644 --- a/src/codec/framed_read.rs +++ b/src/codec/framed_read.rs @@ -88,6 +88,12 @@ impl FramedRead { pub fn set_max_header_list_size(&mut self, val: usize) { self.max_header_list_size = val; } + + /// Update the header table size setting. + #[inline] + pub fn set_header_table_size(&mut self, val: usize) { + self.hpack.queue_size_update(val); + } } /// Decodes a frame. diff --git a/src/codec/mod.rs b/src/codec/mod.rs index 359adf6e..6cbdc1e1 100644 --- a/src/codec/mod.rs +++ b/src/codec/mod.rs @@ -95,6 +95,11 @@ impl Codec { self.framed_write().set_header_table_size(val) } + /// Set the decoder header table size size. + pub fn set_recv_header_table_size(&mut self, val: usize) { + self.inner.set_header_table_size(val) + } + /// Set the max header list size that can be received. pub fn set_max_recv_header_list_size(&mut self, val: usize) { self.inner.set_max_header_list_size(val); diff --git a/src/frame/settings.rs b/src/frame/settings.rs index 0c913f05..484498a9 100644 --- a/src/frame/settings.rs +++ b/src/frame/settings.rs @@ -121,11 +121,9 @@ impl Settings { self.header_table_size } - /* pub fn set_header_table_size(&mut self, size: Option) { self.header_table_size = size; } - */ pub fn load(head: Head, payload: &[u8]) -> Result { use self::Setting::*; diff --git a/src/proto/settings.rs b/src/proto/settings.rs index 6cc61720..28065cc6 100644 --- a/src/proto/settings.rs +++ b/src/proto/settings.rs @@ -60,6 +60,10 @@ impl Settings { codec.set_max_recv_header_list_size(max as usize); } + if let Some(val) = local.header_table_size() { + codec.set_recv_header_table_size(val as usize); + } + streams.apply_local_settings(local)?; self.local = Local::Synced; Ok(()) diff --git a/tests/h2-support/src/frames.rs b/tests/h2-support/src/frames.rs index d302d3ce..a76dd3b6 100644 --- a/tests/h2-support/src/frames.rs +++ b/tests/h2-support/src/frames.rs @@ -391,6 +391,11 @@ impl Mock { self.0.set_enable_connect_protocol(Some(val)); self } + + pub fn header_table_size(mut self, val: u32) -> Self { + self.0.set_header_table_size(Some(val)); + self + } } impl From> for frame::Settings { diff --git a/tests/h2-tests/tests/client_request.rs b/tests/h2-tests/tests/client_request.rs index 7b431600..88c7df46 100644 --- a/tests/h2-tests/tests/client_request.rs +++ b/tests/h2-tests/tests/client_request.rs @@ -1627,6 +1627,40 @@ async fn rogue_server_reused_headers() { join(srv, h2).await; } +#[tokio::test] +async fn client_builder_header_table_size() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + let mut settings = frame::Settings::default(); + + settings.set_header_table_size(Some(10000)); + + let srv = async move { + let recv_settings = srv.assert_client_handshake().await; + assert_frame_eq(recv_settings, settings); + + srv.recv_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + }; + + let mut builder = client::Builder::new(); + builder.header_table_size(10000); + + let h2 = async move { + let (mut client, mut h2) = builder.handshake::<_, Bytes>(io).await.unwrap(); + let request = Request::get("https://example.com/").body(()).unwrap(); + let (response, _) = client.send_request(request, true).unwrap(); + h2.drive(response).await.unwrap(); + }; + + join(srv, h2).await; +} + const SETTINGS: &[u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; const SETTINGS_ACK: &[u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0];