diff --git a/Assets/UniRx.Async/UniTask.Bridge.cs b/Assets/UniRx.Async/UniTask.Bridge.cs index 814bd02..a9ecbf6 100644 --- a/Assets/UniRx.Async/UniTask.Bridge.cs +++ b/Assets/UniRx.Async/UniTask.Bridge.cs @@ -8,6 +8,14 @@ namespace UniRx.Async { // UnityEngine Bridges. + public partial struct UniTask2 + { + public static IEnumerator ToCoroutine(Func taskFactory) + { + return taskFactory().ToCoroutine(); + } + } + public partial struct UniTask { public static IEnumerator ToCoroutine(Func taskFactory) diff --git a/Assets/UniRx.Async/UniTask.Delay.cs b/Assets/UniRx.Async/UniTask.Delay.cs index d828358..25e8b5f 100644 --- a/Assets/UniRx.Async/UniTask.Delay.cs +++ b/Assets/UniRx.Async/UniTask.Delay.cs @@ -9,6 +9,449 @@ using UnityEngine; namespace UniRx.Async { + // TODO:rename + public partial struct UniTask2 + { + public static YieldAwaitable2 Yield(PlayerLoopTiming timing = PlayerLoopTiming.Update) + { + // optimized for single continuation + return new YieldAwaitable2(timing); + } + + public static UniTask2 Yield(PlayerLoopTiming timing, CancellationToken cancellationToken) + { + return new UniTask2(YieldPromise.Create(timing, cancellationToken, out var token), token); + } + + public static UniTask2 DelayFrame(int delayFrameCount, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + { + if (delayFrameCount < 0) + { + throw new ArgumentOutOfRangeException("Delay does not allow minus delayFrameCount. delayFrameCount:" + delayFrameCount); + } + + return new UniTask2(DelayFramePromise.Create(delayFrameCount, delayTiming, cancellationToken, out var token), token); + } + + public static UniTask2 Delay(int millisecondsDelay, bool ignoreTimeScale = false, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + { + var delayTimeSpan = TimeSpan.FromMilliseconds(millisecondsDelay); + if (delayTimeSpan < TimeSpan.Zero) + { + throw new ArgumentOutOfRangeException("Delay does not allow minus millisecondsDelay. millisecondsDelay:" + millisecondsDelay); + } + + return (ignoreTimeScale) + ? new UniTask2(DelayIgnoreTimeScalePromise.Create(delayTimeSpan, delayTiming, cancellationToken, out var token), token) + : new UniTask2(DelayPromise.Create(delayTimeSpan, delayTiming, cancellationToken, out token), token); + } + + public static UniTask2 Delay(TimeSpan delayTimeSpan, bool ignoreTimeScale = false, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + { + if (delayTimeSpan < TimeSpan.Zero) + { + throw new ArgumentOutOfRangeException("Delay does not allow minus delayTimeSpan. delayTimeSpan:" + delayTimeSpan); + } + + return (ignoreTimeScale) + ? new UniTask2(DelayIgnoreTimeScalePromise.Create(delayTimeSpan, delayTiming, cancellationToken, out var token), token) + : new UniTask2(DelayPromise.Create(delayTimeSpan, delayTiming, cancellationToken, out token), token); + } + + class YieldPromise : IUniTaskSource, IPlayerLoopItem, IPromisePoolItem + { + static readonly PromisePool pool = new PromisePool(); + + CancellationToken cancellationToken; + UniTaskCompletionSourceCore core; + + YieldPromise() + { + } + + public static IUniTaskSource Create(PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + { + if (cancellationToken.IsCancellationRequested) + { + return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token); + } + + var result = pool.TryRent() ?? new YieldPromise(); + + result.cancellationToken = cancellationToken; + + TaskTracker2.TrackActiveTask(result, 3); + + PlayerLoopHelper.AddAction(timing, result); + + token = result.core.Version; + return result; + } + + public void GetResult(short token) + { + try + { + TaskTracker2.RemoveTracking(this); + core.GetResult(token); + } + finally + { + pool.TryReturn(this); + } + } + + public AwaiterStatus GetStatus(short token) + { + return core.GetStatus(token); + } + + public AwaiterStatus UnsafeGetStatus() + { + return core.UnsafeGetStatus(); + } + + public void OnCompleted(Action continuation, object state, short token) + { + core.OnCompleted(continuation, state, token); + } + + public bool MoveNext() + { + if (cancellationToken.IsCancellationRequested) + { + core.SetCanceled(cancellationToken); + return false; + } + + core.SetResult(null); + return false; + } + + public void Reset() + { + core.Reset(); + cancellationToken = default; + } + } + + class DelayFramePromise : IUniTaskSource, IPlayerLoopItem, IPromisePoolItem + { + static readonly PromisePool pool = new PromisePool(); + + int delayFrameCount; + CancellationToken cancellationToken; + + int currentFrameCount; + UniTaskCompletionSourceCore core; + + DelayFramePromise() + { + } + + public static IUniTaskSource Create(int delayFrameCount, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + { + if (cancellationToken.IsCancellationRequested) + { + return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token); + } + + var result = pool.TryRent() ?? new DelayFramePromise(); + + result.delayFrameCount = delayFrameCount; + result.cancellationToken = cancellationToken; + + TaskTracker2.TrackActiveTask(result, 3); + + PlayerLoopHelper.AddAction(timing, result); + + token = result.core.Version; + return result; + } + + public void GetResult(short token) + { + try + { + TaskTracker2.RemoveTracking(this); + core.GetResult(token); + } + finally + { + pool.TryReturn(this); + } + } + + public AwaiterStatus GetStatus(short token) + { + return core.GetStatus(token); + } + + public AwaiterStatus UnsafeGetStatus() + { + return core.UnsafeGetStatus(); + } + + public void OnCompleted(Action continuation, object state, short token) + { + core.OnCompleted(continuation, state, token); + } + + public bool MoveNext() + { + if (cancellationToken.IsCancellationRequested) + { + core.SetCanceled(cancellationToken); + return false; + } + + if (currentFrameCount == delayFrameCount) + { + core.SetResult(null); + return false; + } + + currentFrameCount++; + return true; + } + + public void Reset() + { + core.Reset(); + currentFrameCount = default; + delayFrameCount = default; + cancellationToken = default; + } + } + + class DelayPromise : IUniTaskSource, IPlayerLoopItem, IPromisePoolItem + { + static readonly PromisePool pool = new PromisePool(); + + float delayFrameTimeSpan; + float elapsed; + CancellationToken cancellationToken; + + UniTaskCompletionSourceCore core; + + DelayPromise() + { + } + + public static IUniTaskSource Create(TimeSpan delayFrameTimeSpan, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + { + if (cancellationToken.IsCancellationRequested) + { + return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token); + } + + var result = pool.TryRent() ?? new DelayPromise(); + + result.elapsed = 0.0f; + result.delayFrameTimeSpan = (float)delayFrameTimeSpan.TotalSeconds; + result.cancellationToken = cancellationToken; + + TaskTracker2.TrackActiveTask(result, 3); + + PlayerLoopHelper.AddAction(timing, result); + + token = result.core.Version; + return result; + } + + public void GetResult(short token) + { + try + { + TaskTracker2.RemoveTracking(this); + core.GetResult(token); + } + finally + { + pool.TryReturn(this); + } + } + + public AwaiterStatus GetStatus(short token) + { + return core.GetStatus(token); + } + + public AwaiterStatus UnsafeGetStatus() + { + return core.UnsafeGetStatus(); + } + + public void OnCompleted(Action continuation, object state, short token) + { + core.OnCompleted(continuation, state, token); + } + + public bool MoveNext() + { + if (cancellationToken.IsCancellationRequested) + { + core.SetCanceled(cancellationToken); + return false; + } + + elapsed += Time.deltaTime; + if (elapsed >= delayFrameTimeSpan) + { + core.SetResult(null); + return false; + } + + return true; + } + + public void Reset() + { + core.Reset(); + delayFrameTimeSpan = default; + elapsed = default; + cancellationToken = default; + } + } + + class DelayIgnoreTimeScalePromise : IUniTaskSource, IPlayerLoopItem, IPromisePoolItem + { + static readonly PromisePool pool = new PromisePool(); + + float delayFrameTimeSpan; + float elapsed; + CancellationToken cancellationToken; + + UniTaskCompletionSourceCore core; + + DelayIgnoreTimeScalePromise() + { + } + + public static IUniTaskSource Create(TimeSpan delayFrameTimeSpan, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + { + if (cancellationToken.IsCancellationRequested) + { + return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token); + } + + var result = pool.TryRent() ?? new DelayIgnoreTimeScalePromise(); + + result.elapsed = 0.0f; + result.delayFrameTimeSpan = (float)delayFrameTimeSpan.TotalSeconds; + result.cancellationToken = cancellationToken; + + TaskTracker2.TrackActiveTask(result, 3); + + PlayerLoopHelper.AddAction(timing, result); + + token = result.core.Version; + return result; + } + + public void GetResult(short token) + { + try + { + TaskTracker2.RemoveTracking(this); + core.GetResult(token); + } + finally + { + pool.TryReturn(this); + } + } + + public AwaiterStatus GetStatus(short token) + { + return core.GetStatus(token); + } + + public AwaiterStatus UnsafeGetStatus() + { + return core.UnsafeGetStatus(); + } + + public void OnCompleted(Action continuation, object state, short token) + { + core.OnCompleted(continuation, state, token); + } + + public bool MoveNext() + { + if (cancellationToken.IsCancellationRequested) + { + core.SetCanceled(cancellationToken); + return false; + } + + elapsed += Time.unscaledDeltaTime; + if (elapsed >= delayFrameTimeSpan) + { + core.SetResult(null); + return false; + } + + return true; + } + + public void Reset() + { + core.Reset(); + delayFrameTimeSpan = default; + elapsed = default; + cancellationToken = default; + } + } + } + + // TODO:rename + public struct YieldAwaitable2 + { + readonly PlayerLoopTiming timing; + + public YieldAwaitable2(PlayerLoopTiming timing) + { + this.timing = timing; + } + + public Awaiter GetAwaiter() + { + return new Awaiter(timing); + } + + public UniTask2 ToUniTask() + { + return UniTask2.Yield(timing, CancellationToken.None); + } + + public struct Awaiter : ICriticalNotifyCompletion + { + readonly PlayerLoopTiming timing; + + public Awaiter(PlayerLoopTiming timing) + { + this.timing = timing; + } + + public bool IsCompleted => false; + + public void GetResult() { } + + public void OnCompleted(Action continuation) + { + PlayerLoopHelper.AddContinuation(timing, continuation); + } + + public void UnsafeOnCompleted(Action continuation) + { + PlayerLoopHelper.AddContinuation(timing, continuation); + } + } + } + + + // TODO:remove public partial struct UniTask { public static YieldAwaitable Yield(PlayerLoopTiming timing = PlayerLoopTiming.Update) @@ -199,6 +642,7 @@ namespace UniRx.Async } } + // TODO:remove public struct YieldAwaitable { readonly PlayerLoopTiming timing; diff --git a/Assets/UniRx.Async/UniTask.Factory.cs b/Assets/UniRx.Async/UniTask.Factory.cs index 5bc0b84..9893ad2 100644 --- a/Assets/UniRx.Async/UniTask.Factory.cs +++ b/Assets/UniRx.Async/UniTask.Factory.cs @@ -7,6 +7,121 @@ using UnityEngine.Events; namespace UniRx.Async { + public partial struct UniTask2 + { + static readonly UniTask2 CanceledUniTask = new Func(() => + { + var promise = new UniTaskCompletionSource2(); + promise.SetCanceled(CancellationToken.None); + promise.MarkHandled(); + return promise.Task; + })(); + + static class CanceledUniTaskCache + { + public static readonly UniTask2 Task; + + static CanceledUniTaskCache() + { + var promise = new UniTaskCompletionSource2(); + promise.SetCanceled(CancellationToken.None); + promise.MarkHandled(); + Task = promise.Task; + } + } + + public static readonly UniTask2 CompletedTask = new UniTask2(); + + public static UniTask2 FromException(Exception ex) + { + var promise = new UniTaskCompletionSource2(); + promise.SetException(ex); + promise.MarkHandled(); + return promise.Task; + } + + public static UniTask2 FromException(Exception ex) + { + var promise = new UniTaskCompletionSource2(); + promise.SetException(ex); + promise.MarkHandled(); + return promise.Task; + } + + public static UniTask2 FromResult(T value) + { + return new UniTask2(value); + } + + public static UniTask2 FromCanceled(CancellationToken cancellationToken = default) + { + if (cancellationToken == CancellationToken.None) + { + return CanceledUniTask; + } + else + { + var promise = new UniTaskCompletionSource2(); + promise.SetCanceled(cancellationToken); + promise.MarkHandled(); + return promise.Task; + } + } + + public static UniTask2 FromCanceled(CancellationToken cancellationToken = default) + { + if (cancellationToken == CancellationToken.None) + { + return CanceledUniTaskCache.Task; + } + else + { + var promise = new UniTaskCompletionSource2(); + promise.SetCanceled(cancellationToken); + promise.MarkHandled(); + return promise.Task; + } + } + + // TODO:... + + /// shorthand of new UniTask[T](Func[UniTask[T]] factory) + public static UniTask Lazy(Func> factory) + { + return new UniTask(factory); + } + + /// + /// helper of create add UniTaskVoid to delegate. + /// For example: FooEvent += () => UniTask.Void(async () => { /* */ }) + /// + public static void Void(Func asyncAction) + { + asyncAction().Forget(); + } + + public static Action VoidAction(Func asyncAction) + { + return () => Void(asyncAction); + } + + public static UnityAction VoidUnityAction(Func asyncAction) + { + return () => Void(asyncAction); + } + + /// + /// helper of create add UniTaskVoid to delegate. + /// For example: FooEvent += (sender, e) => UniTask.Void(async arg => { /* */ }, (sender, e)) + /// + public static void Void(Func asyncAction, T state) + { + asyncAction(state).Forget(); + } + } + + + // TODO:remove public partial struct UniTask { static readonly UniTask CanceledUniTask = new Func(() => @@ -120,6 +235,8 @@ namespace UniRx.Async } } + + // TODO:remove internal static class CompletedTasks { public static readonly UniTask True = UniTask.FromResult(true); @@ -130,7 +247,7 @@ namespace UniRx.Async } - + // TODO:rename internal static class CompletedTasks2 { public static readonly UniTask2 Completed = new UniTask2(); diff --git a/Assets/UniRx.Async/UniTask.cs b/Assets/UniRx.Async/UniTask.cs index 78cd6e2..bf71764 100644 --- a/Assets/UniRx.Async/UniTask.cs +++ b/Assets/UniRx.Async/UniTask.cs @@ -3,136 +3,15 @@ #pragma warning disable CS0436 using System; -using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.ExceptionServices; -using System.Runtime.InteropServices; -using System.Threading; -using System.Threading.Tasks; -using System.Threading.Tasks.Sources; using UniRx.Async.CompilerServices; using UniRx.Async.Internal; namespace UniRx.Async { - - - public partial struct UniTask2 - { - public static UniTask2 DelayFrame(int frameCount, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default) - { - return new UniTask2(DelayPromiseCore2.Create(frameCount, timing, cancellationToken, out var token), token); - - - //return new ValueTask(DelayPromiseCore2.Create(frameCount, timing, cancellationToken, out var token), token); - } - - public static readonly UniTask2 CompletedTask = new UniTask2(); - - public static UniTask2 FromResult(T result) - { - return new UniTask2(result); - } - } - - - - - public class DelayPromiseCore2 : IUniTaskSource, IPlayerLoopItem, IPromisePoolItem - { - static readonly PromisePool pool = new PromisePool(); - - int delayFrameCount; - CancellationToken cancellationToken; - - int currentFrameCount; - UniTaskCompletionSourceCore core; - - DelayPromiseCore2() - { - } - - public static IUniTaskSource Create(int delayFrameCount, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) - { - if (cancellationToken.IsCancellationRequested) - { - return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token); - } - - var result = pool.TryRent() ?? new DelayPromiseCore2(); - - result.delayFrameCount = delayFrameCount; - result.cancellationToken = cancellationToken; - - TaskTracker2.TrackActiveTask(result, 3); - - PlayerLoopHelper.AddAction(timing, result); - - token = result.core.Version; - return result; - } - - public void GetResult(short token) - { - try - { - TaskTracker2.RemoveTracking(this); - core.GetResult(token); - } - finally - { - pool.TryReturn(this); - } - } - - public AwaiterStatus GetStatus(short token) - { - return core.GetStatus(token); - } - - public AwaiterStatus UnsafeGetStatus() - { - return core.UnsafeGetStatus(); - } - - public void OnCompleted(Action continuation, object state, short token) - { - core.OnCompleted(continuation, state, token); - } - - public bool MoveNext() - { - if (cancellationToken.IsCancellationRequested) - { - core.SetCancellation(cancellationToken); - return false; - } - - if (currentFrameCount == delayFrameCount) - { - core.SetResult(null); - return false; - } - - currentFrameCount++; - return true; - } - - public void Reset() - { - core.Reset(); - currentFrameCount = default; - delayFrameCount = default; - cancellationToken = default; - } - } - - - - - internal static class AwaiterActions { internal static readonly Action InvokeActionDelegate = InvokeAction; @@ -195,7 +74,20 @@ namespace UniRx.Async return "(" + source.UnsafeGetStatus() + ")"; } - // TODO:AsTask??? + /// + /// Memoizing inner IValueTaskSource. The result UniTask can await multiple. + /// + public UniTask2 Preserve() + { + if (source == null) + { + return this; + } + else + { + return new UniTask2(new MemoizeSource(source), token); + } + } public static implicit operator UniTask2(UniTask2 task) { @@ -287,6 +179,86 @@ namespace UniRx.Async } } + class MemoizeSource : IUniTaskSource + { + IUniTaskSource source; + ExceptionDispatchInfo exception; + AwaiterStatus status; + + public MemoizeSource(IUniTaskSource source) + { + this.source = source; + } + + public void GetResult(short token) + { + if (source == null) + { + if (exception != null) + { + exception.Throw(); + } + } + else + { + try + { + source.GetResult(token); + status = AwaiterStatus.Succeeded; + } + catch (Exception ex) + { + exception = ExceptionDispatchInfo.Capture(ex); + if (ex is OperationCanceledException) + { + status = AwaiterStatus.Canceled; + } + else + { + status = AwaiterStatus.Faulted; + } + throw; + } + finally + { + source = null; + } + } + } + + public AwaiterStatus GetStatus(short token) + { + if (source == null) + { + return status; + } + + return source.GetStatus(token); + } + + public void OnCompleted(Action continuation, object state, short token) + { + if (source == null) + { + continuation(state); + } + else + { + source.OnCompleted(continuation, state, token); + } + } + + public AwaiterStatus UnsafeGetStatus() + { + if (source == null) + { + return status; + } + + return source.UnsafeGetStatus(); + } + } + public readonly struct Awaiter : ICriticalNotifyCompletion { readonly UniTask2 task; @@ -343,6 +315,21 @@ namespace UniRx.Async task.source.OnCompleted(AwaiterActions.InvokeActionDelegate, continuation, task.token); } } + + /// + /// If register manually continuation, you can use it instead of for compiler OnCompleted methods. + /// + public void SourceOnCompleted(Action continuation, object state) + { + if (task.source == null) + { + continuation(state); + } + else + { + task.source.OnCompleted(continuation, state, task.token); + } + } } } @@ -391,7 +378,33 @@ namespace UniRx.Async return new Awaiter(this); } - // TODO:AsTask??? + /// + /// Memoizing inner IValueTaskSource. The result UniTask can await multiple. + /// + public UniTask2 Preserve() + { + if (source == null) + { + return this; + } + else + { + return new UniTask2(new MemoizeSource(source), token); + } + } + + public static implicit operator UniTask2(UniTask2 task) + { + if (task.source == null) return UniTask2.CompletedTask; + + var status = task.source.GetStatus(task.token); + if (status.IsCompletedSuccessfully()) + { + return UniTask2.CompletedTask; + } + + return new UniTask2(task.source, task.token); + } /// /// returns (bool IsCanceled, T Result) instead of throws OperationCanceledException. @@ -465,6 +478,94 @@ namespace UniRx.Async } } + class MemoizeSource : IUniTaskSource + { + IUniTaskSource source; + T result; + ExceptionDispatchInfo exception; + AwaiterStatus status; + + public MemoizeSource(IUniTaskSource source) + { + this.source = source; + } + + public T GetResult(short token) + { + if (source == null) + { + if (exception != null) + { + exception.Throw(); + } + return result; + } + else + { + try + { + result = source.GetResult(token); + status = AwaiterStatus.Succeeded; + return result; + } + catch (Exception ex) + { + exception = ExceptionDispatchInfo.Capture(ex); + if (ex is OperationCanceledException) + { + status = AwaiterStatus.Canceled; + } + else + { + status = AwaiterStatus.Faulted; + } + throw; + } + finally + { + source = null; + } + } + } + + void IUniTaskSource.GetResult(short token) + { + GetResult(token); + } + + public AwaiterStatus GetStatus(short token) + { + if (source == null) + { + return status; + } + + return source.GetStatus(token); + } + + public void OnCompleted(Action continuation, object state, short token) + { + if (source == null) + { + continuation(state); + } + else + { + source.OnCompleted(continuation, state, token); + } + } + + public AwaiterStatus UnsafeGetStatus() + { + if (source == null) + { + return status; + } + + return source.UnsafeGetStatus(); + } + } + public readonly struct Awaiter : ICriticalNotifyCompletion { readonly UniTask2 task; @@ -530,6 +631,22 @@ namespace UniRx.Async s.OnCompleted(AwaiterActions.InvokeActionDelegate, continuation, task.token); } } + + /// + /// If register manually continuation, you can use it instead of for compiler OnCompleted methods. + /// + public void SourceOnCompleted(Action continuation, object state) + { + var s = task.source; + if (s == null) + { + continuation(state); + } + else + { + s.OnCompleted(continuation, state, task.token); + } + } } } @@ -1000,4 +1117,4 @@ namespace UniRx.Async } } -#endif \ No newline at end of file +#endif diff --git a/Assets/UniRx.Async/UniTaskCompletionSource.cs b/Assets/UniRx.Async/UniTaskCompletionSource.cs index aa42750..27355a4 100644 --- a/Assets/UniRx.Async/UniTaskCompletionSource.cs +++ b/Assets/UniRx.Async/UniTaskCompletionSource.cs @@ -480,7 +480,7 @@ namespace UniRx.Async SignalCompletion(); } - public void SetCancellation(CancellationToken cancellationToken) + public void SetCanceled(CancellationToken cancellationToken = default) { this.error = new OperationCanceledException(cancellationToken); SignalCompletion(); @@ -613,7 +613,7 @@ namespace UniRx.Async } [Conditional("UNITY_EDITOR")] - void MarkHandled() + internal void MarkHandled() { if (!handled) { @@ -643,9 +643,9 @@ namespace UniRx.Async core.SetResult(AsyncUnit.Default); } - public void SetCancellation(CancellationToken cancellationToken) + public void SetCanceled(CancellationToken cancellationToken = default) { - core.SetCancellation(cancellationToken); + core.SetCanceled(cancellationToken); } public void SetException(Exception exception) @@ -701,7 +701,7 @@ namespace UniRx.Async public static AutoResetUniTaskCompletionSource CreateFromCanceled(CancellationToken cancellationToken, out short token) { var source = Create(); - source.SetCancellation(cancellationToken); + source.SetCanceled(cancellationToken); token = source.core.Version; return source; } @@ -735,7 +735,7 @@ namespace UniRx.Async core.SetResult(AsyncUnit.Default); } - public void SetCancellation(CancellationToken cancellationToken) + public void SetCanceled(CancellationToken cancellationToken = default) { core.SetCancellation(cancellationToken); } @@ -800,7 +800,7 @@ namespace UniRx.Async } [Conditional("UNITY_EDITOR")] - void MarkHandled() + internal void MarkHandled() { if (!handled) { @@ -829,7 +829,7 @@ namespace UniRx.Async core.SetResult(result); } - public void SetCancellation(CancellationToken cancellationToken) + public void SetCanceled(CancellationToken cancellationToken = default) { core.SetCancellation(cancellationToken); } @@ -884,7 +884,7 @@ namespace UniRx.Async public static AutoResetUniTaskCompletionSource Create() { - var result = pool.TryRent() ?? new AutoResetUniTaskCompletionSource(); + var result = pool.TryRent() ?? new AutoResetUniTaskCompletionSource(); TaskTracker2.TrackActiveTask(result, 2); return result; } @@ -926,7 +926,7 @@ namespace UniRx.Async core.SetResult(result); } - public void SetCancellation(CancellationToken cancellationToken) + public void SetCanceled(CancellationToken cancellationToken = default) { core.SetCancellation(cancellationToken); } diff --git a/Assets/UniRx.Async/UniTaskExtensions.cs b/Assets/UniRx.Async/UniTaskExtensions.cs index 06642df..6cdfaf8 100644 --- a/Assets/UniRx.Async/UniTaskExtensions.cs +++ b/Assets/UniRx.Async/UniTaskExtensions.cs @@ -10,6 +10,574 @@ using UniRx.Async.Internal; namespace UniRx.Async { + public static partial class UniTaskExtensions2 + { + /// + /// Convert UniTask -> UniTask[AsyncUnit]. + /// + public static UniTask2 AsAsyncUnitUniTask(this UniTask2 task) + { + // use implicit conversion + return task; + } + + /// + /// Convert UniTask[T] -> UniTask. + /// + public static UniTask2 AsUniTask(this UniTask2 task) + { + // use implicit conversion + return task; + } + + /// + /// Convert Task[T] -> UniTask[T]. + /// + public static UniTask2 AsUniTask(this Task task, bool useCurrentSynchronizationContext = true) + { + var promise = new UniTaskCompletionSource2(); + + task.ContinueWith((x, state) => + { + var p = (UniTaskCompletionSource2)state; + + switch (x.Status) + { + case TaskStatus.Canceled: + p.SetCanceled(); + break; + case TaskStatus.Faulted: + p.SetException(x.Exception); + break; + case TaskStatus.RanToCompletion: + p.SetResult(x.Result); + break; + default: + throw new NotSupportedException(); + } + }, promise, useCurrentSynchronizationContext ? TaskScheduler.FromCurrentSynchronizationContext() : TaskScheduler.Current); + + return promise.Task; + } + + /// + /// Convert Task -> UniTask. + /// + public static UniTask2 AsUniTask(this Task task, bool useCurrentSynchronizationContext = true) + { + var promise = new UniTaskCompletionSource2(); + + task.ContinueWith((x, state) => + { + var p = (UniTaskCompletionSource2)state; + + switch (x.Status) + { + case TaskStatus.Canceled: + p.SetCanceled(); + break; + case TaskStatus.Faulted: + p.SetException(x.Exception); + break; + case TaskStatus.RanToCompletion: + p.SetResult(); + break; + default: + throw new NotSupportedException(); + } + }, promise, useCurrentSynchronizationContext ? TaskScheduler.FromCurrentSynchronizationContext() : TaskScheduler.Current); + + return promise.Task; + } + + public static Task AsTask(this UniTask2 task) + { + try + { + var awaiter = task.GetAwaiter(); + if (awaiter.IsCompleted) + { + try + { + var result = awaiter.GetResult(); + return Task.FromResult(result); + } + catch (Exception ex) + { + return Task.FromException(ex); + } + } + + var tcs = new TaskCompletionSource(); + + awaiter.SourceOnCompleted(state => + { + var (inTcs, inAwaiter) = ((TaskCompletionSource, UniTask2.Awaiter))state; + try + { + var result = inAwaiter.GetResult(); + inTcs.SetResult(result); + } + catch (Exception ex) + { + inTcs.SetException(ex); + } + }, (tcs, awaiter)); + + return tcs.Task; + } + catch (Exception ex) + { + return Task.FromException(ex); + } + } + + public static Task AsTask(this UniTask2 task) + { + try + { + var awaiter = task.GetAwaiter(); + if (awaiter.IsCompleted) + { + try + { + awaiter.GetResult(); // check token valid on Succeeded + return Task.CompletedTask; + } + catch (Exception ex) + { + return Task.FromException(ex); + } + } + + var tcs = new TaskCompletionSource(); + + awaiter.SourceOnCompleted(state => + { + var (inTcs, inAwaiter) = ((TaskCompletionSource, UniTask2.Awaiter))state; + try + { + inAwaiter.GetResult(); + inTcs.SetResult(null); + } + catch (Exception ex) + { + inTcs.SetException(ex); + } + }, (tcs, awaiter)); + + return tcs.Task; + } + catch (Exception ex) + { + return Task.FromException(ex); + } + } + + public static IEnumerator ToCoroutine(this UniTask2 task, Action resultHandler = null, Action exceptionHandler = null) + { + return new ToCoroutineEnumerator(task, resultHandler, exceptionHandler); + } + + public static IEnumerator ToCoroutine(this UniTask2 task, Action exceptionHandler = null) + { + return new ToCoroutineEnumerator(task, exceptionHandler); + } + + public static UniTask Timeout(this UniTask2 task, TimeSpan timeout, bool ignoreTimeScale = true, PlayerLoopTiming timeoutCheckTiming = PlayerLoopTiming.Update, CancellationTokenSource taskCancellationTokenSource = null) + { + return Timeout(task.AsAsyncUnitUniTask(), timeout, ignoreTimeScale, timeoutCheckTiming, taskCancellationTokenSource); + } + + // TODO: require UniTask2.Delay, WhenAny, etc... + + public static async UniTask Timeout(this UniTask2 task, TimeSpan timeout, bool ignoreTimeScale = true, PlayerLoopTiming timeoutCheckTiming = PlayerLoopTiming.Update, CancellationTokenSource taskCancellationTokenSource = null) + { + // left, right both suppress operation canceled exception. + + var delayCancellationTokenSource = new CancellationTokenSource(); + var timeoutTask = (UniTask)UniTask.Delay(timeout, ignoreTimeScale, timeoutCheckTiming).SuppressCancellationThrow(); + + var (hasValue, value) = await UniTask.WhenAny(task.SuppressCancellationThrow(), timeoutTask); + + if (!hasValue) + { + if (taskCancellationTokenSource != null) + { + taskCancellationTokenSource.Cancel(); + taskCancellationTokenSource.Dispose(); + } + + throw new TimeoutException("Exceed Timeout:" + timeout); + } + else + { + delayCancellationTokenSource.Cancel(); + delayCancellationTokenSource.Dispose(); + } + + if (value.IsCanceled) + { + Error.ThrowOperationCanceledException(); + } + + return value.Result; + } + + /// + /// Timeout with suppress OperationCanceledException. Returns (bool, IsCacneled). + /// + public static async UniTask2 TimeoutWithoutException(this UniTask2 task, TimeSpan timeout, bool ignoreTimeScale = true, PlayerLoopTiming timeoutCheckTiming = PlayerLoopTiming.Update, CancellationTokenSource taskCancellationTokenSource = null) + { + var v = await TimeoutWithoutException(task.AsAsyncUnitUniTask(), timeout, ignoreTimeScale, timeoutCheckTiming, taskCancellationTokenSource); + return v.IsTimeout; + } + + + /// + /// Timeout with suppress OperationCanceledException. Returns (bool IsTimeout, T Result). + /// + public static async UniTask2<(bool IsTimeout, T Result)> TimeoutWithoutException(this UniTask2 task, TimeSpan timeout, bool ignoreTimeScale = true, PlayerLoopTiming timeoutCheckTiming = PlayerLoopTiming.Update, CancellationTokenSource taskCancellationTokenSource = null) + { + // left, right both suppress operation canceled exception. + + var delayCancellationTokenSource = new CancellationTokenSource(); + var timeoutTask = (UniTask)UniTask.Delay(timeout, ignoreTimeScale, timeoutCheckTiming).SuppressCancellationThrow(); + + var (hasValue, value) = await UniTask.WhenAny(task.SuppressCancellationThrow(), timeoutTask); + + if (!hasValue) + { + if (taskCancellationTokenSource != null) + { + taskCancellationTokenSource.Cancel(); + taskCancellationTokenSource.Dispose(); + } + + return (true, default(T)); + } + else + { + delayCancellationTokenSource.Cancel(); + delayCancellationTokenSource.Dispose(); + } + + if (value.IsCanceled) + { + Error.ThrowOperationCanceledException(); + } + + return (false, value.Result); + } + + public static void Forget(this UniTask2 task) + { + ForgetCore(task).Forget(); + } + + public static void Forget(this UniTask2 task, Action exceptionHandler, bool handleExceptionOnMainThread = true) + { + if (exceptionHandler == null) + { + ForgetCore(task).Forget(); + } + else + { + ForgetCoreWithCatch(task, exceptionHandler, handleExceptionOnMainThread).Forget(); + } + } + + // UniTask to UniTaskVoid + static async UniTaskVoid ForgetCore(UniTask2 task) + { + await task; + } + + static async UniTaskVoid ForgetCoreWithCatch(UniTask2 task, Action exceptionHandler, bool handleExceptionOnMainThread) + { + try + { + await task; + } + catch (Exception ex) + { + try + { + if (handleExceptionOnMainThread) + { + await UniTask2.SwitchToMainThread(); + } + exceptionHandler(ex); + } + catch (Exception ex2) + { + UniTaskScheduler.PublishUnobservedTaskException(ex2); + } + } + } + + public static void Forget(this UniTask2 task) + { + ForgetCore(task).Forget(); + } + + public static void Forget(this UniTask2 task, Action exceptionHandler, bool handleExceptionOnMainThread = true) + { + if (exceptionHandler == null) + { + ForgetCore(task).Forget(); + } + else + { + ForgetCoreWithCatch(task, exceptionHandler, handleExceptionOnMainThread).Forget(); + } + } + + // UniTask to UniTaskVoid + static async UniTaskVoid ForgetCore(UniTask2 task) + { + await task; + } + + static async UniTaskVoid ForgetCoreWithCatch(UniTask2 task, Action exceptionHandler, bool handleExceptionOnMainThread) + { + try + { + await task; + } + catch (Exception ex) + { + try + { + if (handleExceptionOnMainThread) + { + await UniTask.SwitchToMainThread(); + } + exceptionHandler(ex); + } + catch (Exception ex2) + { + UniTaskScheduler.PublishUnobservedTaskException(ex2); + } + } + } + + public static async UniTask2 ContinueWith(this UniTask2 task, Action continuationFunction) + { + continuationFunction(await task); + } + + public static async UniTask2 ContinueWith(this UniTask2 task, Func continuationFunction) + { + await continuationFunction(await task); + } + + public static async UniTask2 ContinueWith(this UniTask2 task, Func continuationFunction) + { + return continuationFunction(await task); + } + + public static async UniTask2 ContinueWith(this UniTask2 task, Func> continuationFunction) + { + return await continuationFunction(await task); + } + + public static async UniTask2 ContinueWith(this UniTask2 task, Action continuationFunction) + { + await task; + continuationFunction(); + } + + public static async UniTask2 ContinueWith(this UniTask2 task, Func continuationFunction) + { + await task; + await continuationFunction(); + } + + public static async UniTask2 ContinueWith(this UniTask2 task, Func continuationFunction) + { + await task; + return continuationFunction(); + } + + public static async UniTask2 ContinueWith(this UniTask2 task, Func> continuationFunction) + { + await task; + return await continuationFunction(); + } + + public static async UniTask2 ConfigureAwait(this Task task, PlayerLoopTiming timing) + { + await task.ConfigureAwait(false); + await UniTask2.Yield(timing); + } + + public static async UniTask2 ConfigureAwait(this Task task, PlayerLoopTiming timing) + { + var v = await task.ConfigureAwait(false); + await UniTask2.Yield(timing); + return v; + } + + public static async UniTask2 ConfigureAwait(this UniTask2 task, PlayerLoopTiming timing) + { + await task; + await UniTask2.Yield(timing); + } + + public static async UniTask2 ConfigureAwait(this UniTask2 task, PlayerLoopTiming timing) + { + var v = await task; + await UniTask2.Yield(timing); + return v; + } + + public static async UniTask2 Unwrap(this UniTask2> task) + { + return await await task; + } + + public static async UniTask2 Unwrap(this UniTask2 task) + { + await await task; + } + + class ToCoroutineEnumerator : IEnumerator + { + bool completed; + UniTask2 task; + Action exceptionHandler = null; + bool isStarted = false; + ExceptionDispatchInfo exception; + + public ToCoroutineEnumerator(UniTask2 task, Action exceptionHandler) + { + completed = false; + this.exceptionHandler = exceptionHandler; + this.task = task; + } + + async UniTaskVoid RunTask(UniTask2 task) + { + try + { + await task; + } + catch (Exception ex) + { + if (exceptionHandler != null) + { + exceptionHandler(ex); + } + else + { + this.exception = ExceptionDispatchInfo.Capture(ex); + } + } + finally + { + completed = true; + } + } + + public object Current => null; + + public bool MoveNext() + { + if (!isStarted) + { + isStarted = true; + RunTask(task).Forget(); + } + + if (exception != null) + { + // throw exception on iterator (main)thread. + // unfortunately unity test-runner can not handle throw exception on hand-write IEnumerator.MoveNext. + UnityEngine.Debug.LogException(exception.SourceException); + } + + return !completed; + } + + public void Reset() + { + } + } + + class ToCoroutineEnumerator : IEnumerator + { + bool completed; + Action resultHandler = null; + Action exceptionHandler = null; + bool isStarted = false; + UniTask2 task; + object current = null; + ExceptionDispatchInfo exception; + + public ToCoroutineEnumerator(UniTask2 task, Action resultHandler, Action exceptionHandler) + { + completed = false; + this.task = task; + this.resultHandler = resultHandler; + this.exceptionHandler = exceptionHandler; + } + + async UniTaskVoid RunTask(UniTask2 task) + { + try + { + var value = await task; + current = value; // boxed if T is struct... + if (resultHandler != null) + { + resultHandler(value); + } + } + catch (Exception ex) + { + if (exceptionHandler != null) + { + exceptionHandler(ex); + } + else + { + this.exception = ExceptionDispatchInfo.Capture(ex); + } + } + finally + { + completed = true; + } + } + + public object Current => current; + + public bool MoveNext() + { + if (!isStarted) + { + isStarted = true; + RunTask(task).Forget(); + } + + if (exception != null) + { + // throw exception on iterator (main)thread. + // unfortunately unity test-runner can not handle throw exception on hand-write IEnumerator.MoveNext. + UnityEngine.Debug.LogException(exception.SourceException); + } + + return !completed; + } + + public void Reset() + { + } + } + } + + // TODO:remove public static partial class UniTaskExtensions { /// diff --git a/Assets/UniRx.Async/UnityAsyncExtensions.cs b/Assets/UniRx.Async/UnityAsyncExtensions.cs index f14bbe5..47314cb 100644 --- a/Assets/UniRx.Async/UnityAsyncExtensions.cs +++ b/Assets/UniRx.Async/UnityAsyncExtensions.cs @@ -573,7 +573,7 @@ namespace UniRx.Async { // TODO:Remove Tracking // TaskTracker.RemoveTracking(); - core.SetCancellation(cancellationToken); + core.SetCanceled(cancellationToken); return false; }