diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.cs index db22575..235faa0 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.cs @@ -215,12 +215,17 @@ namespace Cysharp.Threading.Tasks return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this ResourceRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WithCancellation(this ResourceRequest asyncOperation, CancellationToken cancellationToken, bool cancelImmediately) + { + return ToUniTask(asyncOperation, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + } + + public static UniTask ToUniTask(this ResourceRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); if (asyncOperation.isDone) return UniTask.FromResult(asyncOperation.asset); - return new UniTask(ResourceRequestConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); + return new UniTask(ResourceRequestConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, cancelImmediately, out var token), token); } public struct ResourceRequestAwaiter : ICriticalNotifyCompletion @@ -281,15 +286,15 @@ namespace Cysharp.Threading.Tasks ResourceRequest asyncOperation; IProgress progress; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; ResourceRequestConfiguredSource() { - } - public static IUniTaskSource Create(ResourceRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(ResourceRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -304,6 +309,15 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.progress = progress; result.cancellationToken = cancellationToken; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (ResourceRequestConfiguredSource)state; + source.core.TrySetCanceled(source.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -374,6 +388,7 @@ namespace Cysharp.Threading.Tasks asyncOperation = default; progress = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -394,12 +409,17 @@ namespace Cysharp.Threading.Tasks return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this AssetBundleRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WithCancellation(this AssetBundleRequest asyncOperation, CancellationToken cancellationToken, bool cancelImmediately) + { + return ToUniTask(asyncOperation, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + } + + public static UniTask ToUniTask(this AssetBundleRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); if (asyncOperation.isDone) return UniTask.FromResult(asyncOperation.asset); - return new UniTask(AssetBundleRequestConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); + return new UniTask(AssetBundleRequestConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, cancelImmediately, out var token), token); } public struct AssetBundleRequestAwaiter : ICriticalNotifyCompletion @@ -460,15 +480,15 @@ namespace Cysharp.Threading.Tasks AssetBundleRequest asyncOperation; IProgress progress; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; AssetBundleRequestConfiguredSource() { - } - public static IUniTaskSource Create(AssetBundleRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(AssetBundleRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -483,6 +503,15 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.progress = progress; result.cancellationToken = cancellationToken; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (AssetBundleRequestConfiguredSource)state; + source.core.TrySetCanceled(source.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -553,6 +582,7 @@ namespace Cysharp.Threading.Tasks asyncOperation = default; progress = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -574,12 +604,17 @@ namespace Cysharp.Threading.Tasks return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this AssetBundleCreateRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WithCancellation(this AssetBundleCreateRequest asyncOperation, CancellationToken cancellationToken, bool cancelImmediately) + { + return ToUniTask(asyncOperation, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + } + + public static UniTask ToUniTask(this AssetBundleCreateRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); if (asyncOperation.isDone) return UniTask.FromResult(asyncOperation.assetBundle); - return new UniTask(AssetBundleCreateRequestConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); + return new UniTask(AssetBundleCreateRequestConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, cancelImmediately, out var token), token); } public struct AssetBundleCreateRequestAwaiter : ICriticalNotifyCompletion @@ -640,15 +675,15 @@ namespace Cysharp.Threading.Tasks AssetBundleCreateRequest asyncOperation; IProgress progress; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; AssetBundleCreateRequestConfiguredSource() { - } - public static IUniTaskSource Create(AssetBundleCreateRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(AssetBundleCreateRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -663,6 +698,15 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.progress = progress; result.cancellationToken = cancellationToken; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (AssetBundleCreateRequestConfiguredSource)state; + source.core.TrySetCanceled(source.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -733,6 +777,7 @@ namespace Cysharp.Threading.Tasks asyncOperation = default; progress = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } }