Skip to content

Commit

Permalink
Revert "Dedup UnboundedChannel and UnboundedPriorityChannel (#101396)" (
Browse files Browse the repository at this point in the history
#103325)

* Revert "Dedup UnboundedChannel and UnboundedPriorityChannel (#101396)"

This reverts commit b4e0169.

* Put back lock

This fix has been part of the previous deduping.
  • Loading branch information
stephentoub committed Jun 13, 2024
1 parent f528563 commit 7e2e874
Show file tree
Hide file tree
Showing 7 changed files with 409 additions and 179 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ System.Threading.Channel&lt;T&gt;</PackageDescription>
<Compile Include="System\Threading\Channels\Channel_1.cs" />
<Compile Include="System\Threading\Channels\Channel_2.cs" />
<Compile Include="System\Threading\Channels\IDebugEnumerator.cs" />
<Compile Include="System\Threading\Channels\IUnboundedChannelQueue.cs" />
<Compile Include="System\Threading\Channels\SingleConsumerUnboundedChannel.cs" />
<Compile Include="System\Threading\Channels\UnboundedChannel.cs" />
<Compile Include="$(CommonPath)Internal\Padding.cs" Link="Common\Internal\Padding.cs" />
Expand All @@ -45,6 +44,7 @@ System.Threading.Channel&lt;T&gt;</PackageDescription>
<Compile Include="System\Threading\Channels\AsyncOperation.netcoreapp.cs" />
<Compile Include="System\Threading\Channels\Channel.netcoreapp.cs" />
<Compile Include="System\Threading\Channels\ChannelOptions.netcoreapp.cs" />
<Compile Include="System\Threading\Channels\UnboundedPriorityChannel.cs" />
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == '$(NetCoreAppCurrent)'">
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;

namespace System.Threading.Channels
{
/// <summary>Provides static methods for creating channels.</summary>
Expand All @@ -13,7 +9,7 @@ public static partial class Channel
/// <summary>Creates an unbounded channel usable by any number of readers and writers concurrently.</summary>
/// <returns>The created channel.</returns>
public static Channel<T> CreateUnbounded<T>() =>
new UnboundedChannel<T, UnboundedChannelConcurrentQueue<T>>(new(new()), runContinuationsAsynchronously: true);
new UnboundedChannel<T>(runContinuationsAsynchronously: true);

/// <summary>Creates an unbounded channel subject to the provided options.</summary>
/// <typeparam name="T">Specifies the type of data in the channel.</typeparam>
Expand All @@ -31,7 +27,7 @@ public static Channel<T> CreateUnbounded<T>(UnboundedChannelOptions options)
return new SingleConsumerUnboundedChannel<T>(!options.AllowSynchronousContinuations);
}

return new UnboundedChannel<T, UnboundedChannelConcurrentQueue<T>>(new(new()), !options.AllowSynchronousContinuations);
return new UnboundedChannel<T>(!options.AllowSynchronousContinuations);
}

/// <summary>Creates a channel with the specified maximum capacity.</summary>
Expand Down Expand Up @@ -75,32 +71,5 @@ public static Channel<T> CreateBounded<T>(BoundedChannelOptions options, Action<

return new BoundedChannel<T>(options.Capacity, options.FullMode, !options.AllowSynchronousContinuations, itemDropped);
}

/// <summary>Provides an <see cref="IUnboundedChannelQueue{T}"/> for a <see cref="ConcurrentQueue{T}"/>.</summary>
private readonly struct UnboundedChannelConcurrentQueue<T>(ConcurrentQueue<T> queue) : IUnboundedChannelQueue<T>
{
private readonly ConcurrentQueue<T> _queue = queue;

/// <inheritdoc/>
public bool IsThreadSafe => true;

/// <inheritdoc/>
public void Enqueue(T item) => _queue.Enqueue(item);

/// <inheritdoc/>
public bool TryDequeue([MaybeNullWhen(false)] out T item) => _queue.TryDequeue(out item);

/// <inheritdoc/>
public bool TryPeek([MaybeNullWhen(false)] out T item) => _queue.TryPeek(out item);

/// <inheritdoc/>
public int Count => _queue.Count;

/// <inheritdoc/>
public bool IsEmpty => _queue.IsEmpty;

/// <inheritdoc/>
public IEnumerator<T> GetEnumerator() => _queue.GetEnumerator();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,20 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;

namespace System.Threading.Channels
{
/// <summary>Provides static methods for creating channels.</summary>
public static partial class Channel
{
/// <summary>Creates an unbounded prioritized channel usable by any number of readers and writers concurrently.</summary>
/// <typeparam name="T">Specifies the type of data in the channel.</typeparam>
/// <returns>The created channel.</returns>
/// <remarks>
/// <see cref="Comparer{T}.Default"/> is used to determine priority of elements.
/// The next item read from the channel will be the element available in the channel with the lowest priority value.
/// </remarks>
public static Channel<T> CreateUnboundedPrioritized<T>() =>
new UnboundedChannel<T, UnboundedChannelPriorityQueue<T>>(new(new()), runContinuationsAsynchronously: true);
new UnboundedPrioritizedChannel<T>(runContinuationsAsynchronously: true, comparer: null);

/// <summary>Creates an unbounded prioritized channel subject to the provided options.</summary>
/// <typeparam name="T">Specifies the type of data in the channel.</typeparam>
Expand All @@ -32,45 +30,7 @@ public static Channel<T> CreateUnboundedPrioritized<T>(UnboundedPrioritizedChann
{
ArgumentNullException.ThrowIfNull(options);

return new UnboundedChannel<T, UnboundedChannelPriorityQueue<T>>(new(new(options.Comparer)), !options.AllowSynchronousContinuations);
}

/// <summary>Provides an <see cref="IUnboundedChannelQueue{T}"/> for a <see cref="PriorityQueue{TElement, TPriority}"/>.</summary>
private readonly struct UnboundedChannelPriorityQueue<T>(PriorityQueue<bool, T> queue) : IUnboundedChannelQueue<T>
{
private readonly PriorityQueue<bool, T> _queue = queue;

/// <inheritdoc/>
public bool IsThreadSafe => false;

/// <inheritdoc/>
public void Enqueue(T item) => _queue.Enqueue(true, item);

/// <inheritdoc/>
public bool TryDequeue([MaybeNullWhen(false)] out T item) => _queue.TryDequeue(out _, out item);

/// <inheritdoc/>
public bool TryPeek([MaybeNullWhen(false)] out T item) => _queue.TryPeek(out _, out item);

/// <inheritdoc/>
public int Count => _queue.Count;

/// <inheritdoc/>
public bool IsEmpty => _queue.Count == 0;

/// <inheritdoc/>
public IEnumerator<T> GetEnumerator()
{
List<T> list = [];
foreach ((bool _, T Priority) item in _queue.UnorderedItems)
{
list.Add(item.Priority);
}

list.Sort(_queue.Comparer);

return list.GetEnumerator();
}
return new UnboundedPrioritizedChannel<T>(!options.AllowSynchronousContinuations, options.Comparer);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ internal interface IDebugEnumerable<T>
IEnumerator<T> GetEnumerator();
}

internal class DebugEnumeratorDebugView<T>
internal sealed class DebugEnumeratorDebugView<T>
{
public DebugEnumeratorDebugView(IDebugEnumerable<T> enumerable)
{
Expand All @@ -26,6 +26,4 @@ public DebugEnumeratorDebugView(IDebugEnumerable<T> enumerable)
[DebuggerBrowsable(DebuggerBrowsableState.RootHidden)]
public T[] Items { get; }
}

internal sealed class DebugEnumeratorDebugView<T, TOther>(IDebugEnumerable<T> enumerable) : DebugEnumeratorDebugView<T>(enumerable);
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,19 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;

namespace System.Threading.Channels
{
/// <summary>Provides a buffered channel of unbounded capacity.</summary>
[DebuggerDisplay("Items = {ItemsCountForDebugger}, Closed = {ChannelIsClosedForDebugger}")]
[DebuggerTypeProxy(typeof(DebugEnumeratorDebugView<,>))]
internal sealed class UnboundedChannel<T, TQueue> : Channel<T>, IDebugEnumerable<T> where TQueue : struct, IUnboundedChannelQueue<T>
[DebuggerTypeProxy(typeof(DebugEnumeratorDebugView<>))]
internal sealed class UnboundedChannel<T> : Channel<T>, IDebugEnumerable<T>
{
/// <summary>Task that indicates the channel has completed.</summary>
private readonly TaskCompletionSource _completion;
/// <summary>The items in the channel.</summary>
private readonly TQueue _items;
private readonly ConcurrentQueue<T> _items = new ConcurrentQueue<T>();
/// <summary>Readers blocked reading from the channel.</summary>
private readonly Deque<AsyncOperation<T>> _blockedReaders = new Deque<AsyncOperation<T>>();
/// <summary>Whether to force continuations to be executed asynchronously from producer writes.</summary>
Expand All @@ -30,24 +29,23 @@ internal sealed class UnboundedChannel<T, TQueue> : Channel<T>, IDebugEnumerable
private Exception? _doneWriting;

/// <summary>Initialize the channel.</summary>
internal UnboundedChannel(TQueue items, bool runContinuationsAsynchronously)
internal UnboundedChannel(bool runContinuationsAsynchronously)
{
_items = items;
_runContinuationsAsynchronously = runContinuationsAsynchronously;
_completion = new TaskCompletionSource(runContinuationsAsynchronously ? TaskCreationOptions.RunContinuationsAsynchronously : TaskCreationOptions.None);
Reader = new UnboundedChannelReader(this);
Writer = new UnboundedChannelWriter(this);
}

[DebuggerDisplay("Items = {Count}")]
[DebuggerTypeProxy(typeof(DebugEnumeratorDebugView<,>))]
[DebuggerTypeProxy(typeof(DebugEnumeratorDebugView<>))]
private sealed class UnboundedChannelReader : ChannelReader<T>, IDebugEnumerable<T>
{
internal readonly UnboundedChannel<T, TQueue> _parent;
internal readonly UnboundedChannel<T> _parent;
private readonly AsyncOperation<T> _readerSingleton;
private readonly AsyncOperation<bool> _waiterSingleton;

internal UnboundedChannelReader(UnboundedChannel<T, TQueue> parent)
internal UnboundedChannelReader(UnboundedChannel<T> parent)
{
_parent = parent;
_readerSingleton = new AsyncOperation<T>(parent._runContinuationsAsynchronously, pooled: true);
Expand All @@ -70,8 +68,8 @@ public override ValueTask<T> ReadAsync(CancellationToken cancellationToken)
}

// Dequeue an item if we can.
UnboundedChannel<T, TQueue> parent = _parent;
if (parent._items.IsThreadSafe && parent._items.TryDequeue(out T? item))
UnboundedChannel<T> parent = _parent;
if (parent._items.TryDequeue(out T? item))
{
CompleteIfDone(parent);
return new ValueTask<T>(item);
Expand Down Expand Up @@ -114,60 +112,24 @@ public override ValueTask<T> ReadAsync(CancellationToken cancellationToken)

public override bool TryRead([MaybeNullWhen(false)] out T item)
{
UnboundedChannel<T, TQueue> parent = _parent;
return parent._items.IsThreadSafe ?
LockFree(parent, out item) :
Locked(parent, out item);
UnboundedChannel<T> parent = _parent;

static bool LockFree(UnboundedChannel<T, TQueue> parent, [MaybeNullWhen(false)] out T item)
// Dequeue an item if we can
if (parent._items.TryDequeue(out item))
{
if (parent._items.TryDequeue(out item))
{
CompleteIfDone(parent);
return true;
}

item = default;
return false;
CompleteIfDone(parent);
return true;
}

static bool Locked(UnboundedChannel<T, TQueue> parent, [MaybeNullWhen(false)] out T item)
{
lock (parent.SyncObj)
{
if (parent._items.TryDequeue(out item))
{
CompleteIfDone(parent);
return true;
}
}

item = default;
return false;
}
item = default;
return false;
}

public override bool TryPeek([MaybeNullWhen(false)] out T item)
{
UnboundedChannel<T, TQueue> parent = _parent;
return parent._items.IsThreadSafe ?
parent._items.TryPeek(out item) :
Locked(parent, out item);

// Separated out to keep the try/finally from preventing TryPeek from being inlined
static bool Locked(UnboundedChannel<T, TQueue> parent, [MaybeNullWhen(false)] out T item)
{
lock (parent.SyncObj)
{
return parent._items.TryPeek(out item);
}
}
}
public override bool TryPeek([MaybeNullWhen(false)] out T item) =>
_parent._items.TryPeek(out item);

private static void CompleteIfDone(UnboundedChannel<T, TQueue> parent)
private static void CompleteIfDone(UnboundedChannel<T> parent)
{
Debug.Assert(parent._items.IsThreadSafe || Monitor.IsEntered(parent.SyncObj));

if (parent._doneWriting != null && parent._items.IsEmpty)
{
// If we've now emptied the items queue and we're not getting any more, complete.
Expand All @@ -182,12 +144,12 @@ public override ValueTask<bool> WaitToReadAsync(CancellationToken cancellationTo
return new ValueTask<bool>(Task.FromCanceled<bool>(cancellationToken));
}

if (_parent._items.IsThreadSafe && !_parent._items.IsEmpty)
if (!_parent._items.IsEmpty)
{
return new ValueTask<bool>(true);
}

UnboundedChannel<T, TQueue> parent = _parent;
UnboundedChannel<T> parent = _parent;

lock (parent.SyncObj)
{
Expand Down Expand Up @@ -230,15 +192,15 @@ public override ValueTask<bool> WaitToReadAsync(CancellationToken cancellationTo
}

[DebuggerDisplay("Items = {ItemsCountForDebugger}")]
[DebuggerTypeProxy(typeof(DebugEnumeratorDebugView<,>))]
[DebuggerTypeProxy(typeof(DebugEnumeratorDebugView<>))]
private sealed class UnboundedChannelWriter : ChannelWriter<T>, IDebugEnumerable<T>
{
internal readonly UnboundedChannel<T, TQueue> _parent;
internal UnboundedChannelWriter(UnboundedChannel<T, TQueue> parent) => _parent = parent;
internal readonly UnboundedChannel<T> _parent;
internal UnboundedChannelWriter(UnboundedChannel<T> parent) => _parent = parent;

public override bool TryComplete(Exception? error)
{
UnboundedChannel<T, TQueue> parent = _parent;
UnboundedChannel<T> parent = _parent;
bool completeTask;

lock (parent.SyncObj)
Expand Down Expand Up @@ -278,7 +240,7 @@ public override bool TryComplete(Exception? error)

public override bool TryWrite(T item)
{
UnboundedChannel<T, TQueue> parent = _parent;
UnboundedChannel<T> parent = _parent;
while (true)
{
AsyncOperation<T>? blockedReader = null;
Expand Down Expand Up @@ -359,7 +321,7 @@ public override ValueTask WriteAsync(T item, CancellationToken cancellationToken
}

/// <summary>Gets the object used to synchronize access to all state on this instance.</summary>
private object SyncObj => _blockedReaders;
private object SyncObj => _items;

[Conditional("DEBUG")]
private void AssertInvariants()
Expand Down
Loading

0 comments on commit 7e2e874

Please sign in to comment.