WhenAll and WhenAny

master
neuecc 2020-04-21 13:36:23 +09:00
parent 082f3e7335
commit 3654a9e2f9
16 changed files with 11143 additions and 3797 deletions

View File

@ -96,7 +96,7 @@ namespace UniRx.Async
{
if (cancellationToken.IsCancellationRequested)
{
core.SetCanceled(cancellationToken);
core.TrySetCanceled(cancellationToken);
return false;
}
@ -109,11 +109,11 @@ namespace UniRx.Async
}
catch (Exception ex)
{
core.SetException(ex);
core.TrySetException(ex);
return false;
}
core.SetResult(null);
core.TrySetResult(null);
return false;
}
@ -126,6 +126,14 @@ namespace UniRx.Async
exception = default;
}
~EnumeratorPromise()
{
if (pool.TryReturn(this))
{
GC.ReRegisterForFinalize(this);
}
}
// Unwrap YieldInstructions
static IEnumerator ConsumeEnumerator(IEnumerator enumerator)

View File

@ -40,7 +40,7 @@ namespace UniRx.Async.Internal
return new RentArray<T>(array, array.Length, null);
}
var defaultCount = 4;
var defaultCount = 32;
if (source is ICollection<T> coll)
{
defaultCount = coll.Count;

View File

@ -10,6 +10,11 @@ namespace UniRx.Async.Internal
{
return StatePool<T1, T2>.Create(item1, item2);
}
public static StateTuple<T1, T2, T3> Create<T1, T2, T3>(T1 item1, T2 item2, T3 item3)
{
return StatePool<T1, T2, T3>.Create(item1, item2, item3);
}
}
internal class StateTuple<T1, T2> : IDisposable
@ -54,4 +59,51 @@ namespace UniRx.Async.Internal
queue.Enqueue(tuple);
}
}
internal class StateTuple<T1, T2, T3> : IDisposable
{
public T1 Item1;
public T2 Item2;
public T3 Item3;
public void Deconstruct(out T1 item1, out T2 item2, out T3 item3)
{
item1 = this.Item1;
item2 = this.Item2;
item3 = this.Item3;
}
public void Dispose()
{
StatePool<T1, T2, T3>.Return(this);
}
}
internal static class StatePool<T1, T2, T3>
{
static readonly ConcurrentQueue<StateTuple<T1, T2, T3>> queue = new ConcurrentQueue<StateTuple<T1, T2, T3>>();
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static StateTuple<T1, T2, T3> Create(T1 item1, T2 item2, T3 item3)
{
if (queue.TryDequeue(out var value))
{
value.Item1 = item1;
value.Item2 = item2;
value.Item3 = item3;
return value;
}
return new StateTuple<T1, T2, T3> { Item1 = item1, Item2 = item2, Item3 = item3 };
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void Return(StateTuple<T1, T2, T3> tuple)
{
tuple.Item1 = default;
tuple.Item2 = default;
tuple.Item3 = default;
queue.Enqueue(tuple);
}
}
}

View File

@ -24,8 +24,6 @@ namespace UniRx.Async
public static UniTask DelayFrame(int delayFrameCount, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken))
{
PlayerLoopHelper.Initialize(
if (delayFrameCount < 0)
{
throw new ArgumentOutOfRangeException("Delay does not allow minus delayFrameCount. delayFrameCount:" + delayFrameCount);
@ -121,11 +119,11 @@ namespace UniRx.Async
{
if (cancellationToken.IsCancellationRequested)
{
core.SetCanceled(cancellationToken);
core.TrySetCanceled(cancellationToken);
return false;
}
core.SetResult(null);
core.TrySetResult(null);
return false;
}
@ -134,6 +132,14 @@ namespace UniRx.Async
core.Reset();
cancellationToken = default;
}
~YieldPromise()
{
if (pool.TryReturn(this))
{
GC.ReRegisterForFinalize(this);
}
}
}
sealed class DelayFramePromise : IUniTaskSource, IPlayerLoopItem, IPromisePoolItem
@ -202,13 +208,13 @@ namespace UniRx.Async
{
if (cancellationToken.IsCancellationRequested)
{
core.SetCanceled(cancellationToken);
core.TrySetCanceled(cancellationToken);
return false;
}
if (currentFrameCount == delayFrameCount)
{
core.SetResult(null);
core.TrySetResult(null);
return false;
}
@ -223,6 +229,14 @@ namespace UniRx.Async
delayFrameCount = default;
cancellationToken = default;
}
~DelayFramePromise()
{
if (pool.TryReturn(this))
{
GC.ReRegisterForFinalize(this);
}
}
}
sealed class DelayPromise : IUniTaskSource, IPlayerLoopItem, IPromisePoolItem
@ -292,14 +306,14 @@ namespace UniRx.Async
{
if (cancellationToken.IsCancellationRequested)
{
core.SetCanceled(cancellationToken);
core.TrySetCanceled(cancellationToken);
return false;
}
elapsed += Time.deltaTime;
if (elapsed >= delayFrameTimeSpan)
{
core.SetResult(null);
core.TrySetResult(null);
return false;
}
@ -313,6 +327,14 @@ namespace UniRx.Async
elapsed = default;
cancellationToken = default;
}
~DelayPromise()
{
if (pool.TryReturn(this))
{
GC.ReRegisterForFinalize(this);
}
}
}
sealed class DelayIgnoreTimeScalePromise : IUniTaskSource, IPlayerLoopItem, IPromisePoolItem
@ -382,14 +404,14 @@ namespace UniRx.Async
{
if (cancellationToken.IsCancellationRequested)
{
core.SetCanceled(cancellationToken);
core.TrySetCanceled(cancellationToken);
return false;
}
elapsed += Time.unscaledDeltaTime;
if (elapsed >= delayFrameTimeSpan)
{
core.SetResult(null);
core.TrySetResult(null);
return false;
}
@ -403,6 +425,14 @@ namespace UniRx.Async
elapsed = default;
cancellationToken = default;
}
~DelayIgnoreTimeScalePromise()
{
if (pool.TryReturn(this))
{
GC.ReRegisterForFinalize(this);
}
}
}
}

View File

@ -12,14 +12,12 @@ namespace UniRx.Async
{
public static UniTask WaitUntil(Func<bool> predicate, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken))
{
var promise = new WaitUntilPromise(predicate, timing, cancellationToken);
return promise.Task;
return new UniTask(WaitUntilPromise.Create(predicate, timing, cancellationToken, out var token), token);
}
public static UniTask WaitWhile(Func<bool> predicate, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken))
{
var promise = new WaitWhilePromise(predicate, timing, cancellationToken);
return promise.Task;
return new UniTask(WaitWhilePromise.Create(predicate, timing, cancellationToken, out var token), token);
}
public static UniTask<U> WaitUntilValueChanged<T, U>(T target, Func<T, U> monitorFunction, PlayerLoopTiming monitorTiming = PlayerLoopTiming.Update, IEqualityComparer<U> equalityComparer = null, CancellationToken cancellationToken = default(CancellationToken))
@ -28,130 +26,293 @@ namespace UniRx.Async
var unityObject = target as UnityEngine.Object;
var isUnityObject = !object.ReferenceEquals(target, null); // don't use (unityObject == null)
return (isUnityObject)
? new WaitUntilValueChangedUnityObjectPromise<T, U>(target, monitorFunction, equalityComparer, monitorTiming, cancellationToken).Task
: new WaitUntilValueChangedStandardObjectPromise<T, U>(target, monitorFunction, equalityComparer, monitorTiming, cancellationToken).Task;
return new UniTask<U>(isUnityObject
? WaitUntilValueChangedUnityObjectPromise<T, U>.Create(target, monitorFunction, equalityComparer, monitorTiming, cancellationToken, out var token)
: WaitUntilValueChangedStandardObjectPromise<T, U>.Create(target, monitorFunction, equalityComparer, monitorTiming, cancellationToken, out token), token);
}
class WaitUntilPromise : PlayerLoopReusablePromiseBase
sealed class WaitUntilPromise : IUniTaskSource, IPlayerLoopItem, IPromisePoolItem
{
readonly Func<bool> predicate;
static readonly PromisePool<WaitUntilPromise> pool = new PromisePool<WaitUntilPromise>();
public WaitUntilPromise(Func<bool> predicate, PlayerLoopTiming timing, CancellationToken cancellationToken)
: base(timing, cancellationToken, 1)
{
this.predicate = predicate;
}
Func<bool> predicate;
CancellationToken cancellationToken;
protected override void OnRunningStart()
UniTaskCompletionSourceCore<object> core;
WaitUntilPromise()
{
}
public override bool MoveNext()
public static IUniTaskSource Create(Func<bool> predicate, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token)
{
if (cancellationToken.IsCancellationRequested)
{
Complete();
TrySetCanceled();
return false;
return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token);
}
bool result = default(bool);
var result = pool.TryRent() ?? new WaitUntilPromise();
result.predicate = predicate;
result.cancellationToken = cancellationToken;
TaskTracker2.TrackActiveTask(result, 3);
PlayerLoopHelper.AddAction(timing, result);
token = result.core.Version;
return result;
}
public void GetResult(short token)
{
try
{
result = predicate();
TaskTracker2.RemoveTracking(this);
core.GetResult(token);
}
catch (Exception ex)
finally
{
Complete();
TrySetException(ex);
return false;
pool.TryReturn(this);
}
if (result)
{
Complete();
TrySetResult();
return false;
}
return true;
}
}
class WaitWhilePromise : PlayerLoopReusablePromiseBase
{
readonly Func<bool> predicate;
public WaitWhilePromise(Func<bool> predicate, PlayerLoopTiming timing, CancellationToken cancellationToken)
: base(timing, cancellationToken, 1)
public UniTaskStatus GetStatus(short token)
{
this.predicate = predicate;
return core.GetStatus(token);
}
protected override void OnRunningStart()
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
public override bool MoveNext()
public void OnCompleted(Action<object> continuation, object state, short token)
{
core.OnCompleted(continuation, state, token);
}
public bool MoveNext()
{
if (cancellationToken.IsCancellationRequested)
{
Complete();
TrySetCanceled();
core.TrySetCanceled(cancellationToken);
return false;
}
bool result = default(bool);
try
{
result = predicate();
if (!predicate())
{
return true;
}
}
catch (Exception ex)
{
Complete();
TrySetException(ex);
core.TrySetException(ex);
return false;
}
if (!result)
core.TrySetResult(null);
return false;
}
public void Reset()
{
core.Reset();
predicate = default;
cancellationToken = default;
}
~WaitUntilPromise()
{
if (pool.TryReturn(this))
{
Complete();
TrySetResult();
GC.ReRegisterForFinalize(this);
}
}
}
sealed class WaitWhilePromise : IUniTaskSource, IPlayerLoopItem, IPromisePoolItem
{
static readonly PromisePool<WaitWhilePromise> pool = new PromisePool<WaitWhilePromise>();
Func<bool> predicate;
CancellationToken cancellationToken;
UniTaskCompletionSourceCore<object> core;
WaitWhilePromise()
{
}
public static IUniTaskSource Create(Func<bool> predicate, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token)
{
if (cancellationToken.IsCancellationRequested)
{
return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token);
}
var result = pool.TryRent() ?? new WaitWhilePromise();
result.predicate = predicate;
result.cancellationToken = cancellationToken;
TaskTracker2.TrackActiveTask(result, 3);
PlayerLoopHelper.AddAction(timing, result);
token = result.core.Version;
return result;
}
public void GetResult(short token)
{
try
{
TaskTracker2.RemoveTracking(this);
core.GetResult(token);
}
finally
{
pool.TryReturn(this);
}
}
public UniTaskStatus GetStatus(short token)
{
return core.GetStatus(token);
}
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
public void OnCompleted(Action<object> continuation, object state, short token)
{
core.OnCompleted(continuation, state, token);
}
public bool MoveNext()
{
if (cancellationToken.IsCancellationRequested)
{
core.TrySetCanceled(cancellationToken);
return false;
}
return true;
try
{
if (predicate())
{
return true;
}
}
catch (Exception ex)
{
core.TrySetException(ex);
return false;
}
core.TrySetResult(null);
return false;
}
public void Reset()
{
core.Reset();
predicate = default;
cancellationToken = default;
}
~WaitWhilePromise()
{
if (pool.TryReturn(this))
{
GC.ReRegisterForFinalize(this);
}
}
}
// where T : UnityEngine.Object, can not add constraint
class WaitUntilValueChangedUnityObjectPromise<T, U> : PlayerLoopReusablePromiseBase<U>
sealed class WaitUntilValueChangedUnityObjectPromise<T, U> : IUniTaskSource<U>, IPlayerLoopItem, IPromisePoolItem
{
readonly T target;
readonly Func<T, U> monitorFunction;
readonly IEqualityComparer<U> equalityComparer;
static readonly PromisePool<WaitUntilValueChangedUnityObjectPromise<T, U>> pool = new PromisePool<WaitUntilValueChangedUnityObjectPromise<T, U>>();
T target;
U currentValue;
Func<T, U> monitorFunction;
IEqualityComparer<U> equalityComparer;
CancellationToken cancellationToken;
public WaitUntilValueChangedUnityObjectPromise(T target, Func<T, U> monitorFunction, IEqualityComparer<U> equalityComparer, PlayerLoopTiming timing, CancellationToken cancellationToken)
: base(timing, cancellationToken, 1)
{
this.target = target;
this.monitorFunction = monitorFunction;
this.equalityComparer = equalityComparer ?? UnityEqualityComparer.GetDefault<U>();
this.currentValue = monitorFunction(target);
}
UniTaskCompletionSourceCore<U> core;
protected override void OnRunningStart()
WaitUntilValueChangedUnityObjectPromise()
{
}
public override bool MoveNext()
public static IUniTaskSource<U> Create(T target, Func<T, U> monitorFunction, IEqualityComparer<U> equalityComparer, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token)
{
if (cancellationToken.IsCancellationRequested)
{
return AutoResetUniTaskCompletionSource<U>.CreateFromCanceled(cancellationToken, out token);
}
var result = pool.TryRent() ?? new WaitUntilValueChangedUnityObjectPromise<T, U>();
result.target = target;
result.monitorFunction = monitorFunction;
result.currentValue = monitorFunction(target);
result.equalityComparer = equalityComparer ?? UnityEqualityComparer.GetDefault<U>();
result.cancellationToken = cancellationToken;
TaskTracker2.TrackActiveTask(result, 3);
PlayerLoopHelper.AddAction(timing, result);
token = result.core.Version;
return result;
}
public U GetResult(short token)
{
try
{
TaskTracker2.RemoveTracking(this);
return core.GetResult(token);
}
finally
{
pool.TryReturn(this);
}
}
void IUniTaskSource.GetResult(short token)
{
GetResult(token);
}
public UniTaskStatus GetStatus(short token)
{
return core.GetStatus(token);
}
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
public void OnCompleted(Action<object> continuation, object state, short token)
{
core.OnCompleted(continuation, state, token);
}
public bool MoveNext()
{
if (cancellationToken.IsCancellationRequested || target == null) // destroyed = cancel.
{
Complete();
TrySetCanceled();
core.TrySetCanceled(cancellationToken);
return false;
}
@ -166,45 +327,112 @@ namespace UniRx.Async
}
catch (Exception ex)
{
Complete();
TrySetException(ex);
core.TrySetException(ex);
return false;
}
Complete();
currentValue = nextValue;
TrySetResult(nextValue);
core.TrySetResult(nextValue);
return false;
}
public void Reset()
{
core.Reset();
target = default;
currentValue = default;
monitorFunction = default;
equalityComparer = default;
cancellationToken = default;
}
~WaitUntilValueChangedUnityObjectPromise()
{
if (pool.TryReturn(this))
{
GC.ReRegisterForFinalize(this);
}
}
}
class WaitUntilValueChangedStandardObjectPromise<T, U> : PlayerLoopReusablePromiseBase<U>
sealed class WaitUntilValueChangedStandardObjectPromise<T, U> : IUniTaskSource<U>, IPlayerLoopItem, IPromisePoolItem
where T : class
{
readonly WeakReference<T> target;
readonly Func<T, U> monitorFunction;
readonly IEqualityComparer<U> equalityComparer;
static readonly PromisePool<WaitUntilValueChangedStandardObjectPromise<T, U>> pool = new PromisePool<WaitUntilValueChangedStandardObjectPromise<T, U>>();
WeakReference<T> target;
U currentValue;
Func<T, U> monitorFunction;
IEqualityComparer<U> equalityComparer;
CancellationToken cancellationToken;
public WaitUntilValueChangedStandardObjectPromise(T target, Func<T, U> monitorFunction, IEqualityComparer<U> equalityComparer, PlayerLoopTiming timing, CancellationToken cancellationToken)
: base(timing, cancellationToken, 1)
{
this.target = new WeakReference<T>(target, false); // wrap in WeakReference.
this.monitorFunction = monitorFunction;
this.equalityComparer = equalityComparer ?? UnityEqualityComparer.GetDefault<U>();
this.currentValue = monitorFunction(target);
}
UniTaskCompletionSourceCore<U> core;
protected override void OnRunningStart()
WaitUntilValueChangedStandardObjectPromise()
{
}
public override bool MoveNext()
public static IUniTaskSource<U> Create(T target, Func<T, U> monitorFunction, IEqualityComparer<U> equalityComparer, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token)
{
if (cancellationToken.IsCancellationRequested || !target.TryGetTarget(out var t))
if (cancellationToken.IsCancellationRequested)
{
Complete();
TrySetCanceled();
return AutoResetUniTaskCompletionSource<U>.CreateFromCanceled(cancellationToken, out token);
}
var result = pool.TryRent() ?? new WaitUntilValueChangedStandardObjectPromise<T, U>();
result.target = new WeakReference<T>(target, false); // wrap in WeakReference.
result.monitorFunction = monitorFunction;
result.currentValue = monitorFunction(target);
result.equalityComparer = equalityComparer ?? UnityEqualityComparer.GetDefault<U>();
result.cancellationToken = cancellationToken;
TaskTracker2.TrackActiveTask(result, 3);
PlayerLoopHelper.AddAction(timing, result);
token = result.core.Version;
return result;
}
public U GetResult(short token)
{
try
{
TaskTracker2.RemoveTracking(this);
return core.GetResult(token);
}
finally
{
pool.TryReturn(this);
}
}
void IUniTaskSource.GetResult(short token)
{
GetResult(token);
}
public UniTaskStatus GetStatus(short token)
{
return core.GetStatus(token);
}
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
public void OnCompleted(Action<object> continuation, object state, short token)
{
core.OnCompleted(continuation, state, token);
}
public bool MoveNext()
{
if (cancellationToken.IsCancellationRequested || !target.TryGetTarget(out var t)) // doesn't find = cancel.
{
core.TrySetCanceled(cancellationToken);
return false;
}
@ -219,16 +447,31 @@ namespace UniRx.Async
}
catch (Exception ex)
{
Complete();
TrySetException(ex);
core.TrySetException(ex);
return false;
}
Complete();
currentValue = nextValue;
TrySetResult(nextValue);
core.TrySetResult(nextValue);
return false;
}
public void Reset()
{
core.Reset();
target = default;
currentValue = default;
monitorFunction = default;
equalityComparer = default;
cancellationToken = default;
}
~WaitUntilValueChangedStandardObjectPromise()
{
if (pool.TryReturn(this))
{
GC.ReRegisterForFinalize(this);
}
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,11 +0,0 @@
fileFormatVersion: 2
guid: 5110117231c8a6d4095fd0cbd3f4c142
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

View File

@ -0,0 +1,127 @@
<#@ template debug="false" hostspecific="false" language="C#" #>
<#@ assembly name="System.Core" #>
<#@ import namespace="System.Linq" #>
<#@ import namespace="System.Text" #>
<#@ import namespace="System.Collections.Generic" #>
<#@ output extension=".cs" #>
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
using System;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Threading;
using UniRx.Async.Internal;
namespace UniRx.Async
{
public partial struct UniTask
{
<# for(var i = 2; i <= 15; i++ ) {
var range = Enumerable.Range(1, i);
var t = string.Join(", ", range.Select(x => "T" + x));
var args = string.Join(", ", range.Select(x => $"UniTask<T{x}> task{x}"));
var targs = string.Join(", ", range.Select(x => $"task{x}"));
var tresult = string.Join(", ", range.Select(x => $"task{x}.GetAwaiter().GetResult()"));
var completedSuccessfullyAnd = string.Join(" && ", range.Select(x => $"task{x}.Status.IsCompletedSuccessfully()"));
var tfield = string.Join(", ", range.Select(x => $"self.t{x}"));
#>
public static UniTask<(<#= t #>)> WhenAll<<#= t #>>(<#= args #>)
{
if (<#= completedSuccessfullyAnd #>)
{
return new UniTask<(<#= t #>)>((<#= tresult #>));
}
return new UniTask<(<#= t #>)>(new WhenAllPromise<<#= t #>>(<#= targs #>), 0);
}
sealed class WhenAllPromise<<#= t #>> : IUniTaskSource<(<#= t #>)>
{
<# for(var j = 1; j <= i; j++) { #>
T<#= j #> t<#= j #> = default;
<# } #>
int completedCount;
UniTaskCompletionSourceCore<(<#= t #>)> core;
public WhenAllPromise(<#= args #>)
{
TaskTracker2.TrackActiveTask(this, 3);
this.completedCount = 0;
<# for(var j = 1; j <= i; j++) { #>
{
var awaiter = task<#= j #>.GetAwaiter();
if (awaiter.IsCompleted)
{
TryInvokeContinuationT<#= j #>(this, awaiter);
}
else
{
awaiter.SourceOnCompleted(state =>
{
using (var t = (StateTuple<WhenAllPromise<<#= t #>>, UniTask<T<#= j #>>.Awaiter>)state)
{
TryInvokeContinuationT<#= j #>(t.Item1, t.Item2);
}
}, StateTuple.Create(this, awaiter));
}
}
<# } #>
}
<# for(var j = 1; j <= i; j++) { #>
static void TryInvokeContinuationT<#= j #>(WhenAllPromise<<#= t #>> self, in UniTask<T<#= j #>>.Awaiter awaiter)
{
try
{
self.t<#= j #> = awaiter.GetResult();
}
catch (Exception ex)
{
self.core.TrySetException(ex);
return;
}
if (Interlocked.Increment(ref self.completedCount) == <#= i #>)
{
self.core.TrySetResult((<#= tfield #>));
}
}
<# } #>
public (<#= t #>) GetResult(short token)
{
TaskTracker2.RemoveTracking(this);
GC.SuppressFinalize(this);
return core.GetResult(token);
}
void IUniTaskSource.GetResult(short token)
{
GetResult(token);
}
public UniTaskStatus GetStatus(short token)
{
return core.GetStatus(token);
}
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
public void OnCompleted(Action<object> continuation, object state, short token)
{
core.OnCompleted(continuation, state, token);
}
~WhenAllPromise()
{
core.Reset();
}
}
<# } #>
}
}

View File

@ -3,8 +3,6 @@
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Threading;
using UniRx.Async.Internal;
@ -12,285 +10,214 @@ namespace UniRx.Async
{
public partial struct UniTask
{
// UniTask
public static async UniTask<T[]> WhenAll<T>(params UniTask<T>[] tasks)
public static UniTask<T[]> WhenAll<T>(params UniTask<T>[] tasks)
{
return await new WhenAllPromise<T>(tasks, tasks.Length);
return new UniTask<T[]>(new WhenAllPromise<T>(tasks, tasks.Length), 0);
}
public static async UniTask<T[]> WhenAll<T>(IEnumerable<UniTask<T>> tasks)
public static UniTask<T[]> WhenAll<T>(IEnumerable<UniTask<T>> tasks)
{
WhenAllPromise<T> promise;
using (var span = ArrayPoolUtil.Materialize(tasks))
{
promise = new WhenAllPromise<T>(span.Array, span.Length);
var promise = new WhenAllPromise<T>(span.Array, span.Length); // consumed array in constructor.
return new UniTask<T[]>(promise, 0);
}
return await promise;
}
public static async UniTask WhenAll(params UniTask[] tasks)
public static UniTask WhenAll(params UniTask[] tasks)
{
await new WhenAllPromise(tasks, tasks.Length);
return new UniTask(new WhenAllPromise(tasks, tasks.Length), 0);
}
public static async UniTask WhenAll(IEnumerable<UniTask> tasks)
public static UniTask WhenAll(IEnumerable<UniTask> tasks)
{
WhenAllPromise promise;
using (var span = ArrayPoolUtil.Materialize(tasks))
{
promise = new WhenAllPromise(span.Array, span.Length);
var promise = new WhenAllPromise(span.Array, span.Length); // consumed array in constructor.
return new UniTask(promise, 0);
}
await promise;
}
class WhenAllPromise<T>
sealed class WhenAllPromise<T> : IUniTaskSource<T[]>
{
readonly T[] result;
T[] result;
int completeCount;
Action whenComplete;
ExceptionDispatchInfo exception;
UniTaskCompletionSourceCore<T[]> core; // don't reset(called after GetResult, will invoke TrySetException.)
public WhenAllPromise(UniTask<T>[] tasks, int tasksLength)
{
TaskTracker2.TrackActiveTask(this, 3);
this.completeCount = 0;
this.whenComplete = null;
this.exception = null;
this.result = new T[tasksLength];
for (int i = 0; i < tasksLength; i++)
{
if (tasks[i].Status.IsCompleted())
UniTask<T>.Awaiter awaiter;
try
{
T value = default(T);
try
{
value = tasks[i].GetAwaiter().GetResult();
}
catch (Exception ex)
{
exception = ExceptionDispatchInfo.Capture(ex);
TryCallContinuation();
continue;
}
awaiter = tasks[i].GetAwaiter();
}
catch (Exception ex)
{
core.TrySetException(ex);
continue;
}
result[i] = value;
var count = Interlocked.Increment(ref completeCount);
if (count == result.Length)
{
TryCallContinuation();
}
if (awaiter.IsCompleted)
{
TryInvokeContinuation(this, awaiter, i);
}
else
{
RunTask(tasks[i], i).Forget();
awaiter.SourceOnCompleted(state =>
{
using (var t = (StateTuple<WhenAllPromise<T>, UniTask<T>.Awaiter, int>)state)
{
TryInvokeContinuation(t.Item1, t.Item2, t.Item3);
}
}, StateTuple.Create(this, awaiter, i));
}
}
}
void TryCallContinuation()
static void TryInvokeContinuation(WhenAllPromise<T> self, in UniTask<T>.Awaiter awaiter, int i)
{
var action = Interlocked.Exchange(ref whenComplete, null);
if (action != null)
{
action.Invoke();
}
}
async UniTaskVoid RunTask(UniTask<T> task, int index)
{
T value = default(T);
try
{
value = await task;
self.result[i] = awaiter.GetResult();
}
catch (Exception ex)
{
exception = ExceptionDispatchInfo.Capture(ex);
TryCallContinuation();
self.core.TrySetException(ex);
return;
}
result[index] = value;
var count = Interlocked.Increment(ref completeCount);
if (count == result.Length)
if (Interlocked.Increment(ref self.completeCount) == self.result.Length)
{
TryCallContinuation();
self.core.TrySetResult(self.result);
}
}
public Awaiter GetAwaiter()
public T[] GetResult(short token)
{
return new Awaiter(this);
TaskTracker2.RemoveTracking(this);
GC.SuppressFinalize(this);
return core.GetResult(token);
}
public struct Awaiter : ICriticalNotifyCompletion
void IUniTaskSource.GetResult(short token)
{
WhenAllPromise<T> parent;
GetResult(token);
}
public Awaiter(WhenAllPromise<T> parent)
{
this.parent = parent;
}
public UniTaskStatus GetStatus(short token)
{
return core.GetStatus(token);
}
public bool IsCompleted
{
get
{
return parent.exception != null || parent.result.Length == parent.completeCount;
}
}
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
public T[] GetResult()
{
if (parent.exception != null)
{
parent.exception.Throw();
}
public void OnCompleted(Action<object> continuation, object state, short token)
{
core.OnCompleted(continuation, state, token);
}
return parent.result;
}
public void OnCompleted(Action continuation)
{
UnsafeOnCompleted(continuation);
}
public void UnsafeOnCompleted(Action continuation)
{
parent.whenComplete = continuation;
if (IsCompleted)
{
var action = Interlocked.Exchange(ref parent.whenComplete, null);
if (action != null)
{
action();
}
}
}
~WhenAllPromise()
{
core.Reset();
}
}
class WhenAllPromise
sealed class WhenAllPromise : IUniTaskSource
{
int completeCount;
int resultLength;
Action whenComplete;
ExceptionDispatchInfo exception;
int tasksLength;
UniTaskCompletionSourceCore<AsyncUnit> core; // don't reset(called after GetResult, will invoke TrySetException.)
public WhenAllPromise(UniTask[] tasks, int tasksLength)
{
TaskTracker2.TrackActiveTask(this, 3);
this.tasksLength = tasksLength;
this.completeCount = 0;
this.whenComplete = null;
this.exception = null;
this.resultLength = tasksLength;
for (int i = 0; i < tasksLength; i++)
{
if (tasks[i].Status.IsCompleted())
UniTask.Awaiter awaiter;
try
{
try
{
tasks[i].GetAwaiter().GetResult();
}
catch (Exception ex)
{
exception = ExceptionDispatchInfo.Capture(ex);
TryCallContinuation();
continue;
}
awaiter = tasks[i].GetAwaiter();
}
catch (Exception ex)
{
core.TrySetException(ex);
continue;
}
var count = Interlocked.Increment(ref completeCount);
if (count == resultLength)
{
TryCallContinuation();
}
if (awaiter.IsCompleted)
{
TryInvokeContinuation(this, awaiter);
}
else
{
RunTask(tasks[i], i).Forget();
awaiter.SourceOnCompleted(state =>
{
using (var t = (StateTuple<WhenAllPromise, UniTask.Awaiter>)state)
{
TryInvokeContinuation(t.Item1, t.Item2);
}
}, StateTuple.Create(this, awaiter));
}
}
}
void TryCallContinuation()
{
var action = Interlocked.Exchange(ref whenComplete, null);
if (action != null)
{
action.Invoke();
}
}
async UniTaskVoid RunTask(UniTask task, int index)
static void TryInvokeContinuation(WhenAllPromise self, in UniTask.Awaiter awaiter)
{
try
{
await task;
awaiter.GetResult();
}
catch (Exception ex)
{
exception = ExceptionDispatchInfo.Capture(ex);
TryCallContinuation();
self.core.TrySetException(ex);
return;
}
var count = Interlocked.Increment(ref completeCount);
if (count == resultLength)
if (Interlocked.Increment(ref self.completeCount) == self.tasksLength)
{
TryCallContinuation();
self.core.TrySetResult(AsyncUnit.Default);
}
}
public Awaiter GetAwaiter()
public void GetResult(short token)
{
return new Awaiter(this);
TaskTracker2.RemoveTracking(this);
GC.SuppressFinalize(this);
core.GetResult(token);
}
public struct Awaiter : ICriticalNotifyCompletion
public UniTaskStatus GetStatus(short token)
{
WhenAllPromise parent;
return core.GetStatus(token);
}
public Awaiter(WhenAllPromise parent)
{
this.parent = parent;
}
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
public bool IsCompleted
{
get
{
return parent.exception != null || parent.resultLength == parent.completeCount;
}
}
public void OnCompleted(Action<object> continuation, object state, short token)
{
core.OnCompleted(continuation, state, token);
}
public void GetResult()
{
if (parent.exception != null)
{
parent.exception.Throw();
}
}
public void OnCompleted(Action continuation)
{
UnsafeOnCompleted(continuation);
}
public void UnsafeOnCompleted(Action continuation)
{
parent.whenComplete = continuation;
if (IsCompleted)
{
var action = Interlocked.Exchange(ref parent.whenComplete, null);
if (action != null)
{
action();
}
}
}
~WhenAllPromise()
{
core.Reset();
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,11 +0,0 @@
fileFormatVersion: 2
guid: 13d604ac281570c4eac9962429f19ca9
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

View File

@ -0,0 +1,122 @@
<#@ template debug="false" hostspecific="false" language="C#" #>
<#@ assembly name="System.Core" #>
<#@ import namespace="System.Linq" #>
<#@ import namespace="System.Text" #>
<#@ import namespace="System.Collections.Generic" #>
<#@ output extension=".cs" #>
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
using System;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Threading;
using UniRx.Async.Internal;
namespace UniRx.Async
{
public partial struct UniTask
{
<# for(var i = 2; i <= 15; i++ ) {
var range = Enumerable.Range(1, i);
var t = string.Join(", ", range.Select(x => "T" + x));
var args = string.Join(", ", range.Select(x => $"UniTask<T{x}> task{x}"));
var targs = string.Join(", ", range.Select(x => $"task{x}"));
var tresult = string.Join(", ", range.Select(x => $"task{x}.GetAwaiter().GetResult()"));
var tBool = string.Join(", ", range.Select(x => $"(bool hasResult, T{x} result{x})"));
var tfield = string.Join(", ", range.Select(x => $"self.t{x}"));
Func<int, string> getResult = j => string.Join(", ", range.Select(x => (x == j) ? "(true, result)" : "(false, default)"));
#>
public static UniTask<(int winArgumentIndex, <#= tBool #>)> WhenAny<<#= t #>>(<#= args #>)
{
return new UniTask<(int winArgumentIndex, <#= tBool #>)>(new WhenAnyPromise<<#= t #>>(<#= targs #>), 0);
}
sealed class WhenAnyPromise<<#= t #>> : IUniTaskSource<(int, <#= tBool #>)>
{
int completedCount;
UniTaskCompletionSourceCore<(int, <#= tBool #>)> core;
public WhenAnyPromise(<#= args #>)
{
TaskTracker2.TrackActiveTask(this, 3);
this.completedCount = 0;
<# for(var j = 1; j <= i; j++) { #>
{
var awaiter = task<#= j #>.GetAwaiter();
if (awaiter.IsCompleted)
{
TryInvokeContinuationT<#= j #>(this, awaiter);
}
else
{
awaiter.SourceOnCompleted(state =>
{
using (var t = (StateTuple<WhenAnyPromise<<#= t #>>, UniTask<T<#= j #>>.Awaiter>)state)
{
TryInvokeContinuationT<#= j #>(t.Item1, t.Item2);
}
}, StateTuple.Create(this, awaiter));
}
}
<# } #>
}
<# for(var j = 1; j <= i; j++) { #>
static void TryInvokeContinuationT<#= j #>(WhenAnyPromise<<#= t #>> self, in UniTask<T<#= j #>>.Awaiter awaiter)
{
T<#= j #> result;
try
{
result = awaiter.GetResult();
}
catch (Exception ex)
{
self.core.TrySetException(ex);
return;
}
if (Interlocked.Increment(ref self.completedCount) == 1)
{
self.core.TrySetResult((<#= j - 1 #>, <#= getResult(j) #>));
}
}
<# } #>
public (int, <#= tBool #>) GetResult(short token)
{
TaskTracker2.RemoveTracking(this);
GC.SuppressFinalize(this);
return core.GetResult(token);
}
public UniTaskStatus GetStatus(short token)
{
return core.GetStatus(token);
}
public void OnCompleted(Action<object> continuation, object state, short token)
{
core.OnCompleted(continuation, state, token);
}
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
void IUniTaskSource.GetResult(short token)
{
GetResult(token);
}
~WhenAnyPromise()
{
core.Reset();
}
}
<# } #>
}
}

View File

@ -2,370 +2,365 @@
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
using System;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Collections.Generic;
using System.Threading;
using UniRx.Async.Internal;
namespace UniRx.Async
{
public partial struct UniTask
{
// UniTask
public static async UniTask<(bool hasResultLeft, T0 result)> WhenAny<T0>(UniTask<T0> task0, UniTask task1)
public static UniTask<(bool hasResultLeft, T result)> WhenAny<T>(UniTask<T> leftTask, UniTask rightTask)
{
return await new UnitWhenAnyPromise<T0>(task0, task1);
return new UniTask<(bool, T)>(new WhenAnyLRPromise<T>(leftTask, rightTask), 0);
}
public static async UniTask<(int winArgumentIndex, T result)> WhenAny<T>(params UniTask<T>[] tasks)
public static UniTask<(int winArgumentIndex, T result)> WhenAny<T>(params UniTask<T>[] tasks)
{
return await new WhenAnyPromise<T>(tasks);
return new UniTask<(int, T)>(new WhenAnyPromise<T>(tasks, tasks.Length), 0);
}
public static UniTask<(int winArgumentIndex, T result)> WhenAny<T>(IEnumerable<UniTask<T>> tasks)
{
using (var span = ArrayPoolUtil.Materialize(tasks))
{
return new UniTask<(int, T)>(new WhenAnyPromise<T>(span.Array, span.Length), 0);
}
}
/// <summary>Return value is winArgumentIndex</summary>
public static async UniTask<int> WhenAny(params UniTask[] tasks)
public static UniTask<int> WhenAny(params UniTask[] tasks)
{
return await new WhenAnyPromise(tasks);
return new UniTask<int>(new WhenAnyPromise(tasks, tasks.Length), 0);
}
class UnitWhenAnyPromise<T0>
/// <summary>Return value is winArgumentIndex</summary>
public static UniTask<int> WhenAny(IEnumerable<UniTask> tasks)
{
T0 result0;
ExceptionDispatchInfo exception;
Action whenComplete;
int completeCount;
using (var span = ArrayPoolUtil.Materialize(tasks))
{
return new UniTask<int>(new WhenAnyPromise(span.Array, span.Length), 0);
}
}
sealed class WhenAnyLRPromise<T> : IUniTaskSource<(bool, T)>
{
int completedCount;
int winArgumentIndex;
UniTaskCompletionSourceCore<(bool, T)> core;
bool IsCompleted => exception != null || Volatile.Read(ref winArgumentIndex) != -1;
public UnitWhenAnyPromise(UniTask<T0> task0, UniTask task1)
public WhenAnyLRPromise(UniTask<T> leftTask, UniTask rightTask)
{
this.whenComplete = null;
this.exception = null;
this.completeCount = 0;
this.winArgumentIndex = -1;
this.result0 = default(T0);
TaskTracker2.TrackActiveTask(this, 3);
RunTask0(task0).Forget();
RunTask1(task1).Forget();
}
void TryCallContinuation()
{
var action = Interlocked.Exchange(ref whenComplete, null);
if (action != null)
{
action.Invoke();
}
}
async UniTaskVoid RunTask0(UniTask<T0> task)
{
T0 value;
try
{
value = await task;
}
catch (Exception ex)
{
exception = ExceptionDispatchInfo.Capture(ex);
TryCallContinuation();
return;
}
var count = Interlocked.Increment(ref completeCount);
if (count == 1)
{
result0 = value;
Volatile.Write(ref winArgumentIndex, 0);
TryCallContinuation();
}
}
async UniTaskVoid RunTask1(UniTask task)
{
try
{
await task;
}
catch (Exception ex)
{
exception = ExceptionDispatchInfo.Capture(ex);
TryCallContinuation();
return;
}
var count = Interlocked.Increment(ref completeCount);
if (count == 1)
{
Volatile.Write(ref winArgumentIndex, 1);
TryCallContinuation();
}
}
public Awaiter GetAwaiter()
{
return new Awaiter(this);
}
public struct Awaiter : ICriticalNotifyCompletion
{
UnitWhenAnyPromise<T0> parent;
public Awaiter(UnitWhenAnyPromise<T0> parent)
{
this.parent = parent;
}
public bool IsCompleted
{
get
UniTask<T>.Awaiter awaiter;
try
{
return parent.IsCompleted;
awaiter = leftTask.GetAwaiter();
}
}
public (bool, T0) GetResult()
{
if (parent.exception != null)
catch (Exception ex)
{
parent.exception.Throw();
core.TrySetException(ex);
goto RIGHT;
}
return (parent.winArgumentIndex == 0, parent.result0);
}
public void OnCompleted(Action continuation)
{
UnsafeOnCompleted(continuation);
}
public void UnsafeOnCompleted(Action continuation)
{
parent.whenComplete = continuation;
if (IsCompleted)
if (awaiter.IsCompleted)
{
var action = Interlocked.Exchange(ref parent.whenComplete, null);
if (action != null)
TryLeftInvokeContinuation(this, awaiter);
}
else
{
awaiter.SourceOnCompleted(state =>
{
action();
}
using (var t = (StateTuple<WhenAnyLRPromise<T>, UniTask<T>.Awaiter>)state)
{
TryLeftInvokeContinuation(t.Item1, t.Item2);
}
}, StateTuple.Create(this, awaiter));
}
}
RIGHT:
{
UniTask.Awaiter awaiter;
try
{
awaiter = rightTask.GetAwaiter();
}
catch (Exception ex)
{
core.TrySetException(ex);
return;
}
if (awaiter.IsCompleted)
{
TryRightInvokeContinuation(this, awaiter);
}
else
{
awaiter.SourceOnCompleted(state =>
{
using (var t = (StateTuple<WhenAnyLRPromise<T>, UniTask.Awaiter>)state)
{
TryRightInvokeContinuation(t.Item1, t.Item2);
}
}, StateTuple.Create(this, awaiter));
}
}
}
static void TryLeftInvokeContinuation(WhenAnyLRPromise<T> self, in UniTask<T>.Awaiter awaiter)
{
T result;
try
{
result = awaiter.GetResult();
}
catch (Exception ex)
{
self.core.TrySetException(ex);
return;
}
if (Interlocked.Increment(ref self.completedCount) == 1)
{
self.core.TrySetResult((true, result));
}
}
static void TryRightInvokeContinuation(WhenAnyLRPromise<T> self, in UniTask.Awaiter awaiter)
{
try
{
awaiter.GetResult();
}
catch (Exception ex)
{
self.core.TrySetException(ex);
return;
}
if (Interlocked.Increment(ref self.completedCount) == 1)
{
self.core.TrySetResult((false, default));
}
}
public (bool, T) GetResult(short token)
{
TaskTracker2.RemoveTracking(this);
GC.SuppressFinalize(this);
return core.GetResult(token);
}
public UniTaskStatus GetStatus(short token)
{
return core.GetStatus(token);
}
public void OnCompleted(Action<object> continuation, object state, short token)
{
core.OnCompleted(continuation, state, token);
}
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
void IUniTaskSource.GetResult(short token)
{
GetResult(token);
}
~WhenAnyLRPromise()
{
core.Reset();
}
}
class WhenAnyPromise<T>
sealed class WhenAnyPromise<T> : IUniTaskSource<(int, T)>
{
T result;
int completeCount;
int completedCount;
int winArgumentIndex;
Action whenComplete;
ExceptionDispatchInfo exception;
UniTaskCompletionSourceCore<(int, T)> core;
public bool IsComplete => exception != null || Volatile.Read(ref winArgumentIndex) != -1;
public WhenAnyPromise(UniTask<T>[] tasks)
public WhenAnyPromise(UniTask<T>[] tasks, int tasksLength)
{
this.completeCount = 0;
this.winArgumentIndex = -1;
this.whenComplete = null;
this.exception = null;
this.result = default(T);
TaskTracker2.TrackActiveTask(this, 3);
for (int i = 0; i < tasks.Length; i++)
for (int i = 0; i < tasksLength; i++)
{
RunTask(tasks[i], i).Forget();
UniTask<T>.Awaiter awaiter;
try
{
awaiter = tasks[i].GetAwaiter();
}
catch (Exception ex)
{
core.TrySetException(ex);
continue; // consume others.
}
if (awaiter.IsCompleted)
{
TryInvokeContinuation(this, awaiter, i);
}
else
{
awaiter.SourceOnCompleted(state =>
{
using (var t = (StateTuple<WhenAnyPromise<T>, UniTask<T>.Awaiter, int>)state)
{
TryInvokeContinuation(t.Item1, t.Item2, t.Item3);
}
}, StateTuple.Create(this, awaiter, i));
}
}
}
async UniTaskVoid RunTask(UniTask<T> task, int index)
static void TryInvokeContinuation(WhenAnyPromise<T> self, in UniTask<T>.Awaiter awaiter, int i)
{
T value;
T result;
try
{
value = await task;
result = awaiter.GetResult();
}
catch (Exception ex)
{
exception = ExceptionDispatchInfo.Capture(ex);
TryCallContinuation();
self.core.TrySetException(ex);
return;
}
var count = Interlocked.Increment(ref completeCount);
if (count == 1)
if (Interlocked.Increment(ref self.completedCount) == 1)
{
result = value;
Volatile.Write(ref winArgumentIndex, index);
TryCallContinuation();
self.core.TrySetResult((i, result));
}
}
void TryCallContinuation()
public (int, T) GetResult(short token)
{
var action = Interlocked.Exchange(ref whenComplete, null);
if (action != null)
{
action.Invoke();
}
TaskTracker2.RemoveTracking(this);
GC.SuppressFinalize(this);
return core.GetResult(token);
}
public Awaiter GetAwaiter()
public UniTaskStatus GetStatus(short token)
{
return new Awaiter(this);
return core.GetStatus(token);
}
public struct Awaiter : ICriticalNotifyCompletion
public void OnCompleted(Action<object> continuation, object state, short token)
{
WhenAnyPromise<T> parent;
core.OnCompleted(continuation, state, token);
}
public Awaiter(WhenAnyPromise<T> parent)
{
this.parent = parent;
}
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
public bool IsCompleted
{
get
{
return parent.IsComplete;
}
}
void IUniTaskSource.GetResult(short token)
{
GetResult(token);
}
public (int, T) GetResult()
{
if (parent.exception != null)
{
parent.exception.Throw();
}
return (parent.winArgumentIndex, parent.result);
}
public void OnCompleted(Action continuation)
{
UnsafeOnCompleted(continuation);
}
public void UnsafeOnCompleted(Action continuation)
{
parent.whenComplete = continuation;
if (IsCompleted)
{
var action = Interlocked.Exchange(ref parent.whenComplete, null);
if (action != null)
{
action();
}
}
}
~WhenAnyPromise()
{
core.Reset();
}
}
class WhenAnyPromise
sealed class WhenAnyPromise : IUniTaskSource<int>
{
int completeCount;
int completedCount;
int winArgumentIndex;
Action whenComplete;
ExceptionDispatchInfo exception;
UniTaskCompletionSourceCore<int> core;
public bool IsComplete => exception != null || Volatile.Read(ref winArgumentIndex) != -1;
public WhenAnyPromise(UniTask[] tasks)
public WhenAnyPromise(UniTask[] tasks, int tasksLength)
{
this.completeCount = 0;
this.winArgumentIndex = -1;
this.whenComplete = null;
this.exception = null;
TaskTracker2.TrackActiveTask(this, 3);
for (int i = 0; i < tasks.Length; i++)
for (int i = 0; i < tasksLength; i++)
{
RunTask(tasks[i], i).Forget();
UniTask.Awaiter awaiter;
try
{
awaiter = tasks[i].GetAwaiter();
}
catch (Exception ex)
{
core.TrySetException(ex);
continue; // consume others.
}
if (awaiter.IsCompleted)
{
TryInvokeContinuation(this, awaiter, i);
}
else
{
awaiter.SourceOnCompleted(state =>
{
using (var t = (StateTuple<WhenAnyPromise, UniTask.Awaiter, int>)state)
{
TryInvokeContinuation(t.Item1, t.Item2, t.Item3);
}
}, StateTuple.Create(this, awaiter, i));
}
}
}
async UniTaskVoid RunTask(UniTask task, int index)
static void TryInvokeContinuation(WhenAnyPromise self, in UniTask.Awaiter awaiter, int i)
{
try
{
await task;
awaiter.GetResult();
}
catch (Exception ex)
{
exception = ExceptionDispatchInfo.Capture(ex);
TryCallContinuation();
self.core.TrySetException(ex);
return;
}
var count = Interlocked.Increment(ref completeCount);
if (count == 1)
if (Interlocked.Increment(ref self.completedCount) == 1)
{
Volatile.Write(ref winArgumentIndex, index);
TryCallContinuation();
self.core.TrySetResult(i);
}
}
void TryCallContinuation()
public int GetResult(short token)
{
var action = Interlocked.Exchange(ref whenComplete, null);
if (action != null)
{
action.Invoke();
}
TaskTracker2.RemoveTracking(this);
GC.SuppressFinalize(this);
return core.GetResult(token);
}
public Awaiter GetAwaiter()
public UniTaskStatus GetStatus(short token)
{
return new Awaiter(this);
return core.GetStatus(token);
}
public struct Awaiter : ICriticalNotifyCompletion
public void OnCompleted(Action<object> continuation, object state, short token)
{
WhenAnyPromise parent;
core.OnCompleted(continuation, state, token);
}
public Awaiter(WhenAnyPromise parent)
{
this.parent = parent;
}
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
public bool IsCompleted
{
get
{
return parent.IsComplete;
}
}
void IUniTaskSource.GetResult(short token)
{
GetResult(token);
}
public int GetResult()
{
if (parent.exception != null)
{
parent.exception.Throw();
}
return parent.winArgumentIndex;
}
public void OnCompleted(Action continuation)
{
UnsafeOnCompleted(continuation);
}
public void UnsafeOnCompleted(Action continuation)
{
parent.whenComplete = continuation;
if (IsCompleted)
{
var action = Interlocked.Exchange(ref parent.whenComplete, null);
if (action != null)
{
action();
}
}
}
~WhenAnyPromise()
{
core.Reset();
}
}
}

View File

@ -47,9 +47,8 @@ namespace UniRx.Async
TResult result;
object error; // ExceptionDispatchInfo or OperationCanceledException
short version;
bool completed;
bool hasUnhandledError;
int completedCount; // 0: completed == false
Action<object> continuation;
object continuationState;
@ -61,7 +60,7 @@ namespace UniRx.Async
{
version += 1; // incr version.
}
completed = false;
completedCount = 0;
result = default;
error = null;
hasUnhandledError = false;
@ -92,25 +91,59 @@ namespace UniRx.Async
/// <summary>Completes with a successful result.</summary>
/// <param name="result">The result.</param>
public void SetResult(TResult result)
public bool TrySetResult(TResult result)
{
this.result = result;
SignalCompletion();
if (Interlocked.Increment(ref completedCount) == 1)
{
// setup result
this.result = result;
if (continuation != null || Interlocked.CompareExchange(ref this.continuation, UniTaskCompletionSourceCoreShared.s_sentinel, null) != null)
{
continuation(continuationState);
return true;
}
}
return false;
}
/// <summary>Completes with an error.</summary>
/// <param name="error">The exception.</param>
public void SetException(Exception error)
public bool TrySetException(Exception error)
{
this.hasUnhandledError = true;
this.error = ExceptionDispatchInfo.Capture(error);
SignalCompletion();
if (Interlocked.Increment(ref completedCount) == 1)
{
// setup result
this.hasUnhandledError = true;
this.error = ExceptionDispatchInfo.Capture(error);
if (continuation != null || Interlocked.CompareExchange(ref this.continuation, UniTaskCompletionSourceCoreShared.s_sentinel, null) != null)
{
continuation(continuationState);
return true;
}
}
return false;
}
public void SetCanceled(CancellationToken cancellationToken = default)
public bool TrySetCanceled(CancellationToken cancellationToken = default)
{
this.error = new OperationCanceledException(cancellationToken);
SignalCompletion();
if (Interlocked.Increment(ref completedCount) == 1)
{
// setup result
this.hasUnhandledError = true;
this.error = new OperationCanceledException(cancellationToken);
if (continuation != null || Interlocked.CompareExchange(ref this.continuation, UniTaskCompletionSourceCoreShared.s_sentinel, null) != null)
{
continuation(continuationState);
return true;
}
}
return false;
}
/// <summary>Gets the operation version.</summary>
@ -122,7 +155,7 @@ namespace UniRx.Async
public UniTaskStatus GetStatus(short token)
{
ValidateToken(token);
return (continuation == null || !completed) ? UniTaskStatus.Pending
return (continuation == null || (completedCount == 0)) ? UniTaskStatus.Pending
: (error == null) ? UniTaskStatus.Succeeded
: (error is OperationCanceledException) ? UniTaskStatus.Canceled
: UniTaskStatus.Faulted;
@ -132,7 +165,7 @@ namespace UniRx.Async
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public UniTaskStatus UnsafeGetStatus()
{
return (continuation == null || !completed) ? UniTaskStatus.Pending
return (continuation == null || (completedCount == 0)) ? UniTaskStatus.Pending
: (error == null) ? UniTaskStatus.Succeeded
: (error is OperationCanceledException) ? UniTaskStatus.Canceled
: UniTaskStatus.Faulted;
@ -145,7 +178,7 @@ namespace UniRx.Async
public TResult GetResult(short token)
{
ValidateToken(token);
if (!completed)
if (!(completedCount == 0))
{
throw new InvalidOperationException("not yet completed.");
}
@ -183,6 +216,15 @@ namespace UniRx.Async
/* no use ValueTaskSourceOnCOmpletedFlags, always no capture ExecutionContext and SynchronizationContext. */
/*
PatternA: GetStatus=Pending => OnCompleted => TrySet*** => GetResult
PatternB: TrySet*** => GetStatus=!Pending => GetResult
PatternC: GetStatus=Pending => TrySet/OnCompleted(race condition) => GetResult
C.1: win OnCompleted -> TrySet invoke saved continuation
C.2: win TrySet -> should invoke continuation here.
*/
// not set continuation yet.
object oldContinuation = this.continuation;
if (oldContinuation == null)
{
@ -192,10 +234,11 @@ namespace UniRx.Async
if (oldContinuation != null)
{
// Operation already completed, so we need to queue the supplied callback.
// already running continuation in TrySet.
// It will cause call OnCompleted multiple time, invalid.
if (!ReferenceEquals(oldContinuation, UniTaskCompletionSourceCoreShared.s_sentinel))
{
throw new InvalidOperationException("already completed.");
throw new InvalidOperationException();
}
continuation(state);
@ -210,23 +253,6 @@ namespace UniRx.Async
throw new InvalidOperationException("token version is not matched.");
}
}
/// <summary>Signals that the operation has completed. Invoked after the result or error has been set.</summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private void SignalCompletion()
{
if (completed)
{
throw new InvalidOperationException();
}
completed = true;
if (continuation != null || Interlocked.CompareExchange(ref this.continuation, UniTaskCompletionSourceCoreShared.s_sentinel, null) != null)
{
continuation(continuationState);
}
}
}
internal static class UniTaskCompletionSourceCoreShared // separated out of generic to avoid unnecessary duplication
@ -277,17 +303,17 @@ namespace UniRx.Async
public void SetResult()
{
core.SetResult(AsyncUnit.Default);
core.TrySetResult(AsyncUnit.Default);
}
public void SetCanceled(CancellationToken cancellationToken = default)
{
core.SetCanceled(cancellationToken);
core.TrySetCanceled(cancellationToken);
}
public void SetException(Exception exception)
{
core.SetException(exception);
core.TrySetException(exception);
}
public void GetResult(short token)
@ -369,17 +395,17 @@ namespace UniRx.Async
public void SetResult()
{
core.SetResult(AsyncUnit.Default);
core.TrySetResult(AsyncUnit.Default);
}
public void SetCanceled(CancellationToken cancellationToken = default)
{
core.SetCanceled(cancellationToken);
core.TrySetCanceled(cancellationToken);
}
public void SetException(Exception exception)
{
core.SetException(exception);
core.TrySetException(exception);
}
public void GetResult(short token)
@ -463,17 +489,17 @@ namespace UniRx.Async
public void SetResult(T result)
{
core.SetResult(result);
core.TrySetResult(result);
}
public void SetCanceled(CancellationToken cancellationToken = default)
{
core.SetCanceled(cancellationToken);
core.TrySetCanceled(cancellationToken);
}
public void SetException(Exception exception)
{
core.SetException(exception);
core.TrySetException(exception);
}
public T GetResult(short token)
@ -560,17 +586,17 @@ namespace UniRx.Async
public void SetResult(T result)
{
core.SetResult(result);
core.TrySetResult(result);
}
public void SetCanceled(CancellationToken cancellationToken = default)
{
core.SetCanceled(cancellationToken);
core.TrySetCanceled(cancellationToken);
}
public void SetException(Exception exception)
{
core.SetException(exception);
core.TrySetException(exception);
}
public T GetResult(short token)
@ -616,7 +642,6 @@ namespace UniRx.Async
if (pool.TryReturn(this))
{
GC.ReRegisterForFinalize(this);
return;
}
}
}

View File

@ -76,7 +76,16 @@ namespace UniRx.Async
{
try
{
var awaiter = task.GetAwaiter();
UniTask<T>.Awaiter awaiter;
try
{
awaiter = task.GetAwaiter();
}
catch (Exception ex)
{
return Task.FromException<T>(ex);
}
if (awaiter.IsCompleted)
{
try
@ -121,7 +130,16 @@ namespace UniRx.Async
{
try
{
var awaiter = task.GetAwaiter();
UniTask.Awaiter awaiter;
try
{
awaiter = task.GetAwaiter();
}
catch (Exception ex)
{
return Task.FromException(ex);
}
if (awaiter.IsCompleted)
{
try

View File

@ -573,7 +573,7 @@ namespace UniRx.Async
{
// TODO:Remove Tracking
// TaskTracker.RemoveTracking();
core.SetCanceled(cancellationToken);
core.TrySetCanceled(cancellationToken);
return false;
}
@ -586,7 +586,7 @@ namespace UniRx.Async
{
// TODO:Remove Tracking
// TaskTracker.RemoveTracking();
core.SetResult(asyncOperation.asset);
core.TrySetResult(asyncOperation.asset);
return false;
}