Skip to content

Commit

Permalink
basic support for TCP fast open (dotnet#99490)
Browse files Browse the repository at this point in the history
* initial drop

* update

* cleanup

* err

* 'test'

* windows

* sync

* feedback & updates

* macos

* windows

* feedback

* await

* docs

* note

* macos

---------

Co-authored-by: Ubuntu <toweinfu@toweinfu-ubu22.c5goow0wwwee5hembzptmbtr0h.xx.internal.cloudapp.net>
  • Loading branch information
2 people authored and steveharter committed May 28, 2024
1 parent c2c6f9b commit caab483
Show file tree
Hide file tree
Showing 16 changed files with 419 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,8 @@ internal static partial class Sys
{
[LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_Connect")]
internal static unsafe partial Error Connect(SafeHandle socket, byte* socketAddress, int socketAddressLen);

[LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_Connectx")]
internal static unsafe partial Error Connectx(SafeHandle socket, byte* socketAddress, int socketAddressLen, Span<byte> buffer, int bufferLen, int enableTfo, int* sent);
}
}
1 change: 1 addition & 0 deletions src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,7 @@ public enum SocketOptionName
DropMembership = 13,
DontFragment = 14,
AddSourceMembership = 15,
FastOpen = 15,
DontRoute = 16,
DropSourceMembership = 16,
TcpKeepAliveRetryCount = 16,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFrameworks>$(NetCoreAppCurrent)-windows;$(NetCoreAppCurrent)-unix;$(NetCoreAppCurrent)</TargetFrameworks>
<TargetFrameworks>$(NetCoreAppCurrent)-windows;$(NetCoreAppCurrent)-unix;$(NetCoreAppCurrent)-osx;$(NetCoreAppCurrent)-ios;$(NetCoreAppCurrent)-tvos;$(NetCoreAppCurrent)</TargetFrameworks>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<!-- SYSTEM_NET_SOCKETS_DLL is required to allow source-level code sharing for types defined within the
System.Net.Internals namespace. -->
Expand All @@ -13,6 +13,8 @@
<PropertyGroup>
<TargetPlatformIdentifier>$([MSBuild]::GetTargetPlatformIdentifier('$(TargetFramework)'))</TargetPlatformIdentifier>
<GeneratePlatformNotSupportedAssemblyMessage Condition="'$(TargetPlatformIdentifier)' == ''">SR.SystemNetSockets_PlatformNotSupported</GeneratePlatformNotSupportedAssemblyMessage>
<IsApplePlatform Condition="'$(TargetPlatformIdentifier)' == 'osx' or '$(TargetPlatformIdentifier)' == 'ios' or '$(TargetPlatformIdentifier)' == 'tvos'">true</IsApplePlatform>
<DefineConstants Condition="'$(IsApplePlatform)' == 'true'">$(DefineConstants);SYSTEM_NET_SOCKETS_APPLE_PLATFROM</DefineConstants>
</PropertyGroup>

<ItemGroup Condition="'$(TargetPlatformIdentifier)' != ''">
Expand Down Expand Up @@ -181,7 +183,7 @@
Link="Common\System\Net\CompletionPortHelper.Windows.cs" />
</ItemGroup>

<ItemGroup Condition="'$(TargetPlatformIdentifier)' == 'unix'">
<ItemGroup Condition="'$(TargetPlatformIdentifier)' == 'unix' or '$(TargetPlatformIdentifier)' == 'osx' or '$(TargetPlatformIdentifier)' == 'ios' or '$(TargetPlatformIdentifier)' == 'tvos'">
<Compile Include="System\Net\Sockets\SafeSocketHandle.Unix.cs" />
<Compile Include="System\Net\Sockets\Socket.Unix.cs" />
<Compile Include="System\Net\Sockets\SocketAsyncContext.Unix.cs" />
Expand Down Expand Up @@ -301,7 +303,7 @@
<Reference Include="System.Threading.ThreadPool" />
</ItemGroup>

<ItemGroup Condition="'$(TargetPlatformIdentifier)' == 'unix'">
<ItemGroup Condition="'$(TargetPlatformIdentifier)' != 'windows'">
<Reference Include="System.Threading.Thread" />
</ItemGroup>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ public partial class SafeSocketHandle
internal bool ExposedHandleOrUntrackedConfiguration { get; private set; }
internal bool PreferInlineCompletions { get; set; } = SocketAsyncEngine.InlineSocketCompletionsEnabled;
internal bool IsSocket { get; set; } = true; // (ab)use Socket class for performing async I/O on non-socket fds.

#if SYSTEM_NET_SOCKETS_APPLE_PLATFROM
internal bool TfoEnabled { get; set; }
#endif
internal void RegisterConnectResult(SocketError error)
{
switch (error)
Expand All @@ -44,6 +46,9 @@ internal void TransferTrackedState(SafeSocketHandle target)
target.DualMode = DualMode;
target.ExposedHandleOrUntrackedConfiguration = ExposedHandleOrUntrackedConfiguration;
target.IsSocket = IsSocket;
#if SYSTEM_NET_SOCKETS_APPLE_PLATFROM
target.TfoEnabled = TfoEnabled;
#endif
}

internal void SetExposed() => ExposedHandleOrUntrackedConfiguration = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ public override unsafe void InvokeCallback(bool allowPooling) =>
Callback!(BytesTransferred, SocketAddress, SocketFlags.None, ErrorCode);
}

private sealed class BufferMemorySendOperation : SendOperation
private class BufferMemorySendOperation : SendOperation
{
public Memory<byte> Buffer;

Expand Down Expand Up @@ -648,21 +648,47 @@ public override void InvokeCallback(bool allowPooling)
}
}

private sealed class ConnectOperation : WriteOperation
private sealed class ConnectOperation : BufferMemorySendOperation
{
public ConnectOperation(SocketAsyncContext context) : base(context) { }

public Action<SocketError>? Callback { get; set; }

protected override bool DoTryComplete(SocketAsyncContext context)
{
bool result = SocketPal.TryCompleteConnect(context._socket, out ErrorCode);
context._socket.RegisterConnectResult(ErrorCode);

if (result && ErrorCode == SocketError.Success && Buffer.Length > 0)
{
SocketError error = context.SendToAsync(Buffer, 0, Buffer.Length, SocketFlags.None, Memory<byte>.Empty, ref BytesTransferred, Callback!, default);
if (error != SocketError.Success && error != SocketError.IOPending)
{
context._socket.RegisterConnectResult(ErrorCode);
}
}
return result;
}

public override void InvokeCallback(bool allowPooling) =>
Callback!(ErrorCode);
public override unsafe void InvokeCallback(bool allowPooling)
{
var cb = Callback!;
int bt = BytesTransferred;
Memory<byte> sa = SocketAddress;
SocketError ec = ErrorCode;
Memory<byte> buffer = Buffer;

if (allowPooling)
{
AssociatedContext.ReturnOperation(this);
}

if (buffer.Length == 0)
{
// Invoke callback only when we are completely done.
// In case data were provided for Connect we may or may not send them all.
// If we did not we will need follow-up with Send operation
cb(bt, sa, SocketFlags.None, ec);
}
}
}

private sealed class SendFileOperation : WriteOperation
Expand Down Expand Up @@ -1478,7 +1504,6 @@ public SocketError AcceptAsync(Memory<byte> socketAddress, out int socketAddress
public SocketError Connect(Memory<byte> socketAddress)
{
Debug.Assert(socketAddress.Length > 0, $"Unexpected socketAddressLen: {socketAddress.Length}");

// Connect is different than the usual "readiness" pattern of other operations.
// We need to call TryStartConnect to initiate the connect with the OS,
// before we try to complete it via epoll notification.
Expand All @@ -1503,7 +1528,7 @@ public SocketError Connect(Memory<byte> socketAddress)
return operation.ErrorCode;
}

public SocketError ConnectAsync(Memory<byte> socketAddress, Action<SocketError> callback)
public SocketError ConnectAsync(Memory<byte> socketAddress, Action<int, Memory<byte>, SocketFlags, SocketError> callback, Memory<byte> buffer, out int sentBytes)
{
Debug.Assert(socketAddress.Length > 0, $"Unexpected socketAddressLen: {socketAddress.Length}");
Debug.Assert(callback != null, "Expected non-null callback");
Expand All @@ -1516,20 +1541,37 @@ public SocketError ConnectAsync(Memory<byte> socketAddress, Action<SocketError>
SocketError errorCode;
int observedSequenceNumber;
_sendQueue.IsReady(this, out observedSequenceNumber);
if (SocketPal.TryStartConnect(_socket, socketAddress, out errorCode))
#if SYSTEM_NET_SOCKETS_APPLE_PLATFROM
if (SocketPal.TryStartConnect(_socket, socketAddress, out errorCode, buffer.Span, _socket.TfoEnabled, out sentBytes))
#else
if (SocketPal.TryStartConnect(_socket, socketAddress, out errorCode, buffer.Span, false, out sentBytes)) // In Linux, we can figure it out as needed inside PAL.
#endif
{
_socket.RegisterConnectResult(errorCode);

int remains = buffer.Length - sentBytes;

if (errorCode == SocketError.Success && remains > 0)
{
errorCode = SendToAsync(buffer.Slice(sentBytes), 0, remains, SocketFlags.None, Memory<byte>.Empty, ref sentBytes, callback!, default);
}
return errorCode;
}

var operation = new ConnectOperation(this)
{
Callback = callback,
SocketAddress = socketAddress,
Buffer = buffer.Slice(sentBytes),
BytesTransferred = sentBytes,
};

if (!_sendQueue.StartAsyncOperation(this, operation, observedSequenceNumber))
{
if (operation.ErrorCode == SocketError.Success)
{
sentBytes += operation.BytesTransferred;
}
return operation.ErrorCode;
}

Expand Down Expand Up @@ -1880,7 +1922,8 @@ public SocketError Send(byte[] buffer, int offset, int count, SocketFlags flags,

public SocketError SendAsync(Memory<byte> buffer, int offset, int count, SocketFlags flags, out int bytesSent, Action<int, Memory<byte>, SocketFlags, SocketError> callback, CancellationToken cancellationToken)
{
return SendToAsync(buffer, offset, count, flags, Memory<byte>.Empty, out bytesSent, callback, cancellationToken);
bytesSent = 0;
return SendToAsync(buffer, offset, count, flags, Memory<byte>.Empty, ref bytesSent, callback, cancellationToken);
}

public SocketError SendTo(byte[] buffer, int offset, int count, SocketFlags flags, Memory<byte> socketAddress, int timeout, out int bytesSent)
Expand Down Expand Up @@ -1947,11 +1990,10 @@ public unsafe SocketError SendTo(ReadOnlySpan<byte> buffer, SocketFlags flags, M
}
}

public SocketError SendToAsync(Memory<byte> buffer, int offset, int count, SocketFlags flags, Memory<byte> socketAddress, out int bytesSent, Action<int, Memory<byte>, SocketFlags, SocketError> callback, CancellationToken cancellationToken = default)
public SocketError SendToAsync(Memory<byte> buffer, int offset, int count, SocketFlags flags, Memory<byte> socketAddress, ref int bytesSent, Action<int, Memory<byte>, SocketFlags, SocketError> callback, CancellationToken cancellationToken = default)
{
SetHandleNonBlocking();

bytesSent = 0;
SocketError errorCode;
int observedSequenceNumber;
if (_sendQueue.IsReady(this, out observedSequenceNumber) &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,24 @@ internal unsafe SocketError DoOperationAccept(Socket _ /*socket*/, SafeSocketHan
return socketError;
}

private void ConnectCompletionCallback(SocketError socketError)
private void ConnectCompletionCallback(int bytesTransferred, Memory<byte> socketAddress, SocketFlags receivedFlags, SocketError socketError)
{
CompletionCallback(0, SocketFlags.None, socketError);
CompletionCallback(bytesTransferred, SocketFlags.None, socketError);
}

internal unsafe SocketError DoOperationConnectEx(Socket _ /*socket*/, SafeSocketHandle handle)
=> DoOperationConnect(handle);
{
SocketError socketError = handle.AsyncContext.ConnectAsync(_socketAddress!.Buffer, ConnectCompletionCallback, _buffer.Slice(_offset, _count), out int sentBytes);
if (socketError != SocketError.IOPending)
{
FinishOperationSync(socketError, sentBytes, SocketFlags.None);
}
return socketError;
}

internal unsafe SocketError DoOperationConnect(SafeSocketHandle handle)
{
SocketError socketError = handle.AsyncContext.ConnectAsync(_socketAddress!.Buffer, ConnectCompletionCallback);
SocketError socketError = handle.AsyncContext.ConnectAsync(_socketAddress!.Buffer, ConnectCompletionCallback, Memory<byte>.Empty, out int _);
if (socketError != SocketError.IOPending)
{
FinishOperationSync(socketError, 0, SocketFlags.None);
Expand Down Expand Up @@ -299,11 +306,11 @@ internal SocketError DoOperationSendTo(SafeSocketHandle handle, CancellationToke
_receivedFlags = System.Net.Sockets.SocketFlags.None;
_socketAddressSize = 0;

int bytesSent;
int bytesSent = 0;
SocketError errorCode;
if (_bufferList == null)
{
errorCode = handle.AsyncContext.SendToAsync(_buffer, _offset, _count, _socketFlags, _socketAddress!.Buffer, out bytesSent, TransferCompletionCallback, cancellationToken);
errorCode = handle.AsyncContext.SendToAsync(_buffer, _offset, _count, _socketFlags, _socketAddress!.Buffer, ref bytesSent, TransferCompletionCallback, cancellationToken);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,34 @@ public enum SocketOptionName
#endregion

#region SocketOptionLevel.Tcp
// Disables the Nagle algorithm for send coalescing.
/// <summary>
/// Disables the Nagle algorithm for send coalescing.
/// </summary>
NoDelay = 1,
/// <summary>
/// Use urgent data as defined in RFC-1222. This option can be set only once; after it is set, it cannot be turned off.
/// </summary>
BsdUrgent = 2,
/// <summary>
/// Use expedited data as defined in RFC-1222. This option can be set only once; after it is set, it cannot be turned off.
/// </summary>
Expedited = 2,
/// <summary>
/// This enables TCP Fast Open as defined in RFC-7413. The actual observed behavior depend on OS configuration and state of kernel TCP cookie cache.
/// Enabling TFO can impact interoperability and casue connectivity issues.
/// </summary>
FastOpen = 15,
/// <summary>
/// The number of TCP keep alive probes that will be sent before the connection is terminated.
/// </summary>
TcpKeepAliveRetryCount = 16,
/// <summary>
/// The number of seconds a TCP connection will remain alive/idle before keepalive probes are sent to the remote.
/// </summary>
TcpKeepAliveTime = 3,
/// <summary>
/// The number of seconds a TCP connection will wait for a keepalive response before sending another keepalive probe.
/// </summary>
TcpKeepAliveInterval = 17,
#endregion

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -647,9 +647,12 @@ public static unsafe bool TryCompleteAccept(SafeSocketHandle socket, Memory<byte
return false;
}

public static unsafe bool TryStartConnect(SafeSocketHandle socket, Memory<byte> socketAddress, out SocketError errorCode)
public static unsafe bool TryStartConnect(SafeSocketHandle socket, Memory<byte> socketAddress, out SocketError errorCode) => TryStartConnect(socket, socketAddress, out errorCode, Span<byte>.Empty, false, out int _ );

public static unsafe bool TryStartConnect(SafeSocketHandle socket, Memory<byte> socketAddress, out SocketError errorCode, Span<byte> data, bool tfo, out int sent)
{
Debug.Assert(socketAddress.Length > 0, $"Unexpected socketAddressLen: {socketAddress.Length}");
sent = 0;

if (socket.IsDisconnected)
{
Expand All @@ -660,7 +663,16 @@ public static unsafe bool TryStartConnect(SafeSocketHandle socket, Memory<byte>
Interop.Error err;
fixed (byte* rawSocketAddress = socketAddress.Span)
{
err = Interop.Sys.Connect(socket, rawSocketAddress, socketAddress.Length);
if (data.Length > 0)
{
int sentBytes = 0;
err = Interop.Sys.Connectx(socket, rawSocketAddress, socketAddress.Length, data, data.Length, tfo ? 1 : 0, &sentBytes);
sent = sentBytes;
}
else
{
err = Interop.Sys.Connect(socket, rawSocketAddress, socketAddress.Length);
}
}

if (err == Interop.Error.SUCCESS)
Expand Down Expand Up @@ -1451,6 +1463,18 @@ public static unsafe SocketError SetSockOpt(SafeSocketHandle handle, SocketOptio
}
}

#if SYSTEM_NET_SOCKETS_APPLE_PLATFROM
// macOS fails to even query it if socket is not actively listening.
// To provide consistent platform experience we will track if
// it was ret and we will use it later as needed.
if (optionLevel == SocketOptionLevel.Tcp && optionName == SocketOptionName.FastOpen)
{
handle.TfoEnabled = optionValue != 0;
// Silently ignore errors - TFO is best effort and it may be disabled by configuration or not
// supported by OS.
err = Interop.Error.SUCCESS;
}
#endif
return GetErrorAndTrackSetting(handle, optionLevel, optionName, err);
}

Expand Down Expand Up @@ -1580,6 +1604,17 @@ public static unsafe SocketError GetSockOpt(SafeSocketHandle handle, SocketOptio
int optLen = sizeof(int);
Interop.Error err = Interop.Sys.GetSockOpt(handle, optionLevel, optionName, (byte*)&value, &optLen);

#if SYSTEM_NET_SOCKETS_APPLE_PLATFROM
// macOS fails to even query it if socket is not actively listening.
// To provide consistent platform experience we will track if
// it was set and we will use it later as needed.
if (optionLevel == SocketOptionLevel.Tcp && optionName == SocketOptionName.FastOpen && err != Interop.Error.SUCCESS)
{
value = handle.TfoEnabled ? 1 : 0;
err = Interop.Error.SUCCESS;
}
#endif

optionValue = value;
return err == Interop.Error.SUCCESS ? SocketError.Success : GetSocketErrorForErrorCode(err);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ public sealed class ConnectEap : Connect<SocketHelperEap>
public ConnectEap(ITestOutputHelper output) : base(output) {}

[Theory]
[PlatformSpecific(TestPlatforms.Windows)]
[InlineData(true)]
[InlineData(false)]
public async Task ConnectAsync_WithData_DataReceived(bool useArrayApi)
Expand Down
Loading

0 comments on commit caab483

Please sign in to comment.