From 1d18b77fc87bd20c08ef84be1cec1294303fd32d Mon Sep 17 00:00:00 2001 From: Jake Banks Date: Tue, 14 Nov 2023 04:33:08 +1100 Subject: [PATCH] Fix limiter policy validation when 'default' or 'disable' used (#2283) --- .../Configuration/ConfigValidator.cs | 19 +++-- .../Configuration/ConfigValidatorTests.cs | 81 +++++++++++++++++++ 2 files changed, 93 insertions(+), 7 deletions(-) diff --git a/src/ReverseProxy/Configuration/ConfigValidator.cs b/src/ReverseProxy/Configuration/ConfigValidator.cs index 51d9c4da0..d6cb70ed9 100644 --- a/src/ReverseProxy/Configuration/ConfigValidator.cs +++ b/src/ReverseProxy/Configuration/ConfigValidator.cs @@ -302,6 +302,18 @@ private async ValueTask ValidateRateLimiterPolicyAsync(IList errors, return; } + if (string.Equals(RateLimitingConstants.Default, rateLimiterPolicyName, StringComparison.OrdinalIgnoreCase) + || string.Equals(RateLimitingConstants.Disable, rateLimiterPolicyName, StringComparison.OrdinalIgnoreCase)) + { + var policy = await _rateLimiterPolicyProvider.GetPolicyAsync(rateLimiterPolicyName); + if (policy is not null) + { + // We weren't expecting to find a policy with these names. + errors.Add(new ArgumentException($"The application has registered a RateLimiter policy named '{rateLimiterPolicyName}' that conflicts with the reserved RateLimiter policy name used on this route. The registered policy name needs to be changed for this route to function.")); + } + return; + } + try { var policy = await _rateLimiterPolicyProvider.GetPolicyAsync(rateLimiterPolicyName); @@ -309,13 +321,6 @@ private async ValueTask ValidateRateLimiterPolicyAsync(IList errors, if (policy is null) { errors.Add(new ArgumentException($"RateLimiter policy '{rateLimiterPolicyName}' not found for route '{routeId}'.")); - return; - } - - if (string.Equals(RateLimitingConstants.Default, rateLimiterPolicyName, StringComparison.OrdinalIgnoreCase) - || string.Equals(RateLimitingConstants.Disable, rateLimiterPolicyName, StringComparison.OrdinalIgnoreCase)) - { - errors.Add(new ArgumentException($"The application has registered a RateLimiter policy named '{rateLimiterPolicyName}' that conflicts with the reserved RateLimiter policy name used on this route. The registered policy name needs to be changed for this route to function.")); } } catch (Exception ex) diff --git a/test/ReverseProxy.Tests/Configuration/ConfigValidatorTests.cs b/test/ReverseProxy.Tests/Configuration/ConfigValidatorTests.cs index fd450802c..b0dbcecf1 100644 --- a/test/ReverseProxy.Tests/Configuration/ConfigValidatorTests.cs +++ b/test/ReverseProxy.Tests/Configuration/ConfigValidatorTests.cs @@ -4,7 +4,13 @@ using System; using System.Collections.Generic; using System.Threading.Tasks; +#if NET7_0_OR_GREATER +using Microsoft.AspNetCore.Builder; +#endif using Microsoft.AspNetCore.Cors.Infrastructure; +#if NET7_0_OR_GREATER +using Microsoft.AspNetCore.RateLimiting; +#endif using Microsoft.Extensions.DependencyInjection; using Moq; using Xunit; @@ -710,6 +716,81 @@ public async Task Rejects_ReservedCorsPolicyIsUsed(string corsPolicy) Assert.Contains(result, err => err.Message.Equals($"The application has registered a CORS policy named '{corsPolicy}' that conflicts with the reserved CORS policy name used on this route. The registered policy name needs to be changed for this route to function.")); } +#if NET7_0_OR_GREATER + [Theory] + [InlineData("Default")] + [InlineData("Disable")] + public async Task Accepts_BuiltInRateLimiterPolicy(string rateLimiterPolicy) + { + var route = new RouteConfig { + RouteId = "route1", + Match = new RouteMatch + { + Hosts = new[] { "localhost" }, + }, + ClusterId = "cluster1", + RateLimiterPolicy = rateLimiterPolicy + }; + + var services = CreateServices(); + var validator = services.GetRequiredService(); + + var result = await validator.ValidateRouteAsync(route); + + Assert.Empty(result); + } + + [Theory] + [InlineData("Default")] + [InlineData("Disable")] + public async Task Reports_BuildInRateLimiterPolicyNameConflict(string rateLimiterPolicy) + { + var route = new RouteConfig + { + RouteId = "route1", + Match = new RouteMatch + { + Hosts = new[] { "localhost" }, + }, + ClusterId = "cluster1", + RateLimiterPolicy = rateLimiterPolicy + }; + + var services = CreateServices(s => + { + s.AddRateLimiter(o => o.AddConcurrencyLimiter(rateLimiterPolicy, c => { })); + }); + var validator = services.GetRequiredService(); + + var result = await validator.ValidateRouteAsync(route); + + Assert.NotEmpty(result); + Assert.Contains(result, err => err.Message.Contains($"The application has registered a RateLimiter policy named '{rateLimiterPolicy}' that conflicts with the reserved RateLimiter policy name used on this route. The registered policy name needs to be changed for this route to function.")); + } + + [Theory] + [InlineData("NotAPolicy")] + public async Task Rejects_InvalidRateLimiterPolicy(string rateLimiterPolicy) + { + var route = new RouteConfig { + RouteId = "route1", + Match = new RouteMatch + { + Hosts = new[] { "localhost" }, + }, + ClusterId = "cluster1", + RateLimiterPolicy = rateLimiterPolicy }; + + var services = CreateServices(); + var validator = services.GetRequiredService(); + + var result = await validator.ValidateRouteAsync(route); + + Assert.NotEmpty(result); + Assert.Contains(result, err => err.Message.Contains($"RateLimiter policy '{rateLimiterPolicy}' not found")); + } +#endif + [Fact] public async Task EmptyCluster_Works() {