Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support TLS Resume with client certificates on Linux #102656

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,42 @@ internal static partial class OpenSsl
private const string TlsCacheSizeCtxName = "System.Net.Security.TlsCacheSize";
private const string TlsCacheSizeEnvironmentVariable = "DOTNET_SYSTEM_NET_SECURITY_TLSCACHESIZE";
private const SslProtocols FakeAlpnSslProtocol = (SslProtocols)1; // used to distinguish server sessions with ALPN
private static readonly ConcurrentDictionary<SslProtocols, SafeSslContextHandle> s_clientSslContexts = new ConcurrentDictionary<SslProtocols, SafeSslContextHandle>();

private sealed class SafeSslContextCache : SafeHandleCache<SslContextCacheKey, SafeSslContextHandle> { }

private static readonly SafeSslContextCache s_clientSslContexts = new();

internal readonly struct SslContextCacheKey : IEquatable<SslContextCacheKey>
{
public readonly byte[]? CertificateThumbprint;
public readonly SslProtocols SslProtocols;

public SslContextCacheKey(SslProtocols sslProtocols, byte[]? certificateThumbprint)
{
SslProtocols = sslProtocols;
CertificateThumbprint = certificateThumbprint;
}

public override bool Equals(object? obj) => obj is SslContextCacheKey key && Equals(key);

public bool Equals(SslContextCacheKey other) =>
wfurt marked this conversation as resolved.
Show resolved Hide resolved
SslProtocols == other.SslProtocols &&
(CertificateThumbprint == null && other.CertificateThumbprint == null ||
CertificateThumbprint != null && other.CertificateThumbprint != null && CertificateThumbprint.AsSpan().SequenceEqual(other.CertificateThumbprint));

public override int GetHashCode()
rzikm marked this conversation as resolved.
Show resolved Hide resolved
{
HashCode hash = default;

hash.Add(SslProtocols);
if (CertificateThumbprint != null)
{
hash.AddBytes(CertificateThumbprint);
}

return hash.ToHashCode();
}
}

#region internal methods
internal static SafeChannelBindingHandle? QueryChannelBinding(SafeSslHandle context, ChannelBindingKind bindingType)
Expand Down Expand Up @@ -113,6 +148,54 @@ private static SslProtocols CalculateEffectiveProtocols(SslAuthenticationOptions
return protocols;
}

internal static SafeSslContextHandle GetOrCreateSslContextHandle(SslAuthenticationOptions sslAuthenticationOptions, bool allowCached)
{
SslProtocols protocols = CalculateEffectiveProtocols(sslAuthenticationOptions);

if (!allowCached)
{
return AllocateSslContext(sslAuthenticationOptions, protocols, allowCached);
}

if (!sslAuthenticationOptions.IsServer)
rzikm marked this conversation as resolved.
Show resolved Hide resolved
rzikm marked this conversation as resolved.
Show resolved Hide resolved
{
var key = new SslContextCacheKey(protocols, sslAuthenticationOptions.CertificateContext?.TargetCertificate.GetCertHash(HashAlgorithmName.SHA256));

return s_clientSslContexts.GetOrCreate(key, static (args) =>
{
var (sslAuthOptions, protocols, allowCached) = args;
return AllocateSslContext(sslAuthOptions, protocols, allowCached);
}, (sslAuthenticationOptions, protocols, allowCached));
}

// cache in SslStreamCertificateContext is bounded and there is no eviction
// so the handle should always be valid,

bool hasAlpn = sslAuthenticationOptions.ApplicationProtocols != null && sslAuthenticationOptions.ApplicationProtocols.Count != 0;

SafeSslContextHandle? handle = AllocateSslContext(sslAuthenticationOptions, protocols, allowCached);

if (!sslAuthenticationOptions.CertificateContext!.SslContexts!.TryGetValue(protocols | (hasAlpn ? FakeAlpnSslProtocol : SslProtocols.None), out handle))
{
// not found in cache, create and insert
handle = AllocateSslContext(sslAuthenticationOptions, protocols, allowCached);

SafeSslContextHandle cached = sslAuthenticationOptions.CertificateContext!.SslContexts!.GetOrAdd(protocols | (hasAlpn ? FakeAlpnSslProtocol : SslProtocols.None), handle);

if (handle != cached)
{
// lost the race, another thread created the SSL_CTX meanwhile, prefer the cached one
handle.Dispose();
Debug.Assert(handle.IsClosed);
handle = cached;
}
}

Debug.Assert(!handle.IsClosed);
handle.TryAddRentCount();
return handle;
}

// This essentially wraps SSL_CTX* aka SSL_CTX_new + setting
internal static unsafe SafeSslContextHandle AllocateSslContext(SslAuthenticationOptions sslAuthenticationOptions, SslProtocols protocols, bool enableResume)
{
Expand Down Expand Up @@ -188,7 +271,7 @@ internal static unsafe SafeSslContextHandle AllocateSslContext(SslAuthentication
Interop.Ssl.SslCtxSetAlpnSelectCb(sslCtx, &AlpnServerSelectCallback, IntPtr.Zero);
}

if (sslAuthenticationOptions.CertificateContext != null)
if (sslAuthenticationOptions.CertificateContext != null && sslAuthenticationOptions.IsServer)
{
SetSslCertificate(sslCtx, sslAuthenticationOptions.CertificateContext.CertificateHandle, sslAuthenticationOptions.CertificateContext.KeyHandle);

Expand Down Expand Up @@ -257,10 +340,6 @@ internal static void UpdateClientCertificate(SafeSslHandle ssl, SslAuthenticatio
internal static SafeSslHandle AllocateSslHandle(SslAuthenticationOptions sslAuthenticationOptions)
{
SafeSslHandle? sslHandle = null;
SafeSslContextHandle? sslCtxHandle = null;
SafeSslContextHandle? newCtxHandle = null;
SslProtocols protocols = CalculateEffectiveProtocols(sslAuthenticationOptions);
bool hasAlpn = sslAuthenticationOptions.ApplicationProtocols != null && sslAuthenticationOptions.ApplicationProtocols.Count != 0;
bool cacheSslContext = sslAuthenticationOptions.AllowTlsResume && !SslStream.DisableTlsResume && sslAuthenticationOptions.EncryptionPolicy == EncryptionPolicy.RequireEncryption && sslAuthenticationOptions.CipherSuitesPolicy == null;

if (cacheSslContext)
Expand All @@ -269,13 +348,12 @@ internal static SafeSslHandle AllocateSslHandle(SslAuthenticationOptions sslAuth
{
// We don't support client resume on old OpenSSL versions.
// We don't want to try on empty TargetName since that is our key.
// And we don't want to mess up with client authentication. It may be possible
// but it seems safe to get full new session.
// If we already have CertificateContext, then we know which cert the user wants to use and we can cache.
// The only client auth scenario where we can't cache is when user provides a cert callback and we don't know
// beforehand which cert will be used. and wan't to avoid resuming session created with different certificate.
if (!Interop.Ssl.Capabilities.Tls13Supported ||
string.IsNullOrEmpty(sslAuthenticationOptions.TargetHost) ||
sslAuthenticationOptions.CertificateContext != null ||
sslAuthenticationOptions.ClientCertificates?.Count > 0 ||
sslAuthenticationOptions.CertSelectionDelegate != null)
(sslAuthenticationOptions.CertificateContext == null && sslAuthenticationOptions.CertSelectionDelegate != null))
{
cacheSslContext = false;
}
Expand All @@ -292,35 +370,14 @@ internal static SafeSslHandle AllocateSslHandle(SslAuthenticationOptions sslAuth
}
}

if (cacheSslContext)
{
if (sslAuthenticationOptions.IsServer)
{
sslAuthenticationOptions.CertificateContext!.SslContexts!.TryGetValue(protocols | (hasAlpn ? FakeAlpnSslProtocol : SslProtocols.None), out sslCtxHandle);
}
else
{

s_clientSslContexts.TryGetValue(protocols, out sslCtxHandle);
}
}

if (sslCtxHandle == null)
{
// We did not get SslContext from cache
sslCtxHandle = newCtxHandle = AllocateSslContext(sslAuthenticationOptions, protocols, cacheSslContext);

if (cacheSslContext)
{
bool added = sslAuthenticationOptions.IsServer ?
sslAuthenticationOptions.CertificateContext!.SslContexts!.TryAdd(protocols | (SslProtocols)(hasAlpn ? 1 : 0), newCtxHandle) :
s_clientSslContexts.TryAdd(protocols, newCtxHandle);
if (added)
{
newCtxHandle = null;
}
}
}
// We do not touch the SSL_CTX after we create and configure SSL
// objects, and SSL object created later in this function will keep an
// outstanding up-ref on SSL_CTX.
//
// For uncached SafeSslContextHandles, the handle will be disposed and closed.
// Cached SafeSslContextHandles are returned with increaset rent count so that
// Dispose() here will not close the handle.
using SafeSslContextHandle sslCtxHandle = GetOrCreateSslContextHandle(sslAuthenticationOptions, cacheSslContext);

GCHandle alpnHandle = default;
try
Expand Down Expand Up @@ -361,19 +418,25 @@ internal static SafeSslHandle AllocateSslHandle(SslAuthenticationOptions sslAuth
Crypto.ErrClearError();
}


if (cacheSslContext)
{
sslCtxHandle.TrySetSession(sslHandle, sslAuthenticationOptions.TargetHost);
bool ignored = false;
sslCtxHandle.DangerousAddRef(ref ignored);

// Maintain additional rent count for the context so
// that it is not evicted from the cache and future
// SSL objects can reuse it. This call should always
// succeed because already have increased rent count
// when getting the context from the cache
bool success = sslCtxHandle.TryAddRentCount();
Debug.Assert(success);
sslHandle.SslContextHandle = sslCtxHandle;
}
}

// relevant to TLS 1.3 only: if user supplied a client cert or cert callback,
// advertise that we are willing to send the certificate post-handshake.
if (sslAuthenticationOptions.ClientCertificates?.Count > 0 ||
if (sslAuthenticationOptions.CertificateContext != null ||
sslAuthenticationOptions.ClientCertificates?.Count > 0 ||
sslAuthenticationOptions.CertSelectionDelegate != null)
{
Ssl.SslSetPostHandshakeAuth(sslHandle, 1);
Expand Down Expand Up @@ -434,10 +497,6 @@ internal static SafeSslHandle AllocateSslHandle(SslAuthenticationOptions sslAuth

throw;
}
finally
{
newCtxHandle?.Dispose();
}

return sslHandle;
}
Expand Down Expand Up @@ -708,6 +767,12 @@ private static unsafe int NewSessionCallback(IntPtr ssl, IntPtr session)
Debug.Assert(ssl != IntPtr.Zero);
Debug.Assert(session != IntPtr.Zero);

// remember if the session used a certificate, this information is used after
// session resumption, the pointer is not being dereferenced and the refcount
// is not going to be manipulated.
IntPtr cert = Interop.Ssl.SslGetCertificate(ssl);
Interop.Ssl.SslSessionSetData(session, cert);

IntPtr ptr = Ssl.SslGetData(ssl);
if (ptr != IntPtr.Zero)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ internal static unsafe ReadOnlySpan<byte> SslGetAlpnSelected(SafeSslHandle ssl)
[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetPeerCertificate")]
internal static partial IntPtr SslGetPeerCertificate(SafeSslHandle ssl);

[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetCertificate")]
internal static partial IntPtr SslGetCertificate(SafeSslHandle ssl);

[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetCertificate")]
internal static partial IntPtr SslGetCertificate(IntPtr ssl);

[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetPeerCertChain")]
internal static partial SafeSharedX509StackHandle SslGetPeerCertChain(SafeSslHandle ssl);

Expand All @@ -129,6 +135,9 @@ internal static unsafe ReadOnlySpan<byte> SslGetAlpnSelected(SafeSslHandle ssl)
[return: MarshalAs(UnmanagedType.Bool)]
internal static partial bool SslSessionReused(SafeSslHandle ssl);

[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetSession")]
internal static partial IntPtr SslGetSession(SafeSslHandle ssl);

[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetClientCAList")]
private static partial SafeSharedX509NameStackHandle SslGetClientCAList_private(SafeSslHandle ssl);

Expand Down Expand Up @@ -182,6 +191,12 @@ internal static unsafe ReadOnlySpan<byte> SslGetAlpnSelected(SafeSslHandle ssl)
[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslSessionSetHostname")]
internal static partial int SessionSetHostname(IntPtr session, IntPtr name);

[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslSessionGetData")]
internal static partial IntPtr SslSessionGetData(IntPtr session);

[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslSessionSetData")]
internal static partial void SslSessionSetData(IntPtr session, IntPtr val);

internal static class Capabilities
{
// needs separate type (separate static cctor) to be sure OpenSSL is initialized.
Expand Down Expand Up @@ -430,7 +445,9 @@ protected override bool ReleaseHandle()
Disconnect();
}

SslContextHandle?.DangerousRelease();
// drop reference to any SSL_CTX handle, any handle present here is being
// rented from (client) SSL_CTX cache.
SslContextHandle?.Dispose();

if (AlpnHandle.IsAllocated)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Net;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.Net.Security;
using System.Runtime.InteropServices;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading;
using Microsoft.Win32.SafeHandles;

internal static partial class Interop
Expand Down Expand Up @@ -65,12 +67,17 @@ internal static bool AddExtraChainCertificates(SafeSslContextHandle ctx, ReadOnl

namespace Microsoft.Win32.SafeHandles
{
internal sealed class SafeSslContextHandle : SafeHandle
internal sealed class SafeSslContextHandle : SafeHandle, ISafeHandleCachable
{
// This is session cache keyed by SNI e.g. TargetHost
private Dictionary<string, IntPtr>? _sslSessions;
private GCHandle _gch;

// SSL_CTX handles are cached, so we need to keep track of the
// number of times a handle is being used. Once we decide to dispose the handle,
// we set the _rentCount to -1.
private volatile int _rentCount;

public SafeSslContextHandle()
: base(IntPtr.Zero, true)
{
Expand All @@ -86,6 +93,38 @@ public override bool IsInvalid
get { return handle == IntPtr.Zero; }
}

public bool TryAddRentCount()
{
int oldCount;

do
{
oldCount = _rentCount;
if (oldCount < 0)
{
// The handle is already disposed.
return false;
}
} while (Interlocked.CompareExchange(ref _rentCount, oldCount + 1, oldCount) != oldCount);

return true;
}

public bool TryMarkForDispose()
{
return Interlocked.CompareExchange(ref _rentCount, -1, 0) == 0;
}

protected override void Dispose(bool disposing)
{
if (Interlocked.Decrement(ref _rentCount) < 0)
{
// _rentCount is 0 if the handle was never rented (e.g. failure during creation),
// and is -1 when evicted from cache.
base.Dispose(disposing);
}
}

protected override bool ReleaseHandle()
{
if (_sslSessions != null)
Expand Down
Loading
Loading