diff --git a/src/client.rs b/src/client.rs index 4147e8a4..e83ef6a4 100644 --- a/src/client.rs +++ b/src/client.rs @@ -510,8 +510,10 @@ where self.inner .send_request(request, end_of_stream, self.pending.as_ref()) .map_err(Into::into) - .map(|stream| { - if stream.is_pending_open() { + .map(|(stream, is_full)| { + if stream.is_pending_open() && is_full { + // Only prevent sending another request when the request queue + // is not full. self.pending = Some(stream.clone_to_opaque()); } diff --git a/src/proto/streams/counts.rs b/src/proto/streams/counts.rs index 6a5aa9cc..add1312e 100644 --- a/src/proto/streams/counts.rs +++ b/src/proto/streams/counts.rs @@ -49,6 +49,14 @@ impl Counts { } } + /// Returns true when the next opened stream will reach capacity of outbound streams + /// + /// The number of client send streams is incremented in prioritize; send_request has to guess if + /// it should wait before allowing another request to be sent. + pub fn next_send_stream_will_reach_capacity(&self) -> bool { + self.max_send_streams <= (self.num_send_streams + 1) + } + /// Returns the current peer pub fn peer(&self) -> peer::Dyn { self.peer diff --git a/src/proto/streams/prioritize.rs b/src/proto/streams/prioritize.rs index 35795fae..3196049a 100644 --- a/src/proto/streams/prioritize.rs +++ b/src/proto/streams/prioritize.rs @@ -520,7 +520,9 @@ impl Prioritize { tracing::trace!("poll_complete"); loop { - self.schedule_pending_open(store, counts); + if let Some(mut stream) = self.pop_pending_open(store, counts) { + self.pending_send.push_front(&mut stream); + } match self.pop_frame(buffer, store, max_frame_len, counts) { Some(frame) => { @@ -874,20 +876,24 @@ impl Prioritize { } } - fn schedule_pending_open(&mut self, store: &mut Store, counts: &mut Counts) { + fn pop_pending_open<'s>( + &mut self, + store: &'s mut Store, + counts: &mut Counts, + ) -> Option> { tracing::trace!("schedule_pending_open"); // check for any pending open streams - while counts.can_inc_num_send_streams() { + if counts.can_inc_num_send_streams() { if let Some(mut stream) = self.pending_open.pop(store) { tracing::trace!("schedule_pending_open; stream={:?}", stream.id); counts.inc_num_send_streams(&mut stream); - self.pending_send.push(&mut stream); stream.notify_send(); - } else { - return; + return Some(stream); } } + + None } } diff --git a/src/proto/streams/send.rs b/src/proto/streams/send.rs index dcb5225c..626e61a3 100644 --- a/src/proto/streams/send.rs +++ b/src/proto/streams/send.rs @@ -143,22 +143,27 @@ impl Send { // Update the state stream.state.send_open(end_stream)?; - if counts.peer().is_local_init(frame.stream_id()) { - // If we're waiting on a PushPromise anyway - // handle potentially queueing the stream at that point - if !stream.is_pending_push { - if counts.can_inc_num_send_streams() { - counts.inc_num_send_streams(stream); - } else { - self.prioritize.queue_open(stream); - } - } + let mut pending_open = false; + if counts.peer().is_local_init(frame.stream_id()) && !stream.is_pending_push { + self.prioritize.queue_open(stream); + pending_open = true; } // Queue the frame for sending + // + // This call expects that, since new streams are in the open queue, new + // streams won't be pushed on pending_send. self.prioritize .queue_frame(frame.into(), buffer, stream, task); + // Need to notify the connection when pushing onto pending_open since + // queue_frame only notifies for pending_send. + if pending_open { + if let Some(task) = task.take() { + task.wake(); + } + } + Ok(()) } diff --git a/src/proto/streams/store.rs b/src/proto/streams/store.rs index d33a01cc..67b377b1 100644 --- a/src/proto/streams/store.rs +++ b/src/proto/streams/store.rs @@ -256,7 +256,7 @@ where /// /// If the stream is already contained by the list, return `false`. pub fn push(&mut self, stream: &mut store::Ptr) -> bool { - tracing::trace!("Queue::push"); + tracing::trace!("Queue::push_back"); if N::is_queued(stream) { tracing::trace!(" -> already queued"); @@ -292,6 +292,46 @@ where true } + /// Queue the stream + /// + /// If the stream is already contained by the list, return `false`. + pub fn push_front(&mut self, stream: &mut store::Ptr) -> bool { + tracing::trace!("Queue::push_front"); + + if N::is_queued(stream) { + tracing::trace!(" -> already queued"); + return false; + } + + N::set_queued(stream, true); + + // The next pointer shouldn't be set + debug_assert!(N::next(stream).is_none()); + + // Queue the stream + match self.indices { + Some(ref mut idxs) => { + tracing::trace!(" -> existing entries"); + + // Update the provided stream to point to the head node + let head_key = stream.resolve(idxs.head).key(); + N::set_next(stream, Some(head_key)); + + // Update the head pointer + idxs.head = stream.key(); + } + None => { + tracing::trace!(" -> first entry"); + self.indices = Some(store::Indices { + head: stream.key(), + tail: stream.key(), + }); + } + } + + true + } + pub fn pop<'a, R>(&mut self, store: &'a mut R) -> Option> where R: Resolve, diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index 02a0f61b..274bf455 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -216,7 +216,7 @@ where mut request: Request<()>, end_of_stream: bool, pending: Option<&OpaqueStreamRef>, - ) -> Result, SendError> { + ) -> Result<(StreamRef, bool), SendError> { use super::stream::ContentLength; use http::Method; @@ -298,10 +298,14 @@ where // the lock, so it can't. me.refs += 1; - Ok(StreamRef { - opaque: OpaqueStreamRef::new(self.inner.clone(), &mut stream), - send_buffer: self.send_buffer.clone(), - }) + let is_full = me.counts.next_send_stream_will_reach_capacity(); + Ok(( + StreamRef { + opaque: OpaqueStreamRef::new(self.inner.clone(), &mut stream), + send_buffer: self.send_buffer.clone(), + }, + is_full, + )) } pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool { diff --git a/tests/h2-tests/tests/client_request.rs b/tests/h2-tests/tests/client_request.rs index 258826d1..7b431600 100644 --- a/tests/h2-tests/tests/client_request.rs +++ b/tests/h2-tests/tests/client_request.rs @@ -239,6 +239,8 @@ async fn request_over_max_concurrent_streams_errors() { // first request is allowed let (resp1, mut stream1) = client.send_request(request, false).unwrap(); + // as long as we let the connection internals tick + client = h2.drive(client.ready()).await.unwrap(); let request = Request::builder() .method(Method::POST) @@ -284,6 +286,90 @@ async fn request_over_max_concurrent_streams_errors() { join(srv, h2).await; } +#[tokio::test] +async fn recv_decrement_max_concurrent_streams_when_requests_queued() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( + frames::headers(1) + .request("POST", "https://example.com/") + .eos(), + ) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + + srv.ping_pong([0; 8]).await; + + // limit this server later in life + srv.send_frame(frames::settings().max_concurrent_streams(1)) + .await; + srv.recv_frame(frames::settings_ack()).await; + srv.recv_frame( + frames::headers(3) + .request("POST", "https://example.com/") + .eos(), + ) + .await; + srv.ping_pong([1; 8]).await; + srv.send_frame(frames::headers(3).response(200).eos()).await; + + srv.recv_frame( + frames::headers(5) + .request("POST", "https://example.com/") + .eos(), + ) + .await; + srv.send_frame(frames::headers(5).response(200).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.expect("handshake"); + // we send a simple req here just to drive the connection so we can + // receive the server settings. + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); + // first request is allowed + let (response, _) = client.send_request(request, true).unwrap(); + h2.drive(response).await.unwrap(); + + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); + + // first request is allowed + let (resp1, _) = client.send_request(request, true).unwrap(); + + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); + + // second request is put into pending_open + let (resp2, _) = client.send_request(request, true).unwrap(); + + h2.drive(async move { + resp1.await.expect("req"); + }) + .await; + join(async move { h2.await.unwrap() }, async move { + resp2.await.unwrap() + }) + .await; + }; + + join(srv, h2).await; +} + #[tokio::test] async fn send_request_poll_ready_when_connection_error() { h2_support::trace_init!(); @@ -336,6 +422,8 @@ async fn send_request_poll_ready_when_connection_error() { // first request is allowed let (resp1, _) = client.send_request(request, true).unwrap(); + // as long as we let the connection internals tick + client = h2.drive(client.ready()).await.unwrap(); let request = Request::builder() .method(Method::POST) diff --git a/tests/h2-tests/tests/server.rs b/tests/h2-tests/tests/server.rs index 33e08c19..6075c7dc 100644 --- a/tests/h2-tests/tests/server.rs +++ b/tests/h2-tests/tests/server.rs @@ -296,10 +296,10 @@ async fn push_request_against_concurrency() { .await; client.recv_frame(frames::data(2, &b""[..]).eos()).await; client - .recv_frame(frames::headers(1).response(200).eos()) + .recv_frame(frames::headers(4).response(200).eos()) .await; client - .recv_frame(frames::headers(4).response(200).eos()) + .recv_frame(frames::headers(1).response(200).eos()) .await; };