diff --git a/docs/docfx/articles/destination-resolvers.md b/docs/docfx/articles/destination-resolvers.md new file mode 100644 index 000000000..e3ebd89ca --- /dev/null +++ b/docs/docfx/articles/destination-resolvers.md @@ -0,0 +1,58 @@ +# Extensibility: Destination Resolvers + +## Introduction + +YARP uses a destination resolver to expand the set of configured destination addresses. The destination resolver can be used as an integration point with service discovery systems. + +## Structure +[IDestinationResolver](xref:Yarp.ReverseProxy.ServiceDiscovery.IDestinationResolver) has a single method `ResolveDestinationsAsync(IReadOnlyDictionary destinations, CancellationToken cancellationToken)` which should return a [ResolvedDestinationCollection](xref:Yarp.ReverseProxy.ServiceDiscovery.ResolvedDestinationCollection) instance. The [ResolvedDestinationCollection](xref:Yarp.ReverseProxy.ServiceDiscovery.ResolvedDestinationCollection) has a collection of [DestinationConfig](xref:Yarp.ReverseProxy.Configuration.DestinationConfig) instances, as well as an `IChangeToken` to notify the proxy when this information is out of date and should be reloaded, which will cause `ResolveDestinationsAsync` to be called again. + +### DestinationConfig +`DestinationConfig` has a `Host` property which can be used to specify the default `Host` header value which the proxy should use when communicating with that destination. This allows the `IDestinationResolver` to resolve destinations to a collection of IP addresses, for example, without causing SNI or host-based routing to fail. + +## Lifecycle + +### Startup +The `IDestinationResolver` should be registered in the DI container as a singleton. At startup, the proxy will resolve this instance and call `ResolveDestinationsAsync(...)` with the configured destinations retrieved from the resolved `IProxyConfigProviders`. On this first call the provider may choose to: +- Throw an exception if the provider cannot produce a valid proxy configuration for any reason. This will prevent the application from starting. +- Asynchronously resolve the destinations. This will stop the application from starting until resolved destinations are available. +- Or, it may choose to return an empty `ResolvedDestinationCollection` instance while it resolves destinations in the background. The provider will need to trigger the `IChangeToken` when the configuration is available. + +### Atomicity +The destinations objects and collections supplied to the proxy should be read-only and not modified once they have been handed to the proxy via `GetConfig()`. + +### Reload +If the `IChangeToken` supports `ActiveChangeCallbacks`, once the proxy has processed the initial set of destinations it will register a callback with this token. If the provider does not support callbacks then `HasChanged` will be polled alongside `IProxyConfig` change tokens, every 5 minutes. + +When the provider wants to provide a new set of destinations to the proxy it should: +- Resolve those destinations in the background. + - `ResolvedDestinationCollection` is immutable, so new instances have to be created for any new data. + - Objects for unchanged destinations can be re-used, or new instances can be created. +- Invalidate the `IChangeToken` returned from the previous `ResolveDestinationsAsync` invocation. + +Once the new destinations have been applied, the proxy will register a callback with the new `IChangeToken`. Note if there are multiple reloads signaled in close succession, the proxy may skip some and resolve destinations as soon as it's ready. + +## DNS Destination Resolver + +YARP includes an [IDestinationResolver](xref:Yarp.ReverseProxy.ServiceDiscovery.IDestinationResolver) implementation which expands the set of configured destinations by resolving each host name to one or more IP addresses using DNS, creating a destination for each resolved IP. +The DNS destination resolver can be added to your reverse proxy using the `IReverseProxyBuilder.AddDnsDestinationResolver(Action)` method. +The method accepts an optional delegate to configure the resolver's options, [DnsDestinationResolverOptions](xref:Yarp.ReverseProxy.ServiceDiscovery.DnsDestinationResolverOptions). + +### Example + +```csharp +// Add the DNS destination resolver, restricting results to IPv4 addresses +reverseProxyBuilder.AddDnsDestinationResolver(o => o.AddressFamily = AddressFamily.InterNetwork); +``` + +### Configuration + +The DNS destination resolver's options, [DnsDestinationResolverOptions](xref:Yarp.ReverseProxy.ServiceDiscovery.DnsDestinationResolverOptions), has the following properties: + +#### RefreshPeriod + +The period between requesting a refresh of a resolved name. This defaults to 5 minutes. + +#### AddressFamily + +Optionally, specify an `System.Net.Sockets.AddressFamily` value of `AddressFamily.InterNetwork` or `AddressFamily.InterNetworkV6` to restrict resolved resolution to IPv4 or IPv6 addresses, respectively. The default value, `null`, instructs the resolver to not restrict the address family of the results and to use accept all returned addresses. diff --git a/docs/docfx/articles/toc.yml b/docs/docfx/articles/toc.yml index 38a1eb346..e340d8442 100644 --- a/docs/docfx/articles/toc.yml +++ b/docs/docfx/articles/toc.yml @@ -10,6 +10,8 @@ href: config-filters.md - name: Direct Forwarding href: direct-forwarding.md +- name: Destination Resolvers + href: destination-resolvers.md - name: HTTP client configuration href: http-client-config.md - name: HTTPS & TLS diff --git a/src/ReverseProxy/Configuration/ConfigProvider/ConfigurationConfigProvider.cs b/src/ReverseProxy/Configuration/ConfigProvider/ConfigurationConfigProvider.cs index def838ded..7768b4113 100644 --- a/src/ReverseProxy/Configuration/ConfigProvider/ConfigurationConfigProvider.cs +++ b/src/ReverseProxy/Configuration/ConfigProvider/ConfigurationConfigProvider.cs @@ -377,6 +377,7 @@ private static DestinationConfig CreateDestination(IConfigurationSection section Address = section[nameof(DestinationConfig.Address)]!, Health = section[nameof(DestinationConfig.Health)], Metadata = section.GetSection(nameof(DestinationConfig.Metadata)).ReadStringDictionary(), + Host = section[nameof(DestinationConfig.Host)] }; } diff --git a/src/ReverseProxy/Configuration/DestinationConfig.cs b/src/ReverseProxy/Configuration/DestinationConfig.cs index 2e3c56205..295170ef6 100644 --- a/src/ReverseProxy/Configuration/DestinationConfig.cs +++ b/src/ReverseProxy/Configuration/DestinationConfig.cs @@ -28,6 +28,12 @@ public sealed record DestinationConfig /// public IReadOnlyDictionary? Metadata { get; init; } + /// + /// Host header value to pass to this destination. + /// Used as a fallback if a host is not already specified by request transforms. + /// + public string? Host { get; init; } + public bool Equals(DestinationConfig? other) { if (other is null) @@ -37,6 +43,7 @@ public bool Equals(DestinationConfig? other) return string.Equals(Address, other.Address, StringComparison.OrdinalIgnoreCase) && string.Equals(Health, other.Health, StringComparison.OrdinalIgnoreCase) + && string.Equals(Host, other.Host, StringComparison.OrdinalIgnoreCase) && CaseSensitiveEqualHelper.Equals(Metadata, other.Metadata); } @@ -45,6 +52,7 @@ public override int GetHashCode() return HashCode.Combine( Address?.GetHashCode(StringComparison.OrdinalIgnoreCase), Health?.GetHashCode(StringComparison.OrdinalIgnoreCase), + Host?.GetHashCode(StringComparison.OrdinalIgnoreCase), CaseSensitiveEqualHelper.GetHashCode(Metadata)); } } diff --git a/src/ReverseProxy/Health/DefaultProbingRequestFactory.cs b/src/ReverseProxy/Health/DefaultProbingRequestFactory.cs index 97d3c43a6..52b0fbf9e 100644 --- a/src/ReverseProxy/Health/DefaultProbingRequestFactory.cs +++ b/src/ReverseProxy/Health/DefaultProbingRequestFactory.cs @@ -28,6 +28,11 @@ public HttpRequestMessage CreateRequest(ClusterModel cluster, DestinationModel d VersionPolicy = cluster.Config.HttpRequest?.VersionPolicy ?? HttpVersionPolicy.RequestVersionOrLower, }; + if (!string.IsNullOrEmpty(destination.Config.Host)) + { + request.Headers.Add(HeaderNames.Host, destination.Config.Host); + } + request.Headers.Add(HeaderNames.UserAgent, _defaultUserAgent); return request; diff --git a/src/ReverseProxy/Management/IReverseProxyBuilderExtensions.cs b/src/ReverseProxy/Management/IReverseProxyBuilderExtensions.cs index 5583c95ff..d91dc0062 100644 --- a/src/ReverseProxy/Management/IReverseProxyBuilderExtensions.cs +++ b/src/ReverseProxy/Management/IReverseProxyBuilderExtensions.cs @@ -12,6 +12,7 @@ using Yarp.ReverseProxy.LoadBalancing; using Yarp.ReverseProxy.Model; using Yarp.ReverseProxy.Routing; +using Yarp.ReverseProxy.ServiceDiscovery; using Yarp.ReverseProxy.SessionAffinity; using Yarp.ReverseProxy.Transforms; using Yarp.ReverseProxy.Utilities; @@ -125,4 +126,10 @@ public static IReverseProxyBuilder AddHttpSysDelegation(this IReverseProxyBuilde return builder; } + + public static IReverseProxyBuilder AddDestinationResolver(this IReverseProxyBuilder builder) + { + builder.Services.TryAddSingleton(); + return builder; + } } diff --git a/src/ReverseProxy/Management/ProxyConfigManager.cs b/src/ReverseProxy/Management/ProxyConfigManager.cs index 40a29b3e0..7c4209316 100644 --- a/src/ReverseProxy/Management/ProxyConfigManager.cs +++ b/src/ReverseProxy/Management/ProxyConfigManager.cs @@ -20,6 +20,7 @@ using Yarp.ReverseProxy.Health; using Yarp.ReverseProxy.Model; using Yarp.ReverseProxy.Routing; +using Yarp.ReverseProxy.ServiceDiscovery; using Yarp.ReverseProxy.Transforms.Builder; namespace Yarp.ReverseProxy.Management; @@ -49,6 +50,7 @@ internal sealed class ProxyConfigManager : EndpointDataSource, IProxyStateLookup private readonly List> _conventions; private readonly IActiveHealthCheckMonitor _activeHealthCheckMonitor; private readonly IClusterDestinationsUpdater _clusterDestinationsUpdater; + private readonly IDestinationResolver _destinationResolver; private readonly IConfigChangeListener[] _configChangeListeners; private List? _endpoints; private CancellationTokenSource _endpointsChangeSource = new(); @@ -67,7 +69,8 @@ public ProxyConfigManager( IForwarderHttpClientFactory httpClientFactory, IActiveHealthCheckMonitor activeHealthCheckMonitor, IClusterDestinationsUpdater clusterDestinationsUpdater, - IEnumerable configChangeListeners) + IEnumerable configChangeListeners, + IDestinationResolver destinationResolver) { _logger = logger ?? throw new ArgumentNullException(nameof(logger)); _providers = providers?.ToArray() ?? throw new ArgumentNullException(nameof(providers)); @@ -80,7 +83,7 @@ public ProxyConfigManager( _httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory)); _activeHealthCheckMonitor = activeHealthCheckMonitor ?? throw new ArgumentNullException(nameof(activeHealthCheckMonitor)); _clusterDestinationsUpdater = clusterDestinationsUpdater ?? throw new ArgumentNullException(nameof(clusterDestinationsUpdater)); - + _destinationResolver = destinationResolver ?? throw new ArgumentNullException(nameof(destinationResolver)); _configChangeListeners = configChangeListeners?.ToArray() ?? Array.Empty(); if (_providers.Length == 0) @@ -158,11 +161,19 @@ internal async Task InitialLoadAsync() var routes = new List(); var clusters = new List(); + // Begin resolving config providers concurrently. + var resolvedConfigs = new List<(int Index, IProxyConfigProvider Provider, ValueTask Config)>(_providers.Length); for (var i = 0; i < _providers.Length; i++) { var provider = _providers[i]; - var config = provider.GetConfig(); - ValidateConfigProperties(config); + var configLoadTask = LoadConfigAsync(provider, cancellationToken: default); + resolvedConfigs.Add((i, provider, configLoadTask)); + } + + // Wait for all configs to be resolved. + foreach (var (i, provider, configLoadTask) in resolvedConfigs) + { + var config = await configLoadTask; _configs[i] = new ConfigState(provider, config); routes.AddRange(config.Routes ?? Array.Empty()); clusters.AddRange(config.Clusters ?? Array.Empty()); @@ -202,33 +213,52 @@ private async Task ReloadConfigAsync() var sourcesChanged = false; var routes = new List(); var clusters = new List(); + var reloadedConfigs = new List<(ConfigState Config, ValueTask ResolveTask)>(); + + // Start reloading changed configurations. foreach (var instance in _configs) { - try + if (instance.LatestConfig.ChangeToken.HasChanged) { - if (instance.LatestConfig.ChangeToken.HasChanged) + try + { + var reloadTask = LoadConfigAsync(instance.Provider, cancellationToken: default); + reloadedConfigs.Add((instance, reloadTask)); + } + catch (Exception ex) { - var config = instance.Provider.GetConfig(); - ValidateConfigProperties(config); - instance.LatestConfig = config; - instance.LoadFailed = false; - sourcesChanged = true; + OnConfigLoadError(instance, ex); } } + } + + // Wait for all changed config providers to be reloaded. + foreach (var (instance, loadTask) in reloadedConfigs) + { + try + { + instance.LatestConfig = await loadTask.ConfigureAwait(false); + instance.LoadFailed = false; + sourcesChanged = true; + } catch (Exception ex) { - instance.LoadFailed = true; - Log.ErrorReloadingConfig(_logger, ex); + OnConfigLoadError(instance, ex); + } + } - foreach (var configChangeListener in _configChangeListeners) - { - configChangeListener.ConfigurationLoadingFailed(instance.Provider, ex); - } + // Extract the routes and clusters from the configs, regardless of whether they were reloaded. + foreach (var instance in _configs) + { + if (instance.LatestConfig.Routes is { Count: > 0 } updatedRoutes) + { + routes.AddRange(updatedRoutes); } - // If we didn't/couldn't get a new config then re-use the last one. - routes.AddRange(instance.LatestConfig.Routes ?? Array.Empty()); - clusters.AddRange(instance.LatestConfig.Clusters ?? Array.Empty()); + if (instance.LatestConfig.Clusters is { Count: > 0 } updatedClusters) + { + clusters.AddRange(updatedClusters); + } } var proxyConfigs = ExtractListOfProxyConfigs(_configs); @@ -270,6 +300,17 @@ private async Task ReloadConfigAsync() } ListenForConfigChanges(); + + void OnConfigLoadError(ConfigState instance, Exception ex) + { + instance.LoadFailed = true; + Log.ErrorReloadingConfig(_logger, ex); + + foreach (var configChangeListener in _configChangeListeners) + { + configChangeListener.ConfigurationLoadingFailed(instance.Provider, ex); + } + } } private static void ValidateConfigProperties(IProxyConfig config) @@ -278,12 +319,90 @@ private static void ValidateConfigProperties(IProxyConfig config) { throw new InvalidOperationException($"{nameof(IProxyConfigProvider.GetConfig)} returned a null value."); } + if (config.ChangeToken is null) { throw new InvalidOperationException($"{nameof(IProxyConfig.ChangeToken)} has a null value."); } } + private ValueTask LoadConfigAsync(IProxyConfigProvider provider, CancellationToken cancellationToken) + { + var config = provider.GetConfig(); + ValidateConfigProperties(config); + + if (_destinationResolver.GetType() == typeof(NoOpDestinationResolver)) + { + return new(config); + } + + return LoadConfigAsyncCore(config, cancellationToken); + } + + private async ValueTask LoadConfigAsyncCore(IProxyConfig config, CancellationToken cancellationToken) + { + List<(int Index, ValueTask Task)> resolverTasks = new(); + List clusters = new(config.Clusters); + List? changeTokens = null; + for (var i = 0; i < clusters.Count; i++) + { + var cluster = clusters[i]; + if (cluster.Destinations is { Count: > 0 } destinations) + { + // Resolve destinations if there are any. + var task = _destinationResolver.ResolveDestinationsAsync(destinations, cancellationToken); + resolverTasks.Add((i, task)); + } + } + + if (resolverTasks.Count > 0) + { + foreach (var (i, task) in resolverTasks) + { + var resolvedDestinations = await task; + clusters[i] = clusters[i] with { Destinations = resolvedDestinations.Destinations }; + if (resolvedDestinations.ChangeToken is { } token) + { + changeTokens ??= new(); + changeTokens.Add(token); + } + } + + IChangeToken changeToken; + if (changeTokens is not null) + { + // Combine change tokens from the resolver with the configuration's existing change token. + changeTokens.Add(config.ChangeToken); + changeToken = new CompositeChangeToken(changeTokens); + } + else + { + changeToken = config.ChangeToken; + } + + // Return updated config + return new ResolvedProxyConfig(config, clusters, changeToken); + } + + return config; + } + + private sealed class ResolvedProxyConfig : IProxyConfig + { + private readonly IProxyConfig _innerConfig; + + public ResolvedProxyConfig(IProxyConfig innerConfig, IReadOnlyList clusters, IChangeToken changeToken) + { + _innerConfig = innerConfig; + Clusters = clusters; + ChangeToken = changeToken; + } + + public IReadOnlyList Routes => _innerConfig.Routes; + public IReadOnlyList Clusters { get; } + public IChangeToken ChangeToken { get; } + } + private void ListenForConfigChanges() { // Use a central change token to avoid overlap between different sources. diff --git a/src/ReverseProxy/Management/ReverseProxyServiceCollectionExtensions.cs b/src/ReverseProxy/Management/ReverseProxyServiceCollectionExtensions.cs index d08fe6258..b2a6bcf6c 100644 --- a/src/ReverseProxy/Management/ReverseProxyServiceCollectionExtensions.cs +++ b/src/ReverseProxy/Management/ReverseProxyServiceCollectionExtensions.cs @@ -13,8 +13,8 @@ using Yarp.ReverseProxy.Forwarder; using Yarp.ReverseProxy.Management; using Yarp.ReverseProxy.Routing; +using Yarp.ReverseProxy.ServiceDiscovery; using Yarp.ReverseProxy.Transforms.Builder; -using Yarp.ReverseProxy.Utilities; namespace Microsoft.Extensions.DependencyInjection; @@ -53,6 +53,7 @@ public static IReverseProxyBuilder AddReverseProxy(this IServiceCollection servi .AddPassiveHealthCheck() .AddLoadBalancingPolicies() .AddHttpSysDelegation() + .AddDestinationResolver() .AddProxy(); services.TryAddSingleton(); @@ -166,4 +167,18 @@ public static IReverseProxyBuilder ConfigureHttpClient(this IReverseProxyBuilder }); return builder; } + + /// + /// Provides a implementation which uses to resolve destinations. + /// + public static IReverseProxyBuilder AddDnsDestinationResolver(this IReverseProxyBuilder builder, Action? configureOptions = null) + { + builder.Services.AddSingleton(); + if (configureOptions is not null) + { + builder.Services.Configure(configureOptions); + } + + return builder; + } } diff --git a/src/ReverseProxy/ServiceDiscovery/DnsDestinationResolver.cs b/src/ReverseProxy/ServiceDiscovery/DnsDestinationResolver.cs new file mode 100644 index 000000000..3ed8a7fbb --- /dev/null +++ b/src/ReverseProxy/ServiceDiscovery/DnsDestinationResolver.cs @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Net; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Options; +using Microsoft.Extensions.Primitives; +using Yarp.ReverseProxy.Configuration; + +namespace Yarp.ReverseProxy.ServiceDiscovery; + +/// +/// Implementation of which resolves host names to IP addresses using DNS. +/// +internal class DnsDestinationResolver : IDestinationResolver +{ + private readonly IOptionsMonitor _options; + + /// + /// Initializes a new instance. + /// + /// The options. + public DnsDestinationResolver(IOptionsMonitor options) + { + _options = options; + } + + /// + public async ValueTask ResolveDestinationsAsync(IReadOnlyDictionary destinations, CancellationToken cancellationToken) + { + var options = _options.CurrentValue; + Dictionary results = new(); + var tasks = new List>>(destinations.Count); + foreach (var (destinationId, destinationConfig) in destinations) + { + tasks.Add(ResolveHostAsync(options, destinationId, destinationConfig, cancellationToken)); + } + + await Task.WhenAll(tasks); + foreach (var task in tasks) + { + var configs = await task; + foreach (var (name, config) in configs) + { + results[name] = config; + } + } + + var changeToken = options.RefreshPeriod switch + { + { } refreshPeriod when refreshPeriod > TimeSpan.Zero => new CancellationChangeToken(new CancellationTokenSource(refreshPeriod).Token), + _ => null, + }; + + return new ResolvedDestinationCollection(results, changeToken); + } + + private static async Task> ResolveHostAsync( + DnsDestinationResolverOptions options, + string originalName, + DestinationConfig originalConfig, + CancellationToken cancellationToken) + { + var originalUri = new Uri(originalConfig.Address); + var originalHost = originalConfig.Host is { Length: > 0 } host ? host : originalUri.Authority; + var addresses = options.AddressFamily switch + { + { } addressFamily => await Dns.GetHostAddressesAsync(originalUri.DnsSafeHost, addressFamily, cancellationToken).ConfigureAwait(false), + null => await Dns.GetHostAddressesAsync(originalUri.DnsSafeHost, cancellationToken).ConfigureAwait(false) + }; + var results = new List<(string Name, DestinationConfig Config)>(addresses.Length); + var uriBuilder = new UriBuilder(originalUri); + var healthUri = originalConfig.Health is { Length: > 0 } health ? new Uri(health) : null; + var healthUriBuilder = healthUri is { } ? new UriBuilder(healthUri) : null; + foreach (var address in addresses) + { + var addressString = address.ToString(); + uriBuilder.Host = addressString; + var resolvedAddress = uriBuilder.Uri.ToString(); + var healthAddress = originalConfig.Health; + if (healthUriBuilder is not null) + { + healthUriBuilder.Host = addressString; + healthAddress = healthUriBuilder.Uri.ToString(); + } + + var name = $"{originalName}[{addressString}]"; + var config = originalConfig with { Host = originalHost, Address = resolvedAddress, Health = healthAddress }; + results.Add((name, config)); + } + + return results; + } +} diff --git a/src/ReverseProxy/ServiceDiscovery/DnsDestinationResolverOptions.cs b/src/ReverseProxy/ServiceDiscovery/DnsDestinationResolverOptions.cs new file mode 100644 index 000000000..48540e970 --- /dev/null +++ b/src/ReverseProxy/ServiceDiscovery/DnsDestinationResolverOptions.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Net.Sockets; + +namespace Yarp.ReverseProxy.ServiceDiscovery; + +/// +/// Options for . +/// +public class DnsDestinationResolverOptions +{ + /// + /// The period between requesting a refresh of a resolved name. + /// + /// + /// Defaults to 5 minutes. + /// + public TimeSpan? RefreshPeriod { get; set; } = TimeSpan.FromMinutes(5); + + /// + /// The optional address family to query for. + /// Use for IPv4 addresses and for IPv6 addresses. + /// + /// + /// Defaults to (any address). + /// + public AddressFamily? AddressFamily { get; set; } +} diff --git a/src/ReverseProxy/ServiceDiscovery/IDestinationResolver.cs b/src/ReverseProxy/ServiceDiscovery/IDestinationResolver.cs new file mode 100644 index 000000000..81d0f1daf --- /dev/null +++ b/src/ReverseProxy/ServiceDiscovery/IDestinationResolver.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Yarp.ReverseProxy.Configuration; + +namespace Yarp.ReverseProxy.ServiceDiscovery; + +/// +/// Resolves destination addresses. +/// +public interface IDestinationResolver +{ + /// + /// Resolves the provided destinations and returns resolved destinations. + /// + /// The destinations to resolve. + /// The cancellation token. + /// + /// The resolved destinations and a change token used to indicate when resolution should be performed again. + /// + ValueTask ResolveDestinationsAsync( + IReadOnlyDictionary destinations, + CancellationToken cancellationToken); +} diff --git a/src/ReverseProxy/ServiceDiscovery/NoOpDestinationResolver.cs b/src/ReverseProxy/ServiceDiscovery/NoOpDestinationResolver.cs new file mode 100644 index 000000000..956588b7f --- /dev/null +++ b/src/ReverseProxy/ServiceDiscovery/NoOpDestinationResolver.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Yarp.ReverseProxy.Configuration; + +namespace Yarp.ReverseProxy.ServiceDiscovery; + +/// +/// An which performs no action. +/// +internal sealed class NoOpDestinationResolver : IDestinationResolver +{ + public ValueTask ResolveDestinationsAsync(IReadOnlyDictionary destinations, CancellationToken cancellationToken) + => new(new ResolvedDestinationCollection(destinations, changeToken: null)); +} diff --git a/src/ReverseProxy/ServiceDiscovery/ResolvedDestinationCollection.cs b/src/ReverseProxy/ServiceDiscovery/ResolvedDestinationCollection.cs new file mode 100644 index 000000000..13cd5d8c3 --- /dev/null +++ b/src/ReverseProxy/ServiceDiscovery/ResolvedDestinationCollection.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Collections.Generic; +using Microsoft.Extensions.Primitives; +using Yarp.ReverseProxy.Configuration; + +namespace Yarp.ReverseProxy.ServiceDiscovery +{ + /// + /// Represents a collection of resolved destinations. + /// + public sealed class ResolvedDestinationCollection + { + public ResolvedDestinationCollection(IReadOnlyDictionary destinations, IChangeToken? changeToken) + { + Destinations = destinations; + ChangeToken = changeToken; + } + + /// + /// Gets the map of destination names to destination configurations. + /// + public IReadOnlyDictionary Destinations { get; init; } + + /// + /// Gets the optional change token used to signal when this collection should be refreshed. + /// + public IChangeToken? ChangeToken { get; init; } + } +} diff --git a/src/ReverseProxy/Transforms/Builder/TransformBuilder.cs b/src/ReverseProxy/Transforms/Builder/TransformBuilder.cs index f3e0b2c40..f9cc4012b 100644 --- a/src/ReverseProxy/Transforms/Builder/TransformBuilder.cs +++ b/src/ReverseProxy/Transforms/Builder/TransformBuilder.cs @@ -157,10 +157,8 @@ internal static StructuredTransformer CreateTransformer(TransformBuilderContext { // RequestHeaderOriginalHostKey defaults to false, and CopyRequestHeaders defaults to true. // If RequestHeaderOriginalHostKey was not specified then we need to make sure the transform gets - // added anyways to remove the original host. If CopyRequestHeaders is false then we can omit the - // transform. - if (context.CopyRequestHeaders.GetValueOrDefault(true) - && !context.RequestTransforms.Any(item => item is RequestHeaderOriginalHostTransform)) + // added anyways to remove the original host and to observe hosts specified in DestinationConfig. + if (!context.RequestTransforms.Any(item => item is RequestHeaderOriginalHostTransform)) { context.AddOriginalHost(false); } diff --git a/src/ReverseProxy/Transforms/RequestHeaderOriginalHostTransform.cs b/src/ReverseProxy/Transforms/RequestHeaderOriginalHostTransform.cs index c62a5b807..5e71ad1f5 100644 --- a/src/ReverseProxy/Transforms/RequestHeaderOriginalHostTransform.cs +++ b/src/ReverseProxy/Transforms/RequestHeaderOriginalHostTransform.cs @@ -5,6 +5,7 @@ using System.Threading.Tasks; using Microsoft.Net.Http.Headers; using Yarp.ReverseProxy.Forwarder; +using Yarp.ReverseProxy.Model; namespace Yarp.ReverseProxy.Transforms; @@ -31,24 +32,25 @@ private RequestHeaderOriginalHostTransform(bool useOriginalHost) public override ValueTask ApplyAsync(RequestTransformContext context) { + var destinationConfigHost = context.HttpContext.Features.Get()?.ProxiedDestination?.Model.Config?.Host; + var originalHost = context.HttpContext.Request.Host.Value is { Length: > 0 } host ? host : null; + var existingHost = RequestUtilities.TryGetValues(context.ProxyRequest.Headers, HeaderNames.Host, out var currentHost) ? currentHost.ToString() : null; + if (UseOriginalHost) { - if (!context.HeadersCopied) + if (!context.HeadersCopied && existingHost is null) { - // Don't override a custom host - if (!context.ProxyRequest.Headers.NonValidated.Contains(HeaderNames.Host)) - { - context.ProxyRequest.Headers.TryAddWithoutValidation(HeaderNames.Host, context.HttpContext.Request.Host.Value); - } + // Propagate the host if the transform pipeline didn't already override it. + // If there was no original host specified, allow the destination config host to flow through. + context.ProxyRequest.Headers.TryAddWithoutValidation(HeaderNames.Host, originalHost ?? destinationConfigHost); } } - else if (context.HeadersCopied - // Don't remove a custom host, only the original - && RequestUtilities.TryGetValues(context.ProxyRequest.Headers, HeaderNames.Host, out var existingHost) - && string.Equals(context.HttpContext.Request.Host.Value, existingHost.ToString(), StringComparison.Ordinal)) + else if (existingHost is null || string.Equals(originalHost, existingHost, StringComparison.Ordinal)) { - // Remove it after the copy, use the destination host instead. - context.ProxyRequest.Headers.Host = null; + // Use the host from destination configuration (which may be null) if either: + // * there is no host header set, or + // * the original host header is being suppressed and has not been modified by the transform pipeline + context.ProxyRequest.Headers.Host = destinationConfigHost; } return default; diff --git a/src/ReverseProxy/Yarp.ReverseProxy.csproj b/src/ReverseProxy/Yarp.ReverseProxy.csproj index d71f3bf33..b6f447e10 100644 --- a/src/ReverseProxy/Yarp.ReverseProxy.csproj +++ b/src/ReverseProxy/Yarp.ReverseProxy.csproj @@ -23,7 +23,7 @@ - + diff --git a/test/ReverseProxy.Tests/Configuration/ConfigProvider/ConfigurationConfigProviderTests.cs b/test/ReverseProxy.Tests/Configuration/ConfigProvider/ConfigurationConfigProviderTests.cs index fbf345dfb..8dceb677d 100644 --- a/test/ReverseProxy.Tests/Configuration/ConfigProvider/ConfigurationConfigProviderTests.cs +++ b/test/ReverseProxy.Tests/Configuration/ConfigProvider/ConfigurationConfigProviderTests.cs @@ -40,7 +40,8 @@ public class ConfigurationConfigProviderTests { Address = "https://localhost:10000/destA", Health = "https://localhost:20000/destA", - Metadata = new Dictionary { { "destA-K1", "destA-V1" }, { "destA-K2", "destA-V2" } } + Metadata = new Dictionary { { "destA-K1", "destA-V1" }, { "destA-K2", "destA-V2" } }, + Host = "localhost" } }, { @@ -49,7 +50,8 @@ public class ConfigurationConfigProviderTests { Address = "https://localhost:10000/destB", Health = "https://localhost:20000/destB", - Metadata = new Dictionary { { "destB-K1", "destB-V1" }, { "destB-K2", "destB-V2" } } + Metadata = new Dictionary { { "destB-K1", "destB-V1" }, { "destB-K2", "destB-V2" } }, + Host = "localhost" } } }, @@ -113,8 +115,8 @@ public class ConfigurationConfigProviderTests ClusterId = "cluster2", Destinations = new Dictionary(StringComparer.OrdinalIgnoreCase) { - { "destinationC", new DestinationConfig { Address = "https://localhost:10001/destC" } }, - { "destinationD", new DestinationConfig { Address = "https://localhost:10000/destB" } } + { "destinationC", new DestinationConfig { Address = "https://localhost:10001/destC", Host = "localhost" } }, + { "destinationD", new DestinationConfig { Address = "https://localhost:10000/destB", Host = "remotehost" } } }, LoadBalancingPolicy = LoadBalancingPolicies.RoundRobin } @@ -262,6 +264,7 @@ public class ConfigurationConfigProviderTests ""destinationA"": { ""Address"": ""https://localhost:10000/destA"", ""Health"": ""https://localhost:20000/destA"", + ""Host"": ""localhost"", ""Metadata"": { ""destA-K1"": ""destA-V1"", ""destA-K2"": ""destA-V2"" @@ -270,6 +273,7 @@ public class ConfigurationConfigProviderTests ""destinationB"": { ""Address"": ""https://localhost:10000/destB"", ""Health"": ""https://localhost:20000/destB"", + ""Host"": ""localhost"", ""Metadata"": { ""destB-K1"": ""destB-V1"", ""destB-K2"": ""destB-V2"" @@ -292,10 +296,12 @@ public class ConfigurationConfigProviderTests ""Destinations"": { ""destinationC"": { ""Address"": ""https://localhost:10001/destC"", + ""Host"": ""localhost"", ""Metadata"": null }, ""destinationD"": { ""Address"": ""https://localhost:10000/destB"", + ""Host"": ""remotehost"", ""Metadata"": null } }, @@ -511,9 +517,11 @@ private void VerifyValidAbstractConfig(IProxyConfig validConfig, IProxyConfig ab Assert.Equal(cluster1.Destinations["destinationA"].Address, abstractCluster1.Destinations["destinationA"].Address); Assert.Equal(cluster1.Destinations["destinationA"].Health, abstractCluster1.Destinations["destinationA"].Health); Assert.Equal(cluster1.Destinations["destinationA"].Metadata, abstractCluster1.Destinations["destinationA"].Metadata); + Assert.Equal(cluster1.Destinations["destinationA"].Host, abstractCluster1.Destinations["destinationA"].Host); Assert.Equal(cluster1.Destinations["destinationB"].Address, abstractCluster1.Destinations["destinationB"].Address); Assert.Equal(cluster1.Destinations["destinationB"].Health, abstractCluster1.Destinations["destinationB"].Health); Assert.Equal(cluster1.Destinations["destinationB"].Metadata, abstractCluster1.Destinations["destinationB"].Metadata); + Assert.Equal(cluster1.Destinations["destinationB"].Host, abstractCluster1.Destinations["destinationB"].Host); Assert.Equal(cluster1.HealthCheck.AvailableDestinationsPolicy, abstractCluster1.HealthCheck.AvailableDestinationsPolicy); Assert.Equal(cluster1.HealthCheck.Passive.Enabled, abstractCluster1.HealthCheck.Passive.Enabled); Assert.Equal(cluster1.HealthCheck.Passive.Policy, abstractCluster1.HealthCheck.Passive.Policy); @@ -552,8 +560,10 @@ private void VerifyValidAbstractConfig(IProxyConfig validConfig, IProxyConfig ab var abstractCluster2 = abstractConfig.Clusters.Single(c => c.ClusterId == "cluster2"); Assert.Equal(cluster2.Destinations["destinationC"].Address, abstractCluster2.Destinations["destinationC"].Address); Assert.Equal(cluster2.Destinations["destinationC"].Metadata, abstractCluster2.Destinations["destinationC"].Metadata); + Assert.Equal(cluster2.Destinations["destinationC"].Host, abstractCluster2.Destinations["destinationC"].Host); Assert.Equal(cluster2.Destinations["destinationD"].Address, abstractCluster2.Destinations["destinationD"].Address); Assert.Equal(cluster2.Destinations["destinationD"].Metadata, abstractCluster2.Destinations["destinationD"].Metadata); + Assert.Equal(cluster2.Destinations["destinationD"].Host, abstractCluster2.Destinations["destinationD"].Host); Assert.Equal(LoadBalancingPolicies.RoundRobin, abstractCluster2.LoadBalancingPolicy); Assert.Equal(2, abstractConfig.Routes.Count); diff --git a/test/ReverseProxy.Tests/Management/ProxyConfigManagerTests.cs b/test/ReverseProxy.Tests/Management/ProxyConfigManagerTests.cs index 0f4a3bfd3..18bfadd81 100644 --- a/test/ReverseProxy.Tests/Management/ProxyConfigManagerTests.cs +++ b/test/ReverseProxy.Tests/Management/ProxyConfigManagerTests.cs @@ -4,9 +4,11 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Net.Sockets; using System.Security.Authentication; using System.Text; using System.Threading; +using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Hosting.Server; @@ -18,12 +20,12 @@ using Moq; using Xunit; using Yarp.ReverseProxy.Configuration; -using Yarp.ReverseProxy.Configuration.ConfigProvider; using Yarp.ReverseProxy.Forwarder; using Yarp.ReverseProxy.Forwarder.Tests; using Yarp.ReverseProxy.Health; using Yarp.ReverseProxy.Model; using Yarp.ReverseProxy.Routing; +using Yarp.ReverseProxy.ServiceDiscovery; using Yarp.Tests.Common; namespace Yarp.ReverseProxy.Management.Tests; @@ -34,7 +36,8 @@ private static IServiceProvider CreateServices( List routes, List clusters, Action configureProxy = null, - IEnumerable configListeners = null) + IEnumerable configListeners = null, + IDestinationResolver destinationResolver = null) { var serviceCollection = new ServiceCollection(); serviceCollection.AddLogging(); @@ -53,6 +56,12 @@ private static IServiceProvider CreateServices( serviceCollection.AddSingleton(configListener); } } + + if (destinationResolver is not null) + { + serviceCollection.AddSingleton(destinationResolver); + } + var services = serviceCollection.BuildServiceProvider(); var routeBuilder = services.GetRequiredService(); routeBuilder.SetProxyPipeline(context => Task.CompletedTask); @@ -1243,7 +1252,6 @@ public async Task LoadAsync_ConfigFilterClusterActionThrows_Throws() Assert.IsType(agex.InnerExceptions.First().InnerException); } - [Fact] public async Task LoadAsync_ConfigFilterRouteActionThrows_Throws() { @@ -1264,4 +1272,375 @@ public async Task LoadAsync_ConfigFilterRouteActionThrows_Throws() Assert.IsType(agex.InnerExceptions.First().InnerException); Assert.IsType(agex.InnerExceptions.Skip(1).First().InnerException); } + + private class FakeDestinationResolver : IDestinationResolver + { + private readonly Func, CancellationToken, ValueTask> _delegate; + + public FakeDestinationResolver( + Func, CancellationToken, ValueTask> @delegate) + { + _delegate = @delegate; + } + + public ValueTask ResolveDestinationsAsync(IReadOnlyDictionary destinations, CancellationToken cancellationToken) + => _delegate(destinations, cancellationToken); + } + + private class TestConfigChangeListener : IConfigChangeListener + { + private readonly bool _includeLoad; + private readonly bool _includeApply; + + public Channel Events { get; } = Channel.CreateUnbounded(); + + public TestConfigChangeListener(bool includeLoad = true, bool includeApply = true) + { + _includeLoad = includeLoad; + _includeApply = includeApply; + } + + public void ConfigurationApplied(IReadOnlyList proxyConfigs) + { + if (!_includeApply) + { + return; + } + + Assert.True(Events.Writer.TryWrite(new ConfigurationAppliedEvent(proxyConfigs))); + } + + public void ConfigurationApplyingFailed(IReadOnlyList proxyConfigs, Exception exception) + { + if (!_includeApply) + { + return; + } + + Assert.True(Events.Writer.TryWrite(new ConfigurationApplyingFailedEvent(proxyConfigs, exception))); + } + + public void ConfigurationLoaded(IReadOnlyList proxyConfigs) + { + if (!_includeLoad) + { + return; + } + + Assert.True(Events.Writer.TryWrite(new ConfigurationLoadedEvent(proxyConfigs))); + } + + public void ConfigurationLoadingFailed(IProxyConfigProvider configProvider, Exception exception) + { + if (!_includeLoad) + { + return; + } + + Assert.True(Events.Writer.TryWrite(new ConfigurationLoadingFailedEvent(configProvider, exception))); + } + + public record ConfigChangeListenerEvent { }; + public record ConfigurationAppliedEvent(IReadOnlyList ProxyConfigs) : ConfigChangeListenerEvent; + public record ConfigurationApplyingFailedEvent(IReadOnlyList ProxyConfigs, Exception exception) : ConfigChangeListenerEvent; + public record ConfigurationLoadedEvent(IReadOnlyList ProxyConfigs) : ConfigChangeListenerEvent; + public record ConfigurationLoadingFailedEvent(IProxyConfigProvider ConfigProvider, Exception Exception) : ConfigChangeListenerEvent; + } + + [Fact] + public async Task LoadAsync_DestinationResolver_Initial_ThrowsSync() + { + var throwResolver = new FakeDestinationResolver((destinations, cancellation) => throw new InvalidOperationException("Throwing!")); + + var cluster = new ClusterConfig() + { + ClusterId = "cluster1", + Destinations = new Dictionary(StringComparer.OrdinalIgnoreCase) + { + { "d1", new DestinationConfig() { Address = "http://localhost" } } + } + }; + var services = CreateServices( + new List(), + new List() { cluster }, + destinationResolver: throwResolver); + var configManager = services.GetRequiredService(); + + var ioEx = await Assert.ThrowsAsync(() => configManager.InitialLoadAsync()); + Assert.Equal("Unable to load or apply the proxy configuration.", ioEx.Message); + + var innerExc = Assert.IsType(ioEx.InnerException); + Assert.Equal("Throwing!", innerExc.Message); + } + + [Fact] + public async Task LoadAsync_DestinationResolver_Initial_ThrowsAsync() + { + var throwResolver = new FakeDestinationResolver((destinations, cancellation) => ValueTask.FromException(new InvalidOperationException("Throwing!"))); + + var cluster = new ClusterConfig() + { + ClusterId = "cluster1", + Destinations = new Dictionary(StringComparer.OrdinalIgnoreCase) + { + { "d1", new DestinationConfig() { Address = "http://localhost" } } + } + }; + var services = CreateServices(new List(), new List() { cluster }, destinationResolver: throwResolver); + var configManager = services.GetRequiredService(); + + var ioEx = await Assert.ThrowsAsync(() => configManager.InitialLoadAsync()); + Assert.Equal("Unable to load or apply the proxy configuration.", ioEx.Message); + + var innerExc = Assert.IsType(ioEx.InnerException); + Assert.Equal("Throwing!", innerExc.Message); + } + + [Fact] + public async Task LoadAsync_DestinationResolver_Successful() + { + var destinationsToExpand = new Dictionary(StringComparer.OrdinalIgnoreCase) + { + { "d1", new DestinationConfig() { Address = "http://localhost" } } + }; + + var syncExpandResolver = new FakeDestinationResolver((destinations, cancellation) => + { + var expandedDestinations = new Dictionary(); + + foreach (var destKvp in destinations) + { + expandedDestinations[$"{destKvp.Key}-1"] = new DestinationConfig { Address = "http://127.0.0.1:8080" }; + expandedDestinations[$"{destKvp.Key}-2"] = new DestinationConfig { Address = "http://127.1.1.1:8080" }; + } + + var result = new ResolvedDestinationCollection(expandedDestinations, null); + return new(result); + }); + + var cluster1 = new ClusterConfig() + { + ClusterId = "cluster1", + Destinations = destinationsToExpand + }; + + var services = CreateServices(new List(), new List() { cluster1 }, destinationResolver: syncExpandResolver); + var configManager = services.GetRequiredService(); + + await configManager.InitialLoadAsync(); + + Assert.True(configManager.TryGetCluster(cluster1.ClusterId, out var cluster)); + + var expectedDestinations = new Dictionary(StringComparer.OrdinalIgnoreCase) + { + { "d1-1", new DestinationConfig() { Address = "http://127.0.0.1:8080" } }, + { "d1-2", new DestinationConfig() { Address = "http://127.1.1.1:8080" } } + }; + + var actualDestinations = cluster.Destinations.ToDictionary(static k => k.Key, static v => v.Value.Model.Config); + Assert.Equal(expectedDestinations, actualDestinations); + } + + [Fact] + public async Task LoadAsync_DestinationResolver_Dns() + { + var destinationsToExpand = new Dictionary(StringComparer.OrdinalIgnoreCase) + { + { "d1", new DestinationConfig() { Address = "http://localhost/a/b/c", Health = "http://localhost/healthz" } }, + { "d2", new DestinationConfig() { Address = "http://localhost:8080/a/b/c", Health = "http://localhost:8080/healthz"} }, + { "d3", new DestinationConfig() { Address = "https://localhost/a/b/c", Health = "https://localhost/healthz" } }, + { "d4", new DestinationConfig() { Address = "https://localhost:8443/a/b/c", Health = "https://localhost:8443/healthz", Host = "overriddenhost" } } + }; + + var cluster1 = new ClusterConfig() + { + ClusterId = "cluster1", + Destinations = destinationsToExpand + }; + + var services = CreateServices( + new List(), + new List() { cluster1 }, + configureProxy: rp => rp.AddDnsDestinationResolver(o => o.AddressFamily = AddressFamily.InterNetwork)); + var configManager = services.GetRequiredService(); + + await configManager.InitialLoadAsync(); + + Assert.True(configManager.TryGetCluster(cluster1.ClusterId, out var cluster)); + + var expectedDestinations = new Dictionary(StringComparer.OrdinalIgnoreCase) + { + { "d1[127.0.0.1]", new DestinationConfig() { Address = "http://127.0.0.1/a/b/c", Health = "http://127.0.0.1/healthz", Host = "localhost" } }, + { "d2[127.0.0.1]", new DestinationConfig() { Address = "http://127.0.0.1:8080/a/b/c", Health = "http://127.0.0.1:8080/healthz", Host = "localhost:8080" } }, + { "d3[127.0.0.1]", new DestinationConfig() { Address = "https://127.0.0.1/a/b/c", Health = "https://127.0.0.1/healthz", Host = "localhost" } }, + { "d4[127.0.0.1]", new DestinationConfig() { Address = "https://127.0.0.1:8443/a/b/c", Health = "https://127.0.0.1:8443/healthz", Host = "overriddenhost" } } + }; + + var actualDestinations = cluster.Destinations.ToDictionary(static k => k.Key, static v => v.Value.Model.Config); + Assert.Equal(expectedDestinations, actualDestinations); + } + + [Fact] + public async Task LoadAsync_DestinationResolver_ReloadResolution() + { + var configListener = new TestConfigChangeListener(includeApply: false); + var destinationsToExpand = new Dictionary(StringComparer.OrdinalIgnoreCase) + { + { "d1", new DestinationConfig() { Address = "http://localhost" } } + }; + + var cts = new[] { new CancellationTokenSource() }; + var signaled = new[] { 0 }; + var syncExpandResolver = new FakeDestinationResolver((destinations, cancellation) => + { + signaled[0]++; + var expandedDestinations = new Dictionary(); + + foreach (var destKvp in destinations) + { + expandedDestinations[$"{destKvp.Key}-1"] = new DestinationConfig { Address = $"http://127.0.0.1:8080/{signaled[0]}" }; + expandedDestinations[$"{destKvp.Key}-2"] = new DestinationConfig { Address = $"http://127.1.1.1:8080/{signaled[0]}" }; + } + + var result = new ResolvedDestinationCollection(expandedDestinations, new CancellationChangeToken(cts[0].Token)); + return new(result); + }); + + var cluster1 = new ClusterConfig() + { + ClusterId = "cluster1", + Destinations = destinationsToExpand + }; + + var services = CreateServices( + new List(), + new List() { cluster1 }, + configListeners: new[] { configListener }, + destinationResolver: syncExpandResolver); + var configManager = services.GetRequiredService(); + + await configManager.InitialLoadAsync(); + var configEvent = await configListener.Events.Reader.ReadAsync(); + var configLoadEvent = Assert.IsType(configEvent); + + Assert.True(configManager.TryGetCluster(cluster1.ClusterId, out var cluster)); + + var expectedDestinations = new Dictionary(StringComparer.OrdinalIgnoreCase) + { + { "d1-1", new DestinationConfig() { Address = "http://127.0.0.1:8080/1" } }, + { "d1-2", new DestinationConfig() { Address = "http://127.1.1.1:8080/1" } } + }; + + var actualDestinations = cluster.Destinations.ToDictionary(static k => k.Key, static v => v.Value.Model.Config); + Assert.Equal(expectedDestinations, actualDestinations); + + // Trigger the change token and wait for a subsequent load + var initialCts = cts[0]; + cts[0] = new(); + initialCts.Cancel(); + + configEvent = await configListener.Events.Reader.ReadAsync(); + configLoadEvent = Assert.IsType(configEvent); + + Assert.True(configManager.TryGetCluster(cluster1.ClusterId, out cluster)); + + expectedDestinations = new Dictionary(StringComparer.OrdinalIgnoreCase) + { + { "d1-1", new DestinationConfig() { Address = "http://127.0.0.1:8080/2" } }, + { "d1-2", new DestinationConfig() { Address = "http://127.1.1.1:8080/2" } } + }; + + actualDestinations = cluster.Destinations.ToDictionary(static k => k.Key, static v => v.Value.Model.Config); + Assert.Equal(expectedDestinations, actualDestinations); + } + + [Fact] + public async Task LoadAsync_DestinationResolver_Reload_ThrowsSync() + { + var configListener = new TestConfigChangeListener(includeApply: false); + var cts = new CancellationTokenSource(); + var syncThrowResolver = new FakeDestinationResolver((destinations, cancellation) => + { + if (cts.IsCancellationRequested) + { + throw new InvalidOperationException("Throwing!"); + } + else + { + return new(new ResolvedDestinationCollection(destinations, new CancellationChangeToken(cts.Token))); + } + }); + var cluster = new ClusterConfig() + { + ClusterId = "cluster1", + Destinations = new Dictionary(StringComparer.OrdinalIgnoreCase) + { + { "d1", new DestinationConfig() { Address = "http://localhost" } } + } + }; + var services = CreateServices( + new List(), + new List() { cluster }, + configListeners: new[] { configListener }, + destinationResolver: syncThrowResolver); + var configManager = services.GetRequiredService(); + await configManager.InitialLoadAsync(); + + // Read the successful load event + Assert.IsType(await configListener.Events.Reader.ReadAsync()); + + // Trigger invalidation + cts.Cancel(); + + // Read the failure event + var configLoadException = Assert.IsType(await configListener.Events.Reader.ReadAsync()); + var ex = configLoadException.Exception; + Assert.Equal("Throwing!", ex.Message); + } + + [Fact] + public async Task LoadAsync_DestinationResolver_Reload_ThrowsAsync() + { + var configListener = new TestConfigChangeListener(includeApply: false); + var cts = new CancellationTokenSource(); + var syncThrowResolver = new FakeDestinationResolver(async (destinations, cancellation) => + { + await Task.Yield(); + + if (cts.IsCancellationRequested) + { + throw new InvalidOperationException("Throwing!"); + } + else + { + return new ResolvedDestinationCollection(destinations, new CancellationChangeToken(cts.Token)); + } + }); + var cluster = new ClusterConfig() + { + ClusterId = "cluster1", + Destinations = new Dictionary(StringComparer.OrdinalIgnoreCase) + { + { "d1", new DestinationConfig() { Address = "http://localhost" } } + } + }; + var services = CreateServices( + new List(), + new List() { cluster }, + configListeners: new[] { configListener }, + destinationResolver: syncThrowResolver); + var configManager = services.GetRequiredService(); + await configManager.InitialLoadAsync(); + + // Read the successful load event + Assert.IsType(await configListener.Events.Reader.ReadAsync()); + + // Trigger invalidation + cts.Cancel(); + + // Read the failure event + var configLoadException = Assert.IsType(await configListener.Events.Reader.ReadAsync()); + var ex = configLoadException.Exception; + Assert.Equal("Throwing!", ex.Message); + } } diff --git a/test/ReverseProxy.Tests/Transforms/Builder/TransformBuilderTests.cs b/test/ReverseProxy.Tests/Transforms/Builder/TransformBuilderTests.cs index 17cc98580..181e9d8e3 100644 --- a/test/ReverseProxy.Tests/Transforms/Builder/TransformBuilderTests.cs +++ b/test/ReverseProxy.Tests/Transforms/Builder/TransformBuilderTests.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Net.Http; using System.Threading; using System.Threading.Tasks; @@ -13,6 +14,7 @@ using Microsoft.Extensions.DependencyInjection; using Xunit; using Yarp.ReverseProxy.Configuration; +using Yarp.ReverseProxy.Model; using Yarp.Tests.Common; namespace Yarp.ReverseProxy.Transforms.Builder.Tests; @@ -240,22 +242,31 @@ public void DefaultsCanBeDisabled() var results = transformBuilder.BuildInternal(route, new ClusterConfig()); Assert.NotNull(results); Assert.False(results.ShouldCopyRequestHeaders); - Assert.Empty(results.RequestTransforms); + Assert.Single(results.RequestTransforms); Assert.Empty(results.ResponseTransforms); Assert.Empty(results.ResponseTrailerTransforms); } [Theory] - [InlineData(null, null)] - [InlineData(null, true)] - [InlineData(null, false)] - [InlineData(true, null)] - [InlineData(false, null)] - [InlineData(true, true)] - [InlineData(true, false)] - [InlineData(false, true)] - [InlineData(false, false)] - public async Task UseOriginalHost(bool? useOriginalHost, bool? copyHeaders) + [InlineData(null, null, false)] + [InlineData(null, true, false)] + [InlineData(null, false, false)] + [InlineData(true, null, false)] + [InlineData(false, null, false)] + [InlineData(true, true, false)] + [InlineData(true, false, false)] + [InlineData(false, true, false)] + [InlineData(false, false, false)] + [InlineData(null, null, true)] + [InlineData(null, true, true)] + [InlineData(null, false, true)] + [InlineData(true, null, true)] + [InlineData(false, null, true)] + [InlineData(true, true, true)] + [InlineData(true, false, true)] + [InlineData(false, true, true)] + [InlineData(false, false, true)] + public async Task UseOriginalHost(bool? useOriginalHost, bool? copyHeaders, bool hasDestinationHost) { var transformBuilder = CreateTransformBuilder(); var transforms = new List>(); @@ -283,41 +294,48 @@ public async Task UseOriginalHost(bool? useOriginalHost, bool? copyHeaders) var errors = transformBuilder.ValidateRoute(route); Assert.Empty(errors); - var results = transformBuilder.BuildInternal(route, new ClusterConfig()); + var destinationHost = hasDestinationHost ? "d1-host" : null; + var clusterConfig = new ClusterConfig + { + ClusterId = "cluster1", + Destinations = new Dictionary + { + ["d1"] = new DestinationConfig + { + Address = "https://localhost", + Host = destinationHost + } + } + }; + var results = transformBuilder.BuildInternal(route, clusterConfig); Assert.NotNull(results); Assert.Equal(copyHeaders, results.ShouldCopyRequestHeaders); Assert.Empty(results.ResponseTransforms); Assert.Empty(results.ResponseTrailerTransforms); - if (useOriginalHost.HasValue) - { - var transform = Assert.Single(results.RequestTransforms); - var hostTransform = Assert.IsType(transform); - Assert.Equal(useOriginalHost.Value, hostTransform.UseOriginalHost); - } - else if (copyHeaders.GetValueOrDefault(true)) - { - var transform = Assert.Single(results.RequestTransforms); - var hostTransform = Assert.IsType(transform); - Assert.False(hostTransform.UseOriginalHost); - } - else - { - Assert.Empty(results.RequestTransforms); - } - var httpContext = new DefaultHttpContext(); + httpContext.Features.Set(new ReverseProxyFeature + { + ProxiedDestination = new DestinationState("d1") { Model = new(clusterConfig.Destinations.Single().Value) } + }); httpContext.Request.Host = new HostString("StartHost"); var proxyRequest = new HttpRequestMessage(); var destinationPrefix = "http://destinationhost:9090/path"; await results.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix, CancellationToken.None); + // We expect the host to be flowed as long as it is being explicitly flowed or it wasn't suppressed and headers are being copied. if (useOriginalHost.GetValueOrDefault(false)) { Assert.Equal("StartHost", proxyRequest.Headers.Host); } - else + else if (destinationHost is not null) + { + // Otherwise, fall back to the destination config host, which will be null if it's not set. + Assert.Equal(destinationHost, proxyRequest.Headers.Host); + } + else { + // Otherwise, the host should be null Assert.Null(proxyRequest.Headers.Host); } } @@ -368,10 +386,26 @@ public async Task UseCustomHost(bool? useOriginalHost, bool? copyHeaders) var errors = transformBuilder.ValidateRoute(route); Assert.Empty(errors); - var results = transformBuilder.BuildInternal(route, new ClusterConfig()); + var clusterConfig = new ClusterConfig + { + ClusterId = "cluster1", + Destinations = new Dictionary + { + ["d1"] = new DestinationConfig + { + Address = "https://localhost", + Host = "d1-host" + } + } + }; + var results = transformBuilder.BuildInternal(route, clusterConfig); Assert.Equal(copyHeaders, results.ShouldCopyRequestHeaders); var httpContext = new DefaultHttpContext(); + httpContext.Features.Set(new ReverseProxyFeature + { + ProxiedDestination = new DestinationState("d1") { Model = new(clusterConfig.Destinations.Single().Value) } + }); httpContext.Request.Host = new HostString("StartHost"); var proxyRequest = new HttpRequestMessage(); var destinationPrefix = "http://destinationhost:9090/path"; @@ -402,9 +436,9 @@ public void DefaultsCanBeOverridenByForwarded() Assert.Empty(errors); var results = transformBuilder.BuildInternal(route, new ClusterConfig()); - Assert.Equal(5, results.RequestTransforms.Length); + Assert.Equal(6, results.RequestTransforms.Length); Assert.All( - results.RequestTransforms.Skip(1).Select(t => (dynamic)t), + results.RequestTransforms.Skip(1).SkipLast(1).Select(t => (dynamic)t), t => { Assert.StartsWith("X-Forwarded-", t.HeaderName);