Skip to content

Commit

Permalink
feat(s2n-quic-platform): wire up tokio sockets to ring
Browse files Browse the repository at this point in the history
  • Loading branch information
camshaft committed Jun 5, 2023
1 parent d6867fb commit b61c15a
Show file tree
Hide file tree
Showing 12 changed files with 378 additions and 40 deletions.
1 change: 1 addition & 0 deletions quic/s2n-quic-platform/src/io/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use tokio::{net::UdpSocket, runtime::Handle};
pub type PathHandle = socket::Handle;

mod clock;
mod task;
pub(crate) use clock::Clock;

impl crate::socket::std::Socket for UdpSocket {
Expand Down
47 changes: 47 additions & 0 deletions quic/s2n-quic-platform/src/io/tokio/task.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

// depending on the platform, some of these implementations aren't used
#![allow(dead_code)]

mod simple;
#[cfg(unix)]
mod unix;

cfg_if::cfg_if! {
if #[cfg(s2n_quic_platform_socket_mmsg)] {
pub use mmsg::{rx, tx};
} else if #[cfg(s2n_quic_platform_socket_msg)] {
pub use msg::{rx, tx};
} else {
pub use simple::{rx, tx};
}
}

macro_rules! libc_msg {
($message:ident, $cfg:ident) => {
#[cfg($cfg)]
mod $message {
use super::unix;
use crate::{features::Gso, message::$message::Message, socket::ring};

pub async fn rx<S: Into<std::net::UdpSocket>>(
socket: S,
producer: ring::Producer<Message>,
) -> std::io::Result<()> {
unix::rx(socket, producer).await
}

pub async fn tx<S: Into<std::net::UdpSocket>>(
socket: S,
consumer: ring::Consumer<Message>,
gso: Gso,
) -> std::io::Result<()> {
unix::tx(socket, consumer, gso).await
}
}
};
}

libc_msg!(msg, s2n_quic_platform_socket_msg);
libc_msg!(mmsg, s2n_quic_platform_socket_mmsg);
127 changes: 127 additions & 0 deletions quic/s2n-quic-platform/src/io/tokio/task/simple.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use crate::{
features::Gso,
message::{simple::Message, Message as _},
socket::{
ring, task,
task::{rx, tx},
},
syscall::SocketEvents,
};
use core::task::{Context, Poll};
use tokio::{io, net::UdpSocket};

pub async fn rx<S: Into<std::net::UdpSocket>>(
socket: S,
producer: ring::Producer<Message>,
) -> io::Result<()> {
let socket = socket.into();
socket.set_nonblocking(true).unwrap();

let socket = UdpSocket::from_std(socket).unwrap();
let result = task::Receiver::new(producer, socket).await;
if let Some(err) = result {
Err(err)
} else {
Ok(())
}
}

pub async fn tx<S: Into<std::net::UdpSocket>>(
socket: S,
consumer: ring::Consumer<Message>,
gso: Gso,
) -> io::Result<()> {
let socket = socket.into();
socket.set_nonblocking(true).unwrap();

let socket = UdpSocket::from_std(socket).unwrap();
let result = task::Sender::new(consumer, socket, gso).await;
if let Some(err) = result {
Err(err)
} else {
Ok(())
}
}

impl tx::Socket<Message> for UdpSocket {
type Error = io::Error;

#[inline]
fn send(
&mut self,
cx: &mut Context,
entries: &mut [Message],
events: &mut tx::Events,
) -> io::Result<()> {
let mut index = 0;
while let Some(entry) = entries.get_mut(index) {
let target = (*entry.remote_address()).into();
let payload = entry.payload_mut();
match self.poll_send_to(cx, payload, target) {
Poll::Ready(Ok(_)) => {
index += 1;
if events.on_complete(1).is_break() {
return Ok(());
}
}
Poll::Ready(Err(err)) => {
if events.on_error(err).is_break() {
return Ok(());
}
}
Poll::Pending => {
events.blocked();
break;
}
}
}

Ok(())
}
}

impl rx::Socket<Message> for UdpSocket {
type Error = io::Error;

#[inline]
fn recv(
&mut self,
cx: &mut Context,
entries: &mut [Message],
events: &mut rx::Events,
) -> io::Result<()> {
let mut index = 0;
while let Some(entry) = entries.get_mut(index) {
let payload = entry.payload_mut();
let mut buf = io::ReadBuf::new(payload);
match self.poll_recv_from(cx, &mut buf) {
Poll::Ready(Ok(addr)) => {
unsafe {
let len = buf.filled().len();
entry.set_payload_len(len);
}
entry.set_remote_address(&(addr.into()));

index += 1;
if events.on_complete(1).is_break() {
return Ok(());
}
}
Poll::Ready(Err(err)) => {
if events.on_error(err).is_break() {
return Ok(());
}
}
Poll::Pending => {
events.blocked();
break;
}
}
}

Ok(())
}
}
124 changes: 124 additions & 0 deletions quic/s2n-quic-platform/src/io/tokio/task/unix.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use crate::{
features::Gso,
socket::{
ring,
task::{rx, tx},
},
syscall::{SocketType, UnixMessage},
};
use core::task::{Context, Poll};
use std::{io, os::unix::io::AsRawFd};
use tokio::io::unix::AsyncFd;

pub async fn rx<S: Into<std::net::UdpSocket>, M: UnixMessage + Unpin>(
socket: S,
producer: ring::Producer<M>,
) -> io::Result<()> {
let socket = socket.into();
socket.set_nonblocking(true).unwrap();

let socket = AsyncFd::new(socket).unwrap();
let result = rx::Receiver::new(producer, socket).await;
if let Some(err) = result {
Err(err)
} else {
Ok(())
}
}

pub async fn tx<S: Into<std::net::UdpSocket>, M: UnixMessage + Unpin>(
socket: S,
consumer: ring::Consumer<M>,
gso: Gso,
) -> io::Result<()> {
let socket = socket.into();
socket.set_nonblocking(true).unwrap();

let socket = AsyncFd::new(socket).unwrap();
let result = tx::Sender::new(consumer, socket, gso).await;
if let Some(err) = result {
Err(err)
} else {
Ok(())
}
}

impl<S: AsRawFd, M: UnixMessage> tx::Socket<M> for AsyncFd<S> {
type Error = io::Error;

#[inline]
fn send(
&mut self,
cx: &mut Context,
entries: &mut [M],
events: &mut tx::Events,
) -> io::Result<()> {
M::send(self.get_ref().as_raw_fd(), entries, events);

if !events.is_blocked() {
return Ok(());
}

for i in 0..2 {
match self.poll_write_ready(cx) {
Poll::Ready(guard) => {
let mut guard = guard?;
if i == 0 {
guard.clear_ready();
} else {
events.take_blocked();
}
}
Poll::Pending => {
return Ok(());
}
}
}

Ok(())
}
}

impl<S: AsRawFd, M: UnixMessage> rx::Socket<M> for AsyncFd<S> {
type Error = io::Error;

#[inline]
fn recv(
&mut self,
cx: &mut Context,
entries: &mut [M],
events: &mut rx::Events,
) -> io::Result<()> {
M::recv(
self.get_ref().as_raw_fd(),
SocketType::NonBlocking,
entries,
events,
);

if !events.is_blocked() {
return Ok(());
}

for i in 0..2 {
match self.poll_read_ready(cx) {
Poll::Ready(guard) => {
let mut guard = guard?;
if i == 0 {
guard.clear_ready();
} else {
events.take_blocked();
}
}
Poll::Pending => {
return Ok(());
}
}
}

Ok(())
}
}
4 changes: 2 additions & 2 deletions quic/s2n-quic-platform/src/message/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ impl Message {
ExplicitCongestionNotification::default()
}

pub(crate) fn remote_address(&self) -> Option<SocketAddress> {
Some(self.address)
pub(crate) fn remote_address(&self) -> &SocketAddress {
&self.address
}

pub(crate) fn set_remote_address(&mut self, remote_address: &SocketAddress) {
Expand Down
36 changes: 17 additions & 19 deletions quic/s2n-quic-platform/src/socket/std.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,28 +98,26 @@ impl<B: Buffer> Queue<B> {
let mut entries = self.0.occupied_mut();

for entry in entries.as_mut() {
if let Some(remote_address) = entry.remote_address() {
match socket.send_to(entry.payload_mut(), &remote_address) {
Ok(_) => {
count += 1;
let remote_address = *entry.remote_address();
match socket.send_to(entry.payload_mut(), &remote_address) {
Ok(_) => {
count += 1;

publisher.on_platform_tx(event::builder::PlatformTx { count: 1 });
}
Err(err) if count > 0 && err.would_block() => {
break;
}
Err(err) if err.was_interrupted() || err.permission_denied() => {
break;
}
Err(err) => {
entries.finish(count);
publisher.on_platform_tx(event::builder::PlatformTx { count: 1 });
}
Err(err) if count > 0 && err.would_block() => {
break;
}
Err(err) if err.was_interrupted() || err.permission_denied() => {
break;
}
Err(err) => {
entries.finish(count);

publisher.on_platform_tx_error(event::builder::PlatformTxError {
errno: errno().0,
});
publisher
.on_platform_tx_error(event::builder::PlatformTxError { errno: errno().0 });

return Err(err);
}
return Err(err);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion quic/s2n-quic-platform/src/socket/task.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

pub mod events;
mod events;
pub mod rx;
pub mod tx;

Expand Down
Loading

0 comments on commit b61c15a

Please sign in to comment.