Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use SearchValues in RequestUtilities.EncodePath #2267

Merged
merged 3 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 36 additions & 26 deletions src/ReverseProxy/Forwarder/RequestUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
Expand All @@ -12,6 +13,7 @@
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Primitives;
using Microsoft.Net.Http.Headers;
using Yarp.ReverseProxy.Utilities;

namespace Yarp.ReverseProxy.Forwarder;

Expand All @@ -20,6 +22,11 @@ namespace Yarp.ReverseProxy.Forwarder;
/// </summary>
public static class RequestUtilities
{
#if NET8_0_OR_GREATER
private static readonly SearchValues<char> s_validPathChars =
SearchValues.Create("!$&'()*+,-./0123456789:;=@ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz~");
#endif

/// <summary>
/// Converts the given HTTP method (usually obtained from <see cref="HttpRequest.Method"/>)
/// into the corresponding <see cref="HttpMethod"/> static instance.
Expand Down Expand Up @@ -124,27 +131,37 @@ public static Uri MakeDestinationAddress(string destinationPrefix, PathString pa
// This isn't using PathString.ToUriComponent() because it doesn't round trip some escape sequences the way we want.
private static string EncodePath(PathString path)
{
if (!path.HasValue)
var value = path.Value;

if (string.IsNullOrEmpty(value))
{
return string.Empty;
}

// Check if any escaping is required.
var value = path.Value!;
#if NET8_0_OR_GREATER
var indexOfInvalidChar = value.AsSpan().IndexOfAnyExcept(s_validPathChars);
#else
var indexOfInvalidChar = -1;

for (var i = 0; i < value.Length; i++)
{
if (!IsValidPathChar(value[i]))
{
return EncodePath(value, i);
indexOfInvalidChar = i;
break;
}
}
#endif

return value;
return indexOfInvalidChar < 0
? value
: EncodePath(value, indexOfInvalidChar);
}

private static string EncodePath(string value, int i)
{
StringBuilder? buffer = null;
var builder = new ValueStringBuilder(stackalloc char[ValueStringBuilder.StackallocThreshold]);

var start = 0;
var count = i;
Expand All @@ -157,8 +174,7 @@ private static string EncodePath(string value, int i)
if (requiresEscaping)
{
// the current segment requires escape
buffer ??= new StringBuilder(value.Length * 3);
buffer.Append(Uri.EscapeDataString(value.Substring(start, count)));
builder.Append(Uri.EscapeDataString(value.Substring(start, count)));

requiresEscaping = false;
start = i;
Expand All @@ -173,8 +189,7 @@ private static string EncodePath(string value, int i)
if (!requiresEscaping)
{
// the current segment doesn't require escape
buffer ??= new StringBuilder(value.Length * 3);
buffer.Append(value, start, count);
builder.Append(value.AsSpan(start, count));

requiresEscaping = true;
start = i;
Expand All @@ -186,30 +201,24 @@ private static string EncodePath(string value, int i)
}
}

if (count == value.Length && !requiresEscaping)
Debug.Assert(count > 0);

if (requiresEscaping)
{
return value;
builder.Append(Uri.EscapeDataString(value.Substring(start, count)));
}
else
{
if (count > 0)
{
buffer ??= new StringBuilder(value.Length * 3);

if (requiresEscaping)
{
buffer.Append(Uri.EscapeDataString(value.Substring(start, count)));
}
else
{
buffer.Append(value, start, count);
}
}

return buffer?.ToString() ?? string.Empty;
builder.Append(value.AsSpan(start, count));
}

return builder.ToString();
}

#if NET8_0_OR_GREATER
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static bool IsValidPathChar(char c) => s_validPathChars.Contains(c);
#else
// https://datatracker.ietf.org/doc/html/rfc3986/#appendix-A
// pchar = unreserved / pct-encoded / sub-delims / ":" / "@"
// pct-encoded = "%" HEXDIG HEXDIG
Expand Down Expand Up @@ -244,6 +253,7 @@ internal static bool IsValidPathChar(char c)
return (uint)offset < (uint)validChars.Length &&
((validChars[offset] & significantBit) != 0);
}
#endif

// Note: HttpClient.SendAsync will end up sending the union of
// HttpRequestMessage.Headers and HttpRequestMessage.Content.Headers.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,12 @@ public override ValueTask ApplyAsync(RequestTransformContext context)

private string GetHeaderValue(HttpContext httpContext)
{
var builder = new ValueStringBuilder();
var builder = new ValueStringBuilder(stackalloc char[ValueStringBuilder.StackallocThreshold]);
AppendProto(httpContext, ref builder);
AppendHost(httpContext, ref builder);
AppendFor(httpContext, ref builder);
AppendBy(httpContext, ref builder);
var value = builder.ToString();
return value;
return builder.ToString();
}

private void AppendProto(HttpContext context, ref ValueStringBuilder builder)
Expand Down
5 changes: 5 additions & 0 deletions src/ReverseProxy/Utilities/SkipLocalsInit.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

// Used to indicate to the compiler that the .locals init flag should not be set in method headers.
[module: System.Runtime.CompilerServices.SkipLocalsInit]
63 changes: 46 additions & 17 deletions src/ReverseProxy/Utilities/ValueStringBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,40 +9,48 @@

namespace Yarp.ReverseProxy.Utilities;

//Copied from https://github.com/dotnet/runtime/blob/1ee59da9f6104c611b137c9d14add04becefdf14/src/libraries/Common/src/System/Text/ValueStringBuilder.cs
// Adapted from https://github.com/dotnet/runtime/blob/82fee2692b3954ba8903fa4764f1f4e36a26341a/src/libraries/Common/src/System/Text/ValueStringBuilder.cs
internal ref partial struct ValueStringBuilder
{
private char[] _arrayToReturnToPool;
public const int StackallocThreshold = 512;

private char[]? _arrayToReturnToPool;
private Span<char> _chars;
private int _pos;

public ValueStringBuilder(Span<char> initialBuffer)
{
_arrayToReturnToPool = null;
_chars = initialBuffer;
_pos = 0;
}

public int Length
{
get => _pos;
set
{
Debug.Assert(value >= 0);
Debug.Assert(value <= RawChars.Length);
Debug.Assert(value <= _chars.Length);
_pos = value;
}
}

public override string ToString()
{
var s = RawChars.Slice(0, _pos).ToString();
var s = _chars.Slice(0, _pos).ToString();
Dispose();
return s;
}

/// <summary>Returns the underlying storage of the builder.</summary>
public Span<char> RawChars { get; private set; }

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void Append(char c)
{
var pos = _pos;
if ((uint)pos < (uint)RawChars.Length)
var chars = _chars;
if ((uint)pos < (uint)chars.Length)
{
RawChars[pos] = c;
chars[pos] = c;
_pos = pos + 1;
}
else
Expand All @@ -60,20 +68,32 @@ public void Append(string s)
}

var pos = _pos;
if (pos > RawChars.Length - s.Length)
if (pos > _chars.Length - s.Length)
{
Grow(s.Length);
}

s.AsSpan().CopyTo(RawChars.Slice(pos));
s.CopyTo(_chars.Slice(pos));
_pos += s.Length;
}

public void Append(ReadOnlySpan<char> value)
{
var pos = _pos;
if (pos > _chars.Length - value.Length)
{
Grow(value.Length);
}

value.CopyTo(_chars.Slice(_pos));
_pos += value.Length;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void Append(int i)
{
var pos = _pos;
if (i.TryFormat(RawChars.Slice(pos), out var charsWritten, default, null))
if (i.TryFormat(_chars.Slice(pos), out var charsWritten, default, null))
{
_pos = pos + charsWritten;
}
Expand Down Expand Up @@ -102,15 +122,24 @@ private void GrowAndAppend(char c)
private void Grow(int additionalCapacityBeyondPos)
{
Debug.Assert(additionalCapacityBeyondPos > 0);
Debug.Assert(_pos > RawChars.Length - additionalCapacityBeyondPos, "Grow called incorrectly, no resize is needed.");
Debug.Assert(_pos > _chars.Length - additionalCapacityBeyondPos, "Grow called incorrectly, no resize is needed.");

const uint ArrayMaxLength = 0x7FFFFFC7; // same as Array.MaxLength

// Increase to at least the required size (_pos + additionalCapacityBeyondPos), but try
// to double the size if possible, bounding the doubling to not go beyond the max array length.
var newCapacity = (int)Math.Max(
(uint)(_pos + additionalCapacityBeyondPos),
Math.Min((uint)_chars.Length * 2, ArrayMaxLength));

// Make sure to let Rent throw an exception if the caller has a bug and the desired capacity is negative
var poolArray = ArrayPool<char>.Shared.Rent((int)Math.Max((uint)(_pos + additionalCapacityBeyondPos), (uint)RawChars.Length * 2));
// Make sure to let Rent throw an exception if the caller has a bug and the desired capacity is negative.
// This could also go negative if the actual required length wraps around.
var poolArray = ArrayPool<char>.Shared.Rent(newCapacity);

RawChars.Slice(0, _pos).CopyTo(poolArray);
_chars.Slice(0, _pos).CopyTo(poolArray);

var toReturn = _arrayToReturnToPool;
RawChars = _arrayToReturnToPool = poolArray;
_chars = _arrayToReturnToPool = poolArray;
if (toReturn is not null)
{
ArrayPool<char>.Shared.Return(toReturn);
Expand Down
Loading