Skip to content

Commit

Permalink
refactor rediscovery to use DU
Browse files Browse the repository at this point in the history
  • Loading branch information
thefringeninja committed Feb 16, 2022
1 parent b08718f commit efe7bf7
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 46 deletions.
25 changes: 15 additions & 10 deletions src/EventStore.Client/EventStoreClientBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public abstract class EventStoreClientBase :
private readonly IDictionary<string, Func<RpcException, Exception>> _exceptionMap;
private readonly CancellationTokenSource _cts;
private readonly ChannelCache _channelCache;
private readonly SharingProvider<DnsEndPoint?, ChannelInfo> _channelInfoProvider;
private readonly SharingProvider<ReconnectionRequired, ChannelInfo> _channelInfoProvider;

/// <summary>
/// The name of the connection.
Expand All @@ -48,27 +48,31 @@ protected EventStoreClientBase(EventStoreClientSettings? settings,
ConnectionName = Settings.ConnectionName ?? $"ES-{Guid.NewGuid()}";

var channelSelector = new ChannelSelector(Settings, _channelCache);
_channelInfoProvider = new SharingProvider<DnsEndPoint?, ChannelInfo>(
factory: (endPoint, onBroken) => GetChannelInfoExpensive(endPoint, onBroken, channelSelector, _cts.Token),
initialInput: null);
_channelInfoProvider = new SharingProvider<ReconnectionRequired, ChannelInfo>(
factory: (endPoint, onBroken) =>
GetChannelInfoExpensive(endPoint, onBroken, channelSelector, _cts.Token),
initialInput: ReconnectionRequired.Rediscover.Instance);
}

// Select a channel and query its capabilities. This is an expensive call that
// we don't want to do often.
private async Task<ChannelInfo> GetChannelInfoExpensive(
DnsEndPoint? endPoint,
Action<DnsEndPoint?> onBroken,
ReconnectionRequired reconnectionRequired,
Action<ReconnectionRequired> onReconnectionRequired,
IChannelSelector channelSelector,
CancellationToken cancellationToken) {

var channel = endPoint is null
? await channelSelector.SelectChannelAsync(cancellationToken).ConfigureAwait(false)
: channelSelector.SelectChannel(endPoint);
var channel = reconnectionRequired switch {
ReconnectionRequired.Rediscover => await channelSelector.SelectChannelAsync(cancellationToken)
.ConfigureAwait(false),
ReconnectionRequired.NewLeader (var endPoint) => channelSelector.SelectChannel(endPoint),
_ => throw new ArgumentException(null, nameof(reconnectionRequired))
};

var invoker = channel.CreateCallInvoker()
.Intercept(new TypedExceptionInterceptor(_exceptionMap))
.Intercept(new ConnectionNameInterceptor(ConnectionName))
.Intercept(new ReportLeaderInterceptor(onBroken));
.Intercept(new ReportLeaderInterceptor(onReconnectionRequired));

if (Settings.Interceptors is not null) {
foreach (var interceptor in Settings.Interceptors) {
Expand All @@ -92,6 +96,7 @@ protected async ValueTask<ChannelInfo> GetChannelInfo(CancellationToken cancella
// in cases where the server doesn't yet let the client know that it needs to.
// see EventStoreClientExtensions.WarmUpWith.
// note if rediscovery is already in progress it will continue, not restart.
// ReSharper disable once UnusedMember.Local
private void Rediscover() {
_channelInfoProvider.Reset();
}
Expand Down
39 changes: 21 additions & 18 deletions src/EventStore.Client/Interceptors/ReportLeaderInterceptor.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
using Grpc.Core;
Expand All @@ -10,21 +9,21 @@ namespace EventStore.Client.Interceptors {
// this has become more general than just detecting leader changes.
// triggers the action on any rpc exception with StatusCode.Unavailable
internal class ReportLeaderInterceptor : Interceptor {
private readonly Action<DnsEndPoint?> _onError;
private readonly Action<ReconnectionRequired> _onReconnectionRequired;

private const TaskContinuationOptions ContinuationOptions =
TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.OnlyOnFaulted;

internal ReportLeaderInterceptor(Action<DnsEndPoint?> onError) {
_onError = onError;
internal ReportLeaderInterceptor(Action<ReconnectionRequired> onReconnectionRequired) {
_onReconnectionRequired = onReconnectionRequired;
}

public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>(TRequest request,
ClientInterceptorContext<TRequest, TResponse> context,
AsyncUnaryCallContinuation<TRequest, TResponse> continuation) {
var response = continuation(request, context);

response.ResponseAsync.ContinueWith(ReportNewLeader, ContinuationOptions);
response.ResponseAsync.ContinueWith(OnReconnectionRequired, ContinuationOptions);

return new AsyncUnaryCall<TResponse>(response.ResponseAsync, response.ResponseHeadersAsync,
response.GetStatus, response.GetTrailers, response.Dispose);
Expand All @@ -35,7 +34,7 @@ public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreami
AsyncClientStreamingCallContinuation<TRequest, TResponse> continuation) {
var response = continuation(context);

response.ResponseAsync.ContinueWith(ReportNewLeader, ContinuationOptions);
response.ResponseAsync.ContinueWith(OnReconnectionRequired, ContinuationOptions);

return new AsyncClientStreamingCall<TRequest, TResponse>(response.RequestStream, response.ResponseAsync,
response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose);
Expand All @@ -47,7 +46,8 @@ public override AsyncDuplexStreamingCall<TRequest, TResponse> AsyncDuplexStreami
var response = continuation(context);

return new AsyncDuplexStreamingCall<TRequest, TResponse>(response.RequestStream,
new StreamReader<TResponse>(response.ResponseStream, ReportNewLeader), response.ResponseHeadersAsync,
new StreamReader<TResponse>(response.ResponseStream, OnReconnectionRequired),
response.ResponseHeadersAsync,
response.GetStatus, response.GetTrailers, response.Dispose);
}

Expand All @@ -57,20 +57,23 @@ public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRe
var response = continuation(request, context);

return new AsyncServerStreamingCall<TResponse>(
new StreamReader<TResponse>(response.ResponseStream, ReportNewLeader), response.ResponseHeadersAsync,
new StreamReader<TResponse>(response.ResponseStream, OnReconnectionRequired),
response.ResponseHeadersAsync,
response.GetStatus, response.GetTrailers, response.Dispose);
}

private void ReportNewLeader<TResponse>(Task<TResponse> task) {
if (task.Exception?.InnerException is NotLeaderException ex) {
_onError(ex.LeaderEndpoint);
} else if (task.Exception?.InnerException is RpcException {
StatusCode: StatusCode.Unavailable or
// StatusCode.Unknown or TODO: use RPC exceptions on server
StatusCode.Aborted
}) {
_onError(null);
}
private void OnReconnectionRequired<TResponse>(Task<TResponse> task) {
ReconnectionRequired reconnectionRequired = task.Exception?.InnerException switch {
NotLeaderException ex => new ReconnectionRequired.NewLeader(ex.LeaderEndpoint),
RpcException {
StatusCode: StatusCode.Unavailable
// or StatusCode.Unknown or TODO: use RPC exceptions on server
} => ReconnectionRequired.Rediscover.Instance,
_ => ReconnectionRequired.None.Instance
};

if (reconnectionRequired is not ReconnectionRequired.None)
_onReconnectionRequired(reconnectionRequired);
}

private class StreamReader<T> : IAsyncStreamReader<T> {
Expand Down
15 changes: 15 additions & 0 deletions src/EventStore.Client/ReconnectionRequired.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using System.Net;

namespace EventStore.Client {
internal abstract record ReconnectionRequired {
public record None : ReconnectionRequired {
public static None Instance = new();
}

public record Rediscover : ReconnectionRequired {
public static Rediscover Instance = new();
}

public record NewLeader(DnsEndPoint EndPoint) : ReconnectionRequired;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ public class ReportLeaderInterceptorTests {
private static readonly Marshaller<object> _marshaller = new(_ => Array.Empty<byte>(), _ => new object());

private static readonly StatusCode[] ForcesRediscoveryStatusCodes = {
StatusCode.Aborted,
//StatusCode.Unknown, TODO: use RPC exceptions on server
StatusCode.Unavailable
};
Expand All @@ -32,12 +31,12 @@ private static IEnumerable<GrpcCall> GrpcCalls() {

[Theory, MemberData(nameof(ReportsNewLeaderCases))]
public async Task ReportsNewLeader(GrpcCall call) {
EndPoint actual = default;
var sut = new ReportLeaderInterceptor(ep => actual = ep);
ReconnectionRequired actual = default;
var sut = new ReportLeaderInterceptor(result => actual = result);

var result = await Assert.ThrowsAsync<NotLeaderException>(() =>
call(sut, Task.FromException<object>(new NotLeaderException("a.host", 2112))));
Assert.Equal(result.LeaderEndpoint, actual);
Assert.Equal(new ReconnectionRequired.NewLeader(result.LeaderEndpoint), actual);
}

public static IEnumerable<object[]> ForcesRediscoveryCases() => from call in GrpcCalls()
Expand All @@ -46,18 +45,12 @@ from statusCode in ForcesRediscoveryStatusCodes

[Theory, MemberData(nameof(ForcesRediscoveryCases))]
public async Task ForcesRediscovery(GrpcCall call, StatusCode statusCode) {
EndPoint actual = default;
bool invoked = false;
ReconnectionRequired actual = default;
var sut = new ReportLeaderInterceptor(result => actual = result);

var sut = new ReportLeaderInterceptor(ep => {
invoked = true;
actual = ep;
});

var result = await Assert.ThrowsAsync<RpcException>(() => call(sut,
await Assert.ThrowsAsync<RpcException>(() => call(sut,
Task.FromException<object>(new RpcException(new Status(statusCode, "oops")))));
Assert.Null(actual);
Assert.True(invoked);
Assert.Equal(ReconnectionRequired.Rediscover.Instance, actual);
}

public static IEnumerable<object[]> DoesNotForceRediscoveryCases() => from call in GrpcCalls()
Expand All @@ -68,12 +61,12 @@ from statusCode in Enum.GetValues(typeof(StatusCode))

[Theory, MemberData(nameof(DoesNotForceRediscoveryCases))]
public async Task DoesNotForceRediscovery(GrpcCall call, StatusCode statusCode) {
bool invoked = false;
var sut = new ReportLeaderInterceptor(ep => invoked = true);
ReconnectionRequired actual = ReconnectionRequired.None.Instance;
var sut = new ReportLeaderInterceptor(result => actual = result);

var result = await Assert.ThrowsAsync<RpcException>(() => call(sut,
await Assert.ThrowsAsync<RpcException>(() => call(sut,
Task.FromException<object>(new RpcException(new Status(statusCode, "oops")))));
Assert.False(invoked);
Assert.Equal(ReconnectionRequired.None.Instance, actual);
}


Expand Down

0 comments on commit efe7bf7

Please sign in to comment.