Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add AcceptAsync cancellation overloads #53340

Merged
merged 4 commits into from
May 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public Task WaitForConnectionAsync(CancellationToken cancellationToken)
WaitForConnectionAsyncCore();

async Task WaitForConnectionAsyncCore() =>
HandleAcceptedSocket(await _instance!.ListeningSocket.AcceptAsync().ConfigureAwait(false));
HandleAcceptedSocket(await _instance!.ListeningSocket.AcceptAsync(cancellationToken).ConfigureAwait(false));
}

private void HandleAcceptedSocket(Socket acceptedSocket)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,10 @@ public async Task CancelTokenOn_ServerWaitForConnectionAsync_Throws_OperationCan

var ctx = new CancellationTokenSource();

if (OperatingSystem.IsWindows()) // cancellation token after the operation has been initiated
{
Task serverWaitTimeout = server.WaitForConnectionAsync(ctx.Token);
ctx.Cancel();
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => serverWaitTimeout);
}

Task serverWaitTimeout = server.WaitForConnectionAsync(ctx.Token);
ctx.Cancel();
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => serverWaitTimeout);

Assert.True(server.WaitForConnectionAsync(ctx.Token).IsCanceled);
}

Expand Down
4 changes: 4 additions & 0 deletions src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,9 @@ public Socket(System.Net.Sockets.SocketType socketType, System.Net.Sockets.Proto
public bool UseOnlyOverlappedIO { get { throw null; } set { } }
public System.Net.Sockets.Socket Accept() { throw null; }
public System.Threading.Tasks.Task<System.Net.Sockets.Socket> AcceptAsync() { throw null; }
public System.Threading.Tasks.ValueTask<System.Net.Sockets.Socket> AcceptAsync(System.Threading.CancellationToken cancellationToken) { throw null; }
public System.Threading.Tasks.Task<System.Net.Sockets.Socket> AcceptAsync(System.Net.Sockets.Socket? acceptSocket) { throw null; }
public System.Threading.Tasks.ValueTask<System.Net.Sockets.Socket> AcceptAsync(System.Net.Sockets.Socket? acceptSocket, System.Threading.CancellationToken cancellationToken) { throw null; }
public bool AcceptAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
public System.IAsyncResult BeginAccept(System.AsyncCallback? callback, object? state) { throw null; }
public System.IAsyncResult BeginAccept(int receiveSize, System.AsyncCallback? callback, object? state) { throw null; }
Expand Down Expand Up @@ -691,8 +693,10 @@ public TcpListener(System.Net.IPEndPoint localEP) { }
public System.Net.Sockets.Socket Server { get { throw null; } }
public System.Net.Sockets.Socket AcceptSocket() { throw null; }
public System.Threading.Tasks.Task<System.Net.Sockets.Socket> AcceptSocketAsync() { throw null; }
public System.Threading.Tasks.ValueTask<System.Net.Sockets.Socket> AcceptSocketAsync(System.Threading.CancellationToken cancellationToken) { throw null; }
public System.Net.Sockets.TcpClient AcceptTcpClient() { throw null; }
public System.Threading.Tasks.Task<System.Net.Sockets.TcpClient> AcceptTcpClientAsync() { throw null; }
public System.Threading.Tasks.ValueTask<System.Net.Sockets.TcpClient> AcceptTcpClientAsync(System.Threading.CancellationToken cancellationToken) { throw null; }
[System.Runtime.Versioning.SupportedOSPlatformAttribute("windows")]
public void AllowNatTraversal(bool allowed) { }
public System.IAsyncResult BeginAcceptSocket(System.AsyncCallback? callback, object? state) { throw null; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,9 @@ namespace System.Net.Sockets
{
public partial class Socket
{
/// <summary>Cached instance for accept operations.</summary>
private TaskSocketAsyncEventArgs<Socket>? _acceptEventArgs;

/// <summary>Cached instance for receive operations that return <see cref="ValueTask{Int32}"/>. Also used for ConnectAsync operations.</summary>
private AwaitableSocketAsyncEventArgs? _singleBufferReceiveEventArgs;
/// <summary>Cached instance for send operations that return <see cref="ValueTask{Int32}"/>.</summary>
/// <summary>Cached instance for send operations that return <see cref="ValueTask{Int32}"/>. Also used for AcceptAsync operations.</summary>
private AwaitableSocketAsyncEventArgs? _singleBufferSendEventArgs;

/// <summary>Cached instance for receive operations that return <see cref="Task{Int32}"/>.</summary>
Expand All @@ -32,54 +29,44 @@ public partial class Socket
/// Accepts an incoming connection.
/// </summary>
/// <returns>An asynchronous task that completes with the accepted Socket.</returns>
public Task<Socket> AcceptAsync() => AcceptAsync((Socket?)null);
public Task<Socket> AcceptAsync() => AcceptAsync((Socket?)null, CancellationToken.None).AsTask();

/// <summary>
/// Accepts an incoming connection.
/// </summary>
/// <param name="acceptSocket">The socket to use for accepting the connection.</param>
/// <param name="cancellationToken">A cancellation token that can be used to cancel the asynchronous operation.</param>
/// <returns>An asynchronous task that completes with the accepted Socket.</returns>
public Task<Socket> AcceptAsync(Socket? acceptSocket)
{
// Get any cached SocketAsyncEventArg we may have.
TaskSocketAsyncEventArgs<Socket>? saea = Interlocked.Exchange(ref _acceptEventArgs, null);
if (saea is null)
{
saea = new TaskSocketAsyncEventArgs<Socket>();
saea.Completed += (s, e) => CompleteAccept((Socket)s!, (TaskSocketAsyncEventArgs<Socket>)e);
}
public ValueTask<Socket> AcceptAsync(CancellationToken cancellationToken) => AcceptAsync((Socket?)null, cancellationToken);

// Configure the SAEA.
saea.AcceptSocket = acceptSocket;
/// <summary>
/// Accepts an incoming connection.
/// </summary>
/// <param name="acceptSocket">The socket to use for accepting the connection.</param>
/// <returns>An asynchronous task that completes with the accepted Socket.</returns>
public Task<Socket> AcceptAsync(Socket? acceptSocket) => AcceptAsync(acceptSocket, CancellationToken.None).AsTask();

// Initiate the accept operation.
Task<Socket> t;
if (AcceptAsync(saea))
/// <summary>
/// Accepts an incoming connection.
/// </summary>
/// <param name="acceptSocket">The socket to use for accepting the connection.</param>
/// <param name="cancellationToken">A cancellation token that can be used to cancel the asynchronous operation.</param>
/// <returns>An asynchronous task that completes with the accepted Socket.</returns>
public ValueTask<Socket> AcceptAsync(Socket? acceptSocket, CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
{
// The operation is completing asynchronously (it may have already completed).
// Get the task for the operation, with appropriate synchronization to coordinate
// with the async callback that'll be completing the task.
bool responsibleForReturningToPool;
t = saea.GetCompletionResponsibility(out responsibleForReturningToPool).Task;
if (responsibleForReturningToPool)
{
// We're responsible for returning it only if the callback has already been invoked
// and gotten what it needs from the SAEA; otherwise, the callback will return it.
ReturnSocketAsyncEventArgs(saea);
}
return ValueTask.FromCanceled<Socket>(cancellationToken);
}
else
{
// The operation completed synchronously. Get a task for it.
t = saea.SocketError == SocketError.Success ?
Task.FromResult(saea.AcceptSocket!) :
Task.FromException<Socket>(GetException(saea.SocketError));

// There won't be a callback, and we're done with the SAEA, so return it to the pool.
ReturnSocketAsyncEventArgs(saea);
}
AwaitableSocketAsyncEventArgs saea =
Interlocked.Exchange(ref _singleBufferSendEventArgs, null) ??
new AwaitableSocketAsyncEventArgs(this, isReceiveForCaching: false);

return t;
Debug.Assert(saea.BufferList == null);
saea.SetBuffer(null, 0, 0);
saea.AcceptSocket = acceptSocket;
saea.WrapExceptionsForNetworkStream = false;
return saea.AcceptAsync(this, cancellationToken);
}

/// <summary>
Expand Down Expand Up @@ -738,34 +725,6 @@ private Task<int> GetTaskForSendReceive(bool pending, TaskSocketAsyncEventArgs<i
return t;
}

/// <summary>Completes the SocketAsyncEventArg's Task with the result of the send or receive, and returns it to the specified pool.</summary>
private static void CompleteAccept(Socket s, TaskSocketAsyncEventArgs<Socket> saea)
{
// Pull the relevant state off of the SAEA
SocketError error = saea.SocketError;
Socket? acceptSocket = saea.AcceptSocket;

// Synchronize with the initiating thread. If the synchronous caller already got what
// it needs from the SAEA, then we can return it to the pool now. Otherwise, it'll be
// responsible for returning it once it's gotten what it needs from it.
bool responsibleForReturningToPool;
AsyncTaskMethodBuilder<Socket> builder = saea.GetCompletionResponsibility(out responsibleForReturningToPool);
if (responsibleForReturningToPool)
{
s.ReturnSocketAsyncEventArgs(saea);
}

// Complete the builder/task with the results.
if (error == SocketError.Success)
{
builder.SetResult(acceptSocket!);
}
else
{
builder.SetException(GetException(error));
}
}

/// <summary>Completes the SocketAsyncEventArg's Task with the result of the send or receive, and returns it to the specified pool.</summary>
private static void CompleteSendReceive(Socket s, TaskSocketAsyncEventArgs<int> saea, bool isReceive)
{
Expand Down Expand Up @@ -824,29 +783,9 @@ private void ReturnSocketAsyncEventArgs(TaskSocketAsyncEventArgs<int> saea, bool
}
}

/// <summary>Returns a <see cref="TaskSocketAsyncEventArgs{TResult}"/> instance for reuse.</summary>
/// <param name="saea">The instance to return.</param>
private void ReturnSocketAsyncEventArgs(TaskSocketAsyncEventArgs<Socket> saea)
{
// Reset state on the SAEA before returning it. But do not reset buffer state. That'll be done
// if necessary by the consumer, but we want to keep the buffers due to likely subsequent reuse
// and the costs associated with changing them.
saea.AcceptSocket = null;
saea._accessed = false;
saea._builder = default;

// Write this instance back as a cached instance, only if there isn't currently one cached.
if (Interlocked.CompareExchange(ref _acceptEventArgs, saea, null) != null)
{
// Couldn't return it, so dispose it.
saea.Dispose();
}
}

/// <summary>Dispose of any cached <see cref="TaskSocketAsyncEventArgs{TResult}"/> instances.</summary>
private void DisposeCachedTaskSocketAsyncEventArgs()
{
Interlocked.Exchange(ref _acceptEventArgs, null)?.Dispose();
Interlocked.Exchange(ref _multiBufferReceiveEventArgs, null)?.Dispose();
Interlocked.Exchange(ref _multiBufferSendEventArgs, null)?.Dispose();
Interlocked.Exchange(ref _singleBufferReceiveEventArgs, null)?.Dispose();
Expand Down Expand Up @@ -907,7 +846,7 @@ internal AsyncTaskMethodBuilder<TResult> GetCompletionResponsibility(out bool re
}

/// <summary>A SocketAsyncEventArgs that can be awaited to get the result of an operation.</summary>
internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, IValueTaskSource, IValueTaskSource<int>, IValueTaskSource<SocketReceiveFromResult>, IValueTaskSource<SocketReceiveMessageFromResult>
internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, IValueTaskSource, IValueTaskSource<int>, IValueTaskSource<Socket>, IValueTaskSource<SocketReceiveFromResult>, IValueTaskSource<SocketReceiveMessageFromResult>
{
private static readonly Action<object?> s_completedSentinel = new Action<object?>(state => throw new InvalidOperationException(SR.Format(SR.net_sockets_valuetaskmisuse, nameof(s_completedSentinel))));
/// <summary>The owning socket.</summary>
Expand Down Expand Up @@ -987,6 +926,28 @@ protected override void OnCompleted(SocketAsyncEventArgs _)
}
}

/// <summary>Initiates an accept operation on the associated socket.</summary>
/// <returns>This instance.</returns>
public ValueTask<Socket> AcceptAsync(Socket socket, CancellationToken cancellationToken)
{
Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use");

if (socket.AcceptAsync(this, cancellationToken))
{
_cancellationToken = cancellationToken;
return new ValueTask<Socket>(this, _token);
}

Socket acceptSocket = AcceptSocket!;
SocketError error = SocketError;

Release();

return error == SocketError.Success ?
new ValueTask<Socket>(acceptSocket) :
ValueTask.FromException<Socket>(CreateException(error));
}

/// <summary>Initiates a receive operation on the associated socket.</summary>
/// <returns>This instance.</returns>
public ValueTask<int> ReceiveAsync(Socket socket, CancellationToken cancellationToken)
Expand Down Expand Up @@ -1288,7 +1249,7 @@ private void InvokeContinuation(Action<object?> continuation, object? state, boo
/// Unlike TaskAwaiter's GetResult, this does not block until the operation completes: it must only
/// be used once the operation has completed. This is handled implicitly by await.
/// </remarks>
public int GetResult(short token)
int IValueTaskSource<int>.GetResult(short token)
{
if (token != _token)
{
Expand Down Expand Up @@ -1326,6 +1287,26 @@ void IValueTaskSource.GetResult(short token)
}
}

Socket IValueTaskSource<Socket>.GetResult(short token)
{
if (token != _token)
{
ThrowIncorrectTokenException();
}

SocketError error = SocketError;
Socket acceptSocket = AcceptSocket!;
CancellationToken cancellationToken = _cancellationToken;

Release();

if (error != SocketError.Success)
{
ThrowException(error, cancellationToken);
}
return acceptSocket;
}

SocketReceiveFromResult IValueTaskSource<SocketReceiveFromResult>.GetResult(short token)
{
if (token != _token)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2656,7 +2656,9 @@ public void Shutdown(SocketShutdown how)
// Async methods
//

public bool AcceptAsync(SocketAsyncEventArgs e)
public bool AcceptAsync(SocketAsyncEventArgs e) => AcceptAsync(e, CancellationToken.None);

private bool AcceptAsync(SocketAsyncEventArgs e, CancellationToken cancellationToken)
{
ThrowIfDisposed();

Expand Down Expand Up @@ -2689,7 +2691,7 @@ public bool AcceptAsync(SocketAsyncEventArgs e)
SocketError socketError;
try
{
socketError = e.DoOperationAccept(this, _handle, acceptHandle);
socketError = e.DoOperationAccept(this, _handle, acceptHandle, cancellationToken);
}
catch (Exception ex)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1433,7 +1433,7 @@ public SocketError Accept(byte[] socketAddress, ref int socketAddressLen, out In
return operation.ErrorCode;
}

public SocketError AcceptAsync(byte[] socketAddress, ref int socketAddressLen, out IntPtr acceptedFd, Action<IntPtr, byte[], int, SocketError> callback)
public SocketError AcceptAsync(byte[] socketAddress, ref int socketAddressLen, out IntPtr acceptedFd, Action<IntPtr, byte[], int, SocketError> callback, CancellationToken cancellationToken)
{
Debug.Assert(socketAddress != null, "Expected non-null socketAddress");
Debug.Assert(socketAddressLen > 0, $"Unexpected socketAddressLen: {socketAddressLen}");
Expand All @@ -1456,7 +1456,7 @@ public SocketError AcceptAsync(byte[] socketAddress, ref int socketAddressLen, o
operation.SocketAddress = socketAddress;
operation.SocketAddressLen = socketAddressLen;

if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber))
if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber, cancellationToken))
{
socketAddressLen = operation.SocketAddressLen;
acceptedFd = operation.AcceptedFileDescriptor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ private void CompleteAcceptOperation(IntPtr acceptedFileDescriptor, byte[] socke
_acceptAddressBufferCount = socketAddressSize;
}

internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle handle, SafeSocketHandle? acceptHandle)
internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle handle, SafeSocketHandle? acceptHandle, CancellationToken cancellationToken)
{
if (!_buffer.Equals(default))
{
Expand All @@ -64,7 +64,7 @@ internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle ha

IntPtr acceptedFd;
int socketAddressLen = _acceptAddressBufferCount / 2;
SocketError socketError = handle.AsyncContext.AcceptAsync(_acceptBuffer!, ref socketAddressLen, out acceptedFd, AcceptCompletionCallback);
SocketError socketError = handle.AsyncContext.AcceptAsync(_acceptBuffer!, ref socketAddressLen, out acceptedFd, AcceptCompletionCallback, cancellationToken);

if (socketError != SocketError.IOPending)
{
Expand Down
Loading