diff --git a/README.md b/README.md index 487aec0..d1cacac 100644 --- a/README.md +++ b/README.md @@ -336,6 +336,18 @@ if (isCanceled) Note: Only suppress throws if you call directly into the most source method. Otherwise, the return value will be converted, but the entire pipeline will not suppress throws. +Some features that use Unity's player loop, such as `UniTask.Yield` and `UniTask.Delay` etc, determines CancellationToken state on the player loop. +This means it does not cancel immediately upon `CancellationToken` fired. + +If you want to change this behaviour, the cancellation to be immediate, set the `cancelImmediately` flag as an argument. + +```csharp +await UniTask.Yield(cancellationToken, cancelImmediately: true); +``` + +Note: Setting `cancelImmediately` to true and detecting an immediate cancellation is more costly than the default behavior. +This is because it uses `CancellationToken.Register`; it is heavier than checking CancellationToken on the player loop. + Timeout handling --- Timeout is a variation of cancellation. You can set timeout by `CancellationTokenSouce.CancelAfterSlim(TimeSpan)` and pass CancellationToken to async methods. @@ -363,7 +375,7 @@ If you want to use timeout with other source of cancellation, use `CancellationT ```csharp var cancelToken = new CancellationTokenSource(); -cancelButton.onClick.AddListener(()=> +cancelButton.onClick.AddListener(() => { cancelToken.Cancel(); // cancel from button click. }); diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/External/Addressables/AddressablesAsyncExtensions.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/External/Addressables/AddressablesAsyncExtensions.cs index f321bdb..a0ca8a1 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/External/Addressables/AddressablesAsyncExtensions.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/External/Addressables/AddressablesAsyncExtensions.cs @@ -25,7 +25,12 @@ namespace Cysharp.Threading.Tasks return ToUniTask(handle, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this AsyncOperationHandle handle, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WithCancellation(this AsyncOperationHandle handle, CancellationToken cancellationToken, bool cancelImmediately) + { + return ToUniTask(handle, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + } + + public static UniTask ToUniTask(this AsyncOperationHandle handle, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); @@ -44,7 +49,7 @@ namespace Cysharp.Threading.Tasks return UniTask.CompletedTask; } - return new UniTask(AsyncOperationHandleConfiguredSource.Create(handle, timing, progress, cancellationToken, out var token), token); + return new UniTask(AsyncOperationHandleConfiguredSource.Create(handle, timing, progress, cancellationToken, cancelImmediately, out var token), token); } public struct AsyncOperationHandleAwaiter : ICriticalNotifyCompletion @@ -106,6 +111,7 @@ namespace Cysharp.Threading.Tasks readonly Action continuationAction; AsyncOperationHandle handle; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; IProgress progress; bool completed; @@ -116,7 +122,7 @@ namespace Cysharp.Threading.Tasks continuationAction = Continuation; } - public static IUniTaskSource Create(AsyncOperationHandle handle, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(AsyncOperationHandle handle, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -132,6 +138,15 @@ namespace Cysharp.Threading.Tasks result.progress = progress; result.cancellationToken = cancellationToken; result.completed = false; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (AsyncOperationHandleConfiguredSource)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -219,6 +234,7 @@ namespace Cysharp.Threading.Tasks handle = default; progress = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -237,7 +253,12 @@ namespace Cysharp.Threading.Tasks return ToUniTask(handle, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this AsyncOperationHandle handle, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WithCancellation(this AsyncOperationHandle handle, CancellationToken cancellationToken, bool cancelImmediately) + { + return ToUniTask(handle, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + } + + public static UniTask ToUniTask(this AsyncOperationHandle handle, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); @@ -255,7 +276,7 @@ namespace Cysharp.Threading.Tasks return UniTask.FromResult(handle.Result); } - return new UniTask(AsyncOperationHandleConfiguredSource.Create(handle, timing, progress, cancellationToken, out var token), token); + return new UniTask(AsyncOperationHandleConfiguredSource.Create(handle, timing, progress, cancellationToken, cancelImmediately, out var token), token); } sealed class AsyncOperationHandleConfiguredSource : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode> @@ -272,6 +293,7 @@ namespace Cysharp.Threading.Tasks readonly Action> continuationAction; AsyncOperationHandle handle; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; IProgress progress; bool completed; @@ -282,7 +304,7 @@ namespace Cysharp.Threading.Tasks continuationAction = Continuation; } - public static IUniTaskSource Create(AsyncOperationHandle handle, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(AsyncOperationHandle handle, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -298,6 +320,15 @@ namespace Cysharp.Threading.Tasks result.cancellationToken = cancellationToken; result.completed = false; result.progress = progress; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (AsyncOperationHandleConfiguredSource)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -390,6 +421,7 @@ namespace Cysharp.Threading.Tasks handle = default; progress = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/EveryUpdate.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/EveryUpdate.cs index 585fff9..8f09110 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/EveryUpdate.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/EveryUpdate.cs @@ -4,38 +4,50 @@ namespace Cysharp.Threading.Tasks.Linq { public static partial class UniTaskAsyncEnumerable { - public static IUniTaskAsyncEnumerable EveryUpdate(PlayerLoopTiming updateTiming = PlayerLoopTiming.Update) + public static IUniTaskAsyncEnumerable EveryUpdate(PlayerLoopTiming updateTiming = PlayerLoopTiming.Update, bool cancelImmediately = false) { - return new EveryUpdate(updateTiming); + return new EveryUpdate(updateTiming, cancelImmediately); } } internal class EveryUpdate : IUniTaskAsyncEnumerable { readonly PlayerLoopTiming updateTiming; + readonly bool cancelImmediately; - public EveryUpdate(PlayerLoopTiming updateTiming) + public EveryUpdate(PlayerLoopTiming updateTiming, bool cancelImmediately) { this.updateTiming = updateTiming; + this.cancelImmediately = cancelImmediately; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { - return new _EveryUpdate(updateTiming, cancellationToken); + return new _EveryUpdate(updateTiming, cancellationToken, cancelImmediately); } class _EveryUpdate : MoveNextSource, IUniTaskAsyncEnumerator, IPlayerLoopItem { readonly PlayerLoopTiming updateTiming; - CancellationToken cancellationToken; + readonly CancellationToken cancellationToken; + readonly CancellationTokenRegistration cancellationTokenRegistration; bool disposed; - public _EveryUpdate(PlayerLoopTiming updateTiming, CancellationToken cancellationToken) + public _EveryUpdate(PlayerLoopTiming updateTiming, CancellationToken cancellationToken, bool cancelImmediately) { this.updateTiming = updateTiming; this.cancellationToken = cancellationToken; + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (_EveryUpdate)state; + source.completionSource.TrySetCanceled(source.cancellationToken); + }, this); + } + TaskTracker.TrackActiveTask(this, 2); PlayerLoopHelper.AddAction(updateTiming, this); } @@ -44,10 +56,14 @@ namespace Cysharp.Threading.Tasks.Linq public UniTask MoveNextAsync() { - // return false instead of throw - if (disposed || cancellationToken.IsCancellationRequested) return CompletedTasks.False; - + if (disposed) return CompletedTasks.False; + completionSource.Reset(); + + if (cancellationToken.IsCancellationRequested) + { + completionSource.TrySetCanceled(cancellationToken); + } return new UniTask(this, completionSource.Version); } @@ -55,6 +71,7 @@ namespace Cysharp.Threading.Tasks.Linq { if (!disposed) { + cancellationTokenRegistration.Dispose(); disposed = true; TaskTracker.RemoveTracking(this); } @@ -63,7 +80,13 @@ namespace Cysharp.Threading.Tasks.Linq public bool MoveNext() { - if (disposed || cancellationToken.IsCancellationRequested) + if (cancellationToken.IsCancellationRequested) + { + completionSource.TrySetCanceled(cancellationToken); + return false; + } + + if (disposed) { completionSource.TrySetResult(false); return false; diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/EveryValueChanged.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/EveryValueChanged.cs index f678e7a..ef5739c 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/EveryValueChanged.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/EveryValueChanged.cs @@ -7,7 +7,7 @@ namespace Cysharp.Threading.Tasks.Linq { public static partial class UniTaskAsyncEnumerable { - public static IUniTaskAsyncEnumerable EveryValueChanged(TTarget target, Func propertySelector, PlayerLoopTiming monitorTiming = PlayerLoopTiming.Update, IEqualityComparer equalityComparer = null) + public static IUniTaskAsyncEnumerable EveryValueChanged(TTarget target, Func propertySelector, PlayerLoopTiming monitorTiming = PlayerLoopTiming.Update, IEqualityComparer equalityComparer = null, bool cancelImmediately = false) where TTarget : class { var unityObject = target as UnityEngine.Object; @@ -15,11 +15,11 @@ namespace Cysharp.Threading.Tasks.Linq if (isUnityObject) { - return new EveryValueChangedUnityObject(target, propertySelector, equalityComparer ?? UnityEqualityComparer.GetDefault(), monitorTiming); + return new EveryValueChangedUnityObject(target, propertySelector, equalityComparer ?? UnityEqualityComparer.GetDefault(), monitorTiming, cancelImmediately); } else { - return new EveryValueChangedStandardObject(target, propertySelector, equalityComparer ?? UnityEqualityComparer.GetDefault(), monitorTiming); + return new EveryValueChangedStandardObject(target, propertySelector, equalityComparer ?? UnityEqualityComparer.GetDefault(), monitorTiming, cancelImmediately); } } } @@ -30,18 +30,20 @@ namespace Cysharp.Threading.Tasks.Linq readonly Func propertySelector; readonly IEqualityComparer equalityComparer; readonly PlayerLoopTiming monitorTiming; + readonly bool cancelImmediately; - public EveryValueChangedUnityObject(TTarget target, Func propertySelector, IEqualityComparer equalityComparer, PlayerLoopTiming monitorTiming) + public EveryValueChangedUnityObject(TTarget target, Func propertySelector, IEqualityComparer equalityComparer, PlayerLoopTiming monitorTiming, bool cancelImmediately) { this.target = target; this.propertySelector = propertySelector; this.equalityComparer = equalityComparer; this.monitorTiming = monitorTiming; + this.cancelImmediately = cancelImmediately; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { - return new _EveryValueChanged(target, propertySelector, equalityComparer, monitorTiming, cancellationToken); + return new _EveryValueChanged(target, propertySelector, equalityComparer, monitorTiming, cancellationToken, cancelImmediately); } sealed class _EveryValueChanged : MoveNextSource, IUniTaskAsyncEnumerator, IPlayerLoopItem @@ -50,13 +52,14 @@ namespace Cysharp.Threading.Tasks.Linq readonly UnityEngine.Object targetAsUnityObject; readonly IEqualityComparer equalityComparer; readonly Func propertySelector; - CancellationToken cancellationToken; + readonly CancellationToken cancellationToken; + readonly CancellationTokenRegistration cancellationTokenRegistration; bool first; TProperty currentValue; bool disposed; - public _EveryValueChanged(TTarget target, Func propertySelector, IEqualityComparer equalityComparer, PlayerLoopTiming monitorTiming, CancellationToken cancellationToken) + public _EveryValueChanged(TTarget target, Func propertySelector, IEqualityComparer equalityComparer, PlayerLoopTiming monitorTiming, CancellationToken cancellationToken, bool cancelImmediately) { this.target = target; this.targetAsUnityObject = target as UnityEngine.Object; @@ -64,6 +67,16 @@ namespace Cysharp.Threading.Tasks.Linq this.equalityComparer = equalityComparer; this.cancellationToken = cancellationToken; this.first = true; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (_EveryValueChanged)state; + source.completionSource.TrySetCanceled(source.cancellationToken); + }, this); + } + TaskTracker.TrackActiveTask(this, 2); PlayerLoopHelper.AddAction(monitorTiming, this); } @@ -72,8 +85,15 @@ namespace Cysharp.Threading.Tasks.Linq public UniTask MoveNextAsync() { - // return false instead of throw - if (disposed || cancellationToken.IsCancellationRequested) return CompletedTasks.False; + if (disposed) return CompletedTasks.False; + + completionSource.Reset(); + + if (cancellationToken.IsCancellationRequested) + { + completionSource.TrySetCanceled(cancellationToken); + return new UniTask(this, completionSource.Version); + } if (first) { @@ -86,7 +106,6 @@ namespace Cysharp.Threading.Tasks.Linq return CompletedTasks.True; } - completionSource.Reset(); return new UniTask(this, completionSource.Version); } @@ -94,6 +113,7 @@ namespace Cysharp.Threading.Tasks.Linq { if (!disposed) { + cancellationTokenRegistration.Dispose(); disposed = true; TaskTracker.RemoveTracking(this); } @@ -102,13 +122,18 @@ namespace Cysharp.Threading.Tasks.Linq public bool MoveNext() { - if (disposed || cancellationToken.IsCancellationRequested || targetAsUnityObject == null) // destroyed = cancel. + if (disposed || targetAsUnityObject == null) { completionSource.TrySetResult(false); DisposeAsync().Forget(); return false; } - + + if (cancellationToken.IsCancellationRequested) + { + completionSource.TrySetCanceled(cancellationToken); + return false; + } TProperty nextValue = default(TProperty); try { @@ -139,18 +164,20 @@ namespace Cysharp.Threading.Tasks.Linq readonly Func propertySelector; readonly IEqualityComparer equalityComparer; readonly PlayerLoopTiming monitorTiming; + readonly bool cancelImmediately; - public EveryValueChangedStandardObject(TTarget target, Func propertySelector, IEqualityComparer equalityComparer, PlayerLoopTiming monitorTiming) + public EveryValueChangedStandardObject(TTarget target, Func propertySelector, IEqualityComparer equalityComparer, PlayerLoopTiming monitorTiming, bool cancelImmediately) { this.target = new WeakReference(target, false); this.propertySelector = propertySelector; this.equalityComparer = equalityComparer; this.monitorTiming = monitorTiming; + this.cancelImmediately = cancelImmediately; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { - return new _EveryValueChanged(target, propertySelector, equalityComparer, monitorTiming, cancellationToken); + return new _EveryValueChanged(target, propertySelector, equalityComparer, monitorTiming, cancellationToken, cancelImmediately); } sealed class _EveryValueChanged : MoveNextSource, IUniTaskAsyncEnumerator, IPlayerLoopItem @@ -158,19 +185,30 @@ namespace Cysharp.Threading.Tasks.Linq readonly WeakReference target; readonly IEqualityComparer equalityComparer; readonly Func propertySelector; - CancellationToken cancellationToken; + readonly CancellationToken cancellationToken; + readonly CancellationTokenRegistration cancellationTokenRegistration; bool first; TProperty currentValue; bool disposed; - public _EveryValueChanged(WeakReference target, Func propertySelector, IEqualityComparer equalityComparer, PlayerLoopTiming monitorTiming, CancellationToken cancellationToken) + public _EveryValueChanged(WeakReference target, Func propertySelector, IEqualityComparer equalityComparer, PlayerLoopTiming monitorTiming, CancellationToken cancellationToken, bool cancelImmediately) { this.target = target; this.propertySelector = propertySelector; this.equalityComparer = equalityComparer; this.cancellationToken = cancellationToken; this.first = true; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (_EveryValueChanged)state; + source.completionSource.TrySetCanceled(source.cancellationToken); + }, this); + } + TaskTracker.TrackActiveTask(this, 2); PlayerLoopHelper.AddAction(monitorTiming, this); } @@ -179,8 +217,16 @@ namespace Cysharp.Threading.Tasks.Linq public UniTask MoveNextAsync() { - if (disposed || cancellationToken.IsCancellationRequested) return CompletedTasks.False; + if (disposed) return CompletedTasks.False; + completionSource.Reset(); + + if (cancellationToken.IsCancellationRequested) + { + completionSource.TrySetCanceled(cancellationToken); + return new UniTask(this, completionSource.Version); + } + if (first) { first = false; @@ -192,7 +238,6 @@ namespace Cysharp.Threading.Tasks.Linq return CompletedTasks.True; } - completionSource.Reset(); return new UniTask(this, completionSource.Version); } @@ -200,6 +245,7 @@ namespace Cysharp.Threading.Tasks.Linq { if (!disposed) { + cancellationTokenRegistration.Dispose(); disposed = true; TaskTracker.RemoveTracking(this); } @@ -208,13 +254,19 @@ namespace Cysharp.Threading.Tasks.Linq public bool MoveNext() { - if (disposed || cancellationToken.IsCancellationRequested || !target.TryGetTarget(out var t)) + if (disposed || !target.TryGetTarget(out var t)) { completionSource.TrySetResult(false); DisposeAsync().Forget(); return false; } + if (cancellationToken.IsCancellationRequested) + { + completionSource.TrySetCanceled(cancellationToken); + return false; + } + TProperty nextValue = default(TProperty); try { diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/Timer.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/Timer.cs index 53ecfcd..b8aabf2 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/Timer.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/Timer.cs @@ -6,32 +6,32 @@ namespace Cysharp.Threading.Tasks.Linq { public static partial class UniTaskAsyncEnumerable { - public static IUniTaskAsyncEnumerable Timer(TimeSpan dueTime, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update, bool ignoreTimeScale = false) + public static IUniTaskAsyncEnumerable Timer(TimeSpan dueTime, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update, bool ignoreTimeScale = false, bool cancelImmediately = false) { - return new Timer(dueTime, null, updateTiming, ignoreTimeScale); + return new Timer(dueTime, null, updateTiming, ignoreTimeScale, cancelImmediately); } - public static IUniTaskAsyncEnumerable Timer(TimeSpan dueTime, TimeSpan period, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update, bool ignoreTimeScale = false) + public static IUniTaskAsyncEnumerable Timer(TimeSpan dueTime, TimeSpan period, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update, bool ignoreTimeScale = false, bool cancelImmediately = false) { - return new Timer(dueTime, period, updateTiming, ignoreTimeScale); + return new Timer(dueTime, period, updateTiming, ignoreTimeScale, cancelImmediately); } - public static IUniTaskAsyncEnumerable Interval(TimeSpan period, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update, bool ignoreTimeScale = false) + public static IUniTaskAsyncEnumerable Interval(TimeSpan period, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update, bool ignoreTimeScale = false, bool cancelImmediately = false) { - return new Timer(period, period, updateTiming, ignoreTimeScale); + return new Timer(period, period, updateTiming, ignoreTimeScale, cancelImmediately); } - public static IUniTaskAsyncEnumerable TimerFrame(int dueTimeFrameCount, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update) + public static IUniTaskAsyncEnumerable TimerFrame(int dueTimeFrameCount, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update, bool cancelImmediately = false) { if (dueTimeFrameCount < 0) { throw new ArgumentOutOfRangeException("Delay does not allow minus delayFrameCount. dueTimeFrameCount:" + dueTimeFrameCount); } - return new TimerFrame(dueTimeFrameCount, null, updateTiming); + return new TimerFrame(dueTimeFrameCount, null, updateTiming, cancelImmediately); } - public static IUniTaskAsyncEnumerable TimerFrame(int dueTimeFrameCount, int periodFrameCount, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update) + public static IUniTaskAsyncEnumerable TimerFrame(int dueTimeFrameCount, int periodFrameCount, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update, bool cancelImmediately = false) { if (dueTimeFrameCount < 0) { @@ -42,16 +42,16 @@ namespace Cysharp.Threading.Tasks.Linq throw new ArgumentOutOfRangeException("Delay does not allow minus periodFrameCount. periodFrameCount:" + dueTimeFrameCount); } - return new TimerFrame(dueTimeFrameCount, periodFrameCount, updateTiming); + return new TimerFrame(dueTimeFrameCount, periodFrameCount, updateTiming, cancelImmediately); } - public static IUniTaskAsyncEnumerable IntervalFrame(int intervalFrameCount, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update) + public static IUniTaskAsyncEnumerable IntervalFrame(int intervalFrameCount, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update, bool cancelImmediately = false) { if (intervalFrameCount < 0) { throw new ArgumentOutOfRangeException("Delay does not allow minus intervalFrameCount. intervalFrameCount:" + intervalFrameCount); } - return new TimerFrame(intervalFrameCount, intervalFrameCount, updateTiming); + return new TimerFrame(intervalFrameCount, intervalFrameCount, updateTiming, cancelImmediately); } } @@ -61,18 +61,20 @@ namespace Cysharp.Threading.Tasks.Linq readonly TimeSpan dueTime; readonly TimeSpan? period; readonly bool ignoreTimeScale; + readonly bool cancelImmediately; - public Timer(TimeSpan dueTime, TimeSpan? period, PlayerLoopTiming updateTiming, bool ignoreTimeScale) + public Timer(TimeSpan dueTime, TimeSpan? period, PlayerLoopTiming updateTiming, bool ignoreTimeScale, bool cancelImmediately) { this.updateTiming = updateTiming; this.dueTime = dueTime; this.period = period; this.ignoreTimeScale = ignoreTimeScale; + this.cancelImmediately = cancelImmediately; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { - return new _Timer(dueTime, period, updateTiming, ignoreTimeScale, cancellationToken); + return new _Timer(dueTime, period, updateTiming, ignoreTimeScale, cancellationToken, cancelImmediately); } class _Timer : MoveNextSource, IUniTaskAsyncEnumerator, IPlayerLoopItem @@ -81,7 +83,8 @@ namespace Cysharp.Threading.Tasks.Linq readonly float? period; readonly PlayerLoopTiming updateTiming; readonly bool ignoreTimeScale; - CancellationToken cancellationToken; + readonly CancellationToken cancellationToken; + readonly CancellationTokenRegistration cancellationTokenRegistration; int initialFrame; float elapsed; @@ -89,7 +92,7 @@ namespace Cysharp.Threading.Tasks.Linq bool completed; bool disposed; - public _Timer(TimeSpan dueTime, TimeSpan? period, PlayerLoopTiming updateTiming, bool ignoreTimeScale, CancellationToken cancellationToken) + public _Timer(TimeSpan dueTime, TimeSpan? period, PlayerLoopTiming updateTiming, bool ignoreTimeScale, CancellationToken cancellationToken, bool cancelImmediately) { this.dueTime = (float)dueTime.TotalSeconds; this.period = (period == null) ? null : (float?)period.Value.TotalSeconds; @@ -105,6 +108,16 @@ namespace Cysharp.Threading.Tasks.Linq this.updateTiming = updateTiming; this.ignoreTimeScale = ignoreTimeScale; this.cancellationToken = cancellationToken; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (_Timer)state; + source.completionSource.TrySetCanceled(source.cancellationToken); + }, this); + } + TaskTracker.TrackActiveTask(this, 2); PlayerLoopHelper.AddAction(updateTiming, this); } @@ -114,12 +127,16 @@ namespace Cysharp.Threading.Tasks.Linq public UniTask MoveNextAsync() { // return false instead of throw - if (disposed || cancellationToken.IsCancellationRequested || completed) return CompletedTasks.False; + if (disposed || completed) return CompletedTasks.False; // reset value here. this.elapsed = 0; completionSource.Reset(); + if (cancellationToken.IsCancellationRequested) + { + completionSource.TrySetCanceled(cancellationToken); + } return new UniTask(this, completionSource.Version); } @@ -127,6 +144,7 @@ namespace Cysharp.Threading.Tasks.Linq { if (!disposed) { + cancellationTokenRegistration.Dispose(); disposed = true; TaskTracker.RemoveTracking(this); } @@ -135,11 +153,16 @@ namespace Cysharp.Threading.Tasks.Linq public bool MoveNext() { - if (disposed || cancellationToken.IsCancellationRequested) + if (disposed) { completionSource.TrySetResult(false); return false; } + if (cancellationToken.IsCancellationRequested) + { + completionSource.TrySetCanceled(cancellationToken); + return false; + } if (dueTimePhase) { @@ -187,24 +210,27 @@ namespace Cysharp.Threading.Tasks.Linq readonly PlayerLoopTiming updateTiming; readonly int dueTimeFrameCount; readonly int? periodFrameCount; + readonly bool cancelImmediately; - public TimerFrame(int dueTimeFrameCount, int? periodFrameCount, PlayerLoopTiming updateTiming) + public TimerFrame(int dueTimeFrameCount, int? periodFrameCount, PlayerLoopTiming updateTiming, bool cancelImmediately) { this.updateTiming = updateTiming; this.dueTimeFrameCount = dueTimeFrameCount; this.periodFrameCount = periodFrameCount; + this.cancelImmediately = cancelImmediately; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { - return new _TimerFrame(dueTimeFrameCount, periodFrameCount, updateTiming, cancellationToken); + return new _TimerFrame(dueTimeFrameCount, periodFrameCount, updateTiming, cancellationToken, cancelImmediately); } class _TimerFrame : MoveNextSource, IUniTaskAsyncEnumerator, IPlayerLoopItem { readonly int dueTimeFrameCount; readonly int? periodFrameCount; - CancellationToken cancellationToken; + readonly CancellationToken cancellationToken; + readonly CancellationTokenRegistration cancellationTokenRegistration; int initialFrame; int currentFrame; @@ -212,7 +238,7 @@ namespace Cysharp.Threading.Tasks.Linq bool completed; bool disposed; - public _TimerFrame(int dueTimeFrameCount, int? periodFrameCount, PlayerLoopTiming updateTiming, CancellationToken cancellationToken) + public _TimerFrame(int dueTimeFrameCount, int? periodFrameCount, PlayerLoopTiming updateTiming, CancellationToken cancellationToken, bool cancelImmediately) { if (dueTimeFrameCount <= 0) dueTimeFrameCount = 0; if (periodFrameCount != null) @@ -225,6 +251,15 @@ namespace Cysharp.Threading.Tasks.Linq this.dueTimeFrameCount = dueTimeFrameCount; this.periodFrameCount = periodFrameCount; this.cancellationToken = cancellationToken; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (_TimerFrame)state; + source.completionSource.TrySetCanceled(source.cancellationToken); + }, this); + } TaskTracker.TrackActiveTask(this, 2); PlayerLoopHelper.AddAction(updateTiming, this); @@ -234,13 +269,15 @@ namespace Cysharp.Threading.Tasks.Linq public UniTask MoveNextAsync() { - // return false instead of throw - if (disposed || cancellationToken.IsCancellationRequested || completed) return CompletedTasks.False; + if (disposed || completed) return CompletedTasks.False; + if (cancellationToken.IsCancellationRequested) + { + completionSource.TrySetCanceled(cancellationToken); + } // reset value here. this.currentFrame = 0; - completionSource.Reset(); return new UniTask(this, completionSource.Version); } @@ -249,6 +286,7 @@ namespace Cysharp.Threading.Tasks.Linq { if (!disposed) { + cancellationTokenRegistration.Dispose(); disposed = true; TaskTracker.RemoveTracking(this); } @@ -257,7 +295,12 @@ namespace Cysharp.Threading.Tasks.Linq public bool MoveNext() { - if (disposed || cancellationToken.IsCancellationRequested) + if (cancellationToken.IsCancellationRequested) + { + completionSource.TrySetCanceled(cancellationToken); + return false; + } + if (disposed) { completionSource.TrySetResult(false); return false; diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Delay.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Delay.cs index fadd63c..7f02a1a 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Delay.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Delay.cs @@ -33,14 +33,14 @@ namespace Cysharp.Threading.Tasks return new YieldAwaitable(timing); } - public static UniTask Yield(CancellationToken cancellationToken) + public static UniTask Yield(CancellationToken cancellationToken, bool cancelImmediately = false) { - return new UniTask(YieldPromise.Create(PlayerLoopTiming.Update, cancellationToken, out var token), token); + return new UniTask(YieldPromise.Create(PlayerLoopTiming.Update, cancellationToken, cancelImmediately, out var token), token); } - public static UniTask Yield(PlayerLoopTiming timing, CancellationToken cancellationToken) + public static UniTask Yield(PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately = false) { - return new UniTask(YieldPromise.Create(timing, cancellationToken, out var token), token); + return new UniTask(YieldPromise.Create(timing, cancellationToken, cancelImmediately, out var token), token); } /// @@ -48,7 +48,7 @@ namespace Cysharp.Threading.Tasks /// public static UniTask NextFrame() { - return new UniTask(NextFramePromise.Create(PlayerLoopTiming.Update, CancellationToken.None, out var token), token); + return new UniTask(NextFramePromise.Create(PlayerLoopTiming.Update, CancellationToken.None, false, out var token), token); } /// @@ -56,23 +56,23 @@ namespace Cysharp.Threading.Tasks /// public static UniTask NextFrame(PlayerLoopTiming timing) { - return new UniTask(NextFramePromise.Create(timing, CancellationToken.None, out var token), token); + return new UniTask(NextFramePromise.Create(timing, CancellationToken.None, false, out var token), token); } /// /// Similar as UniTask.Yield but guaranteed run on next frame. /// - public static UniTask NextFrame(CancellationToken cancellationToken) + public static UniTask NextFrame(CancellationToken cancellationToken, bool cancelImmediately = false) { - return new UniTask(NextFramePromise.Create(PlayerLoopTiming.Update, cancellationToken, out var token), token); + return new UniTask(NextFramePromise.Create(PlayerLoopTiming.Update, cancellationToken, cancelImmediately, out var token), token); } /// /// Similar as UniTask.Yield but guaranteed run on next frame. /// - public static UniTask NextFrame(PlayerLoopTiming timing, CancellationToken cancellationToken) + public static UniTask NextFrame(PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately = false) { - return new UniTask(NextFramePromise.Create(timing, cancellationToken, out var token), token); + return new UniTask(NextFramePromise.Create(timing, cancellationToken, cancelImmediately, out var token), token); } #if UNITY_2023_1_OR_NEWER @@ -88,15 +88,21 @@ namespace Cysharp.Threading.Tasks } [Obsolete("Use WaitForEndOfFrame(MonoBehaviour) instead or UniTask.Yield(PlayerLoopTiming.LastPostLateUpdate). Equivalent for coroutine's WaitForEndOfFrame requires MonoBehaviour(runner of Coroutine).")] - public static UniTask WaitForEndOfFrame(CancellationToken cancellationToken) + public static UniTask WaitForEndOfFrame(CancellationToken cancellationToken, bool cancelImmediately = false) { - return UniTask.Yield(PlayerLoopTiming.LastPostLateUpdate, cancellationToken); + return UniTask.Yield(PlayerLoopTiming.LastPostLateUpdate, cancellationToken, cancelImmediately); } #endif - public static UniTask WaitForEndOfFrame(MonoBehaviour coroutineRunner, CancellationToken cancellationToken = default) + public static UniTask WaitForEndOfFrame(MonoBehaviour coroutineRunner) { - var source = WaitForEndOfFramePromise.Create(coroutineRunner, cancellationToken, out var token); + var source = WaitForEndOfFramePromise.Create(coroutineRunner, CancellationToken.None, false, out var token); + return new UniTask(source, token); + } + + public static UniTask WaitForEndOfFrame(MonoBehaviour coroutineRunner, CancellationToken cancellationToken, bool cancelImmediately = false) + { + var source = WaitForEndOfFramePromise.Create(coroutineRunner, cancellationToken, cancelImmediately, out var token); return new UniTask(source, token); } @@ -113,50 +119,50 @@ namespace Cysharp.Threading.Tasks /// /// Same as UniTask.Yield(PlayerLoopTiming.LastFixedUpdate, cancellationToken). /// - public static UniTask WaitForFixedUpdate(CancellationToken cancellationToken) + public static UniTask WaitForFixedUpdate(CancellationToken cancellationToken, bool cancelImmediately = false) { - return UniTask.Yield(PlayerLoopTiming.LastFixedUpdate, cancellationToken); + return UniTask.Yield(PlayerLoopTiming.LastFixedUpdate, cancellationToken, cancelImmediately); } - public static UniTask WaitForSeconds(float duration, bool ignoreTimeScale = false, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WaitForSeconds(float duration, bool ignoreTimeScale = false, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { - return Delay(Mathf.RoundToInt(1000 * duration), ignoreTimeScale, delayTiming, cancellationToken); + return Delay(Mathf.RoundToInt(1000 * duration), ignoreTimeScale, delayTiming, cancellationToken, cancelImmediately); } - public static UniTask WaitForSeconds(int duration, bool ignoreTimeScale = false, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WaitForSeconds(int duration, bool ignoreTimeScale = false, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { - return Delay(1000 * duration, ignoreTimeScale, delayTiming, cancellationToken); + return Delay(1000 * duration, ignoreTimeScale, delayTiming, cancellationToken, cancelImmediately); } - public static UniTask DelayFrame(int delayFrameCount, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask DelayFrame(int delayFrameCount, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { if (delayFrameCount < 0) { throw new ArgumentOutOfRangeException("Delay does not allow minus delayFrameCount. delayFrameCount:" + delayFrameCount); } - return new UniTask(DelayFramePromise.Create(delayFrameCount, delayTiming, cancellationToken, out var token), token); + return new UniTask(DelayFramePromise.Create(delayFrameCount, delayTiming, cancellationToken, cancelImmediately, out var token), token); } - public static UniTask Delay(int millisecondsDelay, bool ignoreTimeScale = false, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask Delay(int millisecondsDelay, bool ignoreTimeScale = false, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { var delayTimeSpan = TimeSpan.FromMilliseconds(millisecondsDelay); - return Delay(delayTimeSpan, ignoreTimeScale, delayTiming, cancellationToken); + return Delay(delayTimeSpan, ignoreTimeScale, delayTiming, cancellationToken, cancelImmediately); } - public static UniTask Delay(TimeSpan delayTimeSpan, bool ignoreTimeScale = false, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask Delay(TimeSpan delayTimeSpan, bool ignoreTimeScale = false, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { var delayType = ignoreTimeScale ? DelayType.UnscaledDeltaTime : DelayType.DeltaTime; - return Delay(delayTimeSpan, delayType, delayTiming, cancellationToken); + return Delay(delayTimeSpan, delayType, delayTiming, cancellationToken, cancelImmediately); } - public static UniTask Delay(int millisecondsDelay, DelayType delayType, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask Delay(int millisecondsDelay, DelayType delayType, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { var delayTimeSpan = TimeSpan.FromMilliseconds(millisecondsDelay); - return Delay(delayTimeSpan, delayType, delayTiming, cancellationToken); + return Delay(delayTimeSpan, delayType, delayTiming, cancellationToken, cancelImmediately); } - public static UniTask Delay(TimeSpan delayTimeSpan, DelayType delayType, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask Delay(TimeSpan delayTimeSpan, DelayType delayType, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { if (delayTimeSpan < TimeSpan.Zero) { @@ -175,16 +181,16 @@ namespace Cysharp.Threading.Tasks { case DelayType.UnscaledDeltaTime: { - return new UniTask(DelayIgnoreTimeScalePromise.Create(delayTimeSpan, delayTiming, cancellationToken, out var token), token); + return new UniTask(DelayIgnoreTimeScalePromise.Create(delayTimeSpan, delayTiming, cancellationToken, cancelImmediately, out var token), token); } case DelayType.Realtime: { - return new UniTask(DelayRealtimePromise.Create(delayTimeSpan, delayTiming, cancellationToken, out var token), token); + return new UniTask(DelayRealtimePromise.Create(delayTimeSpan, delayTiming, cancellationToken, cancelImmediately, out var token), token); } case DelayType.DeltaTime: default: { - return new UniTask(DelayPromise.Create(delayTimeSpan, delayTiming, cancellationToken, out var token), token); + return new UniTask(DelayPromise.Create(delayTimeSpan, delayTiming, cancellationToken, cancelImmediately, out var token), token); } } } @@ -201,13 +207,14 @@ namespace Cysharp.Threading.Tasks } CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; YieldPromise() { } - public static IUniTaskSource Create(PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -219,8 +226,16 @@ namespace Cysharp.Threading.Tasks result = new YieldPromise(); } - result.cancellationToken = cancellationToken; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (YieldPromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -274,6 +289,7 @@ namespace Cysharp.Threading.Tasks TaskTracker.RemoveTracking(this); core.Reset(); cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -290,14 +306,15 @@ namespace Cysharp.Threading.Tasks } int frameCount; - CancellationToken cancellationToken; UniTaskCompletionSourceCore core; + CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; NextFramePromise() { } - public static IUniTaskSource Create(PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -312,6 +329,15 @@ namespace Cysharp.Threading.Tasks result.frameCount = PlayerLoopHelper.IsMainThread ? Time.frameCount : -1; result.cancellationToken = cancellationToken; + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (NextFramePromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } + TaskTracker.TrackActiveTask(result, 3); PlayerLoopHelper.AddAction(timing, result); @@ -369,6 +395,7 @@ namespace Cysharp.Threading.Tasks TaskTracker.RemoveTracking(this); core.Reset(); cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -384,14 +411,15 @@ namespace Cysharp.Threading.Tasks TaskPool.RegisterSizeGetter(typeof(WaitForEndOfFramePromise), () => pool.Size); } - CancellationToken cancellationToken; UniTaskCompletionSourceCore core; + CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; WaitForEndOfFramePromise() { } - public static IUniTaskSource Create(MonoBehaviour coroutineRunner, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(MonoBehaviour coroutineRunner, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -405,6 +433,15 @@ namespace Cysharp.Threading.Tasks result.cancellationToken = cancellationToken; + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (WaitForEndOfFramePromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } + TaskTracker.TrackActiveTask(result, 3); coroutineRunner.StartCoroutine(result); @@ -446,6 +483,7 @@ namespace Cysharp.Threading.Tasks core.Reset(); Reset(); // Reset Enumerator cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } @@ -494,6 +532,7 @@ namespace Cysharp.Threading.Tasks int initialFrame; int delayFrameCount; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; int currentFrameCount; UniTaskCompletionSourceCore core; @@ -502,7 +541,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(int delayFrameCount, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(int delayFrameCount, PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -518,6 +557,15 @@ namespace Cysharp.Threading.Tasks result.cancellationToken = cancellationToken; result.initialFrame = PlayerLoopHelper.IsMainThread ? Time.frameCount : -1; + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (DelayFramePromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } + TaskTracker.TrackActiveTask(result, 3); PlayerLoopHelper.AddAction(timing, result); @@ -604,6 +652,7 @@ namespace Cysharp.Threading.Tasks currentFrameCount = default; delayFrameCount = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -623,6 +672,7 @@ namespace Cysharp.Threading.Tasks float delayTimeSpan; float elapsed; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; @@ -630,7 +680,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(TimeSpan delayTimeSpan, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(TimeSpan delayTimeSpan, PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -647,6 +697,15 @@ namespace Cysharp.Threading.Tasks result.cancellationToken = cancellationToken; result.initialFrame = PlayerLoopHelper.IsMainThread ? Time.frameCount : -1; + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (DelayPromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } + TaskTracker.TrackActiveTask(result, 3); PlayerLoopHelper.AddAction(timing, result); @@ -715,6 +774,7 @@ namespace Cysharp.Threading.Tasks delayTimeSpan = default; elapsed = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -734,6 +794,7 @@ namespace Cysharp.Threading.Tasks float elapsed; int initialFrame; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; @@ -741,7 +802,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(TimeSpan delayFrameTimeSpan, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(TimeSpan delayFrameTimeSpan, PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -758,6 +819,15 @@ namespace Cysharp.Threading.Tasks result.initialFrame = PlayerLoopHelper.IsMainThread ? Time.frameCount : -1; result.cancellationToken = cancellationToken; + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (DelayIgnoreTimeScalePromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } + TaskTracker.TrackActiveTask(result, 3); PlayerLoopHelper.AddAction(timing, result); @@ -826,6 +896,7 @@ namespace Cysharp.Threading.Tasks delayFrameTimeSpan = default; elapsed = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -844,6 +915,7 @@ namespace Cysharp.Threading.Tasks long delayTimeSpanTicks; ValueStopwatch stopwatch; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; @@ -851,7 +923,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(TimeSpan delayTimeSpan, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(TimeSpan delayTimeSpan, PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -867,6 +939,15 @@ namespace Cysharp.Threading.Tasks result.delayTimeSpanTicks = delayTimeSpan.Ticks; result.cancellationToken = cancellationToken; + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (DelayRealtimePromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } + TaskTracker.TrackActiveTask(result, 3); PlayerLoopHelper.AddAction(timing, result); @@ -931,6 +1012,7 @@ namespace Cysharp.Threading.Tasks core.Reset(); stopwatch = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.WaitUntil.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.WaitUntil.cs index 0a09fe0..b28a529 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.WaitUntil.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.WaitUntil.cs @@ -9,30 +9,30 @@ namespace Cysharp.Threading.Tasks { public partial struct UniTask { - public static UniTask WaitUntil(Func predicate, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WaitUntil(Func predicate, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { - return new UniTask(WaitUntilPromise.Create(predicate, timing, cancellationToken, out var token), token); + return new UniTask(WaitUntilPromise.Create(predicate, timing, cancellationToken, cancelImmediately, out var token), token); } - public static UniTask WaitWhile(Func predicate, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WaitWhile(Func predicate, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { - return new UniTask(WaitWhilePromise.Create(predicate, timing, cancellationToken, out var token), token); + return new UniTask(WaitWhilePromise.Create(predicate, timing, cancellationToken, cancelImmediately, out var token), token); } - public static UniTask WaitUntilCanceled(CancellationToken cancellationToken, PlayerLoopTiming timing = PlayerLoopTiming.Update) + public static UniTask WaitUntilCanceled(CancellationToken cancellationToken, PlayerLoopTiming timing = PlayerLoopTiming.Update, bool completeImmediately = false) { - return new UniTask(WaitUntilCanceledPromise.Create(cancellationToken, timing, out var token), token); + return new UniTask(WaitUntilCanceledPromise.Create(cancellationToken, timing, completeImmediately, out var token), token); } - public static UniTask WaitUntilValueChanged(T target, Func monitorFunction, PlayerLoopTiming monitorTiming = PlayerLoopTiming.Update, IEqualityComparer equalityComparer = null, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WaitUntilValueChanged(T target, Func monitorFunction, PlayerLoopTiming monitorTiming = PlayerLoopTiming.Update, IEqualityComparer equalityComparer = null, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) where T : class { var unityObject = target as UnityEngine.Object; var isUnityObject = target is UnityEngine.Object; // don't use (unityObject == null) return new UniTask(isUnityObject - ? WaitUntilValueChangedUnityObjectPromise.Create(target, monitorFunction, equalityComparer, monitorTiming, cancellationToken, out var token) - : WaitUntilValueChangedStandardObjectPromise.Create(target, monitorFunction, equalityComparer, monitorTiming, cancellationToken, out token), token); + ? WaitUntilValueChangedUnityObjectPromise.Create(target, monitorFunction, equalityComparer, monitorTiming, cancellationToken, cancelImmediately, out var token) + : WaitUntilValueChangedStandardObjectPromise.Create(target, monitorFunction, equalityComparer, monitorTiming, cancellationToken, cancelImmediately, out token), token); } sealed class WaitUntilPromise : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode @@ -48,6 +48,7 @@ namespace Cysharp.Threading.Tasks Func predicate; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; @@ -55,7 +56,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(Func predicate, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(Func predicate, PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -70,6 +71,15 @@ namespace Cysharp.Threading.Tasks result.predicate = predicate; result.cancellationToken = cancellationToken; + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (WaitUntilPromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } + TaskTracker.TrackActiveTask(result, 3); PlayerLoopHelper.AddAction(timing, result); @@ -136,6 +146,7 @@ namespace Cysharp.Threading.Tasks core.Reset(); predicate = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -153,6 +164,7 @@ namespace Cysharp.Threading.Tasks Func predicate; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; @@ -160,7 +172,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(Func predicate, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(Func predicate, PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -174,6 +186,15 @@ namespace Cysharp.Threading.Tasks result.predicate = predicate; result.cancellationToken = cancellationToken; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (WaitWhilePromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -241,6 +262,7 @@ namespace Cysharp.Threading.Tasks core.Reset(); predicate = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -257,6 +279,7 @@ namespace Cysharp.Threading.Tasks } CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; @@ -264,7 +287,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(CancellationToken cancellationToken, PlayerLoopTiming timing, out short token) + public static IUniTaskSource Create(CancellationToken cancellationToken, PlayerLoopTiming timing, bool completeImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -278,6 +301,15 @@ namespace Cysharp.Threading.Tasks result.cancellationToken = cancellationToken; + if (completeImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (WaitUntilCanceledPromise)state; + promise.core.TrySetResult(null); + }, result); + } + TaskTracker.TrackActiveTask(result, 3); PlayerLoopHelper.AddAction(timing, result); @@ -329,6 +361,7 @@ namespace Cysharp.Threading.Tasks TaskTracker.RemoveTracking(this); core.Reset(); cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -351,6 +384,7 @@ namespace Cysharp.Threading.Tasks Func monitorFunction; IEqualityComparer equalityComparer; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; @@ -358,7 +392,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(T target, Func monitorFunction, IEqualityComparer equalityComparer, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(T target, Func monitorFunction, IEqualityComparer equalityComparer, PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -376,6 +410,15 @@ namespace Cysharp.Threading.Tasks result.currentValue = monitorFunction(target); result.equalityComparer = equalityComparer ?? UnityEqualityComparer.GetDefault(); result.cancellationToken = cancellationToken; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (WaitUntilValueChangedUnityObjectPromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -453,6 +496,7 @@ namespace Cysharp.Threading.Tasks monitorFunction = default; equalityComparer = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -474,6 +518,7 @@ namespace Cysharp.Threading.Tasks Func monitorFunction; IEqualityComparer equalityComparer; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; @@ -481,7 +526,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(T target, Func monitorFunction, IEqualityComparer equalityComparer, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(T target, Func monitorFunction, IEqualityComparer equalityComparer, PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -498,6 +543,15 @@ namespace Cysharp.Threading.Tasks result.currentValue = monitorFunction(target); result.equalityComparer = equalityComparer ?? UnityEqualityComparer.GetDefault(); result.cancellationToken = cancellationToken; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (WaitUntilValueChangedStandardObjectPromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -575,6 +629,7 @@ namespace Cysharp.Threading.Tasks monitorFunction = default; equalityComparer = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.AssetBundleRequestAllAssets.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.AssetBundleRequestAllAssets.cs index 1a1e011..5d34692 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.AssetBundleRequestAllAssets.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.AssetBundleRequestAllAssets.cs @@ -24,12 +24,17 @@ namespace Cysharp.Threading.Tasks return AwaitForAllAssets(asyncOperation, null, PlayerLoopTiming.Update, cancellationToken: cancellationToken); } - public static UniTask AwaitForAllAssets(this AssetBundleRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask AwaitForAllAssets(this AssetBundleRequest asyncOperation, CancellationToken cancellationToken, bool cancelImmediately) + { + return AwaitForAllAssets(asyncOperation, progress: null, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + } + + public static UniTask AwaitForAllAssets(this AssetBundleRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); if (asyncOperation.isDone) return UniTask.FromResult(asyncOperation.allAssets); - return new UniTask(AssetBundleRequestAllAssetsConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); + return new UniTask(AssetBundleRequestAllAssetsConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, cancelImmediately, out var token), token); } public struct AssetBundleRequestAllAssetsAwaiter : ICriticalNotifyCompletion @@ -95,15 +100,19 @@ namespace Cysharp.Threading.Tasks AssetBundleRequest asyncOperation; IProgress progress; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + bool completed; UniTaskCompletionSourceCore core; + Action continuationAction; + AssetBundleRequestAllAssetsConfiguredSource() { - + continuationAction = Continuation; } - public static IUniTaskSource Create(AssetBundleRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(AssetBundleRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -118,6 +127,18 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.progress = progress; result.cancellationToken = cancellationToken; + result.completed = false; + + asyncOperation.completed += result.continuationAction; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (AssetBundleRequestAllAssetsConfiguredSource)state; + source.core.TrySetCanceled(source.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -161,6 +182,12 @@ namespace Cysharp.Threading.Tasks public bool MoveNext() { + // Already completed + if (completed || asyncOperation == null) + { + return false; + } + if (cancellationToken.IsCancellationRequested) { core.TrySetCanceled(cancellationToken); @@ -188,8 +215,29 @@ namespace Cysharp.Threading.Tasks asyncOperation = default; progress = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } + + void Continuation(AsyncOperation _) + { + if (completed) + { + TryReturn(); + } + else + { + completed = true; + if (cancellationToken.IsCancellationRequested) + { + core.TrySetCanceled(cancellationToken); + } + else + { + core.TrySetResult(asyncOperation.allAssets); + } + } + } } } } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.AsyncGPUReadback.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.AsyncGPUReadback.cs index 5805dbb..5d73dc1 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.AsyncGPUReadback.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.AsyncGPUReadback.cs @@ -20,12 +20,17 @@ namespace Cysharp.Threading.Tasks return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this AsyncGPUReadbackRequest asyncOperation, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WithCancellation(this AsyncGPUReadbackRequest asyncOperation, CancellationToken cancellationToken, bool cancelImmediately) { - if (asyncOperation.done) return UniTask.FromResult(asyncOperation); - return new UniTask(AsyncGPUReadbackRequestAwaiterConfiguredSource.Create(asyncOperation, timing, cancellationToken, out var token), token); + return ToUniTask(asyncOperation, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); } + public static UniTask ToUniTask(this AsyncGPUReadbackRequest asyncOperation, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) + { + if (asyncOperation.done) return UniTask.FromResult(asyncOperation); + return new UniTask(AsyncGPUReadbackRequestAwaiterConfiguredSource.Create(asyncOperation, timing, cancellationToken, cancelImmediately, out var token), token); + } + sealed class AsyncGPUReadbackRequestAwaiterConfiguredSource : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode { static TaskPool pool; @@ -39,15 +44,15 @@ namespace Cysharp.Threading.Tasks AsyncGPUReadbackRequest asyncOperation; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; AsyncGPUReadbackRequestAwaiterConfiguredSource() { - } - public static IUniTaskSource Create(AsyncGPUReadbackRequest asyncOperation, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(AsyncGPUReadbackRequest asyncOperation, PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -61,6 +66,15 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.cancellationToken = cancellationToken; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (AsyncGPUReadbackRequestAwaiterConfiguredSource)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -131,6 +145,7 @@ namespace Cysharp.Threading.Tasks core.Reset(); asyncOperation = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.cs index 00afe66..b9cd1c9 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.cs @@ -17,7 +17,6 @@ namespace Cysharp.Threading.Tasks #if !UNITY_2023_1_OR_NEWER // from Unity2023.1.0a15, AsyncOperationAwaitableExtensions.GetAwaiter is defined in UnityEngine. - public static AsyncOperationAwaiter GetAwaiter(this AsyncOperation asyncOperation) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); @@ -30,12 +29,17 @@ namespace Cysharp.Threading.Tasks return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this AsyncOperation asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WithCancellation(this AsyncOperation asyncOperation, CancellationToken cancellationToken, bool cancelImmediately) + { + return ToUniTask(asyncOperation, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + } + + public static UniTask ToUniTask(this AsyncOperation asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); if (asyncOperation.isDone) return UniTask.CompletedTask; - return new UniTask(AsyncOperationConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); + return new UniTask(AsyncOperationConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, cancelImmediately, out var token), token); } public struct AsyncOperationAwaiter : ICriticalNotifyCompletion @@ -92,15 +96,19 @@ namespace Cysharp.Threading.Tasks AsyncOperation asyncOperation; IProgress progress; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + bool completed; UniTaskCompletionSourceCore core; + Action continuationAction; + AsyncOperationConfiguredSource() { - + continuationAction = Continuation; } - public static IUniTaskSource Create(AsyncOperation asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(AsyncOperation asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -115,6 +123,18 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.progress = progress; result.cancellationToken = cancellationToken; + result.completed = false; + + asyncOperation.completed += result.continuationAction; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (AsyncOperationConfiguredSource)state; + source.core.TrySetCanceled(source.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -154,6 +174,12 @@ namespace Cysharp.Threading.Tasks public bool MoveNext() { + // Already completed + if (completed || asyncOperation == null) + { + return false; + } + if (cancellationToken.IsCancellationRequested) { core.TrySetCanceled(cancellationToken); @@ -178,11 +204,33 @@ namespace Cysharp.Threading.Tasks { TaskTracker.RemoveTracking(this); core.Reset(); + asyncOperation.completed -= continuationAction; asyncOperation = default; progress = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } + + void Continuation(AsyncOperation _) + { + if (completed) + { + TryReturn(); + } + else + { + completed = true; + if (cancellationToken.IsCancellationRequested) + { + core.TrySetCanceled(cancellationToken); + } + else + { + core.TrySetResult(AsyncUnit.Default); + } + } + } } #endregion @@ -200,12 +248,17 @@ namespace Cysharp.Threading.Tasks return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this ResourceRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WithCancellation(this ResourceRequest asyncOperation, CancellationToken cancellationToken, bool cancelImmediately) + { + return ToUniTask(asyncOperation, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + } + + public static UniTask ToUniTask(this ResourceRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); if (asyncOperation.isDone) return UniTask.FromResult(asyncOperation.asset); - return new UniTask(ResourceRequestConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); + return new UniTask(ResourceRequestConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, cancelImmediately, out var token), token); } public struct ResourceRequestAwaiter : ICriticalNotifyCompletion @@ -266,15 +319,19 @@ namespace Cysharp.Threading.Tasks ResourceRequest asyncOperation; IProgress progress; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + bool completed; UniTaskCompletionSourceCore core; + Action continuationAction; + ResourceRequestConfiguredSource() { - + continuationAction = Continuation; } - public static IUniTaskSource Create(ResourceRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(ResourceRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -289,6 +346,18 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.progress = progress; result.cancellationToken = cancellationToken; + result.completed = false; + + asyncOperation.completed += result.continuationAction; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (ResourceRequestConfiguredSource)state; + source.core.TrySetCanceled(source.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -332,6 +401,12 @@ namespace Cysharp.Threading.Tasks public bool MoveNext() { + // Already completed + if (completed || asyncOperation == null) + { + return false; + } + if (cancellationToken.IsCancellationRequested) { core.TrySetCanceled(cancellationToken); @@ -356,11 +431,33 @@ namespace Cysharp.Threading.Tasks { TaskTracker.RemoveTracking(this); core.Reset(); + asyncOperation.completed -= continuationAction; asyncOperation = default; progress = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } + + void Continuation(AsyncOperation _) + { + if (completed) + { + TryReturn(); + } + else + { + completed = true; + if (cancellationToken.IsCancellationRequested) + { + core.TrySetCanceled(cancellationToken); + } + else + { + core.TrySetResult(asyncOperation.asset); + } + } + } } #endregion @@ -379,12 +476,17 @@ namespace Cysharp.Threading.Tasks return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this AssetBundleRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WithCancellation(this AssetBundleRequest asyncOperation, CancellationToken cancellationToken, bool cancelImmediately) + { + return ToUniTask(asyncOperation, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + } + + public static UniTask ToUniTask(this AssetBundleRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); if (asyncOperation.isDone) return UniTask.FromResult(asyncOperation.asset); - return new UniTask(AssetBundleRequestConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); + return new UniTask(AssetBundleRequestConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, cancelImmediately, out var token), token); } public struct AssetBundleRequestAwaiter : ICriticalNotifyCompletion @@ -445,15 +547,19 @@ namespace Cysharp.Threading.Tasks AssetBundleRequest asyncOperation; IProgress progress; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + bool completed; UniTaskCompletionSourceCore core; + Action continuationAction; + AssetBundleRequestConfiguredSource() { - + continuationAction = Continuation; } - public static IUniTaskSource Create(AssetBundleRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(AssetBundleRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -468,6 +574,18 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.progress = progress; result.cancellationToken = cancellationToken; + result.completed = false; + + asyncOperation.completed += result.continuationAction; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (AssetBundleRequestConfiguredSource)state; + source.core.TrySetCanceled(source.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -511,6 +629,12 @@ namespace Cysharp.Threading.Tasks public bool MoveNext() { + // Already completed + if (completed || asyncOperation == null) + { + return false; + } + if (cancellationToken.IsCancellationRequested) { core.TrySetCanceled(cancellationToken); @@ -535,11 +659,33 @@ namespace Cysharp.Threading.Tasks { TaskTracker.RemoveTracking(this); core.Reset(); + asyncOperation.completed -= continuationAction; asyncOperation = default; progress = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } + + void Continuation(AsyncOperation _) + { + if (completed) + { + TryReturn(); + } + else + { + completed = true; + if (cancellationToken.IsCancellationRequested) + { + core.TrySetCanceled(cancellationToken); + } + else + { + core.TrySetResult(asyncOperation.asset); + } + } + } } #endregion @@ -559,12 +705,17 @@ namespace Cysharp.Threading.Tasks return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this AssetBundleCreateRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WithCancellation(this AssetBundleCreateRequest asyncOperation, CancellationToken cancellationToken, bool cancelImmediately) + { + return ToUniTask(asyncOperation, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + } + + public static UniTask ToUniTask(this AssetBundleCreateRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); if (asyncOperation.isDone) return UniTask.FromResult(asyncOperation.assetBundle); - return new UniTask(AssetBundleCreateRequestConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); + return new UniTask(AssetBundleCreateRequestConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, cancelImmediately, out var token), token); } public struct AssetBundleCreateRequestAwaiter : ICriticalNotifyCompletion @@ -625,15 +776,19 @@ namespace Cysharp.Threading.Tasks AssetBundleCreateRequest asyncOperation; IProgress progress; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + bool completed; UniTaskCompletionSourceCore core; + Action continuationAction; + AssetBundleCreateRequestConfiguredSource() { - + continuationAction = Continuation; } - public static IUniTaskSource Create(AssetBundleCreateRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(AssetBundleCreateRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -648,6 +803,18 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.progress = progress; result.cancellationToken = cancellationToken; + result.completed = false; + + asyncOperation.completed += result.continuationAction; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (AssetBundleCreateRequestConfiguredSource)state; + source.core.TrySetCanceled(source.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -691,6 +858,12 @@ namespace Cysharp.Threading.Tasks public bool MoveNext() { + // Already completed + if (completed || asyncOperation == null) + { + return false; + } + if (cancellationToken.IsCancellationRequested) { core.TrySetCanceled(cancellationToken); @@ -715,11 +888,33 @@ namespace Cysharp.Threading.Tasks { TaskTracker.RemoveTracking(this); core.Reset(); + asyncOperation.completed -= continuationAction; asyncOperation = default; progress = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } + + void Continuation(AsyncOperation _) + { + if (completed) + { + TryReturn(); + } + else + { + completed = true; + if (cancellationToken.IsCancellationRequested) + { + core.TrySetCanceled(cancellationToken); + } + else + { + core.TrySetResult(asyncOperation.assetBundle); + } + } + } } #endregion @@ -739,7 +934,12 @@ namespace Cysharp.Threading.Tasks return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this UnityWebRequestAsyncOperation asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WithCancellation(this UnityWebRequestAsyncOperation asyncOperation, CancellationToken cancellationToken, bool cancelImmediately) + { + return ToUniTask(asyncOperation, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + } + + public static UniTask ToUniTask(this UnityWebRequestAsyncOperation asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); @@ -751,7 +951,7 @@ namespace Cysharp.Threading.Tasks } return UniTask.FromResult(asyncOperation.webRequest); } - return new UniTask(UnityWebRequestAsyncOperationConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); + return new UniTask(UnityWebRequestAsyncOperationConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, cancelImmediately, out var token), token); } public struct UnityWebRequestAsyncOperationAwaiter : ICriticalNotifyCompletion @@ -820,15 +1020,19 @@ namespace Cysharp.Threading.Tasks UnityWebRequestAsyncOperation asyncOperation; IProgress progress; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + bool completed; UniTaskCompletionSourceCore core; + Action continuationAction; + UnityWebRequestAsyncOperationConfiguredSource() { - + continuationAction = Continuation; } - public static IUniTaskSource Create(UnityWebRequestAsyncOperation asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(UnityWebRequestAsyncOperation asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -843,6 +1047,19 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.progress = progress; result.cancellationToken = cancellationToken; + result.completed = false; + + asyncOperation.completed += result.continuationAction; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (UnityWebRequestAsyncOperationConfiguredSource)state; + source.asyncOperation.webRequest.Abort(); + source.core.TrySetCanceled(source.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -886,6 +1103,12 @@ namespace Cysharp.Threading.Tasks public bool MoveNext() { + // Already completed + if (completed || asyncOperation == null) + { + return false; + } + if (cancellationToken.IsCancellationRequested) { asyncOperation.webRequest.Abort(); @@ -900,11 +1123,7 @@ namespace Cysharp.Threading.Tasks if (asyncOperation.isDone) { - if (asyncOperation.webRequest == null) - { - core.TrySetException(new ObjectDisposedException("The webRequest has been destroyed.")); - } - else if (asyncOperation.webRequest.IsError()) + if (asyncOperation.webRequest.IsError()) { core.TrySetException(new UnityWebRequestException(asyncOperation.webRequest)); } @@ -922,11 +1141,37 @@ namespace Cysharp.Threading.Tasks { TaskTracker.RemoveTracking(this); core.Reset(); + asyncOperation.completed -= continuationAction; asyncOperation = default; progress = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } + + void Continuation(AsyncOperation _) + { + if (completed) + { + TryReturn(); + } + else + { + completed = true; + if (cancellationToken.IsCancellationRequested) + { + core.TrySetCanceled(cancellationToken); + } + else if (asyncOperation.webRequest.IsError()) + { + core.TrySetException(new UnityWebRequestException(asyncOperation.webRequest)); + } + else + { + core.TrySetResult(asyncOperation.webRequest); + } + } + } } #endregion diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.tt b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.tt index 65dac9e..0516fef 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.tt +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.tt @@ -16,6 +16,7 @@ Func ToUniTaskReturnType = x => (x == "void") ? "UniTask" : $"UniTask<{x}>"; Func ToIUniTaskSourceReturnType = x => (x == "void") ? "IUniTaskSource" : $"IUniTaskSource<{x}>"; + Func<(string typeName, string returnType, string returnField), bool> IsAsyncOperationBase = x => x.typeName == "AsyncOperation"; Func<(string typeName, string returnType, string returnField), bool> IsUnityWebRequest = x => x.returnType == "UnityWebRequest"; Func<(string typeName, string returnType, string returnField), bool> IsAssetBundleModule = x => x.typeName == "AssetBundleRequest" || x.typeName == "AssetBundleCreateRequest"; Func<(string typeName, string returnType, string returnField), bool> IsVoid = x => x.returnType == "void"; @@ -43,18 +44,30 @@ namespace Cysharp.Threading.Tasks <# } #> #region <#= t.typeName #> +<# if (IsAsyncOperationBase(t)) { #> +#if !UNITY_2023_1_OR_NEWER + // from Unity2023.1.0a15, AsyncOperationAwaitableExtensions.GetAwaiter is defined in UnityEngine. +<# } #> public static <#= t.typeName #>Awaiter GetAwaiter(this <#= t.typeName #> asyncOperation) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); return new <#= t.typeName #>Awaiter(asyncOperation); } +<# if (IsAsyncOperationBase(t)) { #> +#endif +<# } #> public static <#= ToUniTaskReturnType(t.returnType) #> WithCancellation(this <#= t.typeName #> asyncOperation, CancellationToken cancellationToken) { return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } - public static <#= ToUniTaskReturnType(t.returnType) #> ToUniTask(this <#= t.typeName #> asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static <#= ToUniTaskReturnType(t.returnType) #> WithCancellation(this <#= t.typeName #> asyncOperation, CancellationToken cancellationToken, bool cancelImmediately) + { + return ToUniTask(asyncOperation, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + } + + public static <#= ToUniTaskReturnType(t.returnType) #> ToUniTask(this <#= t.typeName #> asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled<#= IsVoid(t) ? "" : "<" + t.returnType + ">" #>(cancellationToken); @@ -70,7 +83,7 @@ namespace Cysharp.Threading.Tasks <# } else { #> if (asyncOperation.isDone) return <#= IsVoid(t) ? "UniTask.CompletedTask" : $"UniTask.FromResult(asyncOperation.{t.returnField})" #>; <# } #> - return new <#= ToUniTaskReturnType(t.returnType) #>(<#= t.typeName #>ConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); + return new <#= ToUniTaskReturnType(t.returnType) #>(<#= t.typeName #>ConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, cancelImmediately, out var token), token); } public struct <#= t.typeName #>Awaiter : ICriticalNotifyCompletion @@ -151,15 +164,19 @@ namespace Cysharp.Threading.Tasks <#= t.typeName #> asyncOperation; IProgress progress; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + bool completed; UniTaskCompletionSourceCore<<#= IsVoid(t) ? "AsyncUnit" : t.returnType #>> core; + Action continuationAction; + <#= t.typeName #>ConfiguredSource() { - + continuationAction = Continuation; } - public static <#= ToIUniTaskSourceReturnType(t.returnType) #> Create(<#= t.typeName #> asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) + public static <#= ToIUniTaskSourceReturnType(t.returnType) #> Create(<#= t.typeName #> asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -174,6 +191,21 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.progress = progress; result.cancellationToken = cancellationToken; + result.completed = false; + + asyncOperation.completed += result.continuationAction; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (<#= t.typeName #>ConfiguredSource)state; +<# if(IsUnityWebRequest(t)) { #> + source.asyncOperation.webRequest.Abort(); +<# } #> + source.core.TrySetCanceled(source.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -223,6 +255,12 @@ namespace Cysharp.Threading.Tasks public bool MoveNext() { + // Already completed + if (completed || asyncOperation == null) + { + return false; + } + if (cancellationToken.IsCancellationRequested) { <# if(IsUnityWebRequest(t)) { #> @@ -261,11 +299,44 @@ namespace Cysharp.Threading.Tasks { TaskTracker.RemoveTracking(this); core.Reset(); + asyncOperation.completed -= continuationAction; asyncOperation = default; progress = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } + + void Continuation(AsyncOperation _) + { + if (completed) + { + TryReturn(); + } + else + { + completed = true; + if (cancellationToken.IsCancellationRequested) + { + core.TrySetCanceled(cancellationToken); + } +<# if(IsUnityWebRequest(t)) { #> + else if (asyncOperation.webRequest.IsError()) + { + core.TrySetException(new UnityWebRequestException(asyncOperation.webRequest)); + } + else + { + core.TrySetResult(asyncOperation.webRequest); + } +<# } else { #> + else + { + core.TrySetResult(<#= IsVoid(t) ? "AsyncUnit.Default" : $"asyncOperation.{t.returnField}" #>); + } +<# } #> + } + } } #endregion diff --git a/src/UniTask/Assets/Tests/AsyncOperationTest.cs b/src/UniTask/Assets/Tests/AsyncOperationTest.cs new file mode 100644 index 0000000..0eaa188 --- /dev/null +++ b/src/UniTask/Assets/Tests/AsyncOperationTest.cs @@ -0,0 +1,90 @@ +using System.Collections; +using System.Threading; +using Cysharp.Threading.Tasks; +using FluentAssertions; +using NUnit.Framework; +using UnityEngine; +using UnityEngine.Networking; +using UnityEngine.TestTools; + +#if CSHARP_7_OR_LATER || (UNITY_2018_3_OR_NEWER && (NET_STANDARD_2_0 || NET_4_6)) +#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member + +namespace Cysharp.Threading.TasksTests +{ + public class AsyncOperationTest + { + [UnityTest] + public IEnumerator ResourcesLoad_Completed() => UniTask.ToCoroutine(async () => + { + var asyncOperation = Resources.LoadAsync("sample_texture"); + await asyncOperation.ToUniTask(); + asyncOperation.isDone.Should().BeTrue(); + asyncOperation.asset.GetType().Should().Be(typeof(Texture2D)); + }); + + [UnityTest] + public IEnumerator ResourcesLoad_CancelOnPlayerLoop() => UniTask.ToCoroutine(async () => + { + var cts = new CancellationTokenSource(); + var task = Resources.LoadAsync("sample_texture").ToUniTask(cancellationToken: cts.Token, cancelImmediately: false); + + cts.Cancel(); + task.Status.Should().Be(UniTaskStatus.Pending); + + await UniTask.NextFrame(); + task.Status.Should().Be(UniTaskStatus.Canceled); + }); + + [Test] + public void ResourcesLoad_CancelImmediately() + { + { + var cts = new CancellationTokenSource(); + var task = Resources.LoadAsync("sample_texture").ToUniTask(cancellationToken: cts.Token, cancelImmediately: true); + + cts.Cancel(); + task.Status.Should().Be(UniTaskStatus.Canceled); + } + } + +#if ENABLE_UNITYWEBREQUEST && (!UNITY_2019_1_OR_NEWER || UNITASK_WEBREQUEST_SUPPORT) + [UnityTest] + public IEnumerator UnityWebRequest_Completed() => UniTask.ToCoroutine(async () => + { + var filePath = System.IO.Path.Combine(Application.dataPath, "Tests", "Resources", "sample_texture.png"); + var asyncOperation = UnityWebRequest.Get($"file://{filePath}").SendWebRequest(); + await asyncOperation.ToUniTask(); + + asyncOperation.isDone.Should().BeTrue(); + asyncOperation.webRequest.result.Should().Be(UnityWebRequest.Result.Success); + }); + + [UnityTest] + public IEnumerator UnityWebRequest_CancelOnPlayerLoop() => UniTask.ToCoroutine(async () => + { + var cts = new CancellationTokenSource(); + var filePath = System.IO.Path.Combine(Application.dataPath, "Tests", "Resources", "sample_texture.png"); + var task = UnityWebRequest.Get($"file://{filePath}").SendWebRequest().ToUniTask(cancellationToken: cts.Token); + + cts.Cancel(); + task.Status.Should().Be(UniTaskStatus.Pending); + + await UniTask.NextFrame(); + task.Status.Should().Be(UniTaskStatus.Canceled); + }); + + [Test] + public void UnityWebRequest_CancelImmediately() + { + var cts = new CancellationTokenSource(); + cts.Cancel(); + var filePath = System.IO.Path.Combine(Application.dataPath, "Tests", "Resources", "sample_texture.png"); + var task = UnityWebRequest.Get($"file://{filePath}").SendWebRequest().ToUniTask(cancellationToken: cts.Token, cancelImmediately: true); + + task.Status.Should().Be(UniTaskStatus.Canceled); + } +#endif + } +} +#endif diff --git a/src/UniTask/Assets/Tests/AsyncOperationTest.cs.meta b/src/UniTask/Assets/Tests/AsyncOperationTest.cs.meta new file mode 100644 index 0000000..fed3f0d --- /dev/null +++ b/src/UniTask/Assets/Tests/AsyncOperationTest.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 295d574a16494d6aa4d02fcb32179e39 +timeCreated: 1698887128 \ No newline at end of file diff --git a/src/UniTask/Assets/Tests/Resources.meta b/src/UniTask/Assets/Tests/Resources.meta new file mode 100644 index 0000000..d568559 --- /dev/null +++ b/src/UniTask/Assets/Tests/Resources.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 8d82913edf6ac48aca30f66ae9ba42d6 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/src/UniTask/Assets/Tests/Resources/sample_texture.png b/src/UniTask/Assets/Tests/Resources/sample_texture.png new file mode 100644 index 0000000..2da8909 Binary files /dev/null and b/src/UniTask/Assets/Tests/Resources/sample_texture.png differ diff --git a/src/UniTask/Assets/Tests/Resources/sample_texture.png.meta b/src/UniTask/Assets/Tests/Resources/sample_texture.png.meta new file mode 100644 index 0000000..15a1de5 --- /dev/null +++ b/src/UniTask/Assets/Tests/Resources/sample_texture.png.meta @@ -0,0 +1,208 @@ +fileFormatVersion: 2 +guid: 535006a83baed4ebda99d24a909a2efe +TextureImporter: + internalIDToNameTable: + - first: + 213: -2664112245596591751 + second: sample_texture_0 + - first: + 213: -4606777057269188692 + second: sample_texture_1 + - first: + 213: 1950921086533113773 + second: sample_texture_2 + externalObjects: {} + serializedVersion: 12 + mipmaps: + mipMapMode: 0 + enableMipMap: 0 + sRGBTexture: 1 + linearTexture: 0 + fadeOut: 0 + borderMipMap: 0 + mipMapsPreserveCoverage: 0 + alphaTestReferenceValue: 0.5 + mipMapFadeDistanceStart: 1 + mipMapFadeDistanceEnd: 3 + bumpmap: + convertToNormalMap: 0 + externalNormalMap: 0 + heightScale: 0.25 + normalMapFilter: 0 + flipGreenChannel: 0 + isReadable: 0 + streamingMipmaps: 0 + streamingMipmapsPriority: 0 + vTOnly: 0 + ignoreMipmapLimit: 0 + grayScaleToAlpha: 0 + generateCubemap: 6 + cubemapConvolution: 0 + seamlessCubemap: 0 + textureFormat: 1 + maxTextureSize: 2048 + textureSettings: + serializedVersion: 2 + filterMode: 1 + aniso: 1 + mipBias: 0 + wrapU: 1 + wrapV: 1 + wrapW: 1 + nPOTScale: 0 + lightmap: 0 + compressionQuality: 50 + spriteMode: 2 + spriteExtrude: 1 + spriteMeshType: 1 + alignment: 0 + spritePivot: {x: 0.5, y: 0.5} + spritePixelsToUnits: 100 + spriteBorder: {x: 0, y: 0, z: 0, w: 0} + spriteGenerateFallbackPhysicsShape: 1 + alphaUsage: 1 + alphaIsTransparency: 1 + spriteTessellationDetail: -1 + textureType: 8 + textureShape: 1 + singleChannelComponent: 0 + flipbookRows: 1 + flipbookColumns: 1 + maxTextureSizeSet: 0 + compressionQualitySet: 0 + textureFormatSet: 0 + ignorePngGamma: 0 + applyGammaDecoding: 0 + swizzle: 50462976 + cookieLightType: 0 + platformSettings: + - serializedVersion: 3 + buildTarget: DefaultTexturePlatform + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + - serializedVersion: 3 + buildTarget: WebGL + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + - serializedVersion: 3 + buildTarget: Standalone + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + - serializedVersion: 3 + buildTarget: iPhone + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + spriteSheet: + serializedVersion: 2 + sprites: + - serializedVersion: 2 + name: sample_texture_0 + rect: + serializedVersion: 2 + x: 0 + y: 76 + width: 243 + height: 251 + alignment: 0 + pivot: {x: 0, y: 0} + border: {x: 0, y: 0, z: 0, w: 0} + outline: [] + physicsShape: [] + tessellationDetail: -1 + bones: [] + spriteID: 9796277170c270bd0800000000000000 + internalID: -2664112245596591751 + vertices: [] + indices: + edges: [] + weights: [] + - serializedVersion: 2 + name: sample_texture_1 + rect: + serializedVersion: 2 + x: 227 + y: 0 + width: 190 + height: 231 + alignment: 0 + pivot: {x: 0, y: 0} + border: {x: 0, y: 0, z: 0, w: 0} + outline: [] + physicsShape: [] + tessellationDetail: -1 + bones: [] + spriteID: ca7fc069ca07110c0800000000000000 + internalID: -4606777057269188692 + vertices: [] + indices: + edges: [] + weights: [] + - serializedVersion: 2 + name: sample_texture_2 + rect: + serializedVersion: 2 + x: 398 + y: 87 + width: 202 + height: 188 + alignment: 0 + pivot: {x: 0, y: 0} + border: {x: 0, y: 0, z: 0, w: 0} + outline: [] + physicsShape: [] + tessellationDetail: -1 + bones: [] + spriteID: da710ab4460131b10800000000000000 + internalID: 1950921086533113773 + vertices: [] + indices: + edges: [] + weights: [] + outline: [] + physicsShape: [] + bones: [] + spriteID: + internalID: 0 + vertices: [] + indices: + edges: [] + weights: [] + secondaryTextures: [] + nameFileIdTable: {} + mipmapLimitGroupName: + pSDRemoveMatte: 0 + userData: + assetBundleName: + assetBundleVariant: