From 769b5c6babf028661f4cabd5ca703ae277f78f40 Mon Sep 17 00:00:00 2001 From: neuecc Date: Thu, 18 Jun 2020 02:34:56 +0900 Subject: [PATCH] UniTaskCompletionSource can await multiple times(same behaviour as TaskCompletionSource) --- .../CompletionSourceTest.cs | 590 ++++++++++++++++++ .../UniTask/Runtime/UniTask.Factory.cs | 170 ++++- .../Runtime/UniTaskCompletionSource.cs | 544 +++++++++++----- 3 files changed, 1107 insertions(+), 197 deletions(-) create mode 100644 src/UniTask.NetCoreTests/CompletionSourceTest.cs diff --git a/src/UniTask.NetCoreTests/CompletionSourceTest.cs b/src/UniTask.NetCoreTests/CompletionSourceTest.cs new file mode 100644 index 0000000..4ab3ae1 --- /dev/null +++ b/src/UniTask.NetCoreTests/CompletionSourceTest.cs @@ -0,0 +1,590 @@ +using Cysharp.Threading.Tasks; +using FluentAssertions; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Channels; +using Cysharp.Threading.Tasks.Linq; +using System.Threading.Tasks; +using Xunit; + +namespace NetCoreTests +{ + public class CompletionSourceTest + { + [Fact] + public async Task SetFirst() + { + { + var tcs = new UniTaskCompletionSource(); + + tcs.TrySetResult(); + await tcs.Task; // ok. + await tcs.Task; // ok. + tcs.Task.Status.Should().Be(UniTaskStatus.Succeeded); + } + + { + var tcs = new UniTaskCompletionSource(); + + tcs.TrySetException(new TestException()); + + await Assert.ThrowsAsync(async () => await tcs.Task); + await Assert.ThrowsAsync(async () => await tcs.Task); + + tcs.Task.Status.Should().Be(UniTaskStatus.Faulted); + } + + var cts = new CancellationTokenSource(); + + { + var tcs = new UniTaskCompletionSource(); + + tcs.TrySetException(new OperationCanceledException(cts.Token)); + + (await Assert.ThrowsAsync(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token); + + tcs.Task.Status.Should().Be(UniTaskStatus.Canceled); + } + + { + var tcs = new UniTaskCompletionSource(); + + tcs.TrySetCanceled(cts.Token); + + (await Assert.ThrowsAsync(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token); + + tcs.Task.Status.Should().Be(UniTaskStatus.Canceled); + } + } + + [Fact] + public async Task SingleOnFirst() + { + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + await tcs.Task; + } + + var a = Await(); + + tcs.TrySetResult(); + await a; + await tcs.Task; // ok. + tcs.Task.Status.Should().Be(UniTaskStatus.Succeeded); + } + + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + await tcs.Task; + } + + var a = Await(); + + tcs.TrySetException(new TestException()); + await Assert.ThrowsAsync(async () => await a); + await Assert.ThrowsAsync(async () => await tcs.Task); + tcs.Task.Status.Should().Be(UniTaskStatus.Faulted); + } + + var cts = new CancellationTokenSource(); + + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + await tcs.Task; + } + + var a = Await(); + + tcs.TrySetException(new OperationCanceledException(cts.Token)); + (await Assert.ThrowsAsync(async () => await a)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token); + tcs.Task.Status.Should().Be(UniTaskStatus.Canceled); + } + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + await tcs.Task; + } + + var a = Await(); + + tcs.TrySetCanceled(cts.Token); + (await Assert.ThrowsAsync(async () => await a)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token); + tcs.Task.Status.Should().Be(UniTaskStatus.Canceled); + } + } + + [Fact] + public async Task MultiOne() + { + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + await tcs.Task; + } + + var a = Await(); + var b = Await(); + tcs.TrySetResult(); + await a; + await b; + await tcs.Task; // ok. + tcs.Task.Status.Should().Be(UniTaskStatus.Succeeded); + } + + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + await tcs.Task; + } + + var a = Await(); + var b = Await(); + + tcs.TrySetException(new TestException()); + await Assert.ThrowsAsync(async () => await a); + await Assert.ThrowsAsync(async () => await b); + await Assert.ThrowsAsync(async () => await tcs.Task); + tcs.Task.Status.Should().Be(UniTaskStatus.Faulted); + } + + var cts = new CancellationTokenSource(); + + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + await tcs.Task; + } + + var a = Await(); + var b = Await(); + + tcs.TrySetException(new OperationCanceledException(cts.Token)); + (await Assert.ThrowsAsync(async () => await a)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await b)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token); + tcs.Task.Status.Should().Be(UniTaskStatus.Canceled); + } + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + await tcs.Task; + } + + var a = Await(); + var b = Await(); + + tcs.TrySetCanceled(cts.Token); + (await Assert.ThrowsAsync(async () => await a)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await b)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token); + tcs.Task.Status.Should().Be(UniTaskStatus.Canceled); + } + } + + [Fact] + public async Task MultiTwo() + { + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + await tcs.Task; + } + + var a = Await(); + var b = Await(); + var c = Await(); + tcs.TrySetResult(); + await a; + await b; + await c; + await tcs.Task; // ok. + tcs.Task.Status.Should().Be(UniTaskStatus.Succeeded); + } + + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + await tcs.Task; + } + + var a = Await(); + var b = Await(); + var c = Await(); + + tcs.TrySetException(new TestException()); + await Assert.ThrowsAsync(async () => await a); + await Assert.ThrowsAsync(async () => await b); + await Assert.ThrowsAsync(async () => await c); + await Assert.ThrowsAsync(async () => await tcs.Task); + tcs.Task.Status.Should().Be(UniTaskStatus.Faulted); + } + + var cts = new CancellationTokenSource(); + + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + await tcs.Task; + } + + var a = Await(); + var b = Await(); + var c = Await(); + + tcs.TrySetException(new OperationCanceledException(cts.Token)); + (await Assert.ThrowsAsync(async () => await a)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await b)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await c)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token); + tcs.Task.Status.Should().Be(UniTaskStatus.Canceled); + } + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + await tcs.Task; + } + + var a = Await(); + var b = Await(); + var c = Await(); + + tcs.TrySetCanceled(cts.Token); + (await Assert.ThrowsAsync(async () => await a)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await b)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await c)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token); + tcs.Task.Status.Should().Be(UniTaskStatus.Canceled); + } + } + + class TestException : Exception + { + + } + } + + public class CompletionSourceTest2 + { + [Fact] + public async Task SetFirst() + { + { + var tcs = new UniTaskCompletionSource(); + + tcs.TrySetResult(10); + var a = await tcs.Task; // ok. + var b = await tcs.Task; // ok. + a.Should().Be(10); + b.Should().Be(10); + tcs.Task.Status.Should().Be(UniTaskStatus.Succeeded); + } + + { + var tcs = new UniTaskCompletionSource(); + + tcs.TrySetException(new TestException()); + + await Assert.ThrowsAsync(async () => await tcs.Task); + await Assert.ThrowsAsync(async () => await tcs.Task); + + tcs.Task.Status.Should().Be(UniTaskStatus.Faulted); + } + + var cts = new CancellationTokenSource(); + + { + var tcs = new UniTaskCompletionSource(); + + tcs.TrySetException(new OperationCanceledException(cts.Token)); + + (await Assert.ThrowsAsync(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token); + + tcs.Task.Status.Should().Be(UniTaskStatus.Canceled); + } + + { + var tcs = new UniTaskCompletionSource(); + + tcs.TrySetCanceled(cts.Token); + + (await Assert.ThrowsAsync(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token); + + tcs.Task.Status.Should().Be(UniTaskStatus.Canceled); + } + } + + [Fact] + public async Task SingleOnFirst() + { + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + return await tcs.Task; + } + + var a = Await(); + + tcs.TrySetResult(10); + var r1 = await a; + var r2 = await tcs.Task; // ok. + r1.Should().Be(10); + r2.Should().Be(10); + tcs.Task.Status.Should().Be(UniTaskStatus.Succeeded); + } + + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + return await tcs.Task; + } + + var a = Await(); + + tcs.TrySetException(new TestException()); + await Assert.ThrowsAsync(async () => await a); + await Assert.ThrowsAsync(async () => await tcs.Task); + tcs.Task.Status.Should().Be(UniTaskStatus.Faulted); + } + + var cts = new CancellationTokenSource(); + + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + return await tcs.Task; + } + + var a = Await(); + + tcs.TrySetException(new OperationCanceledException(cts.Token)); + (await Assert.ThrowsAsync(async () => await a)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token); + tcs.Task.Status.Should().Be(UniTaskStatus.Canceled); + } + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + return await tcs.Task; + } + + var a = Await(); + + tcs.TrySetCanceled(cts.Token); + (await Assert.ThrowsAsync(async () => await a)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token); + tcs.Task.Status.Should().Be(UniTaskStatus.Canceled); + } + } + + [Fact] + public async Task MultiOne() + { + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + return await tcs.Task; + } + + var a = Await(); + var b = Await(); + tcs.TrySetResult(10); + var r1 = await a; + var r2 = await b; + var r3 = await tcs.Task; // ok. + (r1, r2, r3).Should().Be((10, 10, 10)); + tcs.Task.Status.Should().Be(UniTaskStatus.Succeeded); + } + + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + return await tcs.Task; + } + + var a = Await(); + var b = Await(); + + tcs.TrySetException(new TestException()); + await Assert.ThrowsAsync(async () => await a); + await Assert.ThrowsAsync(async () => await b); + await Assert.ThrowsAsync(async () => await tcs.Task); + tcs.Task.Status.Should().Be(UniTaskStatus.Faulted); + } + + var cts = new CancellationTokenSource(); + + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + return await tcs.Task; + } + + var a = Await(); + var b = Await(); + + tcs.TrySetException(new OperationCanceledException(cts.Token)); + (await Assert.ThrowsAsync(async () => await a)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await b)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token); + tcs.Task.Status.Should().Be(UniTaskStatus.Canceled); + } + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + return await tcs.Task; + } + + var a = Await(); + var b = Await(); + + tcs.TrySetCanceled(cts.Token); + (await Assert.ThrowsAsync(async () => await a)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await b)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token); + tcs.Task.Status.Should().Be(UniTaskStatus.Canceled); + } + } + + [Fact] + public async Task MultiTwo() + { + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + return await tcs.Task; + } + + var a = Await(); + var b = Await(); + var c = Await(); + tcs.TrySetResult(10); + var r1 = await a; + var r2 = await b; + var r3 = await c; + var r4 = await tcs.Task; // ok. + (r1, r2, r3, r4).Should().Be((10, 10, 10, 10)); + tcs.Task.Status.Should().Be(UniTaskStatus.Succeeded); + } + + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + return await tcs.Task; + } + + var a = Await(); + var b = Await(); + var c = Await(); + + tcs.TrySetException(new TestException()); + await Assert.ThrowsAsync(async () => await a); + await Assert.ThrowsAsync(async () => await b); + await Assert.ThrowsAsync(async () => await c); + await Assert.ThrowsAsync(async () => await tcs.Task); + tcs.Task.Status.Should().Be(UniTaskStatus.Faulted); + } + + var cts = new CancellationTokenSource(); + + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + return await tcs.Task; + } + + var a = Await(); + var b = Await(); + var c = Await(); + + tcs.TrySetException(new OperationCanceledException(cts.Token)); + (await Assert.ThrowsAsync(async () => await a)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await b)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await c)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token); + tcs.Task.Status.Should().Be(UniTaskStatus.Canceled); + } + { + var tcs = new UniTaskCompletionSource(); + + async UniTask Await() + { + return await tcs.Task; + } + + var a = Await(); + var b = Await(); + var c = Await(); + + tcs.TrySetCanceled(cts.Token); + (await Assert.ThrowsAsync(async () => await a)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await b)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await c)).CancellationToken.Should().Be(cts.Token); + (await Assert.ThrowsAsync(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token); + tcs.Task.Status.Should().Be(UniTaskStatus.Canceled); + } + } + + class TestException : Exception + { + + } + } +} diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Factory.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Factory.cs index 915ee78..387dca8 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Factory.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Factory.cs @@ -11,10 +11,7 @@ namespace Cysharp.Threading.Tasks { static readonly UniTask CanceledUniTask = new Func(() => { - var promise = new UniTaskCompletionSource(); - promise.TrySetCanceled(CancellationToken.None); - promise.MarkHandled(); - return promise.Task; + return new UniTask(new CanceledResultSource(CancellationToken.None), 0); })(); static class CanceledUniTaskCache @@ -23,10 +20,7 @@ namespace Cysharp.Threading.Tasks static CanceledUniTaskCache() { - var promise = new UniTaskCompletionSource(); - promise.TrySetCanceled(CancellationToken.None); - promise.MarkHandled(); - Task = promise.Task; + Task = new UniTask(new CanceledResultSource(CancellationToken.None), 0); } } @@ -34,18 +28,22 @@ namespace Cysharp.Threading.Tasks public static UniTask FromException(Exception ex) { - var promise = new UniTaskCompletionSource(); - promise.TrySetException(ex); - promise.MarkHandled(); - return promise.Task; + if (ex is OperationCanceledException oce) + { + return FromCanceled(oce.CancellationToken); + } + + return new UniTask(new ExceptionResultSource(ex), 0); } public static UniTask FromException(Exception ex) { - var promise = new UniTaskCompletionSource(); - promise.TrySetException(ex); - promise.MarkHandled(); - return promise.Task; + if (ex is OperationCanceledException oce) + { + return FromCanceled(oce.CancellationToken); + } + + return new UniTask(new ExceptionResultSource(ex), 0); } public static UniTask FromResult(T value) @@ -61,10 +59,7 @@ namespace Cysharp.Threading.Tasks } else { - var promise = new UniTaskCompletionSource(); - promise.TrySetCanceled(cancellationToken); - promise.MarkHandled(); - return promise.Task; + return new UniTask(new CanceledResultSource(cancellationToken), 0); } } @@ -76,10 +71,7 @@ namespace Cysharp.Threading.Tasks } else { - var promise = new UniTaskCompletionSource(); - promise.TrySetCanceled(cancellationToken); - promise.MarkHandled(); - return promise.Task; + return new UniTask(new CanceledResultSource(cancellationToken), 0); } } @@ -182,6 +174,136 @@ namespace Cysharp.Threading.Tasks return new UniTask(new DeferPromise(factory), 0); } + sealed class ExceptionResultSource : IUniTaskSource + { + readonly Exception exception; + + public ExceptionResultSource(Exception exception) + { + this.exception = exception; + } + + public void GetResult(short token) + { + throw exception; + } + + public UniTaskStatus GetStatus(short token) + { + return UniTaskStatus.Faulted; + } + + public UniTaskStatus UnsafeGetStatus() + { + return UniTaskStatus.Faulted; + } + + public void OnCompleted(Action continuation, object state, short token) + { + continuation(state); + } + } + + sealed class ExceptionResultSource : IUniTaskSource + { + readonly Exception exception; + + public ExceptionResultSource(Exception exception) + { + this.exception = exception; + } + + public T GetResult(short token) + { + throw exception; + } + + void IUniTaskSource.GetResult(short token) + { + throw exception; + } + + public UniTaskStatus GetStatus(short token) + { + return UniTaskStatus.Faulted; + } + + public UniTaskStatus UnsafeGetStatus() + { + return UniTaskStatus.Faulted; + } + + public void OnCompleted(Action continuation, object state, short token) + { + continuation(state); + } + } + + sealed class CanceledResultSource : IUniTaskSource + { + readonly CancellationToken cancellationToken; + + public CanceledResultSource(CancellationToken cancellationToken) + { + this.cancellationToken = cancellationToken; + } + + public void GetResult(short token) + { + throw new OperationCanceledException(cancellationToken); + } + + public UniTaskStatus GetStatus(short token) + { + return UniTaskStatus.Canceled; + } + + public UniTaskStatus UnsafeGetStatus() + { + return UniTaskStatus.Canceled; + } + + public void OnCompleted(Action continuation, object state, short token) + { + continuation(state); + } + } + + sealed class CanceledResultSource : IUniTaskSource + { + readonly CancellationToken cancellationToken; + + public CanceledResultSource(CancellationToken cancellationToken) + { + this.cancellationToken = cancellationToken; + } + + public T GetResult(short token) + { + throw new OperationCanceledException(cancellationToken); + } + + void IUniTaskSource.GetResult(short token) + { + throw new OperationCanceledException(cancellationToken); + } + + public UniTaskStatus GetStatus(short token) + { + return UniTaskStatus.Canceled; + } + + public UniTaskStatus UnsafeGetStatus() + { + return UniTaskStatus.Canceled; + } + + public void OnCompleted(Action continuation, object state, short token) + { + continuation(state); + } + } + sealed class DeferPromise : IUniTaskSource { Func factory; diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTaskCompletionSource.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTaskCompletionSource.cs index 02d6ad7..acc3000 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTaskCompletionSource.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTaskCompletionSource.cs @@ -1,6 +1,7 @@ #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member using System; +using System.Collections.Generic; using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.ExceptionServices; @@ -315,89 +316,6 @@ namespace Cysharp.Threading.Tasks } } - public class UniTaskCompletionSource : IUniTaskSource, IPromise - { - UniTaskCompletionSourceCore core; - bool handled = false; - - public UniTaskCompletionSource() - { - TaskTracker.TrackActiveTask(this, 2); - } - - [DebuggerHidden] - internal void MarkHandled() - { - if (!handled) - { - handled = true; - core.MarkHandled(); - TaskTracker.RemoveTracking(this); - } - } - - public UniTask Task - { - [DebuggerHidden] - get - { - return new UniTask(this, core.Version); - } - } - - [DebuggerHidden] - public void Reset() - { - // Reset, re-active tracker - handled = false; - TaskTracker.TrackActiveTask(this, 2); - core.Reset(); - } - - [DebuggerHidden] - public bool TrySetResult() - { - return core.TrySetResult(AsyncUnit.Default); - } - - [DebuggerHidden] - public bool TrySetCanceled(CancellationToken cancellationToken = default) - { - return core.TrySetCanceled(cancellationToken); - } - - [DebuggerHidden] - public bool TrySetException(Exception exception) - { - return core.TrySetException(exception); - } - - [DebuggerHidden] - public void GetResult(short token) - { - MarkHandled(); - core.GetResult(token); - } - - [DebuggerHidden] - public UniTaskStatus GetStatus(short token) - { - return core.GetStatus(token); - } - - [DebuggerHidden] - public UniTaskStatus UnsafeGetStatus() - { - return core.UnsafeGetStatus(); - } - - [DebuggerHidden] - public void OnCompleted(Action continuation, object state, short token) - { - core.OnCompleted(continuation, state, token); - } - } - public class AutoResetUniTaskCompletionSource : IUniTaskSource, ITaskPoolNode, IPromise { static TaskPool pool; @@ -520,95 +438,6 @@ namespace Cysharp.Threading.Tasks } } - public class UniTaskCompletionSource : IUniTaskSource, IPromise - { - UniTaskCompletionSourceCore core; - bool handled = false; - - [DebuggerHidden] - public UniTaskCompletionSource() - { - TaskTracker.TrackActiveTask(this, 2); - } - - [DebuggerHidden] - internal void MarkHandled() - { - if (!handled) - { - handled = true; - core.MarkHandled(); - TaskTracker.RemoveTracking(this); - } - } - - [DebuggerHidden] - public UniTask Task - { - get - { - return new UniTask(this, core.Version); - } - } - - [DebuggerHidden] - public void Reset() - { - handled = false; - core.Reset(); - TaskTracker.TrackActiveTask(this, 2); - } - - [DebuggerHidden] - public bool TrySetResult(T result) - { - return core.TrySetResult(result); - } - - [DebuggerHidden] - public bool TrySetCanceled(CancellationToken cancellationToken = default) - { - return core.TrySetCanceled(cancellationToken); - } - - [DebuggerHidden] - public bool TrySetException(Exception exception) - { - return core.TrySetException(exception); - } - - [DebuggerHidden] - public T GetResult(short token) - { - MarkHandled(); - return core.GetResult(token); - } - - [DebuggerHidden] - void IUniTaskSource.GetResult(short token) - { - GetResult(token); - } - - [DebuggerHidden] - public UniTaskStatus GetStatus(short token) - { - return core.GetStatus(token); - } - - [DebuggerHidden] - public UniTaskStatus UnsafeGetStatus() - { - return core.UnsafeGetStatus(); - } - - [DebuggerHidden] - public void OnCompleted(Action continuation, object state, short token) - { - core.OnCompleted(continuation, state, token); - } - } - public class AutoResetUniTaskCompletionSource : IUniTaskSource, ITaskPoolNode>, IPromise { static TaskPool> pool; @@ -735,5 +564,374 @@ namespace Cysharp.Threading.Tasks return pool.TryPush(this); } } -} + public class UniTaskCompletionSource : IUniTaskSource, IPromise + { + CancellationToken cancellationToken; + ExceptionHolder exception; + object gate; + Action singleContinuation; + object singleState; + List<(Action, object)> secondaryContinuationList; + + int intStatus; // UniTaskStatus + bool handled = false; + + public UniTaskCompletionSource() + { + TaskTracker.TrackActiveTask(this, 2); + } + + [DebuggerHidden] + internal void MarkHandled() + { + if (!handled) + { + handled = true; + TaskTracker.RemoveTracking(this); + } + } + + public UniTask Task + { + [DebuggerHidden] + get + { + return new UniTask(this, 0); + } + } + + [DebuggerHidden] + public bool TrySetResult() + { + return TrySignalCompletion(UniTaskStatus.Succeeded); + } + + [DebuggerHidden] + public bool TrySetCanceled(CancellationToken cancellationToken = default) + { + if (UnsafeGetStatus() != UniTaskStatus.Pending) return false; + + this.cancellationToken = cancellationToken; + return TrySignalCompletion(UniTaskStatus.Canceled); + } + + [DebuggerHidden] + public bool TrySetException(Exception exception) + { + if (exception is OperationCanceledException oce) + { + return TrySetCanceled(oce.CancellationToken); + } + + if (UnsafeGetStatus() != UniTaskStatus.Pending) return false; + + this.exception = new ExceptionHolder(ExceptionDispatchInfo.Capture(exception)); + return TrySignalCompletion(UniTaskStatus.Faulted); + } + + [DebuggerHidden] + public void GetResult(short token) + { + MarkHandled(); + + var status = (UniTaskStatus)intStatus; + switch (status) + { + case UniTaskStatus.Succeeded: + return; + case UniTaskStatus.Faulted: + exception.GetException().Throw(); + return; + case UniTaskStatus.Canceled: + throw new OperationCanceledException(cancellationToken); + default: + case UniTaskStatus.Pending: + throw new InvalidOperationException("not yet completed."); + } + } + + [DebuggerHidden] + public UniTaskStatus GetStatus(short token) + { + return (UniTaskStatus)intStatus; + } + + [DebuggerHidden] + public UniTaskStatus UnsafeGetStatus() + { + return (UniTaskStatus)intStatus; + } + + [DebuggerHidden] + public void OnCompleted(Action continuation, object state, short token) + { + if (gate == null) + { + Interlocked.CompareExchange(ref gate, new object(), null); + } + + var lockGate = Thread.VolatileRead(ref gate); + lock (lockGate) // wait TrySignalCompletion, after status is not pending. + { + if ((UniTaskStatus)intStatus != UniTaskStatus.Pending) + { + continuation(state); + return; + } + + if (singleContinuation == null) + { + singleContinuation = continuation; + singleState = state; + } + else + { + if (secondaryContinuationList == null) + { + secondaryContinuationList = new List<(Action, object)>(); + } + secondaryContinuationList.Add((continuation, state)); + } + } + } + + bool TrySignalCompletion(UniTaskStatus status) + { + if (Interlocked.CompareExchange(ref intStatus, (int)status, (int)UniTaskStatus.Pending) == (int)UniTaskStatus.Pending) + { + if (gate == null) + { + Interlocked.CompareExchange(ref gate, new object(), null); + } + + var lockGate = Thread.VolatileRead(ref gate); + lock (lockGate) // wait OnCompleted. + { + if (singleContinuation != null) + { + try + { + singleContinuation(singleState); + } + catch (Exception ex) + { + UniTaskScheduler.PublishUnobservedTaskException(ex); + } + } + + if (secondaryContinuationList != null) + { + foreach (var (c, state) in secondaryContinuationList) + { + try + { + c(state); + } + catch (Exception ex) + { + UniTaskScheduler.PublishUnobservedTaskException(ex); + } + } + } + + singleContinuation = null; + singleState = null; + secondaryContinuationList = null; + } + return true; + } + return false; + } + } + + public class UniTaskCompletionSource : IUniTaskSource, IPromise + { + CancellationToken cancellationToken; + T result; + ExceptionHolder exception; + object gate; + Action singleContinuation; + object singleState; + List<(Action, object)> secondaryContinuationList; + + int intStatus; // UniTaskStatus + bool handled = false; + + public UniTaskCompletionSource() + { + TaskTracker.TrackActiveTask(this, 2); + } + + [DebuggerHidden] + internal void MarkHandled() + { + if (!handled) + { + handled = true; + TaskTracker.RemoveTracking(this); + } + } + + public UniTask Task + { + [DebuggerHidden] + get + { + return new UniTask(this, 0); + } + } + + [DebuggerHidden] + public bool TrySetResult(T result) + { + if (UnsafeGetStatus() != UniTaskStatus.Pending) return false; + + this.result = result; + return TrySignalCompletion(UniTaskStatus.Succeeded); + } + + [DebuggerHidden] + public bool TrySetCanceled(CancellationToken cancellationToken = default) + { + if (UnsafeGetStatus() != UniTaskStatus.Pending) return false; + + this.cancellationToken = cancellationToken; + return TrySignalCompletion(UniTaskStatus.Canceled); + } + + [DebuggerHidden] + public bool TrySetException(Exception exception) + { + if (exception is OperationCanceledException oce) + { + return TrySetCanceled(oce.CancellationToken); + } + + if (UnsafeGetStatus() != UniTaskStatus.Pending) return false; + + this.exception = new ExceptionHolder(ExceptionDispatchInfo.Capture(exception)); + return TrySignalCompletion(UniTaskStatus.Faulted); + } + + [DebuggerHidden] + public T GetResult(short token) + { + MarkHandled(); + + var status = (UniTaskStatus)intStatus; + switch (status) + { + case UniTaskStatus.Succeeded: + return result; + case UniTaskStatus.Faulted: + exception.GetException().Throw(); + return default; + case UniTaskStatus.Canceled: + throw new OperationCanceledException(cancellationToken); + default: + case UniTaskStatus.Pending: + throw new InvalidOperationException("not yet completed."); + } + } + + [DebuggerHidden] + void IUniTaskSource.GetResult(short token) + { + GetResult(token); + } + + [DebuggerHidden] + public UniTaskStatus GetStatus(short token) + { + return (UniTaskStatus)intStatus; + } + + [DebuggerHidden] + public UniTaskStatus UnsafeGetStatus() + { + return (UniTaskStatus)intStatus; + } + + [DebuggerHidden] + public void OnCompleted(Action continuation, object state, short token) + { + if (gate == null) + { + Interlocked.CompareExchange(ref gate, new object(), null); + } + + var lockGate = Thread.VolatileRead(ref gate); + lock (lockGate) // wait TrySignalCompletion, after status is not pending. + { + if ((UniTaskStatus)intStatus != UniTaskStatus.Pending) + { + continuation(state); + return; + } + + if (singleContinuation == null) + { + singleContinuation = continuation; + singleState = state; + } + else + { + if (secondaryContinuationList == null) + { + secondaryContinuationList = new List<(Action, object)>(); + } + secondaryContinuationList.Add((continuation, state)); + } + } + } + + bool TrySignalCompletion(UniTaskStatus status) + { + if (Interlocked.CompareExchange(ref intStatus, (int)status, (int)UniTaskStatus.Pending) == (int)UniTaskStatus.Pending) + { + if (gate == null) + { + Interlocked.CompareExchange(ref gate, new object(), null); + } + + var lockGate = Thread.VolatileRead(ref gate); + lock (lockGate) // wait OnCompleted. + { + if (singleContinuation != null) + { + try + { + singleContinuation(singleState); + } + catch (Exception ex) + { + UniTaskScheduler.PublishUnobservedTaskException(ex); + } + } + + if (secondaryContinuationList != null) + { + foreach (var (c, state) in secondaryContinuationList) + { + try + { + c(state); + } + catch (Exception ex) + { + UniTaskScheduler.PublishUnobservedTaskException(ex); + } + } + } + + singleContinuation = null; + singleState = null; + secondaryContinuationList = null; + } + return true; + } + return false; + } + } +} \ No newline at end of file