diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/EveryUpdate.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/EveryUpdate.cs index 585fff9..09c5cf9 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/EveryUpdate.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/EveryUpdate.cs @@ -4,38 +4,50 @@ namespace Cysharp.Threading.Tasks.Linq { public static partial class UniTaskAsyncEnumerable { - public static IUniTaskAsyncEnumerable EveryUpdate(PlayerLoopTiming updateTiming = PlayerLoopTiming.Update) + public static IUniTaskAsyncEnumerable EveryUpdate(PlayerLoopTiming updateTiming = PlayerLoopTiming.Update, bool cancelImmediately = false) { - return new EveryUpdate(updateTiming); + return new EveryUpdate(updateTiming, cancelImmediately); } } internal class EveryUpdate : IUniTaskAsyncEnumerable { readonly PlayerLoopTiming updateTiming; + readonly bool cancelImmediately; - public EveryUpdate(PlayerLoopTiming updateTiming) + public EveryUpdate(PlayerLoopTiming updateTiming, bool cancelImmediately) { this.updateTiming = updateTiming; + this.cancelImmediately = cancelImmediately; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { - return new _EveryUpdate(updateTiming, cancellationToken); + return new _EveryUpdate(updateTiming, cancellationToken, cancelImmediately); } class _EveryUpdate : MoveNextSource, IUniTaskAsyncEnumerator, IPlayerLoopItem { readonly PlayerLoopTiming updateTiming; - CancellationToken cancellationToken; + readonly CancellationToken cancellationToken; + readonly CancellationTokenRegistration cancellationTokenRegistration; bool disposed; - public _EveryUpdate(PlayerLoopTiming updateTiming, CancellationToken cancellationToken) + public _EveryUpdate(PlayerLoopTiming updateTiming, CancellationToken cancellationToken, bool cancelImmediately) { this.updateTiming = updateTiming; this.cancellationToken = cancellationToken; + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (_EveryUpdate)state; + source.completionSource.TrySetCanceled(source.cancellationToken); + }, this); + } + TaskTracker.TrackActiveTask(this, 2); PlayerLoopHelper.AddAction(updateTiming, this); } @@ -55,6 +67,7 @@ namespace Cysharp.Threading.Tasks.Linq { if (!disposed) { + cancellationTokenRegistration.Dispose(); disposed = true; TaskTracker.RemoveTracking(this); } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/EveryValueChanged.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/EveryValueChanged.cs index f678e7a..8b19f64 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/EveryValueChanged.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/EveryValueChanged.cs @@ -7,7 +7,7 @@ namespace Cysharp.Threading.Tasks.Linq { public static partial class UniTaskAsyncEnumerable { - public static IUniTaskAsyncEnumerable EveryValueChanged(TTarget target, Func propertySelector, PlayerLoopTiming monitorTiming = PlayerLoopTiming.Update, IEqualityComparer equalityComparer = null) + public static IUniTaskAsyncEnumerable EveryValueChanged(TTarget target, Func propertySelector, PlayerLoopTiming monitorTiming = PlayerLoopTiming.Update, IEqualityComparer equalityComparer = null, bool cancelImmediately = false) where TTarget : class { var unityObject = target as UnityEngine.Object; @@ -15,11 +15,11 @@ namespace Cysharp.Threading.Tasks.Linq if (isUnityObject) { - return new EveryValueChangedUnityObject(target, propertySelector, equalityComparer ?? UnityEqualityComparer.GetDefault(), monitorTiming); + return new EveryValueChangedUnityObject(target, propertySelector, equalityComparer ?? UnityEqualityComparer.GetDefault(), monitorTiming, cancelImmediately); } else { - return new EveryValueChangedStandardObject(target, propertySelector, equalityComparer ?? UnityEqualityComparer.GetDefault(), monitorTiming); + return new EveryValueChangedStandardObject(target, propertySelector, equalityComparer ?? UnityEqualityComparer.GetDefault(), monitorTiming, cancelImmediately); } } } @@ -30,18 +30,20 @@ namespace Cysharp.Threading.Tasks.Linq readonly Func propertySelector; readonly IEqualityComparer equalityComparer; readonly PlayerLoopTiming monitorTiming; + readonly bool cancelImmediately; - public EveryValueChangedUnityObject(TTarget target, Func propertySelector, IEqualityComparer equalityComparer, PlayerLoopTiming monitorTiming) + public EveryValueChangedUnityObject(TTarget target, Func propertySelector, IEqualityComparer equalityComparer, PlayerLoopTiming monitorTiming, bool cancelImmediately) { this.target = target; this.propertySelector = propertySelector; this.equalityComparer = equalityComparer; this.monitorTiming = monitorTiming; + this.cancelImmediately = cancelImmediately; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { - return new _EveryValueChanged(target, propertySelector, equalityComparer, monitorTiming, cancellationToken); + return new _EveryValueChanged(target, propertySelector, equalityComparer, monitorTiming, cancellationToken, cancelImmediately); } sealed class _EveryValueChanged : MoveNextSource, IUniTaskAsyncEnumerator, IPlayerLoopItem @@ -50,13 +52,14 @@ namespace Cysharp.Threading.Tasks.Linq readonly UnityEngine.Object targetAsUnityObject; readonly IEqualityComparer equalityComparer; readonly Func propertySelector; - CancellationToken cancellationToken; + readonly CancellationToken cancellationToken; + readonly CancellationTokenRegistration cancellationTokenRegistration; bool first; TProperty currentValue; bool disposed; - public _EveryValueChanged(TTarget target, Func propertySelector, IEqualityComparer equalityComparer, PlayerLoopTiming monitorTiming, CancellationToken cancellationToken) + public _EveryValueChanged(TTarget target, Func propertySelector, IEqualityComparer equalityComparer, PlayerLoopTiming monitorTiming, CancellationToken cancellationToken, bool cancelImmediately) { this.target = target; this.targetAsUnityObject = target as UnityEngine.Object; @@ -64,6 +67,16 @@ namespace Cysharp.Threading.Tasks.Linq this.equalityComparer = equalityComparer; this.cancellationToken = cancellationToken; this.first = true; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (_EveryValueChanged)state; + source.completionSource.TrySetCanceled(source.cancellationToken); + }, this); + } + TaskTracker.TrackActiveTask(this, 2); PlayerLoopHelper.AddAction(monitorTiming, this); } @@ -139,18 +152,20 @@ namespace Cysharp.Threading.Tasks.Linq readonly Func propertySelector; readonly IEqualityComparer equalityComparer; readonly PlayerLoopTiming monitorTiming; + readonly bool cancelImmediately; - public EveryValueChangedStandardObject(TTarget target, Func propertySelector, IEqualityComparer equalityComparer, PlayerLoopTiming monitorTiming) + public EveryValueChangedStandardObject(TTarget target, Func propertySelector, IEqualityComparer equalityComparer, PlayerLoopTiming monitorTiming, bool cancelImmediately) { this.target = new WeakReference(target, false); this.propertySelector = propertySelector; this.equalityComparer = equalityComparer; this.monitorTiming = monitorTiming; + this.cancelImmediately = cancelImmediately; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { - return new _EveryValueChanged(target, propertySelector, equalityComparer, monitorTiming, cancellationToken); + return new _EveryValueChanged(target, propertySelector, equalityComparer, monitorTiming, cancellationToken, cancelImmediately); } sealed class _EveryValueChanged : MoveNextSource, IUniTaskAsyncEnumerator, IPlayerLoopItem @@ -158,19 +173,30 @@ namespace Cysharp.Threading.Tasks.Linq readonly WeakReference target; readonly IEqualityComparer equalityComparer; readonly Func propertySelector; - CancellationToken cancellationToken; + readonly CancellationToken cancellationToken; + readonly CancellationTokenRegistration cancellationTokenRegistration; bool first; TProperty currentValue; bool disposed; - public _EveryValueChanged(WeakReference target, Func propertySelector, IEqualityComparer equalityComparer, PlayerLoopTiming monitorTiming, CancellationToken cancellationToken) + public _EveryValueChanged(WeakReference target, Func propertySelector, IEqualityComparer equalityComparer, PlayerLoopTiming monitorTiming, CancellationToken cancellationToken, bool cancelImmediately) { this.target = target; this.propertySelector = propertySelector; this.equalityComparer = equalityComparer; this.cancellationToken = cancellationToken; this.first = true; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (_EveryValueChanged)state; + source.completionSource.TrySetCanceled(source.cancellationToken); + }, this); + } + TaskTracker.TrackActiveTask(this, 2); PlayerLoopHelper.AddAction(monitorTiming, this); } @@ -200,6 +226,7 @@ namespace Cysharp.Threading.Tasks.Linq { if (!disposed) { + cancellationTokenRegistration.Dispose(); disposed = true; TaskTracker.RemoveTracking(this); } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/Timer.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/Timer.cs index 53ecfcd..a3bab3c 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/Timer.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/UnityExtensions/Timer.cs @@ -6,32 +6,32 @@ namespace Cysharp.Threading.Tasks.Linq { public static partial class UniTaskAsyncEnumerable { - public static IUniTaskAsyncEnumerable Timer(TimeSpan dueTime, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update, bool ignoreTimeScale = false) + public static IUniTaskAsyncEnumerable Timer(TimeSpan dueTime, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update, bool ignoreTimeScale = false, bool cancelImmediately = false) { - return new Timer(dueTime, null, updateTiming, ignoreTimeScale); + return new Timer(dueTime, null, updateTiming, ignoreTimeScale, cancelImmediately); } - public static IUniTaskAsyncEnumerable Timer(TimeSpan dueTime, TimeSpan period, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update, bool ignoreTimeScale = false) + public static IUniTaskAsyncEnumerable Timer(TimeSpan dueTime, TimeSpan period, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update, bool ignoreTimeScale = false, bool cancelImmediately = false) { - return new Timer(dueTime, period, updateTiming, ignoreTimeScale); + return new Timer(dueTime, period, updateTiming, ignoreTimeScale, cancelImmediately); } - public static IUniTaskAsyncEnumerable Interval(TimeSpan period, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update, bool ignoreTimeScale = false) + public static IUniTaskAsyncEnumerable Interval(TimeSpan period, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update, bool ignoreTimeScale = false, bool cancelImmediately = false) { - return new Timer(period, period, updateTiming, ignoreTimeScale); + return new Timer(period, period, updateTiming, ignoreTimeScale, cancelImmediately); } - public static IUniTaskAsyncEnumerable TimerFrame(int dueTimeFrameCount, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update) + public static IUniTaskAsyncEnumerable TimerFrame(int dueTimeFrameCount, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update, bool cancelImmediately = false) { if (dueTimeFrameCount < 0) { throw new ArgumentOutOfRangeException("Delay does not allow minus delayFrameCount. dueTimeFrameCount:" + dueTimeFrameCount); } - return new TimerFrame(dueTimeFrameCount, null, updateTiming); + return new TimerFrame(dueTimeFrameCount, null, updateTiming, cancelImmediately); } - public static IUniTaskAsyncEnumerable TimerFrame(int dueTimeFrameCount, int periodFrameCount, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update) + public static IUniTaskAsyncEnumerable TimerFrame(int dueTimeFrameCount, int periodFrameCount, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update, bool cancelImmediately = false) { if (dueTimeFrameCount < 0) { @@ -42,16 +42,16 @@ namespace Cysharp.Threading.Tasks.Linq throw new ArgumentOutOfRangeException("Delay does not allow minus periodFrameCount. periodFrameCount:" + dueTimeFrameCount); } - return new TimerFrame(dueTimeFrameCount, periodFrameCount, updateTiming); + return new TimerFrame(dueTimeFrameCount, periodFrameCount, updateTiming, cancelImmediately); } - public static IUniTaskAsyncEnumerable IntervalFrame(int intervalFrameCount, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update) + public static IUniTaskAsyncEnumerable IntervalFrame(int intervalFrameCount, PlayerLoopTiming updateTiming = PlayerLoopTiming.Update, bool cancelImmediately = false) { if (intervalFrameCount < 0) { throw new ArgumentOutOfRangeException("Delay does not allow minus intervalFrameCount. intervalFrameCount:" + intervalFrameCount); } - return new TimerFrame(intervalFrameCount, intervalFrameCount, updateTiming); + return new TimerFrame(intervalFrameCount, intervalFrameCount, updateTiming, cancelImmediately); } } @@ -61,18 +61,20 @@ namespace Cysharp.Threading.Tasks.Linq readonly TimeSpan dueTime; readonly TimeSpan? period; readonly bool ignoreTimeScale; + readonly bool cancelImmediately; - public Timer(TimeSpan dueTime, TimeSpan? period, PlayerLoopTiming updateTiming, bool ignoreTimeScale) + public Timer(TimeSpan dueTime, TimeSpan? period, PlayerLoopTiming updateTiming, bool ignoreTimeScale, bool cancelImmediately) { this.updateTiming = updateTiming; this.dueTime = dueTime; this.period = period; this.ignoreTimeScale = ignoreTimeScale; + this.cancelImmediately = cancelImmediately; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { - return new _Timer(dueTime, period, updateTiming, ignoreTimeScale, cancellationToken); + return new _Timer(dueTime, period, updateTiming, ignoreTimeScale, cancellationToken, cancelImmediately); } class _Timer : MoveNextSource, IUniTaskAsyncEnumerator, IPlayerLoopItem @@ -81,7 +83,8 @@ namespace Cysharp.Threading.Tasks.Linq readonly float? period; readonly PlayerLoopTiming updateTiming; readonly bool ignoreTimeScale; - CancellationToken cancellationToken; + readonly CancellationToken cancellationToken; + readonly CancellationTokenRegistration cancellationTokenRegistration; int initialFrame; float elapsed; @@ -89,7 +92,7 @@ namespace Cysharp.Threading.Tasks.Linq bool completed; bool disposed; - public _Timer(TimeSpan dueTime, TimeSpan? period, PlayerLoopTiming updateTiming, bool ignoreTimeScale, CancellationToken cancellationToken) + public _Timer(TimeSpan dueTime, TimeSpan? period, PlayerLoopTiming updateTiming, bool ignoreTimeScale, CancellationToken cancellationToken, bool cancelImmediately) { this.dueTime = (float)dueTime.TotalSeconds; this.period = (period == null) ? null : (float?)period.Value.TotalSeconds; @@ -105,6 +108,17 @@ namespace Cysharp.Threading.Tasks.Linq this.updateTiming = updateTiming; this.ignoreTimeScale = ignoreTimeScale; this.cancellationToken = cancellationToken; + + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (_Timer)state; + source.completionSource.TrySetCanceled(source.cancellationToken); + }, this); + } + TaskTracker.TrackActiveTask(this, 2); PlayerLoopHelper.AddAction(updateTiming, this); } @@ -127,6 +141,7 @@ namespace Cysharp.Threading.Tasks.Linq { if (!disposed) { + cancellationTokenRegistration.Dispose(); disposed = true; TaskTracker.RemoveTracking(this); } @@ -187,24 +202,27 @@ namespace Cysharp.Threading.Tasks.Linq readonly PlayerLoopTiming updateTiming; readonly int dueTimeFrameCount; readonly int? periodFrameCount; + readonly bool cancelImmediately; - public TimerFrame(int dueTimeFrameCount, int? periodFrameCount, PlayerLoopTiming updateTiming) + public TimerFrame(int dueTimeFrameCount, int? periodFrameCount, PlayerLoopTiming updateTiming, bool cancelImmediately) { this.updateTiming = updateTiming; this.dueTimeFrameCount = dueTimeFrameCount; this.periodFrameCount = periodFrameCount; + this.cancelImmediately = cancelImmediately; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { - return new _TimerFrame(dueTimeFrameCount, periodFrameCount, updateTiming, cancellationToken); + return new _TimerFrame(dueTimeFrameCount, periodFrameCount, updateTiming, cancellationToken, cancelImmediately); } class _TimerFrame : MoveNextSource, IUniTaskAsyncEnumerator, IPlayerLoopItem { readonly int dueTimeFrameCount; readonly int? periodFrameCount; - CancellationToken cancellationToken; + readonly CancellationToken cancellationToken; + readonly CancellationTokenRegistration cancellationTokenRegistration; int initialFrame; int currentFrame; @@ -212,7 +230,7 @@ namespace Cysharp.Threading.Tasks.Linq bool completed; bool disposed; - public _TimerFrame(int dueTimeFrameCount, int? periodFrameCount, PlayerLoopTiming updateTiming, CancellationToken cancellationToken) + public _TimerFrame(int dueTimeFrameCount, int? periodFrameCount, PlayerLoopTiming updateTiming, CancellationToken cancellationToken, bool cancelImmediately) { if (dueTimeFrameCount <= 0) dueTimeFrameCount = 0; if (periodFrameCount != null) @@ -225,6 +243,15 @@ namespace Cysharp.Threading.Tasks.Linq this.dueTimeFrameCount = dueTimeFrameCount; this.periodFrameCount = periodFrameCount; this.cancellationToken = cancellationToken; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (_TimerFrame)state; + source.completionSource.TrySetCanceled(source.cancellationToken); + }, this); + } TaskTracker.TrackActiveTask(this, 2); PlayerLoopHelper.AddAction(updateTiming, this); @@ -249,6 +276,7 @@ namespace Cysharp.Threading.Tasks.Linq { if (!disposed) { + cancellationTokenRegistration.Dispose(); disposed = true; TaskTracker.RemoveTracking(this); } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Delay.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Delay.cs index fadd63c..7f02a1a 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Delay.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Delay.cs @@ -33,14 +33,14 @@ namespace Cysharp.Threading.Tasks return new YieldAwaitable(timing); } - public static UniTask Yield(CancellationToken cancellationToken) + public static UniTask Yield(CancellationToken cancellationToken, bool cancelImmediately = false) { - return new UniTask(YieldPromise.Create(PlayerLoopTiming.Update, cancellationToken, out var token), token); + return new UniTask(YieldPromise.Create(PlayerLoopTiming.Update, cancellationToken, cancelImmediately, out var token), token); } - public static UniTask Yield(PlayerLoopTiming timing, CancellationToken cancellationToken) + public static UniTask Yield(PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately = false) { - return new UniTask(YieldPromise.Create(timing, cancellationToken, out var token), token); + return new UniTask(YieldPromise.Create(timing, cancellationToken, cancelImmediately, out var token), token); } /// @@ -48,7 +48,7 @@ namespace Cysharp.Threading.Tasks /// public static UniTask NextFrame() { - return new UniTask(NextFramePromise.Create(PlayerLoopTiming.Update, CancellationToken.None, out var token), token); + return new UniTask(NextFramePromise.Create(PlayerLoopTiming.Update, CancellationToken.None, false, out var token), token); } /// @@ -56,23 +56,23 @@ namespace Cysharp.Threading.Tasks /// public static UniTask NextFrame(PlayerLoopTiming timing) { - return new UniTask(NextFramePromise.Create(timing, CancellationToken.None, out var token), token); + return new UniTask(NextFramePromise.Create(timing, CancellationToken.None, false, out var token), token); } /// /// Similar as UniTask.Yield but guaranteed run on next frame. /// - public static UniTask NextFrame(CancellationToken cancellationToken) + public static UniTask NextFrame(CancellationToken cancellationToken, bool cancelImmediately = false) { - return new UniTask(NextFramePromise.Create(PlayerLoopTiming.Update, cancellationToken, out var token), token); + return new UniTask(NextFramePromise.Create(PlayerLoopTiming.Update, cancellationToken, cancelImmediately, out var token), token); } /// /// Similar as UniTask.Yield but guaranteed run on next frame. /// - public static UniTask NextFrame(PlayerLoopTiming timing, CancellationToken cancellationToken) + public static UniTask NextFrame(PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately = false) { - return new UniTask(NextFramePromise.Create(timing, cancellationToken, out var token), token); + return new UniTask(NextFramePromise.Create(timing, cancellationToken, cancelImmediately, out var token), token); } #if UNITY_2023_1_OR_NEWER @@ -88,15 +88,21 @@ namespace Cysharp.Threading.Tasks } [Obsolete("Use WaitForEndOfFrame(MonoBehaviour) instead or UniTask.Yield(PlayerLoopTiming.LastPostLateUpdate). Equivalent for coroutine's WaitForEndOfFrame requires MonoBehaviour(runner of Coroutine).")] - public static UniTask WaitForEndOfFrame(CancellationToken cancellationToken) + public static UniTask WaitForEndOfFrame(CancellationToken cancellationToken, bool cancelImmediately = false) { - return UniTask.Yield(PlayerLoopTiming.LastPostLateUpdate, cancellationToken); + return UniTask.Yield(PlayerLoopTiming.LastPostLateUpdate, cancellationToken, cancelImmediately); } #endif - public static UniTask WaitForEndOfFrame(MonoBehaviour coroutineRunner, CancellationToken cancellationToken = default) + public static UniTask WaitForEndOfFrame(MonoBehaviour coroutineRunner) { - var source = WaitForEndOfFramePromise.Create(coroutineRunner, cancellationToken, out var token); + var source = WaitForEndOfFramePromise.Create(coroutineRunner, CancellationToken.None, false, out var token); + return new UniTask(source, token); + } + + public static UniTask WaitForEndOfFrame(MonoBehaviour coroutineRunner, CancellationToken cancellationToken, bool cancelImmediately = false) + { + var source = WaitForEndOfFramePromise.Create(coroutineRunner, cancellationToken, cancelImmediately, out var token); return new UniTask(source, token); } @@ -113,50 +119,50 @@ namespace Cysharp.Threading.Tasks /// /// Same as UniTask.Yield(PlayerLoopTiming.LastFixedUpdate, cancellationToken). /// - public static UniTask WaitForFixedUpdate(CancellationToken cancellationToken) + public static UniTask WaitForFixedUpdate(CancellationToken cancellationToken, bool cancelImmediately = false) { - return UniTask.Yield(PlayerLoopTiming.LastFixedUpdate, cancellationToken); + return UniTask.Yield(PlayerLoopTiming.LastFixedUpdate, cancellationToken, cancelImmediately); } - public static UniTask WaitForSeconds(float duration, bool ignoreTimeScale = false, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WaitForSeconds(float duration, bool ignoreTimeScale = false, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { - return Delay(Mathf.RoundToInt(1000 * duration), ignoreTimeScale, delayTiming, cancellationToken); + return Delay(Mathf.RoundToInt(1000 * duration), ignoreTimeScale, delayTiming, cancellationToken, cancelImmediately); } - public static UniTask WaitForSeconds(int duration, bool ignoreTimeScale = false, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WaitForSeconds(int duration, bool ignoreTimeScale = false, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { - return Delay(1000 * duration, ignoreTimeScale, delayTiming, cancellationToken); + return Delay(1000 * duration, ignoreTimeScale, delayTiming, cancellationToken, cancelImmediately); } - public static UniTask DelayFrame(int delayFrameCount, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask DelayFrame(int delayFrameCount, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { if (delayFrameCount < 0) { throw new ArgumentOutOfRangeException("Delay does not allow minus delayFrameCount. delayFrameCount:" + delayFrameCount); } - return new UniTask(DelayFramePromise.Create(delayFrameCount, delayTiming, cancellationToken, out var token), token); + return new UniTask(DelayFramePromise.Create(delayFrameCount, delayTiming, cancellationToken, cancelImmediately, out var token), token); } - public static UniTask Delay(int millisecondsDelay, bool ignoreTimeScale = false, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask Delay(int millisecondsDelay, bool ignoreTimeScale = false, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { var delayTimeSpan = TimeSpan.FromMilliseconds(millisecondsDelay); - return Delay(delayTimeSpan, ignoreTimeScale, delayTiming, cancellationToken); + return Delay(delayTimeSpan, ignoreTimeScale, delayTiming, cancellationToken, cancelImmediately); } - public static UniTask Delay(TimeSpan delayTimeSpan, bool ignoreTimeScale = false, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask Delay(TimeSpan delayTimeSpan, bool ignoreTimeScale = false, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { var delayType = ignoreTimeScale ? DelayType.UnscaledDeltaTime : DelayType.DeltaTime; - return Delay(delayTimeSpan, delayType, delayTiming, cancellationToken); + return Delay(delayTimeSpan, delayType, delayTiming, cancellationToken, cancelImmediately); } - public static UniTask Delay(int millisecondsDelay, DelayType delayType, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask Delay(int millisecondsDelay, DelayType delayType, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { var delayTimeSpan = TimeSpan.FromMilliseconds(millisecondsDelay); - return Delay(delayTimeSpan, delayType, delayTiming, cancellationToken); + return Delay(delayTimeSpan, delayType, delayTiming, cancellationToken, cancelImmediately); } - public static UniTask Delay(TimeSpan delayTimeSpan, DelayType delayType, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask Delay(TimeSpan delayTimeSpan, DelayType delayType, PlayerLoopTiming delayTiming = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { if (delayTimeSpan < TimeSpan.Zero) { @@ -175,16 +181,16 @@ namespace Cysharp.Threading.Tasks { case DelayType.UnscaledDeltaTime: { - return new UniTask(DelayIgnoreTimeScalePromise.Create(delayTimeSpan, delayTiming, cancellationToken, out var token), token); + return new UniTask(DelayIgnoreTimeScalePromise.Create(delayTimeSpan, delayTiming, cancellationToken, cancelImmediately, out var token), token); } case DelayType.Realtime: { - return new UniTask(DelayRealtimePromise.Create(delayTimeSpan, delayTiming, cancellationToken, out var token), token); + return new UniTask(DelayRealtimePromise.Create(delayTimeSpan, delayTiming, cancellationToken, cancelImmediately, out var token), token); } case DelayType.DeltaTime: default: { - return new UniTask(DelayPromise.Create(delayTimeSpan, delayTiming, cancellationToken, out var token), token); + return new UniTask(DelayPromise.Create(delayTimeSpan, delayTiming, cancellationToken, cancelImmediately, out var token), token); } } } @@ -201,13 +207,14 @@ namespace Cysharp.Threading.Tasks } CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; YieldPromise() { } - public static IUniTaskSource Create(PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -219,8 +226,16 @@ namespace Cysharp.Threading.Tasks result = new YieldPromise(); } - result.cancellationToken = cancellationToken; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (YieldPromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -274,6 +289,7 @@ namespace Cysharp.Threading.Tasks TaskTracker.RemoveTracking(this); core.Reset(); cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -290,14 +306,15 @@ namespace Cysharp.Threading.Tasks } int frameCount; - CancellationToken cancellationToken; UniTaskCompletionSourceCore core; + CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; NextFramePromise() { } - public static IUniTaskSource Create(PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -312,6 +329,15 @@ namespace Cysharp.Threading.Tasks result.frameCount = PlayerLoopHelper.IsMainThread ? Time.frameCount : -1; result.cancellationToken = cancellationToken; + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (NextFramePromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } + TaskTracker.TrackActiveTask(result, 3); PlayerLoopHelper.AddAction(timing, result); @@ -369,6 +395,7 @@ namespace Cysharp.Threading.Tasks TaskTracker.RemoveTracking(this); core.Reset(); cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -384,14 +411,15 @@ namespace Cysharp.Threading.Tasks TaskPool.RegisterSizeGetter(typeof(WaitForEndOfFramePromise), () => pool.Size); } - CancellationToken cancellationToken; UniTaskCompletionSourceCore core; + CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; WaitForEndOfFramePromise() { } - public static IUniTaskSource Create(MonoBehaviour coroutineRunner, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(MonoBehaviour coroutineRunner, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -405,6 +433,15 @@ namespace Cysharp.Threading.Tasks result.cancellationToken = cancellationToken; + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (WaitForEndOfFramePromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } + TaskTracker.TrackActiveTask(result, 3); coroutineRunner.StartCoroutine(result); @@ -446,6 +483,7 @@ namespace Cysharp.Threading.Tasks core.Reset(); Reset(); // Reset Enumerator cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } @@ -494,6 +532,7 @@ namespace Cysharp.Threading.Tasks int initialFrame; int delayFrameCount; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; int currentFrameCount; UniTaskCompletionSourceCore core; @@ -502,7 +541,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(int delayFrameCount, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(int delayFrameCount, PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -518,6 +557,15 @@ namespace Cysharp.Threading.Tasks result.cancellationToken = cancellationToken; result.initialFrame = PlayerLoopHelper.IsMainThread ? Time.frameCount : -1; + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (DelayFramePromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } + TaskTracker.TrackActiveTask(result, 3); PlayerLoopHelper.AddAction(timing, result); @@ -604,6 +652,7 @@ namespace Cysharp.Threading.Tasks currentFrameCount = default; delayFrameCount = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -623,6 +672,7 @@ namespace Cysharp.Threading.Tasks float delayTimeSpan; float elapsed; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; @@ -630,7 +680,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(TimeSpan delayTimeSpan, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(TimeSpan delayTimeSpan, PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -647,6 +697,15 @@ namespace Cysharp.Threading.Tasks result.cancellationToken = cancellationToken; result.initialFrame = PlayerLoopHelper.IsMainThread ? Time.frameCount : -1; + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (DelayPromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } + TaskTracker.TrackActiveTask(result, 3); PlayerLoopHelper.AddAction(timing, result); @@ -715,6 +774,7 @@ namespace Cysharp.Threading.Tasks delayTimeSpan = default; elapsed = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -734,6 +794,7 @@ namespace Cysharp.Threading.Tasks float elapsed; int initialFrame; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; @@ -741,7 +802,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(TimeSpan delayFrameTimeSpan, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(TimeSpan delayFrameTimeSpan, PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -758,6 +819,15 @@ namespace Cysharp.Threading.Tasks result.initialFrame = PlayerLoopHelper.IsMainThread ? Time.frameCount : -1; result.cancellationToken = cancellationToken; + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (DelayIgnoreTimeScalePromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } + TaskTracker.TrackActiveTask(result, 3); PlayerLoopHelper.AddAction(timing, result); @@ -826,6 +896,7 @@ namespace Cysharp.Threading.Tasks delayFrameTimeSpan = default; elapsed = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -844,6 +915,7 @@ namespace Cysharp.Threading.Tasks long delayTimeSpanTicks; ValueStopwatch stopwatch; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; @@ -851,7 +923,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(TimeSpan delayTimeSpan, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(TimeSpan delayTimeSpan, PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -867,6 +939,15 @@ namespace Cysharp.Threading.Tasks result.delayTimeSpanTicks = delayTimeSpan.Ticks; result.cancellationToken = cancellationToken; + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (DelayRealtimePromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } + TaskTracker.TrackActiveTask(result, 3); PlayerLoopHelper.AddAction(timing, result); @@ -931,6 +1012,7 @@ namespace Cysharp.Threading.Tasks core.Reset(); stopwatch = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.WaitUntil.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.WaitUntil.cs index 0a09fe0..b28a529 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.WaitUntil.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.WaitUntil.cs @@ -9,30 +9,30 @@ namespace Cysharp.Threading.Tasks { public partial struct UniTask { - public static UniTask WaitUntil(Func predicate, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WaitUntil(Func predicate, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { - return new UniTask(WaitUntilPromise.Create(predicate, timing, cancellationToken, out var token), token); + return new UniTask(WaitUntilPromise.Create(predicate, timing, cancellationToken, cancelImmediately, out var token), token); } - public static UniTask WaitWhile(Func predicate, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WaitWhile(Func predicate, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { - return new UniTask(WaitWhilePromise.Create(predicate, timing, cancellationToken, out var token), token); + return new UniTask(WaitWhilePromise.Create(predicate, timing, cancellationToken, cancelImmediately, out var token), token); } - public static UniTask WaitUntilCanceled(CancellationToken cancellationToken, PlayerLoopTiming timing = PlayerLoopTiming.Update) + public static UniTask WaitUntilCanceled(CancellationToken cancellationToken, PlayerLoopTiming timing = PlayerLoopTiming.Update, bool completeImmediately = false) { - return new UniTask(WaitUntilCanceledPromise.Create(cancellationToken, timing, out var token), token); + return new UniTask(WaitUntilCanceledPromise.Create(cancellationToken, timing, completeImmediately, out var token), token); } - public static UniTask WaitUntilValueChanged(T target, Func monitorFunction, PlayerLoopTiming monitorTiming = PlayerLoopTiming.Update, IEqualityComparer equalityComparer = null, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WaitUntilValueChanged(T target, Func monitorFunction, PlayerLoopTiming monitorTiming = PlayerLoopTiming.Update, IEqualityComparer equalityComparer = null, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) where T : class { var unityObject = target as UnityEngine.Object; var isUnityObject = target is UnityEngine.Object; // don't use (unityObject == null) return new UniTask(isUnityObject - ? WaitUntilValueChangedUnityObjectPromise.Create(target, monitorFunction, equalityComparer, monitorTiming, cancellationToken, out var token) - : WaitUntilValueChangedStandardObjectPromise.Create(target, monitorFunction, equalityComparer, monitorTiming, cancellationToken, out token), token); + ? WaitUntilValueChangedUnityObjectPromise.Create(target, monitorFunction, equalityComparer, monitorTiming, cancellationToken, cancelImmediately, out var token) + : WaitUntilValueChangedStandardObjectPromise.Create(target, monitorFunction, equalityComparer, monitorTiming, cancellationToken, cancelImmediately, out token), token); } sealed class WaitUntilPromise : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode @@ -48,6 +48,7 @@ namespace Cysharp.Threading.Tasks Func predicate; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; @@ -55,7 +56,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(Func predicate, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(Func predicate, PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -70,6 +71,15 @@ namespace Cysharp.Threading.Tasks result.predicate = predicate; result.cancellationToken = cancellationToken; + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (WaitUntilPromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } + TaskTracker.TrackActiveTask(result, 3); PlayerLoopHelper.AddAction(timing, result); @@ -136,6 +146,7 @@ namespace Cysharp.Threading.Tasks core.Reset(); predicate = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -153,6 +164,7 @@ namespace Cysharp.Threading.Tasks Func predicate; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; @@ -160,7 +172,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(Func predicate, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(Func predicate, PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -174,6 +186,15 @@ namespace Cysharp.Threading.Tasks result.predicate = predicate; result.cancellationToken = cancellationToken; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (WaitWhilePromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -241,6 +262,7 @@ namespace Cysharp.Threading.Tasks core.Reset(); predicate = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -257,6 +279,7 @@ namespace Cysharp.Threading.Tasks } CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; @@ -264,7 +287,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(CancellationToken cancellationToken, PlayerLoopTiming timing, out short token) + public static IUniTaskSource Create(CancellationToken cancellationToken, PlayerLoopTiming timing, bool completeImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -278,6 +301,15 @@ namespace Cysharp.Threading.Tasks result.cancellationToken = cancellationToken; + if (completeImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (WaitUntilCanceledPromise)state; + promise.core.TrySetResult(null); + }, result); + } + TaskTracker.TrackActiveTask(result, 3); PlayerLoopHelper.AddAction(timing, result); @@ -329,6 +361,7 @@ namespace Cysharp.Threading.Tasks TaskTracker.RemoveTracking(this); core.Reset(); cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -351,6 +384,7 @@ namespace Cysharp.Threading.Tasks Func monitorFunction; IEqualityComparer equalityComparer; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; @@ -358,7 +392,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(T target, Func monitorFunction, IEqualityComparer equalityComparer, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(T target, Func monitorFunction, IEqualityComparer equalityComparer, PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -376,6 +410,15 @@ namespace Cysharp.Threading.Tasks result.currentValue = monitorFunction(target); result.equalityComparer = equalityComparer ?? UnityEqualityComparer.GetDefault(); result.cancellationToken = cancellationToken; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (WaitUntilValueChangedUnityObjectPromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -453,6 +496,7 @@ namespace Cysharp.Threading.Tasks monitorFunction = default; equalityComparer = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -474,6 +518,7 @@ namespace Cysharp.Threading.Tasks Func monitorFunction; IEqualityComparer equalityComparer; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; @@ -481,7 +526,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(T target, Func monitorFunction, IEqualityComparer equalityComparer, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(T target, Func monitorFunction, IEqualityComparer equalityComparer, PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -498,6 +543,15 @@ namespace Cysharp.Threading.Tasks result.currentValue = monitorFunction(target); result.equalityComparer = equalityComparer ?? UnityEqualityComparer.GetDefault(); result.cancellationToken = cancellationToken; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (WaitUntilValueChangedStandardObjectPromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -575,6 +629,7 @@ namespace Cysharp.Threading.Tasks monitorFunction = default; equalityComparer = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.AssetBundleRequestAllAssets.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.AssetBundleRequestAllAssets.cs index 1a1e011..95f00b2 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.AssetBundleRequestAllAssets.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.AssetBundleRequestAllAssets.cs @@ -24,12 +24,17 @@ namespace Cysharp.Threading.Tasks return AwaitForAllAssets(asyncOperation, null, PlayerLoopTiming.Update, cancellationToken: cancellationToken); } - public static UniTask AwaitForAllAssets(this AssetBundleRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask AwaitForAllAssets(this AssetBundleRequest asyncOperation, CancellationToken cancellationToken, bool cancelImmediately) + { + return AwaitForAllAssets(asyncOperation, null, PlayerLoopTiming.Update, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + } + + public static UniTask AwaitForAllAssets(this AssetBundleRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); if (asyncOperation.isDone) return UniTask.FromResult(asyncOperation.allAssets); - return new UniTask(AssetBundleRequestAllAssetsConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); + return new UniTask(AssetBundleRequestAllAssetsConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, cancelImmediately, out var token), token); } public struct AssetBundleRequestAllAssetsAwaiter : ICriticalNotifyCompletion @@ -95,15 +100,15 @@ namespace Cysharp.Threading.Tasks AssetBundleRequest asyncOperation; IProgress progress; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; AssetBundleRequestAllAssetsConfiguredSource() { - } - public static IUniTaskSource Create(AssetBundleRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(AssetBundleRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -118,6 +123,15 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.progress = progress; result.cancellationToken = cancellationToken; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (AssetBundleRequestAllAssetsConfiguredSource)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -188,6 +202,7 @@ namespace Cysharp.Threading.Tasks asyncOperation = default; progress = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.AsyncGPUReadback.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.AsyncGPUReadback.cs index 5805dbb..0faf0be 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.AsyncGPUReadback.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.AsyncGPUReadback.cs @@ -20,10 +20,15 @@ namespace Cysharp.Threading.Tasks return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this AsyncGPUReadbackRequest asyncOperation, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WithCancellation(this AsyncGPUReadbackRequest asyncOperation, CancellationToken cancellationToken, bool cancelImmediately) + { + return ToUniTask(asyncOperation, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + } + + public static UniTask ToUniTask(this AsyncGPUReadbackRequest asyncOperation, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { if (asyncOperation.done) return UniTask.FromResult(asyncOperation); - return new UniTask(AsyncGPUReadbackRequestAwaiterConfiguredSource.Create(asyncOperation, timing, cancellationToken, out var token), token); + return new UniTask(AsyncGPUReadbackRequestAwaiterConfiguredSource.Create(asyncOperation, timing, cancellationToken, cancelImmediately, out var token), token); } sealed class AsyncGPUReadbackRequestAwaiterConfiguredSource : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode @@ -39,15 +44,15 @@ namespace Cysharp.Threading.Tasks AsyncGPUReadbackRequest asyncOperation; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; AsyncGPUReadbackRequestAwaiterConfiguredSource() { - } - public static IUniTaskSource Create(AsyncGPUReadbackRequest asyncOperation, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(AsyncGPUReadbackRequest asyncOperation, PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -61,6 +66,15 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.cancellationToken = cancellationToken; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (AsyncGPUReadbackRequestAwaiterConfiguredSource)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -131,6 +145,7 @@ namespace Cysharp.Threading.Tasks core.Reset(); asyncOperation = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.cs index 00afe66..db22575 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.cs @@ -29,13 +29,18 @@ namespace Cysharp.Threading.Tasks { return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } + + public static UniTask WithCancellation(this AsyncOperation asyncOperation, CancellationToken cancellationToken, bool cancelImmediately) + { + return ToUniTask(asyncOperation, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + } - public static UniTask ToUniTask(this AsyncOperation asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask ToUniTask(this AsyncOperation asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); if (asyncOperation.isDone) return UniTask.CompletedTask; - return new UniTask(AsyncOperationConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); + return new UniTask(AsyncOperationConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, cancelImmediately, out var token), token); } public struct AsyncOperationAwaiter : ICriticalNotifyCompletion @@ -92,15 +97,15 @@ namespace Cysharp.Threading.Tasks AsyncOperation asyncOperation; IProgress progress; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; AsyncOperationConfiguredSource() { - } - public static IUniTaskSource Create(AsyncOperation asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(AsyncOperation asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -111,6 +116,16 @@ namespace Cysharp.Threading.Tasks { result = new AsyncOperationConfiguredSource(); } + + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (AsyncOperationConfiguredSource)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } result.asyncOperation = asyncOperation; result.progress = progress; @@ -136,7 +151,6 @@ namespace Cysharp.Threading.Tasks } } - public UniTaskStatus GetStatus(short token) { return core.GetStatus(token); @@ -181,6 +195,7 @@ namespace Cysharp.Threading.Tasks asyncOperation = default; progress = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -739,7 +754,12 @@ namespace Cysharp.Threading.Tasks return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this UnityWebRequestAsyncOperation asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WithCancellation(this UnityWebRequestAsyncOperation asyncOperation, CancellationToken cancellationToken, bool cancelImmediately) + { + return ToUniTask(asyncOperation, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + } + + public static UniTask ToUniTask(this UnityWebRequestAsyncOperation asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); @@ -751,7 +771,7 @@ namespace Cysharp.Threading.Tasks } return UniTask.FromResult(asyncOperation.webRequest); } - return new UniTask(UnityWebRequestAsyncOperationConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); + return new UniTask(UnityWebRequestAsyncOperationConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, cancelImmediately, out var token), token); } public struct UnityWebRequestAsyncOperationAwaiter : ICriticalNotifyCompletion @@ -820,15 +840,15 @@ namespace Cysharp.Threading.Tasks UnityWebRequestAsyncOperation asyncOperation; IProgress progress; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; UnityWebRequestAsyncOperationConfiguredSource() { - } - public static IUniTaskSource Create(UnityWebRequestAsyncOperation asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(UnityWebRequestAsyncOperation asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -843,6 +863,15 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.progress = progress; result.cancellationToken = cancellationToken; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (UnityWebRequestAsyncOperationConfiguredSource)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -925,6 +954,7 @@ namespace Cysharp.Threading.Tasks asyncOperation = default; progress = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } }