WhenAll and WhenAny

master
neuecc 2020-04-21 13:36:23 +09:00
parent 082f3e7335
commit 3654a9e2f9
16 changed files with 11143 additions and 3797 deletions

View File

@ -96,7 +96,7 @@ namespace UniRx.Async
{ {
if (cancellationToken.IsCancellationRequested) if (cancellationToken.IsCancellationRequested)
{ {
core.SetCanceled(cancellationToken); core.TrySetCanceled(cancellationToken);
return false; return false;
} }
@ -109,11 +109,11 @@ namespace UniRx.Async
} }
catch (Exception ex) catch (Exception ex)
{ {
core.SetException(ex); core.TrySetException(ex);
return false; return false;
} }
core.SetResult(null); core.TrySetResult(null);
return false; return false;
} }
@ -126,6 +126,14 @@ namespace UniRx.Async
exception = default; exception = default;
} }
~EnumeratorPromise()
{
if (pool.TryReturn(this))
{
GC.ReRegisterForFinalize(this);
}
}
// Unwrap YieldInstructions // Unwrap YieldInstructions
static IEnumerator ConsumeEnumerator(IEnumerator enumerator) static IEnumerator ConsumeEnumerator(IEnumerator enumerator)

View File

@ -40,7 +40,7 @@ namespace UniRx.Async.Internal
return new RentArray<T>(array, array.Length, null); return new RentArray<T>(array, array.Length, null);
} }
var defaultCount = 4; var defaultCount = 32;
if (source is ICollection<T> coll) if (source is ICollection<T> coll)
{ {
defaultCount = coll.Count; defaultCount = coll.Count;

View File

@ -10,6 +10,11 @@ namespace UniRx.Async.Internal
{ {
return StatePool<T1, T2>.Create(item1, item2); return StatePool<T1, T2>.Create(item1, item2);
} }
public static StateTuple<T1, T2, T3> Create<T1, T2, T3>(T1 item1, T2 item2, T3 item3)
{
return StatePool<T1, T2, T3>.Create(item1, item2, item3);
}
} }
internal class StateTuple<T1, T2> : IDisposable internal class StateTuple<T1, T2> : IDisposable
@ -54,4 +59,51 @@ namespace UniRx.Async.Internal
queue.Enqueue(tuple); queue.Enqueue(tuple);
} }
} }
internal class StateTuple<T1, T2, T3> : IDisposable
{
public T1 Item1;
public T2 Item2;
public T3 Item3;
public void Deconstruct(out T1 item1, out T2 item2, out T3 item3)
{
item1 = this.Item1;
item2 = this.Item2;
item3 = this.Item3;
}
public void Dispose()
{
StatePool<T1, T2, T3>.Return(this);
}
}
internal static class StatePool<T1, T2, T3>
{
static readonly ConcurrentQueue<StateTuple<T1, T2, T3>> queue = new ConcurrentQueue<StateTuple<T1, T2, T3>>();
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static StateTuple<T1, T2, T3> Create(T1 item1, T2 item2, T3 item3)
{
if (queue.TryDequeue(out var value))
{
value.Item1 = item1;
value.Item2 = item2;
value.Item3 = item3;
return value;
}
return new StateTuple<T1, T2, T3> { Item1 = item1, Item2 = item2, Item3 = item3 };
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void Return(StateTuple<T1, T2, T3> tuple)
{
tuple.Item1 = default;
tuple.Item2 = default;
tuple.Item3 = default;
queue.Enqueue(tuple);
}
}
} }

View File

@ -24,8 +24,6 @@ namespace UniRx.Async
public static UniTask DelayFrame(int delayFrameCount, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) public static UniTask DelayFrame(int delayFrameCount, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken))
{ {
PlayerLoopHelper.Initialize(
if (delayFrameCount < 0) if (delayFrameCount < 0)
{ {
throw new ArgumentOutOfRangeException("Delay does not allow minus delayFrameCount. delayFrameCount:" + delayFrameCount); throw new ArgumentOutOfRangeException("Delay does not allow minus delayFrameCount. delayFrameCount:" + delayFrameCount);
@ -121,11 +119,11 @@ namespace UniRx.Async
{ {
if (cancellationToken.IsCancellationRequested) if (cancellationToken.IsCancellationRequested)
{ {
core.SetCanceled(cancellationToken); core.TrySetCanceled(cancellationToken);
return false; return false;
} }
core.SetResult(null); core.TrySetResult(null);
return false; return false;
} }
@ -134,6 +132,14 @@ namespace UniRx.Async
core.Reset(); core.Reset();
cancellationToken = default; cancellationToken = default;
} }
~YieldPromise()
{
if (pool.TryReturn(this))
{
GC.ReRegisterForFinalize(this);
}
}
} }
sealed class DelayFramePromise : IUniTaskSource, IPlayerLoopItem, IPromisePoolItem sealed class DelayFramePromise : IUniTaskSource, IPlayerLoopItem, IPromisePoolItem
@ -202,13 +208,13 @@ namespace UniRx.Async
{ {
if (cancellationToken.IsCancellationRequested) if (cancellationToken.IsCancellationRequested)
{ {
core.SetCanceled(cancellationToken); core.TrySetCanceled(cancellationToken);
return false; return false;
} }
if (currentFrameCount == delayFrameCount) if (currentFrameCount == delayFrameCount)
{ {
core.SetResult(null); core.TrySetResult(null);
return false; return false;
} }
@ -223,6 +229,14 @@ namespace UniRx.Async
delayFrameCount = default; delayFrameCount = default;
cancellationToken = default; cancellationToken = default;
} }
~DelayFramePromise()
{
if (pool.TryReturn(this))
{
GC.ReRegisterForFinalize(this);
}
}
} }
sealed class DelayPromise : IUniTaskSource, IPlayerLoopItem, IPromisePoolItem sealed class DelayPromise : IUniTaskSource, IPlayerLoopItem, IPromisePoolItem
@ -292,14 +306,14 @@ namespace UniRx.Async
{ {
if (cancellationToken.IsCancellationRequested) if (cancellationToken.IsCancellationRequested)
{ {
core.SetCanceled(cancellationToken); core.TrySetCanceled(cancellationToken);
return false; return false;
} }
elapsed += Time.deltaTime; elapsed += Time.deltaTime;
if (elapsed >= delayFrameTimeSpan) if (elapsed >= delayFrameTimeSpan)
{ {
core.SetResult(null); core.TrySetResult(null);
return false; return false;
} }
@ -313,6 +327,14 @@ namespace UniRx.Async
elapsed = default; elapsed = default;
cancellationToken = default; cancellationToken = default;
} }
~DelayPromise()
{
if (pool.TryReturn(this))
{
GC.ReRegisterForFinalize(this);
}
}
} }
sealed class DelayIgnoreTimeScalePromise : IUniTaskSource, IPlayerLoopItem, IPromisePoolItem sealed class DelayIgnoreTimeScalePromise : IUniTaskSource, IPlayerLoopItem, IPromisePoolItem
@ -382,14 +404,14 @@ namespace UniRx.Async
{ {
if (cancellationToken.IsCancellationRequested) if (cancellationToken.IsCancellationRequested)
{ {
core.SetCanceled(cancellationToken); core.TrySetCanceled(cancellationToken);
return false; return false;
} }
elapsed += Time.unscaledDeltaTime; elapsed += Time.unscaledDeltaTime;
if (elapsed >= delayFrameTimeSpan) if (elapsed >= delayFrameTimeSpan)
{ {
core.SetResult(null); core.TrySetResult(null);
return false; return false;
} }
@ -403,6 +425,14 @@ namespace UniRx.Async
elapsed = default; elapsed = default;
cancellationToken = default; cancellationToken = default;
} }
~DelayIgnoreTimeScalePromise()
{
if (pool.TryReturn(this))
{
GC.ReRegisterForFinalize(this);
}
}
} }
} }

View File

@ -12,14 +12,12 @@ namespace UniRx.Async
{ {
public static UniTask WaitUntil(Func<bool> predicate, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) public static UniTask WaitUntil(Func<bool> predicate, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken))
{ {
var promise = new WaitUntilPromise(predicate, timing, cancellationToken); return new UniTask(WaitUntilPromise.Create(predicate, timing, cancellationToken, out var token), token);
return promise.Task;
} }
public static UniTask WaitWhile(Func<bool> predicate, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) public static UniTask WaitWhile(Func<bool> predicate, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken))
{ {
var promise = new WaitWhilePromise(predicate, timing, cancellationToken); return new UniTask(WaitWhilePromise.Create(predicate, timing, cancellationToken, out var token), token);
return promise.Task;
} }
public static UniTask<U> WaitUntilValueChanged<T, U>(T target, Func<T, U> monitorFunction, PlayerLoopTiming monitorTiming = PlayerLoopTiming.Update, IEqualityComparer<U> equalityComparer = null, CancellationToken cancellationToken = default(CancellationToken)) public static UniTask<U> WaitUntilValueChanged<T, U>(T target, Func<T, U> monitorFunction, PlayerLoopTiming monitorTiming = PlayerLoopTiming.Update, IEqualityComparer<U> equalityComparer = null, CancellationToken cancellationToken = default(CancellationToken))
@ -28,130 +26,293 @@ namespace UniRx.Async
var unityObject = target as UnityEngine.Object; var unityObject = target as UnityEngine.Object;
var isUnityObject = !object.ReferenceEquals(target, null); // don't use (unityObject == null) var isUnityObject = !object.ReferenceEquals(target, null); // don't use (unityObject == null)
return (isUnityObject) return new UniTask<U>(isUnityObject
? new WaitUntilValueChangedUnityObjectPromise<T, U>(target, monitorFunction, equalityComparer, monitorTiming, cancellationToken).Task ? WaitUntilValueChangedUnityObjectPromise<T, U>.Create(target, monitorFunction, equalityComparer, monitorTiming, cancellationToken, out var token)
: new WaitUntilValueChangedStandardObjectPromise<T, U>(target, monitorFunction, equalityComparer, monitorTiming, cancellationToken).Task; : WaitUntilValueChangedStandardObjectPromise<T, U>.Create(target, monitorFunction, equalityComparer, monitorTiming, cancellationToken, out token), token);
} }
class WaitUntilPromise : PlayerLoopReusablePromiseBase sealed class WaitUntilPromise : IUniTaskSource, IPlayerLoopItem, IPromisePoolItem
{ {
readonly Func<bool> predicate; static readonly PromisePool<WaitUntilPromise> pool = new PromisePool<WaitUntilPromise>();
public WaitUntilPromise(Func<bool> predicate, PlayerLoopTiming timing, CancellationToken cancellationToken) Func<bool> predicate;
: base(timing, cancellationToken, 1) CancellationToken cancellationToken;
{
this.predicate = predicate;
}
protected override void OnRunningStart() UniTaskCompletionSourceCore<object> core;
WaitUntilPromise()
{ {
} }
public override bool MoveNext() public static IUniTaskSource Create(Func<bool> predicate, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token)
{ {
if (cancellationToken.IsCancellationRequested) if (cancellationToken.IsCancellationRequested)
{ {
Complete(); return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token);
TrySetCanceled();
return false;
} }
bool result = default(bool); var result = pool.TryRent() ?? new WaitUntilPromise();
result.predicate = predicate;
result.cancellationToken = cancellationToken;
TaskTracker2.TrackActiveTask(result, 3);
PlayerLoopHelper.AddAction(timing, result);
token = result.core.Version;
return result;
}
public void GetResult(short token)
{
try try
{ {
result = predicate(); TaskTracker2.RemoveTracking(this);
core.GetResult(token);
} }
catch (Exception ex) finally
{ {
Complete(); pool.TryReturn(this);
TrySetException(ex);
return false;
} }
if (result)
{
Complete();
TrySetResult();
return false;
}
return true;
} }
}
class WaitWhilePromise : PlayerLoopReusablePromiseBase public UniTaskStatus GetStatus(short token)
{
readonly Func<bool> predicate;
public WaitWhilePromise(Func<bool> predicate, PlayerLoopTiming timing, CancellationToken cancellationToken)
: base(timing, cancellationToken, 1)
{ {
this.predicate = predicate; return core.GetStatus(token);
} }
protected override void OnRunningStart() public UniTaskStatus UnsafeGetStatus()
{ {
return core.UnsafeGetStatus();
} }
public override bool MoveNext() public void OnCompleted(Action<object> continuation, object state, short token)
{
core.OnCompleted(continuation, state, token);
}
public bool MoveNext()
{ {
if (cancellationToken.IsCancellationRequested) if (cancellationToken.IsCancellationRequested)
{ {
Complete(); core.TrySetCanceled(cancellationToken);
TrySetCanceled();
return false; return false;
} }
bool result = default(bool);
try try
{ {
result = predicate(); if (!predicate())
{
return true;
}
} }
catch (Exception ex) catch (Exception ex)
{ {
Complete(); core.TrySetException(ex);
TrySetException(ex);
return false; return false;
} }
if (!result) core.TrySetResult(null);
return false;
}
public void Reset()
{
core.Reset();
predicate = default;
cancellationToken = default;
}
~WaitUntilPromise()
{
if (pool.TryReturn(this))
{ {
Complete(); GC.ReRegisterForFinalize(this);
TrySetResult(); }
}
}
sealed class WaitWhilePromise : IUniTaskSource, IPlayerLoopItem, IPromisePoolItem
{
static readonly PromisePool<WaitWhilePromise> pool = new PromisePool<WaitWhilePromise>();
Func<bool> predicate;
CancellationToken cancellationToken;
UniTaskCompletionSourceCore<object> core;
WaitWhilePromise()
{
}
public static IUniTaskSource Create(Func<bool> predicate, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token)
{
if (cancellationToken.IsCancellationRequested)
{
return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token);
}
var result = pool.TryRent() ?? new WaitWhilePromise();
result.predicate = predicate;
result.cancellationToken = cancellationToken;
TaskTracker2.TrackActiveTask(result, 3);
PlayerLoopHelper.AddAction(timing, result);
token = result.core.Version;
return result;
}
public void GetResult(short token)
{
try
{
TaskTracker2.RemoveTracking(this);
core.GetResult(token);
}
finally
{
pool.TryReturn(this);
}
}
public UniTaskStatus GetStatus(short token)
{
return core.GetStatus(token);
}
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
public void OnCompleted(Action<object> continuation, object state, short token)
{
core.OnCompleted(continuation, state, token);
}
public bool MoveNext()
{
if (cancellationToken.IsCancellationRequested)
{
core.TrySetCanceled(cancellationToken);
return false; return false;
} }
return true; try
{
if (predicate())
{
return true;
}
}
catch (Exception ex)
{
core.TrySetException(ex);
return false;
}
core.TrySetResult(null);
return false;
}
public void Reset()
{
core.Reset();
predicate = default;
cancellationToken = default;
}
~WaitWhilePromise()
{
if (pool.TryReturn(this))
{
GC.ReRegisterForFinalize(this);
}
} }
} }
// where T : UnityEngine.Object, can not add constraint // where T : UnityEngine.Object, can not add constraint
class WaitUntilValueChangedUnityObjectPromise<T, U> : PlayerLoopReusablePromiseBase<U> sealed class WaitUntilValueChangedUnityObjectPromise<T, U> : IUniTaskSource<U>, IPlayerLoopItem, IPromisePoolItem
{ {
readonly T target; static readonly PromisePool<WaitUntilValueChangedUnityObjectPromise<T, U>> pool = new PromisePool<WaitUntilValueChangedUnityObjectPromise<T, U>>();
readonly Func<T, U> monitorFunction;
readonly IEqualityComparer<U> equalityComparer; T target;
U currentValue; U currentValue;
Func<T, U> monitorFunction;
IEqualityComparer<U> equalityComparer;
CancellationToken cancellationToken;
public WaitUntilValueChangedUnityObjectPromise(T target, Func<T, U> monitorFunction, IEqualityComparer<U> equalityComparer, PlayerLoopTiming timing, CancellationToken cancellationToken) UniTaskCompletionSourceCore<U> core;
: base(timing, cancellationToken, 1)
{
this.target = target;
this.monitorFunction = monitorFunction;
this.equalityComparer = equalityComparer ?? UnityEqualityComparer.GetDefault<U>();
this.currentValue = monitorFunction(target);
}
protected override void OnRunningStart() WaitUntilValueChangedUnityObjectPromise()
{ {
} }
public override bool MoveNext() public static IUniTaskSource<U> Create(T target, Func<T, U> monitorFunction, IEqualityComparer<U> equalityComparer, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token)
{
if (cancellationToken.IsCancellationRequested)
{
return AutoResetUniTaskCompletionSource<U>.CreateFromCanceled(cancellationToken, out token);
}
var result = pool.TryRent() ?? new WaitUntilValueChangedUnityObjectPromise<T, U>();
result.target = target;
result.monitorFunction = monitorFunction;
result.currentValue = monitorFunction(target);
result.equalityComparer = equalityComparer ?? UnityEqualityComparer.GetDefault<U>();
result.cancellationToken = cancellationToken;
TaskTracker2.TrackActiveTask(result, 3);
PlayerLoopHelper.AddAction(timing, result);
token = result.core.Version;
return result;
}
public U GetResult(short token)
{
try
{
TaskTracker2.RemoveTracking(this);
return core.GetResult(token);
}
finally
{
pool.TryReturn(this);
}
}
void IUniTaskSource.GetResult(short token)
{
GetResult(token);
}
public UniTaskStatus GetStatus(short token)
{
return core.GetStatus(token);
}
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
public void OnCompleted(Action<object> continuation, object state, short token)
{
core.OnCompleted(continuation, state, token);
}
public bool MoveNext()
{ {
if (cancellationToken.IsCancellationRequested || target == null) // destroyed = cancel. if (cancellationToken.IsCancellationRequested || target == null) // destroyed = cancel.
{ {
Complete(); core.TrySetCanceled(cancellationToken);
TrySetCanceled();
return false; return false;
} }
@ -166,45 +327,112 @@ namespace UniRx.Async
} }
catch (Exception ex) catch (Exception ex)
{ {
Complete(); core.TrySetException(ex);
TrySetException(ex);
return false; return false;
} }
Complete(); core.TrySetResult(nextValue);
currentValue = nextValue;
TrySetResult(nextValue);
return false; return false;
} }
public void Reset()
{
core.Reset();
target = default;
currentValue = default;
monitorFunction = default;
equalityComparer = default;
cancellationToken = default;
}
~WaitUntilValueChangedUnityObjectPromise()
{
if (pool.TryReturn(this))
{
GC.ReRegisterForFinalize(this);
}
}
} }
class WaitUntilValueChangedStandardObjectPromise<T, U> : PlayerLoopReusablePromiseBase<U>
sealed class WaitUntilValueChangedStandardObjectPromise<T, U> : IUniTaskSource<U>, IPlayerLoopItem, IPromisePoolItem
where T : class where T : class
{ {
readonly WeakReference<T> target; static readonly PromisePool<WaitUntilValueChangedStandardObjectPromise<T, U>> pool = new PromisePool<WaitUntilValueChangedStandardObjectPromise<T, U>>();
readonly Func<T, U> monitorFunction;
readonly IEqualityComparer<U> equalityComparer; WeakReference<T> target;
U currentValue; U currentValue;
Func<T, U> monitorFunction;
IEqualityComparer<U> equalityComparer;
CancellationToken cancellationToken;
public WaitUntilValueChangedStandardObjectPromise(T target, Func<T, U> monitorFunction, IEqualityComparer<U> equalityComparer, PlayerLoopTiming timing, CancellationToken cancellationToken) UniTaskCompletionSourceCore<U> core;
: base(timing, cancellationToken, 1)
{
this.target = new WeakReference<T>(target, false); // wrap in WeakReference.
this.monitorFunction = monitorFunction;
this.equalityComparer = equalityComparer ?? UnityEqualityComparer.GetDefault<U>();
this.currentValue = monitorFunction(target);
}
protected override void OnRunningStart() WaitUntilValueChangedStandardObjectPromise()
{ {
} }
public override bool MoveNext() public static IUniTaskSource<U> Create(T target, Func<T, U> monitorFunction, IEqualityComparer<U> equalityComparer, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token)
{ {
if (cancellationToken.IsCancellationRequested || !target.TryGetTarget(out var t)) if (cancellationToken.IsCancellationRequested)
{ {
Complete(); return AutoResetUniTaskCompletionSource<U>.CreateFromCanceled(cancellationToken, out token);
TrySetCanceled(); }
var result = pool.TryRent() ?? new WaitUntilValueChangedStandardObjectPromise<T, U>();
result.target = new WeakReference<T>(target, false); // wrap in WeakReference.
result.monitorFunction = monitorFunction;
result.currentValue = monitorFunction(target);
result.equalityComparer = equalityComparer ?? UnityEqualityComparer.GetDefault<U>();
result.cancellationToken = cancellationToken;
TaskTracker2.TrackActiveTask(result, 3);
PlayerLoopHelper.AddAction(timing, result);
token = result.core.Version;
return result;
}
public U GetResult(short token)
{
try
{
TaskTracker2.RemoveTracking(this);
return core.GetResult(token);
}
finally
{
pool.TryReturn(this);
}
}
void IUniTaskSource.GetResult(short token)
{
GetResult(token);
}
public UniTaskStatus GetStatus(short token)
{
return core.GetStatus(token);
}
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
public void OnCompleted(Action<object> continuation, object state, short token)
{
core.OnCompleted(continuation, state, token);
}
public bool MoveNext()
{
if (cancellationToken.IsCancellationRequested || !target.TryGetTarget(out var t)) // doesn't find = cancel.
{
core.TrySetCanceled(cancellationToken);
return false; return false;
} }
@ -219,16 +447,31 @@ namespace UniRx.Async
} }
catch (Exception ex) catch (Exception ex)
{ {
Complete(); core.TrySetException(ex);
TrySetException(ex);
return false; return false;
} }
Complete(); core.TrySetResult(nextValue);
currentValue = nextValue;
TrySetResult(nextValue);
return false; return false;
} }
public void Reset()
{
core.Reset();
target = default;
currentValue = default;
monitorFunction = default;
equalityComparer = default;
cancellationToken = default;
}
~WaitUntilValueChangedStandardObjectPromise()
{
if (pool.TryReturn(this))
{
GC.ReRegisterForFinalize(this);
}
}
} }
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -1,11 +0,0 @@
fileFormatVersion: 2
guid: 5110117231c8a6d4095fd0cbd3f4c142
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

View File

@ -0,0 +1,127 @@
<#@ template debug="false" hostspecific="false" language="C#" #>
<#@ assembly name="System.Core" #>
<#@ import namespace="System.Linq" #>
<#@ import namespace="System.Text" #>
<#@ import namespace="System.Collections.Generic" #>
<#@ output extension=".cs" #>
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
using System;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Threading;
using UniRx.Async.Internal;
namespace UniRx.Async
{
public partial struct UniTask
{
<# for(var i = 2; i <= 15; i++ ) {
var range = Enumerable.Range(1, i);
var t = string.Join(", ", range.Select(x => "T" + x));
var args = string.Join(", ", range.Select(x => $"UniTask<T{x}> task{x}"));
var targs = string.Join(", ", range.Select(x => $"task{x}"));
var tresult = string.Join(", ", range.Select(x => $"task{x}.GetAwaiter().GetResult()"));
var completedSuccessfullyAnd = string.Join(" && ", range.Select(x => $"task{x}.Status.IsCompletedSuccessfully()"));
var tfield = string.Join(", ", range.Select(x => $"self.t{x}"));
#>
public static UniTask<(<#= t #>)> WhenAll<<#= t #>>(<#= args #>)
{
if (<#= completedSuccessfullyAnd #>)
{
return new UniTask<(<#= t #>)>((<#= tresult #>));
}
return new UniTask<(<#= t #>)>(new WhenAllPromise<<#= t #>>(<#= targs #>), 0);
}
sealed class WhenAllPromise<<#= t #>> : IUniTaskSource<(<#= t #>)>
{
<# for(var j = 1; j <= i; j++) { #>
T<#= j #> t<#= j #> = default;
<# } #>
int completedCount;
UniTaskCompletionSourceCore<(<#= t #>)> core;
public WhenAllPromise(<#= args #>)
{
TaskTracker2.TrackActiveTask(this, 3);
this.completedCount = 0;
<# for(var j = 1; j <= i; j++) { #>
{
var awaiter = task<#= j #>.GetAwaiter();
if (awaiter.IsCompleted)
{
TryInvokeContinuationT<#= j #>(this, awaiter);
}
else
{
awaiter.SourceOnCompleted(state =>
{
using (var t = (StateTuple<WhenAllPromise<<#= t #>>, UniTask<T<#= j #>>.Awaiter>)state)
{
TryInvokeContinuationT<#= j #>(t.Item1, t.Item2);
}
}, StateTuple.Create(this, awaiter));
}
}
<# } #>
}
<# for(var j = 1; j <= i; j++) { #>
static void TryInvokeContinuationT<#= j #>(WhenAllPromise<<#= t #>> self, in UniTask<T<#= j #>>.Awaiter awaiter)
{
try
{
self.t<#= j #> = awaiter.GetResult();
}
catch (Exception ex)
{
self.core.TrySetException(ex);
return;
}
if (Interlocked.Increment(ref self.completedCount) == <#= i #>)
{
self.core.TrySetResult((<#= tfield #>));
}
}
<# } #>
public (<#= t #>) GetResult(short token)
{
TaskTracker2.RemoveTracking(this);
GC.SuppressFinalize(this);
return core.GetResult(token);
}
void IUniTaskSource.GetResult(short token)
{
GetResult(token);
}
public UniTaskStatus GetStatus(short token)
{
return core.GetStatus(token);
}
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
public void OnCompleted(Action<object> continuation, object state, short token)
{
core.OnCompleted(continuation, state, token);
}
~WhenAllPromise()
{
core.Reset();
}
}
<# } #>
}
}

View File

@ -3,8 +3,6 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Threading; using System.Threading;
using UniRx.Async.Internal; using UniRx.Async.Internal;
@ -12,285 +10,214 @@ namespace UniRx.Async
{ {
public partial struct UniTask public partial struct UniTask
{ {
// UniTask public static UniTask<T[]> WhenAll<T>(params UniTask<T>[] tasks)
public static async UniTask<T[]> WhenAll<T>(params UniTask<T>[] tasks)
{ {
return await new WhenAllPromise<T>(tasks, tasks.Length); return new UniTask<T[]>(new WhenAllPromise<T>(tasks, tasks.Length), 0);
} }
public static async UniTask<T[]> WhenAll<T>(IEnumerable<UniTask<T>> tasks) public static UniTask<T[]> WhenAll<T>(IEnumerable<UniTask<T>> tasks)
{ {
WhenAllPromise<T> promise;
using (var span = ArrayPoolUtil.Materialize(tasks)) using (var span = ArrayPoolUtil.Materialize(tasks))
{ {
promise = new WhenAllPromise<T>(span.Array, span.Length); var promise = new WhenAllPromise<T>(span.Array, span.Length); // consumed array in constructor.
return new UniTask<T[]>(promise, 0);
} }
return await promise;
} }
public static async UniTask WhenAll(params UniTask[] tasks) public static UniTask WhenAll(params UniTask[] tasks)
{ {
await new WhenAllPromise(tasks, tasks.Length); return new UniTask(new WhenAllPromise(tasks, tasks.Length), 0);
} }
public static async UniTask WhenAll(IEnumerable<UniTask> tasks) public static UniTask WhenAll(IEnumerable<UniTask> tasks)
{ {
WhenAllPromise promise;
using (var span = ArrayPoolUtil.Materialize(tasks)) using (var span = ArrayPoolUtil.Materialize(tasks))
{ {
promise = new WhenAllPromise(span.Array, span.Length); var promise = new WhenAllPromise(span.Array, span.Length); // consumed array in constructor.
return new UniTask(promise, 0);
} }
await promise;
} }
class WhenAllPromise<T> sealed class WhenAllPromise<T> : IUniTaskSource<T[]>
{ {
readonly T[] result; T[] result;
int completeCount; int completeCount;
Action whenComplete; UniTaskCompletionSourceCore<T[]> core; // don't reset(called after GetResult, will invoke TrySetException.)
ExceptionDispatchInfo exception;
public WhenAllPromise(UniTask<T>[] tasks, int tasksLength) public WhenAllPromise(UniTask<T>[] tasks, int tasksLength)
{ {
TaskTracker2.TrackActiveTask(this, 3);
this.completeCount = 0; this.completeCount = 0;
this.whenComplete = null;
this.exception = null;
this.result = new T[tasksLength]; this.result = new T[tasksLength];
for (int i = 0; i < tasksLength; i++) for (int i = 0; i < tasksLength; i++)
{ {
if (tasks[i].Status.IsCompleted()) UniTask<T>.Awaiter awaiter;
try
{ {
T value = default(T); awaiter = tasks[i].GetAwaiter();
try }
{ catch (Exception ex)
value = tasks[i].GetAwaiter().GetResult(); {
} core.TrySetException(ex);
catch (Exception ex) continue;
{ }
exception = ExceptionDispatchInfo.Capture(ex);
TryCallContinuation();
continue;
}
result[i] = value; if (awaiter.IsCompleted)
var count = Interlocked.Increment(ref completeCount); {
if (count == result.Length) TryInvokeContinuation(this, awaiter, i);
{
TryCallContinuation();
}
} }
else else
{ {
RunTask(tasks[i], i).Forget(); awaiter.SourceOnCompleted(state =>
{
using (var t = (StateTuple<WhenAllPromise<T>, UniTask<T>.Awaiter, int>)state)
{
TryInvokeContinuation(t.Item1, t.Item2, t.Item3);
}
}, StateTuple.Create(this, awaiter, i));
} }
} }
} }
void TryCallContinuation() static void TryInvokeContinuation(WhenAllPromise<T> self, in UniTask<T>.Awaiter awaiter, int i)
{ {
var action = Interlocked.Exchange(ref whenComplete, null);
if (action != null)
{
action.Invoke();
}
}
async UniTaskVoid RunTask(UniTask<T> task, int index)
{
T value = default(T);
try try
{ {
value = await task; self.result[i] = awaiter.GetResult();
} }
catch (Exception ex) catch (Exception ex)
{ {
exception = ExceptionDispatchInfo.Capture(ex); self.core.TrySetException(ex);
TryCallContinuation();
return; return;
} }
result[index] = value; if (Interlocked.Increment(ref self.completeCount) == self.result.Length)
var count = Interlocked.Increment(ref completeCount);
if (count == result.Length)
{ {
TryCallContinuation(); self.core.TrySetResult(self.result);
} }
} }
public Awaiter GetAwaiter() public T[] GetResult(short token)
{ {
return new Awaiter(this); TaskTracker2.RemoveTracking(this);
GC.SuppressFinalize(this);
return core.GetResult(token);
} }
public struct Awaiter : ICriticalNotifyCompletion void IUniTaskSource.GetResult(short token)
{ {
WhenAllPromise<T> parent; GetResult(token);
}
public Awaiter(WhenAllPromise<T> parent) public UniTaskStatus GetStatus(short token)
{ {
this.parent = parent; return core.GetStatus(token);
} }
public bool IsCompleted public UniTaskStatus UnsafeGetStatus()
{ {
get return core.UnsafeGetStatus();
{ }
return parent.exception != null || parent.result.Length == parent.completeCount;
}
}
public T[] GetResult() public void OnCompleted(Action<object> continuation, object state, short token)
{ {
if (parent.exception != null) core.OnCompleted(continuation, state, token);
{ }
parent.exception.Throw();
}
return parent.result; ~WhenAllPromise()
} {
core.Reset();
public void OnCompleted(Action continuation)
{
UnsafeOnCompleted(continuation);
}
public void UnsafeOnCompleted(Action continuation)
{
parent.whenComplete = continuation;
if (IsCompleted)
{
var action = Interlocked.Exchange(ref parent.whenComplete, null);
if (action != null)
{
action();
}
}
}
} }
} }
class WhenAllPromise sealed class WhenAllPromise : IUniTaskSource
{ {
int completeCount; int completeCount;
int resultLength; int tasksLength;
Action whenComplete; UniTaskCompletionSourceCore<AsyncUnit> core; // don't reset(called after GetResult, will invoke TrySetException.)
ExceptionDispatchInfo exception;
public WhenAllPromise(UniTask[] tasks, int tasksLength) public WhenAllPromise(UniTask[] tasks, int tasksLength)
{ {
TaskTracker2.TrackActiveTask(this, 3);
this.tasksLength = tasksLength;
this.completeCount = 0; this.completeCount = 0;
this.whenComplete = null;
this.exception = null;
this.resultLength = tasksLength;
for (int i = 0; i < tasksLength; i++) for (int i = 0; i < tasksLength; i++)
{ {
if (tasks[i].Status.IsCompleted()) UniTask.Awaiter awaiter;
try
{ {
try awaiter = tasks[i].GetAwaiter();
{ }
tasks[i].GetAwaiter().GetResult(); catch (Exception ex)
} {
catch (Exception ex) core.TrySetException(ex);
{ continue;
exception = ExceptionDispatchInfo.Capture(ex); }
TryCallContinuation();
continue;
}
var count = Interlocked.Increment(ref completeCount); if (awaiter.IsCompleted)
if (count == resultLength) {
{ TryInvokeContinuation(this, awaiter);
TryCallContinuation();
}
} }
else else
{ {
RunTask(tasks[i], i).Forget(); awaiter.SourceOnCompleted(state =>
{
using (var t = (StateTuple<WhenAllPromise, UniTask.Awaiter>)state)
{
TryInvokeContinuation(t.Item1, t.Item2);
}
}, StateTuple.Create(this, awaiter));
} }
} }
} }
void TryCallContinuation() static void TryInvokeContinuation(WhenAllPromise self, in UniTask.Awaiter awaiter)
{
var action = Interlocked.Exchange(ref whenComplete, null);
if (action != null)
{
action.Invoke();
}
}
async UniTaskVoid RunTask(UniTask task, int index)
{ {
try try
{ {
await task; awaiter.GetResult();
} }
catch (Exception ex) catch (Exception ex)
{ {
exception = ExceptionDispatchInfo.Capture(ex); self.core.TrySetException(ex);
TryCallContinuation();
return; return;
} }
var count = Interlocked.Increment(ref completeCount); if (Interlocked.Increment(ref self.completeCount) == self.tasksLength)
if (count == resultLength)
{ {
TryCallContinuation(); self.core.TrySetResult(AsyncUnit.Default);
} }
} }
public Awaiter GetAwaiter() public void GetResult(short token)
{ {
return new Awaiter(this); TaskTracker2.RemoveTracking(this);
GC.SuppressFinalize(this);
core.GetResult(token);
} }
public struct Awaiter : ICriticalNotifyCompletion public UniTaskStatus GetStatus(short token)
{ {
WhenAllPromise parent; return core.GetStatus(token);
}
public Awaiter(WhenAllPromise parent) public UniTaskStatus UnsafeGetStatus()
{ {
this.parent = parent; return core.UnsafeGetStatus();
} }
public bool IsCompleted public void OnCompleted(Action<object> continuation, object state, short token)
{ {
get core.OnCompleted(continuation, state, token);
{ }
return parent.exception != null || parent.resultLength == parent.completeCount;
}
}
public void GetResult() ~WhenAllPromise()
{ {
if (parent.exception != null) core.Reset();
{
parent.exception.Throw();
}
}
public void OnCompleted(Action continuation)
{
UnsafeOnCompleted(continuation);
}
public void UnsafeOnCompleted(Action continuation)
{
parent.whenComplete = continuation;
if (IsCompleted)
{
var action = Interlocked.Exchange(ref parent.whenComplete, null);
if (action != null)
{
action();
}
}
}
} }
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -1,11 +0,0 @@
fileFormatVersion: 2
guid: 13d604ac281570c4eac9962429f19ca9
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

View File

@ -0,0 +1,122 @@
<#@ template debug="false" hostspecific="false" language="C#" #>
<#@ assembly name="System.Core" #>
<#@ import namespace="System.Linq" #>
<#@ import namespace="System.Text" #>
<#@ import namespace="System.Collections.Generic" #>
<#@ output extension=".cs" #>
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
using System;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Threading;
using UniRx.Async.Internal;
namespace UniRx.Async
{
public partial struct UniTask
{
<# for(var i = 2; i <= 15; i++ ) {
var range = Enumerable.Range(1, i);
var t = string.Join(", ", range.Select(x => "T" + x));
var args = string.Join(", ", range.Select(x => $"UniTask<T{x}> task{x}"));
var targs = string.Join(", ", range.Select(x => $"task{x}"));
var tresult = string.Join(", ", range.Select(x => $"task{x}.GetAwaiter().GetResult()"));
var tBool = string.Join(", ", range.Select(x => $"(bool hasResult, T{x} result{x})"));
var tfield = string.Join(", ", range.Select(x => $"self.t{x}"));
Func<int, string> getResult = j => string.Join(", ", range.Select(x => (x == j) ? "(true, result)" : "(false, default)"));
#>
public static UniTask<(int winArgumentIndex, <#= tBool #>)> WhenAny<<#= t #>>(<#= args #>)
{
return new UniTask<(int winArgumentIndex, <#= tBool #>)>(new WhenAnyPromise<<#= t #>>(<#= targs #>), 0);
}
sealed class WhenAnyPromise<<#= t #>> : IUniTaskSource<(int, <#= tBool #>)>
{
int completedCount;
UniTaskCompletionSourceCore<(int, <#= tBool #>)> core;
public WhenAnyPromise(<#= args #>)
{
TaskTracker2.TrackActiveTask(this, 3);
this.completedCount = 0;
<# for(var j = 1; j <= i; j++) { #>
{
var awaiter = task<#= j #>.GetAwaiter();
if (awaiter.IsCompleted)
{
TryInvokeContinuationT<#= j #>(this, awaiter);
}
else
{
awaiter.SourceOnCompleted(state =>
{
using (var t = (StateTuple<WhenAnyPromise<<#= t #>>, UniTask<T<#= j #>>.Awaiter>)state)
{
TryInvokeContinuationT<#= j #>(t.Item1, t.Item2);
}
}, StateTuple.Create(this, awaiter));
}
}
<# } #>
}
<# for(var j = 1; j <= i; j++) { #>
static void TryInvokeContinuationT<#= j #>(WhenAnyPromise<<#= t #>> self, in UniTask<T<#= j #>>.Awaiter awaiter)
{
T<#= j #> result;
try
{
result = awaiter.GetResult();
}
catch (Exception ex)
{
self.core.TrySetException(ex);
return;
}
if (Interlocked.Increment(ref self.completedCount) == 1)
{
self.core.TrySetResult((<#= j - 1 #>, <#= getResult(j) #>));
}
}
<# } #>
public (int, <#= tBool #>) GetResult(short token)
{
TaskTracker2.RemoveTracking(this);
GC.SuppressFinalize(this);
return core.GetResult(token);
}
public UniTaskStatus GetStatus(short token)
{
return core.GetStatus(token);
}
public void OnCompleted(Action<object> continuation, object state, short token)
{
core.OnCompleted(continuation, state, token);
}
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
void IUniTaskSource.GetResult(short token)
{
GetResult(token);
}
~WhenAnyPromise()
{
core.Reset();
}
}
<# } #>
}
}

View File

@ -2,370 +2,365 @@
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
using System; using System;
using System.Runtime.CompilerServices; using System.Collections.Generic;
using System.Runtime.ExceptionServices;
using System.Threading; using System.Threading;
using UniRx.Async.Internal;
namespace UniRx.Async namespace UniRx.Async
{ {
public partial struct UniTask public partial struct UniTask
{ {
// UniTask public static UniTask<(bool hasResultLeft, T result)> WhenAny<T>(UniTask<T> leftTask, UniTask rightTask)
public static async UniTask<(bool hasResultLeft, T0 result)> WhenAny<T0>(UniTask<T0> task0, UniTask task1)
{ {
return await new UnitWhenAnyPromise<T0>(task0, task1); return new UniTask<(bool, T)>(new WhenAnyLRPromise<T>(leftTask, rightTask), 0);
} }
public static async UniTask<(int winArgumentIndex, T result)> WhenAny<T>(params UniTask<T>[] tasks) public static UniTask<(int winArgumentIndex, T result)> WhenAny<T>(params UniTask<T>[] tasks)
{ {
return await new WhenAnyPromise<T>(tasks); return new UniTask<(int, T)>(new WhenAnyPromise<T>(tasks, tasks.Length), 0);
}
public static UniTask<(int winArgumentIndex, T result)> WhenAny<T>(IEnumerable<UniTask<T>> tasks)
{
using (var span = ArrayPoolUtil.Materialize(tasks))
{
return new UniTask<(int, T)>(new WhenAnyPromise<T>(span.Array, span.Length), 0);
}
} }
/// <summary>Return value is winArgumentIndex</summary> /// <summary>Return value is winArgumentIndex</summary>
public static async UniTask<int> WhenAny(params UniTask[] tasks) public static UniTask<int> WhenAny(params UniTask[] tasks)
{ {
return await new WhenAnyPromise(tasks); return new UniTask<int>(new WhenAnyPromise(tasks, tasks.Length), 0);
} }
class UnitWhenAnyPromise<T0> /// <summary>Return value is winArgumentIndex</summary>
public static UniTask<int> WhenAny(IEnumerable<UniTask> tasks)
{ {
T0 result0; using (var span = ArrayPoolUtil.Materialize(tasks))
ExceptionDispatchInfo exception; {
Action whenComplete; return new UniTask<int>(new WhenAnyPromise(span.Array, span.Length), 0);
int completeCount; }
}
sealed class WhenAnyLRPromise<T> : IUniTaskSource<(bool, T)>
{
int completedCount;
int winArgumentIndex; int winArgumentIndex;
UniTaskCompletionSourceCore<(bool, T)> core;
bool IsCompleted => exception != null || Volatile.Read(ref winArgumentIndex) != -1; public WhenAnyLRPromise(UniTask<T> leftTask, UniTask rightTask)
public UnitWhenAnyPromise(UniTask<T0> task0, UniTask task1)
{ {
this.whenComplete = null; TaskTracker2.TrackActiveTask(this, 3);
this.exception = null;
this.completeCount = 0;
this.winArgumentIndex = -1;
this.result0 = default(T0);
RunTask0(task0).Forget();
RunTask1(task1).Forget();
}
void TryCallContinuation()
{
var action = Interlocked.Exchange(ref whenComplete, null);
if (action != null)
{ {
action.Invoke(); UniTask<T>.Awaiter awaiter;
} try
}
async UniTaskVoid RunTask0(UniTask<T0> task)
{
T0 value;
try
{
value = await task;
}
catch (Exception ex)
{
exception = ExceptionDispatchInfo.Capture(ex);
TryCallContinuation();
return;
}
var count = Interlocked.Increment(ref completeCount);
if (count == 1)
{
result0 = value;
Volatile.Write(ref winArgumentIndex, 0);
TryCallContinuation();
}
}
async UniTaskVoid RunTask1(UniTask task)
{
try
{
await task;
}
catch (Exception ex)
{
exception = ExceptionDispatchInfo.Capture(ex);
TryCallContinuation();
return;
}
var count = Interlocked.Increment(ref completeCount);
if (count == 1)
{
Volatile.Write(ref winArgumentIndex, 1);
TryCallContinuation();
}
}
public Awaiter GetAwaiter()
{
return new Awaiter(this);
}
public struct Awaiter : ICriticalNotifyCompletion
{
UnitWhenAnyPromise<T0> parent;
public Awaiter(UnitWhenAnyPromise<T0> parent)
{
this.parent = parent;
}
public bool IsCompleted
{
get
{ {
return parent.IsCompleted; awaiter = leftTask.GetAwaiter();
} }
} catch (Exception ex)
public (bool, T0) GetResult()
{
if (parent.exception != null)
{ {
parent.exception.Throw(); core.TrySetException(ex);
goto RIGHT;
} }
return (parent.winArgumentIndex == 0, parent.result0); if (awaiter.IsCompleted)
}
public void OnCompleted(Action continuation)
{
UnsafeOnCompleted(continuation);
}
public void UnsafeOnCompleted(Action continuation)
{
parent.whenComplete = continuation;
if (IsCompleted)
{ {
var action = Interlocked.Exchange(ref parent.whenComplete, null); TryLeftInvokeContinuation(this, awaiter);
if (action != null) }
else
{
awaiter.SourceOnCompleted(state =>
{ {
action(); using (var t = (StateTuple<WhenAnyLRPromise<T>, UniTask<T>.Awaiter>)state)
} {
TryLeftInvokeContinuation(t.Item1, t.Item2);
}
}, StateTuple.Create(this, awaiter));
} }
} }
RIGHT:
{
UniTask.Awaiter awaiter;
try
{
awaiter = rightTask.GetAwaiter();
}
catch (Exception ex)
{
core.TrySetException(ex);
return;
}
if (awaiter.IsCompleted)
{
TryRightInvokeContinuation(this, awaiter);
}
else
{
awaiter.SourceOnCompleted(state =>
{
using (var t = (StateTuple<WhenAnyLRPromise<T>, UniTask.Awaiter>)state)
{
TryRightInvokeContinuation(t.Item1, t.Item2);
}
}, StateTuple.Create(this, awaiter));
}
}
}
static void TryLeftInvokeContinuation(WhenAnyLRPromise<T> self, in UniTask<T>.Awaiter awaiter)
{
T result;
try
{
result = awaiter.GetResult();
}
catch (Exception ex)
{
self.core.TrySetException(ex);
return;
}
if (Interlocked.Increment(ref self.completedCount) == 1)
{
self.core.TrySetResult((true, result));
}
}
static void TryRightInvokeContinuation(WhenAnyLRPromise<T> self, in UniTask.Awaiter awaiter)
{
try
{
awaiter.GetResult();
}
catch (Exception ex)
{
self.core.TrySetException(ex);
return;
}
if (Interlocked.Increment(ref self.completedCount) == 1)
{
self.core.TrySetResult((false, default));
}
}
public (bool, T) GetResult(short token)
{
TaskTracker2.RemoveTracking(this);
GC.SuppressFinalize(this);
return core.GetResult(token);
}
public UniTaskStatus GetStatus(short token)
{
return core.GetStatus(token);
}
public void OnCompleted(Action<object> continuation, object state, short token)
{
core.OnCompleted(continuation, state, token);
}
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
void IUniTaskSource.GetResult(short token)
{
GetResult(token);
}
~WhenAnyLRPromise()
{
core.Reset();
} }
} }
class WhenAnyPromise<T> sealed class WhenAnyPromise<T> : IUniTaskSource<(int, T)>
{ {
T result; int completedCount;
int completeCount;
int winArgumentIndex; int winArgumentIndex;
Action whenComplete; UniTaskCompletionSourceCore<(int, T)> core;
ExceptionDispatchInfo exception;
public bool IsComplete => exception != null || Volatile.Read(ref winArgumentIndex) != -1; public WhenAnyPromise(UniTask<T>[] tasks, int tasksLength)
public WhenAnyPromise(UniTask<T>[] tasks)
{ {
this.completeCount = 0; TaskTracker2.TrackActiveTask(this, 3);
this.winArgumentIndex = -1;
this.whenComplete = null;
this.exception = null;
this.result = default(T);
for (int i = 0; i < tasks.Length; i++) for (int i = 0; i < tasksLength; i++)
{ {
RunTask(tasks[i], i).Forget(); UniTask<T>.Awaiter awaiter;
try
{
awaiter = tasks[i].GetAwaiter();
}
catch (Exception ex)
{
core.TrySetException(ex);
continue; // consume others.
}
if (awaiter.IsCompleted)
{
TryInvokeContinuation(this, awaiter, i);
}
else
{
awaiter.SourceOnCompleted(state =>
{
using (var t = (StateTuple<WhenAnyPromise<T>, UniTask<T>.Awaiter, int>)state)
{
TryInvokeContinuation(t.Item1, t.Item2, t.Item3);
}
}, StateTuple.Create(this, awaiter, i));
}
} }
} }
async UniTaskVoid RunTask(UniTask<T> task, int index) static void TryInvokeContinuation(WhenAnyPromise<T> self, in UniTask<T>.Awaiter awaiter, int i)
{ {
T value; T result;
try try
{ {
value = await task; result = awaiter.GetResult();
} }
catch (Exception ex) catch (Exception ex)
{ {
exception = ExceptionDispatchInfo.Capture(ex); self.core.TrySetException(ex);
TryCallContinuation();
return; return;
} }
var count = Interlocked.Increment(ref completeCount); if (Interlocked.Increment(ref self.completedCount) == 1)
if (count == 1)
{ {
result = value; self.core.TrySetResult((i, result));
Volatile.Write(ref winArgumentIndex, index);
TryCallContinuation();
} }
} }
void TryCallContinuation() public (int, T) GetResult(short token)
{ {
var action = Interlocked.Exchange(ref whenComplete, null); TaskTracker2.RemoveTracking(this);
if (action != null) GC.SuppressFinalize(this);
{ return core.GetResult(token);
action.Invoke();
}
} }
public Awaiter GetAwaiter() public UniTaskStatus GetStatus(short token)
{ {
return new Awaiter(this); return core.GetStatus(token);
} }
public struct Awaiter : ICriticalNotifyCompletion public void OnCompleted(Action<object> continuation, object state, short token)
{ {
WhenAnyPromise<T> parent; core.OnCompleted(continuation, state, token);
}
public Awaiter(WhenAnyPromise<T> parent) public UniTaskStatus UnsafeGetStatus()
{ {
this.parent = parent; return core.UnsafeGetStatus();
} }
public bool IsCompleted void IUniTaskSource.GetResult(short token)
{ {
get GetResult(token);
{ }
return parent.IsComplete;
}
}
public (int, T) GetResult() ~WhenAnyPromise()
{ {
if (parent.exception != null) core.Reset();
{
parent.exception.Throw();
}
return (parent.winArgumentIndex, parent.result);
}
public void OnCompleted(Action continuation)
{
UnsafeOnCompleted(continuation);
}
public void UnsafeOnCompleted(Action continuation)
{
parent.whenComplete = continuation;
if (IsCompleted)
{
var action = Interlocked.Exchange(ref parent.whenComplete, null);
if (action != null)
{
action();
}
}
}
} }
} }
class WhenAnyPromise sealed class WhenAnyPromise : IUniTaskSource<int>
{ {
int completeCount; int completedCount;
int winArgumentIndex; int winArgumentIndex;
Action whenComplete; UniTaskCompletionSourceCore<int> core;
ExceptionDispatchInfo exception;
public bool IsComplete => exception != null || Volatile.Read(ref winArgumentIndex) != -1; public WhenAnyPromise(UniTask[] tasks, int tasksLength)
public WhenAnyPromise(UniTask[] tasks)
{ {
this.completeCount = 0; TaskTracker2.TrackActiveTask(this, 3);
this.winArgumentIndex = -1;
this.whenComplete = null;
this.exception = null;
for (int i = 0; i < tasks.Length; i++) for (int i = 0; i < tasksLength; i++)
{ {
RunTask(tasks[i], i).Forget(); UniTask.Awaiter awaiter;
try
{
awaiter = tasks[i].GetAwaiter();
}
catch (Exception ex)
{
core.TrySetException(ex);
continue; // consume others.
}
if (awaiter.IsCompleted)
{
TryInvokeContinuation(this, awaiter, i);
}
else
{
awaiter.SourceOnCompleted(state =>
{
using (var t = (StateTuple<WhenAnyPromise, UniTask.Awaiter, int>)state)
{
TryInvokeContinuation(t.Item1, t.Item2, t.Item3);
}
}, StateTuple.Create(this, awaiter, i));
}
} }
} }
async UniTaskVoid RunTask(UniTask task, int index) static void TryInvokeContinuation(WhenAnyPromise self, in UniTask.Awaiter awaiter, int i)
{ {
try try
{ {
await task; awaiter.GetResult();
} }
catch (Exception ex) catch (Exception ex)
{ {
exception = ExceptionDispatchInfo.Capture(ex); self.core.TrySetException(ex);
TryCallContinuation();
return; return;
} }
var count = Interlocked.Increment(ref completeCount); if (Interlocked.Increment(ref self.completedCount) == 1)
if (count == 1)
{ {
Volatile.Write(ref winArgumentIndex, index); self.core.TrySetResult(i);
TryCallContinuation();
} }
} }
void TryCallContinuation() public int GetResult(short token)
{ {
var action = Interlocked.Exchange(ref whenComplete, null); TaskTracker2.RemoveTracking(this);
if (action != null) GC.SuppressFinalize(this);
{ return core.GetResult(token);
action.Invoke();
}
} }
public Awaiter GetAwaiter() public UniTaskStatus GetStatus(short token)
{ {
return new Awaiter(this); return core.GetStatus(token);
} }
public struct Awaiter : ICriticalNotifyCompletion public void OnCompleted(Action<object> continuation, object state, short token)
{ {
WhenAnyPromise parent; core.OnCompleted(continuation, state, token);
}
public Awaiter(WhenAnyPromise parent) public UniTaskStatus UnsafeGetStatus()
{ {
this.parent = parent; return core.UnsafeGetStatus();
} }
public bool IsCompleted void IUniTaskSource.GetResult(short token)
{ {
get GetResult(token);
{ }
return parent.IsComplete;
}
}
public int GetResult() ~WhenAnyPromise()
{ {
if (parent.exception != null) core.Reset();
{
parent.exception.Throw();
}
return parent.winArgumentIndex;
}
public void OnCompleted(Action continuation)
{
UnsafeOnCompleted(continuation);
}
public void UnsafeOnCompleted(Action continuation)
{
parent.whenComplete = continuation;
if (IsCompleted)
{
var action = Interlocked.Exchange(ref parent.whenComplete, null);
if (action != null)
{
action();
}
}
}
} }
} }
} }

View File

@ -47,9 +47,8 @@ namespace UniRx.Async
TResult result; TResult result;
object error; // ExceptionDispatchInfo or OperationCanceledException object error; // ExceptionDispatchInfo or OperationCanceledException
short version; short version;
bool completed;
bool hasUnhandledError; bool hasUnhandledError;
int completedCount; // 0: completed == false
Action<object> continuation; Action<object> continuation;
object continuationState; object continuationState;
@ -61,7 +60,7 @@ namespace UniRx.Async
{ {
version += 1; // incr version. version += 1; // incr version.
} }
completed = false; completedCount = 0;
result = default; result = default;
error = null; error = null;
hasUnhandledError = false; hasUnhandledError = false;
@ -92,25 +91,59 @@ namespace UniRx.Async
/// <summary>Completes with a successful result.</summary> /// <summary>Completes with a successful result.</summary>
/// <param name="result">The result.</param> /// <param name="result">The result.</param>
public void SetResult(TResult result) public bool TrySetResult(TResult result)
{ {
this.result = result; if (Interlocked.Increment(ref completedCount) == 1)
SignalCompletion(); {
// setup result
this.result = result;
if (continuation != null || Interlocked.CompareExchange(ref this.continuation, UniTaskCompletionSourceCoreShared.s_sentinel, null) != null)
{
continuation(continuationState);
return true;
}
}
return false;
} }
/// <summary>Completes with an error.</summary> /// <summary>Completes with an error.</summary>
/// <param name="error">The exception.</param> /// <param name="error">The exception.</param>
public void SetException(Exception error) public bool TrySetException(Exception error)
{ {
this.hasUnhandledError = true; if (Interlocked.Increment(ref completedCount) == 1)
this.error = ExceptionDispatchInfo.Capture(error); {
SignalCompletion(); // setup result
this.hasUnhandledError = true;
this.error = ExceptionDispatchInfo.Capture(error);
if (continuation != null || Interlocked.CompareExchange(ref this.continuation, UniTaskCompletionSourceCoreShared.s_sentinel, null) != null)
{
continuation(continuationState);
return true;
}
}
return false;
} }
public void SetCanceled(CancellationToken cancellationToken = default) public bool TrySetCanceled(CancellationToken cancellationToken = default)
{ {
this.error = new OperationCanceledException(cancellationToken); if (Interlocked.Increment(ref completedCount) == 1)
SignalCompletion(); {
// setup result
this.hasUnhandledError = true;
this.error = new OperationCanceledException(cancellationToken);
if (continuation != null || Interlocked.CompareExchange(ref this.continuation, UniTaskCompletionSourceCoreShared.s_sentinel, null) != null)
{
continuation(continuationState);
return true;
}
}
return false;
} }
/// <summary>Gets the operation version.</summary> /// <summary>Gets the operation version.</summary>
@ -122,7 +155,7 @@ namespace UniRx.Async
public UniTaskStatus GetStatus(short token) public UniTaskStatus GetStatus(short token)
{ {
ValidateToken(token); ValidateToken(token);
return (continuation == null || !completed) ? UniTaskStatus.Pending return (continuation == null || (completedCount == 0)) ? UniTaskStatus.Pending
: (error == null) ? UniTaskStatus.Succeeded : (error == null) ? UniTaskStatus.Succeeded
: (error is OperationCanceledException) ? UniTaskStatus.Canceled : (error is OperationCanceledException) ? UniTaskStatus.Canceled
: UniTaskStatus.Faulted; : UniTaskStatus.Faulted;
@ -132,7 +165,7 @@ namespace UniRx.Async
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public UniTaskStatus UnsafeGetStatus() public UniTaskStatus UnsafeGetStatus()
{ {
return (continuation == null || !completed) ? UniTaskStatus.Pending return (continuation == null || (completedCount == 0)) ? UniTaskStatus.Pending
: (error == null) ? UniTaskStatus.Succeeded : (error == null) ? UniTaskStatus.Succeeded
: (error is OperationCanceledException) ? UniTaskStatus.Canceled : (error is OperationCanceledException) ? UniTaskStatus.Canceled
: UniTaskStatus.Faulted; : UniTaskStatus.Faulted;
@ -145,7 +178,7 @@ namespace UniRx.Async
public TResult GetResult(short token) public TResult GetResult(short token)
{ {
ValidateToken(token); ValidateToken(token);
if (!completed) if (!(completedCount == 0))
{ {
throw new InvalidOperationException("not yet completed."); throw new InvalidOperationException("not yet completed.");
} }
@ -183,6 +216,15 @@ namespace UniRx.Async
/* no use ValueTaskSourceOnCOmpletedFlags, always no capture ExecutionContext and SynchronizationContext. */ /* no use ValueTaskSourceOnCOmpletedFlags, always no capture ExecutionContext and SynchronizationContext. */
/*
PatternA: GetStatus=Pending => OnCompleted => TrySet*** => GetResult
PatternB: TrySet*** => GetStatus=!Pending => GetResult
PatternC: GetStatus=Pending => TrySet/OnCompleted(race condition) => GetResult
C.1: win OnCompleted -> TrySet invoke saved continuation
C.2: win TrySet -> should invoke continuation here.
*/
// not set continuation yet.
object oldContinuation = this.continuation; object oldContinuation = this.continuation;
if (oldContinuation == null) if (oldContinuation == null)
{ {
@ -192,10 +234,11 @@ namespace UniRx.Async
if (oldContinuation != null) if (oldContinuation != null)
{ {
// Operation already completed, so we need to queue the supplied callback. // already running continuation in TrySet.
// It will cause call OnCompleted multiple time, invalid.
if (!ReferenceEquals(oldContinuation, UniTaskCompletionSourceCoreShared.s_sentinel)) if (!ReferenceEquals(oldContinuation, UniTaskCompletionSourceCoreShared.s_sentinel))
{ {
throw new InvalidOperationException("already completed."); throw new InvalidOperationException();
} }
continuation(state); continuation(state);
@ -210,23 +253,6 @@ namespace UniRx.Async
throw new InvalidOperationException("token version is not matched."); throw new InvalidOperationException("token version is not matched.");
} }
} }
/// <summary>Signals that the operation has completed. Invoked after the result or error has been set.</summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private void SignalCompletion()
{
if (completed)
{
throw new InvalidOperationException();
}
completed = true;
if (continuation != null || Interlocked.CompareExchange(ref this.continuation, UniTaskCompletionSourceCoreShared.s_sentinel, null) != null)
{
continuation(continuationState);
}
}
} }
internal static class UniTaskCompletionSourceCoreShared // separated out of generic to avoid unnecessary duplication internal static class UniTaskCompletionSourceCoreShared // separated out of generic to avoid unnecessary duplication
@ -277,17 +303,17 @@ namespace UniRx.Async
public void SetResult() public void SetResult()
{ {
core.SetResult(AsyncUnit.Default); core.TrySetResult(AsyncUnit.Default);
} }
public void SetCanceled(CancellationToken cancellationToken = default) public void SetCanceled(CancellationToken cancellationToken = default)
{ {
core.SetCanceled(cancellationToken); core.TrySetCanceled(cancellationToken);
} }
public void SetException(Exception exception) public void SetException(Exception exception)
{ {
core.SetException(exception); core.TrySetException(exception);
} }
public void GetResult(short token) public void GetResult(short token)
@ -369,17 +395,17 @@ namespace UniRx.Async
public void SetResult() public void SetResult()
{ {
core.SetResult(AsyncUnit.Default); core.TrySetResult(AsyncUnit.Default);
} }
public void SetCanceled(CancellationToken cancellationToken = default) public void SetCanceled(CancellationToken cancellationToken = default)
{ {
core.SetCanceled(cancellationToken); core.TrySetCanceled(cancellationToken);
} }
public void SetException(Exception exception) public void SetException(Exception exception)
{ {
core.SetException(exception); core.TrySetException(exception);
} }
public void GetResult(short token) public void GetResult(short token)
@ -463,17 +489,17 @@ namespace UniRx.Async
public void SetResult(T result) public void SetResult(T result)
{ {
core.SetResult(result); core.TrySetResult(result);
} }
public void SetCanceled(CancellationToken cancellationToken = default) public void SetCanceled(CancellationToken cancellationToken = default)
{ {
core.SetCanceled(cancellationToken); core.TrySetCanceled(cancellationToken);
} }
public void SetException(Exception exception) public void SetException(Exception exception)
{ {
core.SetException(exception); core.TrySetException(exception);
} }
public T GetResult(short token) public T GetResult(short token)
@ -560,17 +586,17 @@ namespace UniRx.Async
public void SetResult(T result) public void SetResult(T result)
{ {
core.SetResult(result); core.TrySetResult(result);
} }
public void SetCanceled(CancellationToken cancellationToken = default) public void SetCanceled(CancellationToken cancellationToken = default)
{ {
core.SetCanceled(cancellationToken); core.TrySetCanceled(cancellationToken);
} }
public void SetException(Exception exception) public void SetException(Exception exception)
{ {
core.SetException(exception); core.TrySetException(exception);
} }
public T GetResult(short token) public T GetResult(short token)
@ -616,7 +642,6 @@ namespace UniRx.Async
if (pool.TryReturn(this)) if (pool.TryReturn(this))
{ {
GC.ReRegisterForFinalize(this); GC.ReRegisterForFinalize(this);
return;
} }
} }
} }

View File

@ -76,7 +76,16 @@ namespace UniRx.Async
{ {
try try
{ {
var awaiter = task.GetAwaiter(); UniTask<T>.Awaiter awaiter;
try
{
awaiter = task.GetAwaiter();
}
catch (Exception ex)
{
return Task.FromException<T>(ex);
}
if (awaiter.IsCompleted) if (awaiter.IsCompleted)
{ {
try try
@ -121,7 +130,16 @@ namespace UniRx.Async
{ {
try try
{ {
var awaiter = task.GetAwaiter(); UniTask.Awaiter awaiter;
try
{
awaiter = task.GetAwaiter();
}
catch (Exception ex)
{
return Task.FromException(ex);
}
if (awaiter.IsCompleted) if (awaiter.IsCompleted)
{ {
try try

View File

@ -573,7 +573,7 @@ namespace UniRx.Async
{ {
// TODO:Remove Tracking // TODO:Remove Tracking
// TaskTracker.RemoveTracking(); // TaskTracker.RemoveTracking();
core.SetCanceled(cancellationToken); core.TrySetCanceled(cancellationToken);
return false; return false;
} }
@ -586,7 +586,7 @@ namespace UniRx.Async
{ {
// TODO:Remove Tracking // TODO:Remove Tracking
// TaskTracker.RemoveTracking(); // TaskTracker.RemoveTracking();
core.SetResult(asyncOperation.asset); core.TrySetResult(asyncOperation.asset);
return false; return false;
} }