UniTaskCompletionSource can await multiple times(same behaviour as TaskCompletionSource)

master
neuecc 2020-06-18 02:34:56 +09:00
parent bdd569e213
commit 769b5c6bab
3 changed files with 1107 additions and 197 deletions

View File

@ -0,0 +1,590 @@
using Cysharp.Threading.Tasks;
using FluentAssertions;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Channels;
using Cysharp.Threading.Tasks.Linq;
using System.Threading.Tasks;
using Xunit;
namespace NetCoreTests
{
public class CompletionSourceTest
{
[Fact]
public async Task SetFirst()
{
{
var tcs = new UniTaskCompletionSource();
tcs.TrySetResult();
await tcs.Task; // ok.
await tcs.Task; // ok.
tcs.Task.Status.Should().Be(UniTaskStatus.Succeeded);
}
{
var tcs = new UniTaskCompletionSource();
tcs.TrySetException(new TestException());
await Assert.ThrowsAsync<TestException>(async () => await tcs.Task);
await Assert.ThrowsAsync<TestException>(async () => await tcs.Task);
tcs.Task.Status.Should().Be(UniTaskStatus.Faulted);
}
var cts = new CancellationTokenSource();
{
var tcs = new UniTaskCompletionSource();
tcs.TrySetException(new OperationCanceledException(cts.Token));
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token);
tcs.Task.Status.Should().Be(UniTaskStatus.Canceled);
}
{
var tcs = new UniTaskCompletionSource();
tcs.TrySetCanceled(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token);
tcs.Task.Status.Should().Be(UniTaskStatus.Canceled);
}
}
[Fact]
public async Task SingleOnFirst()
{
{
var tcs = new UniTaskCompletionSource();
async UniTask Await()
{
await tcs.Task;
}
var a = Await();
tcs.TrySetResult();
await a;
await tcs.Task; // ok.
tcs.Task.Status.Should().Be(UniTaskStatus.Succeeded);
}
{
var tcs = new UniTaskCompletionSource();
async UniTask Await()
{
await tcs.Task;
}
var a = Await();
tcs.TrySetException(new TestException());
await Assert.ThrowsAsync<TestException>(async () => await a);
await Assert.ThrowsAsync<TestException>(async () => await tcs.Task);
tcs.Task.Status.Should().Be(UniTaskStatus.Faulted);
}
var cts = new CancellationTokenSource();
{
var tcs = new UniTaskCompletionSource();
async UniTask Await()
{
await tcs.Task;
}
var a = Await();
tcs.TrySetException(new OperationCanceledException(cts.Token));
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await a)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token);
tcs.Task.Status.Should().Be(UniTaskStatus.Canceled);
}
{
var tcs = new UniTaskCompletionSource();
async UniTask Await()
{
await tcs.Task;
}
var a = Await();
tcs.TrySetCanceled(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await a)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token);
tcs.Task.Status.Should().Be(UniTaskStatus.Canceled);
}
}
[Fact]
public async Task MultiOne()
{
{
var tcs = new UniTaskCompletionSource();
async UniTask Await()
{
await tcs.Task;
}
var a = Await();
var b = Await();
tcs.TrySetResult();
await a;
await b;
await tcs.Task; // ok.
tcs.Task.Status.Should().Be(UniTaskStatus.Succeeded);
}
{
var tcs = new UniTaskCompletionSource();
async UniTask Await()
{
await tcs.Task;
}
var a = Await();
var b = Await();
tcs.TrySetException(new TestException());
await Assert.ThrowsAsync<TestException>(async () => await a);
await Assert.ThrowsAsync<TestException>(async () => await b);
await Assert.ThrowsAsync<TestException>(async () => await tcs.Task);
tcs.Task.Status.Should().Be(UniTaskStatus.Faulted);
}
var cts = new CancellationTokenSource();
{
var tcs = new UniTaskCompletionSource();
async UniTask Await()
{
await tcs.Task;
}
var a = Await();
var b = Await();
tcs.TrySetException(new OperationCanceledException(cts.Token));
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await a)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await b)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token);
tcs.Task.Status.Should().Be(UniTaskStatus.Canceled);
}
{
var tcs = new UniTaskCompletionSource();
async UniTask Await()
{
await tcs.Task;
}
var a = Await();
var b = Await();
tcs.TrySetCanceled(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await a)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await b)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token);
tcs.Task.Status.Should().Be(UniTaskStatus.Canceled);
}
}
[Fact]
public async Task MultiTwo()
{
{
var tcs = new UniTaskCompletionSource();
async UniTask Await()
{
await tcs.Task;
}
var a = Await();
var b = Await();
var c = Await();
tcs.TrySetResult();
await a;
await b;
await c;
await tcs.Task; // ok.
tcs.Task.Status.Should().Be(UniTaskStatus.Succeeded);
}
{
var tcs = new UniTaskCompletionSource();
async UniTask Await()
{
await tcs.Task;
}
var a = Await();
var b = Await();
var c = Await();
tcs.TrySetException(new TestException());
await Assert.ThrowsAsync<TestException>(async () => await a);
await Assert.ThrowsAsync<TestException>(async () => await b);
await Assert.ThrowsAsync<TestException>(async () => await c);
await Assert.ThrowsAsync<TestException>(async () => await tcs.Task);
tcs.Task.Status.Should().Be(UniTaskStatus.Faulted);
}
var cts = new CancellationTokenSource();
{
var tcs = new UniTaskCompletionSource();
async UniTask Await()
{
await tcs.Task;
}
var a = Await();
var b = Await();
var c = Await();
tcs.TrySetException(new OperationCanceledException(cts.Token));
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await a)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await b)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await c)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token);
tcs.Task.Status.Should().Be(UniTaskStatus.Canceled);
}
{
var tcs = new UniTaskCompletionSource();
async UniTask Await()
{
await tcs.Task;
}
var a = Await();
var b = Await();
var c = Await();
tcs.TrySetCanceled(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await a)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await b)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await c)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token);
tcs.Task.Status.Should().Be(UniTaskStatus.Canceled);
}
}
class TestException : Exception
{
}
}
public class CompletionSourceTest2
{
[Fact]
public async Task SetFirst()
{
{
var tcs = new UniTaskCompletionSource<int>();
tcs.TrySetResult(10);
var a = await tcs.Task; // ok.
var b = await tcs.Task; // ok.
a.Should().Be(10);
b.Should().Be(10);
tcs.Task.Status.Should().Be(UniTaskStatus.Succeeded);
}
{
var tcs = new UniTaskCompletionSource<int>();
tcs.TrySetException(new TestException());
await Assert.ThrowsAsync<TestException>(async () => await tcs.Task);
await Assert.ThrowsAsync<TestException>(async () => await tcs.Task);
tcs.Task.Status.Should().Be(UniTaskStatus.Faulted);
}
var cts = new CancellationTokenSource();
{
var tcs = new UniTaskCompletionSource<int>();
tcs.TrySetException(new OperationCanceledException(cts.Token));
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token);
tcs.Task.Status.Should().Be(UniTaskStatus.Canceled);
}
{
var tcs = new UniTaskCompletionSource<int>();
tcs.TrySetCanceled(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token);
tcs.Task.Status.Should().Be(UniTaskStatus.Canceled);
}
}
[Fact]
public async Task SingleOnFirst()
{
{
var tcs = new UniTaskCompletionSource<int>();
async UniTask<int> Await()
{
return await tcs.Task;
}
var a = Await();
tcs.TrySetResult(10);
var r1 = await a;
var r2 = await tcs.Task; // ok.
r1.Should().Be(10);
r2.Should().Be(10);
tcs.Task.Status.Should().Be(UniTaskStatus.Succeeded);
}
{
var tcs = new UniTaskCompletionSource<int>();
async UniTask<int> Await()
{
return await tcs.Task;
}
var a = Await();
tcs.TrySetException(new TestException());
await Assert.ThrowsAsync<TestException>(async () => await a);
await Assert.ThrowsAsync<TestException>(async () => await tcs.Task);
tcs.Task.Status.Should().Be(UniTaskStatus.Faulted);
}
var cts = new CancellationTokenSource();
{
var tcs = new UniTaskCompletionSource<int>();
async UniTask<int> Await()
{
return await tcs.Task;
}
var a = Await();
tcs.TrySetException(new OperationCanceledException(cts.Token));
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await a)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token);
tcs.Task.Status.Should().Be(UniTaskStatus.Canceled);
}
{
var tcs = new UniTaskCompletionSource<int>();
async UniTask<int> Await()
{
return await tcs.Task;
}
var a = Await();
tcs.TrySetCanceled(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await a)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token);
tcs.Task.Status.Should().Be(UniTaskStatus.Canceled);
}
}
[Fact]
public async Task MultiOne()
{
{
var tcs = new UniTaskCompletionSource<int>();
async UniTask<int> Await()
{
return await tcs.Task;
}
var a = Await();
var b = Await();
tcs.TrySetResult(10);
var r1 = await a;
var r2 = await b;
var r3 = await tcs.Task; // ok.
(r1, r2, r3).Should().Be((10, 10, 10));
tcs.Task.Status.Should().Be(UniTaskStatus.Succeeded);
}
{
var tcs = new UniTaskCompletionSource<int>();
async UniTask<int> Await()
{
return await tcs.Task;
}
var a = Await();
var b = Await();
tcs.TrySetException(new TestException());
await Assert.ThrowsAsync<TestException>(async () => await a);
await Assert.ThrowsAsync<TestException>(async () => await b);
await Assert.ThrowsAsync<TestException>(async () => await tcs.Task);
tcs.Task.Status.Should().Be(UniTaskStatus.Faulted);
}
var cts = new CancellationTokenSource();
{
var tcs = new UniTaskCompletionSource<int>();
async UniTask<int> Await()
{
return await tcs.Task;
}
var a = Await();
var b = Await();
tcs.TrySetException(new OperationCanceledException(cts.Token));
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await a)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await b)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token);
tcs.Task.Status.Should().Be(UniTaskStatus.Canceled);
}
{
var tcs = new UniTaskCompletionSource<int>();
async UniTask<int> Await()
{
return await tcs.Task;
}
var a = Await();
var b = Await();
tcs.TrySetCanceled(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await a)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await b)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token);
tcs.Task.Status.Should().Be(UniTaskStatus.Canceled);
}
}
[Fact]
public async Task MultiTwo()
{
{
var tcs = new UniTaskCompletionSource<int>();
async UniTask<int> Await()
{
return await tcs.Task;
}
var a = Await();
var b = Await();
var c = Await();
tcs.TrySetResult(10);
var r1 = await a;
var r2 = await b;
var r3 = await c;
var r4 = await tcs.Task; // ok.
(r1, r2, r3, r4).Should().Be((10, 10, 10, 10));
tcs.Task.Status.Should().Be(UniTaskStatus.Succeeded);
}
{
var tcs = new UniTaskCompletionSource<int>();
async UniTask<int> Await()
{
return await tcs.Task;
}
var a = Await();
var b = Await();
var c = Await();
tcs.TrySetException(new TestException());
await Assert.ThrowsAsync<TestException>(async () => await a);
await Assert.ThrowsAsync<TestException>(async () => await b);
await Assert.ThrowsAsync<TestException>(async () => await c);
await Assert.ThrowsAsync<TestException>(async () => await tcs.Task);
tcs.Task.Status.Should().Be(UniTaskStatus.Faulted);
}
var cts = new CancellationTokenSource();
{
var tcs = new UniTaskCompletionSource<int>();
async UniTask<int> Await()
{
return await tcs.Task;
}
var a = Await();
var b = Await();
var c = Await();
tcs.TrySetException(new OperationCanceledException(cts.Token));
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await a)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await b)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await c)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token);
tcs.Task.Status.Should().Be(UniTaskStatus.Canceled);
}
{
var tcs = new UniTaskCompletionSource<int>();
async UniTask<int> Await()
{
return await tcs.Task;
}
var a = Await();
var b = Await();
var c = Await();
tcs.TrySetCanceled(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await a)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await b)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await c)).CancellationToken.Should().Be(cts.Token);
(await Assert.ThrowsAsync<OperationCanceledException>(async () => await tcs.Task)).CancellationToken.Should().Be(cts.Token);
tcs.Task.Status.Should().Be(UniTaskStatus.Canceled);
}
}
class TestException : Exception
{
}
}
}

View File

@ -11,10 +11,7 @@ namespace Cysharp.Threading.Tasks
{
static readonly UniTask CanceledUniTask = new Func<UniTask>(() =>
{
var promise = new UniTaskCompletionSource();
promise.TrySetCanceled(CancellationToken.None);
promise.MarkHandled();
return promise.Task;
return new UniTask(new CanceledResultSource(CancellationToken.None), 0);
})();
static class CanceledUniTaskCache<T>
@ -23,10 +20,7 @@ namespace Cysharp.Threading.Tasks
static CanceledUniTaskCache()
{
var promise = new UniTaskCompletionSource<T>();
promise.TrySetCanceled(CancellationToken.None);
promise.MarkHandled();
Task = promise.Task;
Task = new UniTask<T>(new CanceledResultSource<T>(CancellationToken.None), 0);
}
}
@ -34,18 +28,22 @@ namespace Cysharp.Threading.Tasks
public static UniTask FromException(Exception ex)
{
var promise = new UniTaskCompletionSource();
promise.TrySetException(ex);
promise.MarkHandled();
return promise.Task;
if (ex is OperationCanceledException oce)
{
return FromCanceled(oce.CancellationToken);
}
return new UniTask(new ExceptionResultSource(ex), 0);
}
public static UniTask<T> FromException<T>(Exception ex)
{
var promise = new UniTaskCompletionSource<T>();
promise.TrySetException(ex);
promise.MarkHandled();
return promise.Task;
if (ex is OperationCanceledException oce)
{
return FromCanceled<T>(oce.CancellationToken);
}
return new UniTask<T>(new ExceptionResultSource<T>(ex), 0);
}
public static UniTask<T> FromResult<T>(T value)
@ -61,10 +59,7 @@ namespace Cysharp.Threading.Tasks
}
else
{
var promise = new UniTaskCompletionSource();
promise.TrySetCanceled(cancellationToken);
promise.MarkHandled();
return promise.Task;
return new UniTask(new CanceledResultSource(cancellationToken), 0);
}
}
@ -76,10 +71,7 @@ namespace Cysharp.Threading.Tasks
}
else
{
var promise = new UniTaskCompletionSource<T>();
promise.TrySetCanceled(cancellationToken);
promise.MarkHandled();
return promise.Task;
return new UniTask<T>(new CanceledResultSource<T>(cancellationToken), 0);
}
}
@ -182,6 +174,136 @@ namespace Cysharp.Threading.Tasks
return new UniTask<T>(new DeferPromise<T>(factory), 0);
}
sealed class ExceptionResultSource : IUniTaskSource
{
readonly Exception exception;
public ExceptionResultSource(Exception exception)
{
this.exception = exception;
}
public void GetResult(short token)
{
throw exception;
}
public UniTaskStatus GetStatus(short token)
{
return UniTaskStatus.Faulted;
}
public UniTaskStatus UnsafeGetStatus()
{
return UniTaskStatus.Faulted;
}
public void OnCompleted(Action<object> continuation, object state, short token)
{
continuation(state);
}
}
sealed class ExceptionResultSource<T> : IUniTaskSource<T>
{
readonly Exception exception;
public ExceptionResultSource(Exception exception)
{
this.exception = exception;
}
public T GetResult(short token)
{
throw exception;
}
void IUniTaskSource.GetResult(short token)
{
throw exception;
}
public UniTaskStatus GetStatus(short token)
{
return UniTaskStatus.Faulted;
}
public UniTaskStatus UnsafeGetStatus()
{
return UniTaskStatus.Faulted;
}
public void OnCompleted(Action<object> continuation, object state, short token)
{
continuation(state);
}
}
sealed class CanceledResultSource : IUniTaskSource
{
readonly CancellationToken cancellationToken;
public CanceledResultSource(CancellationToken cancellationToken)
{
this.cancellationToken = cancellationToken;
}
public void GetResult(short token)
{
throw new OperationCanceledException(cancellationToken);
}
public UniTaskStatus GetStatus(short token)
{
return UniTaskStatus.Canceled;
}
public UniTaskStatus UnsafeGetStatus()
{
return UniTaskStatus.Canceled;
}
public void OnCompleted(Action<object> continuation, object state, short token)
{
continuation(state);
}
}
sealed class CanceledResultSource<T> : IUniTaskSource<T>
{
readonly CancellationToken cancellationToken;
public CanceledResultSource(CancellationToken cancellationToken)
{
this.cancellationToken = cancellationToken;
}
public T GetResult(short token)
{
throw new OperationCanceledException(cancellationToken);
}
void IUniTaskSource.GetResult(short token)
{
throw new OperationCanceledException(cancellationToken);
}
public UniTaskStatus GetStatus(short token)
{
return UniTaskStatus.Canceled;
}
public UniTaskStatus UnsafeGetStatus()
{
return UniTaskStatus.Canceled;
}
public void OnCompleted(Action<object> continuation, object state, short token)
{
continuation(state);
}
}
sealed class DeferPromise : IUniTaskSource
{
Func<UniTask> factory;

View File

@ -1,6 +1,7 @@
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
@ -315,89 +316,6 @@ namespace Cysharp.Threading.Tasks
}
}
public class UniTaskCompletionSource : IUniTaskSource, IPromise
{
UniTaskCompletionSourceCore<AsyncUnit> core;
bool handled = false;
public UniTaskCompletionSource()
{
TaskTracker.TrackActiveTask(this, 2);
}
[DebuggerHidden]
internal void MarkHandled()
{
if (!handled)
{
handled = true;
core.MarkHandled();
TaskTracker.RemoveTracking(this);
}
}
public UniTask Task
{
[DebuggerHidden]
get
{
return new UniTask(this, core.Version);
}
}
[DebuggerHidden]
public void Reset()
{
// Reset, re-active tracker
handled = false;
TaskTracker.TrackActiveTask(this, 2);
core.Reset();
}
[DebuggerHidden]
public bool TrySetResult()
{
return core.TrySetResult(AsyncUnit.Default);
}
[DebuggerHidden]
public bool TrySetCanceled(CancellationToken cancellationToken = default)
{
return core.TrySetCanceled(cancellationToken);
}
[DebuggerHidden]
public bool TrySetException(Exception exception)
{
return core.TrySetException(exception);
}
[DebuggerHidden]
public void GetResult(short token)
{
MarkHandled();
core.GetResult(token);
}
[DebuggerHidden]
public UniTaskStatus GetStatus(short token)
{
return core.GetStatus(token);
}
[DebuggerHidden]
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
[DebuggerHidden]
public void OnCompleted(Action<object> continuation, object state, short token)
{
core.OnCompleted(continuation, state, token);
}
}
public class AutoResetUniTaskCompletionSource : IUniTaskSource, ITaskPoolNode<AutoResetUniTaskCompletionSource>, IPromise
{
static TaskPool<AutoResetUniTaskCompletionSource> pool;
@ -520,95 +438,6 @@ namespace Cysharp.Threading.Tasks
}
}
public class UniTaskCompletionSource<T> : IUniTaskSource<T>, IPromise<T>
{
UniTaskCompletionSourceCore<T> core;
bool handled = false;
[DebuggerHidden]
public UniTaskCompletionSource()
{
TaskTracker.TrackActiveTask(this, 2);
}
[DebuggerHidden]
internal void MarkHandled()
{
if (!handled)
{
handled = true;
core.MarkHandled();
TaskTracker.RemoveTracking(this);
}
}
[DebuggerHidden]
public UniTask<T> Task
{
get
{
return new UniTask<T>(this, core.Version);
}
}
[DebuggerHidden]
public void Reset()
{
handled = false;
core.Reset();
TaskTracker.TrackActiveTask(this, 2);
}
[DebuggerHidden]
public bool TrySetResult(T result)
{
return core.TrySetResult(result);
}
[DebuggerHidden]
public bool TrySetCanceled(CancellationToken cancellationToken = default)
{
return core.TrySetCanceled(cancellationToken);
}
[DebuggerHidden]
public bool TrySetException(Exception exception)
{
return core.TrySetException(exception);
}
[DebuggerHidden]
public T GetResult(short token)
{
MarkHandled();
return core.GetResult(token);
}
[DebuggerHidden]
void IUniTaskSource.GetResult(short token)
{
GetResult(token);
}
[DebuggerHidden]
public UniTaskStatus GetStatus(short token)
{
return core.GetStatus(token);
}
[DebuggerHidden]
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
[DebuggerHidden]
public void OnCompleted(Action<object> continuation, object state, short token)
{
core.OnCompleted(continuation, state, token);
}
}
public class AutoResetUniTaskCompletionSource<T> : IUniTaskSource<T>, ITaskPoolNode<AutoResetUniTaskCompletionSource<T>>, IPromise<T>
{
static TaskPool<AutoResetUniTaskCompletionSource<T>> pool;
@ -735,5 +564,374 @@ namespace Cysharp.Threading.Tasks
return pool.TryPush(this);
}
}
}
public class UniTaskCompletionSource : IUniTaskSource, IPromise
{
CancellationToken cancellationToken;
ExceptionHolder exception;
object gate;
Action<object> singleContinuation;
object singleState;
List<(Action<object>, object)> secondaryContinuationList;
int intStatus; // UniTaskStatus
bool handled = false;
public UniTaskCompletionSource()
{
TaskTracker.TrackActiveTask(this, 2);
}
[DebuggerHidden]
internal void MarkHandled()
{
if (!handled)
{
handled = true;
TaskTracker.RemoveTracking(this);
}
}
public UniTask Task
{
[DebuggerHidden]
get
{
return new UniTask(this, 0);
}
}
[DebuggerHidden]
public bool TrySetResult()
{
return TrySignalCompletion(UniTaskStatus.Succeeded);
}
[DebuggerHidden]
public bool TrySetCanceled(CancellationToken cancellationToken = default)
{
if (UnsafeGetStatus() != UniTaskStatus.Pending) return false;
this.cancellationToken = cancellationToken;
return TrySignalCompletion(UniTaskStatus.Canceled);
}
[DebuggerHidden]
public bool TrySetException(Exception exception)
{
if (exception is OperationCanceledException oce)
{
return TrySetCanceled(oce.CancellationToken);
}
if (UnsafeGetStatus() != UniTaskStatus.Pending) return false;
this.exception = new ExceptionHolder(ExceptionDispatchInfo.Capture(exception));
return TrySignalCompletion(UniTaskStatus.Faulted);
}
[DebuggerHidden]
public void GetResult(short token)
{
MarkHandled();
var status = (UniTaskStatus)intStatus;
switch (status)
{
case UniTaskStatus.Succeeded:
return;
case UniTaskStatus.Faulted:
exception.GetException().Throw();
return;
case UniTaskStatus.Canceled:
throw new OperationCanceledException(cancellationToken);
default:
case UniTaskStatus.Pending:
throw new InvalidOperationException("not yet completed.");
}
}
[DebuggerHidden]
public UniTaskStatus GetStatus(short token)
{
return (UniTaskStatus)intStatus;
}
[DebuggerHidden]
public UniTaskStatus UnsafeGetStatus()
{
return (UniTaskStatus)intStatus;
}
[DebuggerHidden]
public void OnCompleted(Action<object> continuation, object state, short token)
{
if (gate == null)
{
Interlocked.CompareExchange(ref gate, new object(), null);
}
var lockGate = Thread.VolatileRead(ref gate);
lock (lockGate) // wait TrySignalCompletion, after status is not pending.
{
if ((UniTaskStatus)intStatus != UniTaskStatus.Pending)
{
continuation(state);
return;
}
if (singleContinuation == null)
{
singleContinuation = continuation;
singleState = state;
}
else
{
if (secondaryContinuationList == null)
{
secondaryContinuationList = new List<(Action<object>, object)>();
}
secondaryContinuationList.Add((continuation, state));
}
}
}
bool TrySignalCompletion(UniTaskStatus status)
{
if (Interlocked.CompareExchange(ref intStatus, (int)status, (int)UniTaskStatus.Pending) == (int)UniTaskStatus.Pending)
{
if (gate == null)
{
Interlocked.CompareExchange(ref gate, new object(), null);
}
var lockGate = Thread.VolatileRead(ref gate);
lock (lockGate) // wait OnCompleted.
{
if (singleContinuation != null)
{
try
{
singleContinuation(singleState);
}
catch (Exception ex)
{
UniTaskScheduler.PublishUnobservedTaskException(ex);
}
}
if (secondaryContinuationList != null)
{
foreach (var (c, state) in secondaryContinuationList)
{
try
{
c(state);
}
catch (Exception ex)
{
UniTaskScheduler.PublishUnobservedTaskException(ex);
}
}
}
singleContinuation = null;
singleState = null;
secondaryContinuationList = null;
}
return true;
}
return false;
}
}
public class UniTaskCompletionSource<T> : IUniTaskSource<T>, IPromise<T>
{
CancellationToken cancellationToken;
T result;
ExceptionHolder exception;
object gate;
Action<object> singleContinuation;
object singleState;
List<(Action<object>, object)> secondaryContinuationList;
int intStatus; // UniTaskStatus
bool handled = false;
public UniTaskCompletionSource()
{
TaskTracker.TrackActiveTask(this, 2);
}
[DebuggerHidden]
internal void MarkHandled()
{
if (!handled)
{
handled = true;
TaskTracker.RemoveTracking(this);
}
}
public UniTask<T> Task
{
[DebuggerHidden]
get
{
return new UniTask<T>(this, 0);
}
}
[DebuggerHidden]
public bool TrySetResult(T result)
{
if (UnsafeGetStatus() != UniTaskStatus.Pending) return false;
this.result = result;
return TrySignalCompletion(UniTaskStatus.Succeeded);
}
[DebuggerHidden]
public bool TrySetCanceled(CancellationToken cancellationToken = default)
{
if (UnsafeGetStatus() != UniTaskStatus.Pending) return false;
this.cancellationToken = cancellationToken;
return TrySignalCompletion(UniTaskStatus.Canceled);
}
[DebuggerHidden]
public bool TrySetException(Exception exception)
{
if (exception is OperationCanceledException oce)
{
return TrySetCanceled(oce.CancellationToken);
}
if (UnsafeGetStatus() != UniTaskStatus.Pending) return false;
this.exception = new ExceptionHolder(ExceptionDispatchInfo.Capture(exception));
return TrySignalCompletion(UniTaskStatus.Faulted);
}
[DebuggerHidden]
public T GetResult(short token)
{
MarkHandled();
var status = (UniTaskStatus)intStatus;
switch (status)
{
case UniTaskStatus.Succeeded:
return result;
case UniTaskStatus.Faulted:
exception.GetException().Throw();
return default;
case UniTaskStatus.Canceled:
throw new OperationCanceledException(cancellationToken);
default:
case UniTaskStatus.Pending:
throw new InvalidOperationException("not yet completed.");
}
}
[DebuggerHidden]
void IUniTaskSource.GetResult(short token)
{
GetResult(token);
}
[DebuggerHidden]
public UniTaskStatus GetStatus(short token)
{
return (UniTaskStatus)intStatus;
}
[DebuggerHidden]
public UniTaskStatus UnsafeGetStatus()
{
return (UniTaskStatus)intStatus;
}
[DebuggerHidden]
public void OnCompleted(Action<object> continuation, object state, short token)
{
if (gate == null)
{
Interlocked.CompareExchange(ref gate, new object(), null);
}
var lockGate = Thread.VolatileRead(ref gate);
lock (lockGate) // wait TrySignalCompletion, after status is not pending.
{
if ((UniTaskStatus)intStatus != UniTaskStatus.Pending)
{
continuation(state);
return;
}
if (singleContinuation == null)
{
singleContinuation = continuation;
singleState = state;
}
else
{
if (secondaryContinuationList == null)
{
secondaryContinuationList = new List<(Action<object>, object)>();
}
secondaryContinuationList.Add((continuation, state));
}
}
}
bool TrySignalCompletion(UniTaskStatus status)
{
if (Interlocked.CompareExchange(ref intStatus, (int)status, (int)UniTaskStatus.Pending) == (int)UniTaskStatus.Pending)
{
if (gate == null)
{
Interlocked.CompareExchange(ref gate, new object(), null);
}
var lockGate = Thread.VolatileRead(ref gate);
lock (lockGate) // wait OnCompleted.
{
if (singleContinuation != null)
{
try
{
singleContinuation(singleState);
}
catch (Exception ex)
{
UniTaskScheduler.PublishUnobservedTaskException(ex);
}
}
if (secondaryContinuationList != null)
{
foreach (var (c, state) in secondaryContinuationList)
{
try
{
c(state);
}
catch (Exception ex)
{
UniTaskScheduler.PublishUnobservedTaskException(ex);
}
}
}
singleContinuation = null;
singleState = null;
secondaryContinuationList = null;
}
return true;
}
return false;
}
}
}