From dd18c9fff8d447410ea6cc8213226649037f8c72 Mon Sep 17 00:00:00 2001 From: neuecc Date: Mon, 18 May 2020 02:34:29 +0900 Subject: [PATCH] Add Channel.CreateSingleConsumerUnbounded --- src/UniTask.NetCoreTests/ChannelTest.cs | 370 ++++++++++++++++ .../Assets/Plugins/UniTask/Runtime/Channel.cs | 403 ++++++++++++++++++ .../Plugins/UniTask/Runtime/Channel.cs.meta | 11 + 3 files changed, 784 insertions(+) create mode 100644 src/UniTask.NetCoreTests/ChannelTest.cs create mode 100644 src/UniTask/Assets/Plugins/UniTask/Runtime/Channel.cs create mode 100644 src/UniTask/Assets/Plugins/UniTask/Runtime/Channel.cs.meta diff --git a/src/UniTask.NetCoreTests/ChannelTest.cs b/src/UniTask.NetCoreTests/ChannelTest.cs new file mode 100644 index 0000000..cf58411 --- /dev/null +++ b/src/UniTask.NetCoreTests/ChannelTest.cs @@ -0,0 +1,370 @@ +using Cysharp.Threading.Tasks; +using FluentAssertions; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Channels; +using Cysharp.Threading.Tasks.Linq; +using System.Threading.Tasks; +using Xunit; + +namespace NetCoreTests +{ + public class ChannelTest + { + (System.Threading.Channels.Channel, Cysharp.Threading.Tasks.Channel) CreateChannel() + { + var reference = System.Threading.Channels.Channel.CreateUnbounded(new UnboundedChannelOptions + { + AllowSynchronousContinuations = true, + SingleReader = true, + SingleWriter = false + }); + + var channel = Cysharp.Threading.Tasks.Channel.CreateSingleConsumerUnbounded(); + + return (reference, channel); + } + + [Fact] + public async Task SingleWriteSingleRead() + { + var (reference, channel) = CreateChannel(); + + foreach (var item in new[] { 10, 20, 30 }) + { + var t1 = reference.Reader.WaitToReadAsync(); + var t2 = channel.Reader.WaitToReadAsync(); + + t1.IsCompleted.Should().BeFalse(); + t2.Status.IsCompleted().Should().BeFalse(); + + reference.Writer.TryWrite(item); + channel.Writer.TryWrite(item); + + (await t1).Should().BeTrue(); + (await t2).Should().BeTrue(); + + reference.Reader.TryRead(out var refitem).Should().BeTrue(); + channel.Reader.TryRead(out var chanitem).Should().BeTrue(); + refitem.Should().Be(item); + chanitem.Should().Be(item); + } + } + + [Fact] + public async Task MultiWrite() + { + var (reference, channel) = CreateChannel(); + + foreach (var item in new[] { 10, 20, 30 }) + { + var t1 = reference.Reader.WaitToReadAsync(); + var t2 = channel.Reader.WaitToReadAsync(); + + t1.IsCompleted.Should().BeFalse(); + t2.Status.IsCompleted().Should().BeFalse(); + + foreach (var i in Enumerable.Range(1, 3)) + { + reference.Writer.TryWrite(item * i); + channel.Writer.TryWrite(item * i); + } + + (await t1).Should().BeTrue(); + (await t2).Should().BeTrue(); + + foreach (var i in Enumerable.Range(1, 3)) + { + (await reference.Reader.WaitToReadAsync()).Should().BeTrue(); + (await channel.Reader.WaitToReadAsync()).Should().BeTrue(); + + reference.Reader.TryRead(out var refitem).Should().BeTrue(); + channel.Reader.TryRead(out var chanitem).Should().BeTrue(); + refitem.Should().Be(item * i); + chanitem.Should().Be(item * i); + } + } + } + + [Fact] + public async Task CompleteOnEmpty() + { + var (reference, channel) = CreateChannel(); + + foreach (var item in new[] { 10, 20, 30 }) + { + reference.Writer.TryWrite(item); + channel.Writer.TryWrite(item); + reference.Reader.TryRead(out var refitem); + channel.Reader.TryRead(out var chanitem); + } + + // Empty. + + var completion1 = reference.Reader.Completion; + var wait1 = reference.Reader.WaitToReadAsync(); + + var completion2 = channel.Reader.Completion; + var wait2 = channel.Reader.WaitToReadAsync(); + + reference.Writer.TryComplete(); + channel.Writer.TryComplete(); + + completion1.Status.Should().Be(TaskStatus.RanToCompletion); + completion2.Status.Should().Be(UniTaskStatus.Succeeded); + + (await wait1).Should().BeFalse(); + (await wait2).Should().BeFalse(); + } + + [Fact] + public async Task CompleteErrorOnEmpty() + { + var (reference, channel) = CreateChannel(); + + foreach (var item in new[] { 10, 20, 30 }) + { + reference.Writer.TryWrite(item); + channel.Writer.TryWrite(item); + reference.Reader.TryRead(out var refitem); + channel.Reader.TryRead(out var chanitem); + } + + // Empty. + + var completion1 = reference.Reader.Completion; + var wait1 = reference.Reader.WaitToReadAsync(); + + var completion2 = channel.Reader.Completion; + var wait2 = channel.Reader.WaitToReadAsync(); + + var ex = new Exception(); + reference.Writer.TryComplete(ex); + channel.Writer.TryComplete(ex); + + completion1.Status.Should().Be(TaskStatus.Faulted); + completion2.Status.Should().Be(UniTaskStatus.Faulted); + + (await Assert.ThrowsAsync(async () => await wait1)).Should().Be(ex); + (await Assert.ThrowsAsync(async () => await wait2)).Should().Be(ex); + } + + [Fact] + public async Task CompleteWithRest() + { + var (reference, channel) = CreateChannel(); + + foreach (var item in new[] { 10, 20, 30 }) + { + reference.Writer.TryWrite(item); + channel.Writer.TryWrite(item); + } + + // Three Item2. + + var completion1 = reference.Reader.Completion; + var wait1 = reference.Reader.WaitToReadAsync(); + + var completion2 = channel.Reader.Completion; + var wait2 = channel.Reader.WaitToReadAsync(); + + reference.Writer.TryComplete(); + channel.Writer.TryComplete(); + + // completion1.Status.Should().Be(TaskStatus.WaitingForActivation); + completion2.Status.Should().Be(UniTaskStatus.Pending); + + (await wait1).Should().BeTrue(); + (await wait2).Should().BeTrue(); + + foreach (var item in new[] { 10, 20, 30 }) + { + reference.Reader.TryRead(out var i1).Should().BeTrue(); + channel.Reader.TryRead(out var i2).Should().BeTrue(); + i1.Should().Be(item); + i2.Should().Be(item); + } + + (await reference.Reader.WaitToReadAsync()).Should().BeFalse(); + (await channel.Reader.WaitToReadAsync()).Should().BeFalse(); + + completion1.Status.Should().Be(TaskStatus.RanToCompletion); + completion2.Status.Should().Be(UniTaskStatus.Succeeded); + } + + + [Fact] + public async Task CompleteErrorWithRest() + { + var (reference, channel) = CreateChannel(); + + foreach (var item in new[] { 10, 20, 30 }) + { + reference.Writer.TryWrite(item); + channel.Writer.TryWrite(item); + } + + // Three Item2. + + var completion1 = reference.Reader.Completion; + var wait1 = reference.Reader.WaitToReadAsync(); + + var completion2 = channel.Reader.Completion; + var wait2 = channel.Reader.WaitToReadAsync(); + + var ex = new Exception(); + reference.Writer.TryComplete(ex); + channel.Writer.TryComplete(ex); + + // completion1.Status.Should().Be(TaskStatus.WaitingForActivation); + completion2.Status.Should().Be(UniTaskStatus.Pending); + + (await wait1).Should().BeTrue(); + (await wait2).Should().BeTrue(); + + foreach (var item in new[] { 10, 20, 30 }) + { + reference.Reader.TryRead(out var i1).Should().BeTrue(); + channel.Reader.TryRead(out var i2).Should().BeTrue(); + i1.Should().Be(item); + i2.Should().Be(item); + } + + wait1 = reference.Reader.WaitToReadAsync(); + wait2 = channel.Reader.WaitToReadAsync(); + + (await Assert.ThrowsAsync(async () => await wait1)).Should().Be(ex); + (await Assert.ThrowsAsync(async () => await wait2)).Should().Be(ex); + + completion1.Status.Should().Be(TaskStatus.Faulted); + completion2.Status.Should().Be(UniTaskStatus.Faulted); + } + + [Fact] + public async Task Cancellation() + { + var (reference, channel) = CreateChannel(); + + var cts = new CancellationTokenSource(); + + var wait1 = reference.Reader.WaitToReadAsync(cts.Token); + var wait2 = channel.Reader.WaitToReadAsync(cts.Token); + + cts.Cancel(); + + (await Assert.ThrowsAsync(async () => await wait1)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await wait2)).CancellationToken.Should().Be(cts.Token); + } + + [Fact] + public async Task AsyncEnumerator() + { + var (reference, channel) = CreateChannel(); + + var ta1 = reference.Reader.ReadAllAsync().ToArrayAsync(); + var ta2 = channel.Reader.ReadAllAsync().ToArrayAsync(); + + foreach (var item in new[] { 10, 20, 30 }) + { + reference.Writer.TryWrite(item); + channel.Writer.TryWrite(item); + } + + reference.Writer.TryComplete(); + channel.Writer.TryComplete(); + + (await ta1).Should().BeEquivalentTo(new[] { 10, 20, 30 }); + (await ta2).Should().BeEquivalentTo(new[] { 10, 20, 30 }); + } + + [Fact] + public async Task AsyncEnumeratorCancellation() + { + // Token1, Token2 and Cancel1 + { + var cts1 = new CancellationTokenSource(); + var cts2 = new CancellationTokenSource(); + + var (reference, channel) = CreateChannel(); + + var ta1 = reference.Reader.ReadAllAsync(cts1.Token).ToArrayAsync(cts2.Token); + var ta2 = channel.Reader.ReadAllAsync(cts1.Token).ToArrayAsync(cts2.Token); + + foreach (var item in new[] { 10, 20, 30 }) + { + reference.Writer.TryWrite(item); + channel.Writer.TryWrite(item); + } + + cts1.Cancel(); + + await Assert.ThrowsAsync(async () => await ta1); + (await Assert.ThrowsAsync(async () => await ta2)).CancellationToken.Should().Be(cts1.Token); + } + // Token1, Token2 and Cancel2 + { + var cts1 = new CancellationTokenSource(); + var cts2 = new CancellationTokenSource(); + + var (reference, channel) = CreateChannel(); + + var ta1 = reference.Reader.ReadAllAsync(cts1.Token).ToArrayAsync(cts2.Token); + var ta2 = channel.Reader.ReadAllAsync(cts1.Token).ToArrayAsync(cts2.Token); + + foreach (var item in new[] { 10, 20, 30 }) + { + reference.Writer.TryWrite(item); + channel.Writer.TryWrite(item); + } + + cts2.Cancel(); + + await Assert.ThrowsAsync(async () => await ta1); + (await Assert.ThrowsAsync(async () => await ta2)).CancellationToken.Should().Be(cts2.Token); + } + // Token1 and Cancel1 + { + var cts1 = new CancellationTokenSource(); + + var (reference, channel) = CreateChannel(); + + var ta1 = reference.Reader.ReadAllAsync(cts1.Token).ToArrayAsync(); + var ta2 = channel.Reader.ReadAllAsync(cts1.Token).ToArrayAsync(); + + foreach (var item in new[] { 10, 20, 30 }) + { + reference.Writer.TryWrite(item); + channel.Writer.TryWrite(item); + } + + cts1.Cancel(); + + await Assert.ThrowsAsync(async () => await ta1); + (await Assert.ThrowsAsync(async () => await ta2)).CancellationToken.Should().Be(cts1.Token); + } + // Token2 and Cancel2 + { + var cts2 = new CancellationTokenSource(); + + var (reference, channel) = CreateChannel(); + + var ta1 = reference.Reader.ReadAllAsync().ToArrayAsync(cts2.Token); + var ta2 = channel.Reader.ReadAllAsync().ToArrayAsync(cts2.Token); + + foreach (var item in new[] { 10, 20, 30 }) + { + reference.Writer.TryWrite(item); + channel.Writer.TryWrite(item); + } + + cts2.Cancel(); + + await Assert.ThrowsAsync(async () => await ta1); + (await Assert.ThrowsAsync(async () => await ta2)).CancellationToken.Should().Be(cts2.Token); + } + } + } +} diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Channel.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Channel.cs new file mode 100644 index 0000000..bbed075 --- /dev/null +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Channel.cs @@ -0,0 +1,403 @@ +using System; +using System.Collections.Generic; +using System.Threading; + +namespace Cysharp.Threading.Tasks +{ + public static class Channel + { + public static Channel CreateSingleConsumerUnbounded() + { + return new SingleConsumerUnboundedChannel(); + } + } + + public abstract class Channel + { + public ChannelReader Reader { get; protected set; } + public ChannelWriter Writer { get; protected set; } + + public static implicit operator ChannelReader(Channel channel) => channel.Reader; + public static implicit operator ChannelWriter(Channel channel) => channel.Writer; + } + + public abstract class Channel : Channel + { + } + + public abstract class ChannelReader + { + public abstract bool TryRead(out T item); + public abstract UniTask WaitToReadAsync(CancellationToken cancellationToken = default(CancellationToken)); + + public abstract UniTask Completion { get; } + + public virtual UniTask ReadAsync(CancellationToken cancellationToken = default(CancellationToken)) + { + if (this.TryRead(out var item)) + { + return UniTask.FromResult(item); + } + + return ReadAsyncCore(cancellationToken); + } + + async UniTask ReadAsyncCore(CancellationToken cancellationToken = default(CancellationToken)) + { + if (await WaitToReadAsync(cancellationToken)) + { + if (TryRead(out var item)) + { + return item; + } + } + + throw new ChannelClosedException(); + } + + public abstract IUniTaskAsyncEnumerable ReadAllAsync(CancellationToken cancellationToken = default(CancellationToken)); + } + + public abstract class ChannelWriter + { + public abstract bool TryWrite(T item); + public abstract bool TryComplete(Exception error = null); + + public void Complete(Exception error = null) + { + if (!TryComplete(error)) + { + throw new ChannelClosedException(); + } + } + } + + public partial class ChannelClosedException : InvalidOperationException + { + public ChannelClosedException() : + base("Channel is already closed.") + { } + + public ChannelClosedException(string message) : base(message) { } + + public ChannelClosedException(Exception innerException) : + base("Channel is already closed", innerException) + { } + + public ChannelClosedException(string message, Exception innerException) : base(message, innerException) { } + } + + internal class SingleConsumerUnboundedChannel : Channel + { + readonly Queue items; + readonly SingleConsumerUnboundedChannelReader readerSource; + readonly UniTaskCompletionSource completedTask; + + Exception completionError; + bool closed; + + public SingleConsumerUnboundedChannel() + { + items = new Queue(); + completedTask = new UniTaskCompletionSource(); + Writer = new SingleConsumerUnboundedChannelWriter(this); + readerSource = new SingleConsumerUnboundedChannelReader(this); + Reader = readerSource; + } + + sealed class SingleConsumerUnboundedChannelWriter : ChannelWriter + { + readonly SingleConsumerUnboundedChannel parent; + + public SingleConsumerUnboundedChannelWriter(SingleConsumerUnboundedChannel parent) + { + this.parent = parent; + } + + public override bool TryWrite(T item) + { + bool waiting; + lock (parent.items) + { + if (parent.closed) return false; + + parent.items.Enqueue(item); + waiting = parent.readerSource.isWaiting; + } + + if (waiting) + { + parent.readerSource.SingalContinuation(); + } + + return true; + } + + public override bool TryComplete(Exception error = null) + { + bool waiting; + lock (parent.items) + { + if (parent.closed) return false; + parent.closed = true; + waiting = parent.readerSource.isWaiting; + + if (parent.items.Count == 0) + { + if (error == null) + { + parent.completedTask.TrySetResult(); + } + else + { + parent.completedTask.TrySetException(error); + } + + if (waiting) + { + parent.readerSource.SingalCompleted(error); + } + } + + parent.completionError = error; + } + + return true; + } + } + + sealed class SingleConsumerUnboundedChannelReader : ChannelReader, IUniTaskSource + { + readonly Action CancellationCallbackDelegate = CancellationCallback; + readonly SingleConsumerUnboundedChannel parent; + + CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + UniTaskCompletionSourceCore core; + internal bool isWaiting; + + public SingleConsumerUnboundedChannelReader(SingleConsumerUnboundedChannel parent) + { + this.parent = parent; + } + + public override UniTask Completion => parent.completedTask.Task; + + public override bool TryRead(out T item) + { + lock (parent.items) + { + if (parent.items.Count != 0) + { + item = parent.items.Dequeue(); + + // complete when all value was consumed. + if (parent.closed && parent.items.Count == 0) + { + if (parent.completionError != null) + { + parent.completedTask.TrySetException(parent.completionError); + } + else + { + parent.completedTask.TrySetResult(); + } + } + } + else + { + item = default; + return false; + } + } + + return true; + } + + public override UniTask WaitToReadAsync(CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + return UniTask.FromCanceled(cancellationToken); + } + + lock (parent.items) + { + if (parent.items.Count != 0) + { + return CompletedTasks.True; + } + + if (parent.closed) + { + if (parent.completionError == null) + { + return CompletedTasks.False; + } + else + { + return UniTask.FromException(parent.completionError); + } + } + + cancellationTokenRegistration.Dispose(); + + core.Reset(); + isWaiting = true; + + this.cancellationToken = cancellationToken; + if (this.cancellationToken.CanBeCanceled) + { + cancellationTokenRegistration = this.cancellationToken.RegisterWithoutCaptureExecutionContext(CancellationCallbackDelegate, this); + } + + return new UniTask(this, core.Version); + } + } + + public void SingalContinuation() + { + core.TrySetResult(true); + } + + public void SingalCancellation(CancellationToken cancellationToken) + { + core.TrySetCanceled(cancellationToken); + } + + public void SingalCompleted(Exception error) + { + if (error != null) + { + core.TrySetException(error); + } + else + { + core.TrySetResult(false); + } + } + + public override IUniTaskAsyncEnumerable ReadAllAsync(CancellationToken cancellationToken = default) + { + return new ReadAllAsyncEnumerable(this, cancellationToken); + } + + bool IUniTaskSource.GetResult(short token) + { + return core.GetResult(token); + } + + void IUniTaskSource.GetResult(short token) + { + core.GetResult(token); + } + + UniTaskStatus IUniTaskSource.GetStatus(short token) + { + return core.GetStatus(token); + } + + void IUniTaskSource.OnCompleted(Action continuation, object state, short token) + { + core.OnCompleted(continuation, state, token); + } + + UniTaskStatus IUniTaskSource.UnsafeGetStatus() + { + return core.UnsafeGetStatus(); + } + + static void CancellationCallback(object state) + { + var self = (SingleConsumerUnboundedChannelReader)state; + self.SingalCancellation(self.cancellationToken); + } + + sealed class ReadAllAsyncEnumerable : IUniTaskAsyncEnumerable, IUniTaskAsyncEnumerator + { + readonly Action CancellationCallback1Delegate = CancellationCallback1; + readonly Action CancellationCallback2Delegate = CancellationCallback2; + + readonly SingleConsumerUnboundedChannelReader parent; + CancellationToken cancellationToken1; + CancellationToken cancellationToken2; + CancellationTokenRegistration CancellationTokenRegistration1; + CancellationTokenRegistration CancellationTokenRegistration2; + + T current; + bool cacheValue; + bool running; + + public ReadAllAsyncEnumerable(SingleConsumerUnboundedChannelReader parent, CancellationToken cancellationToken) + { + this.parent = parent; + this.cancellationToken1 = cancellationToken; + } + + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + if (running) + { + throw new InvalidOperationException("Enumerator is already running, does not allow call GetAsyncEnumerator twice."); + } + + if (this.cancellationToken1 != cancellationToken) + { + this.cancellationToken2 = cancellationToken; + } + + if (this.cancellationToken1.CanBeCanceled) + { + this.cancellationToken1.RegisterWithoutCaptureExecutionContext(CancellationCallback1Delegate, this); + } + + if (this.cancellationToken2.CanBeCanceled) + { + this.cancellationToken2.RegisterWithoutCaptureExecutionContext(CancellationCallback2Delegate, this); + } + + running = true; + return this; + } + + public T Current + { + get + { + if (cacheValue) + { + return current; + } + parent.TryRead(out current); + return current; + } + } + + public UniTask MoveNextAsync() + { + cacheValue = false; + return parent.WaitToReadAsync(CancellationToken.None); // ok to use None, registered in ctor. + } + + public UniTask DisposeAsync() + { + CancellationTokenRegistration1.Dispose(); + CancellationTokenRegistration2.Dispose(); + return default; + } + + static void CancellationCallback1(object state) + { + var self = (ReadAllAsyncEnumerable)state; + self.parent.SingalCancellation(self.cancellationToken1); + } + + static void CancellationCallback2(object state) + { + var self = (ReadAllAsyncEnumerable)state; + self.parent.SingalCancellation(self.cancellationToken2); + } + } + } + } +} \ No newline at end of file diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Channel.cs.meta b/src/UniTask/Assets/Plugins/UniTask/Runtime/Channel.cs.meta new file mode 100644 index 0000000..32edb9c --- /dev/null +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Channel.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 5ceb3107bbdd1f14eb39091273798360 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: