From 94be2e748b650baefa506f1def9a880e2a918c07 Mon Sep 17 00:00:00 2001 From: hadashiA Date: Fri, 27 Oct 2023 15:06:12 +0900 Subject: [PATCH] Add cancelImmediately flag for addressable extensions --- .../AddressablesAsyncExtensions.cs | 44 ++++++++++++++++--- 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/External/Addressables/AddressablesAsyncExtensions.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/External/Addressables/AddressablesAsyncExtensions.cs index f321bdb..a0ca8a1 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/External/Addressables/AddressablesAsyncExtensions.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/External/Addressables/AddressablesAsyncExtensions.cs @@ -25,7 +25,12 @@ namespace Cysharp.Threading.Tasks return ToUniTask(handle, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this AsyncOperationHandle handle, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WithCancellation(this AsyncOperationHandle handle, CancellationToken cancellationToken, bool cancelImmediately) + { + return ToUniTask(handle, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + } + + public static UniTask ToUniTask(this AsyncOperationHandle handle, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); @@ -44,7 +49,7 @@ namespace Cysharp.Threading.Tasks return UniTask.CompletedTask; } - return new UniTask(AsyncOperationHandleConfiguredSource.Create(handle, timing, progress, cancellationToken, out var token), token); + return new UniTask(AsyncOperationHandleConfiguredSource.Create(handle, timing, progress, cancellationToken, cancelImmediately, out var token), token); } public struct AsyncOperationHandleAwaiter : ICriticalNotifyCompletion @@ -106,6 +111,7 @@ namespace Cysharp.Threading.Tasks readonly Action continuationAction; AsyncOperationHandle handle; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; IProgress progress; bool completed; @@ -116,7 +122,7 @@ namespace Cysharp.Threading.Tasks continuationAction = Continuation; } - public static IUniTaskSource Create(AsyncOperationHandle handle, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(AsyncOperationHandle handle, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -132,6 +138,15 @@ namespace Cysharp.Threading.Tasks result.progress = progress; result.cancellationToken = cancellationToken; result.completed = false; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (AsyncOperationHandleConfiguredSource)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -219,6 +234,7 @@ namespace Cysharp.Threading.Tasks handle = default; progress = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -237,7 +253,12 @@ namespace Cysharp.Threading.Tasks return ToUniTask(handle, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this AsyncOperationHandle handle, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask WithCancellation(this AsyncOperationHandle handle, CancellationToken cancellationToken, bool cancelImmediately) + { + return ToUniTask(handle, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + } + + public static UniTask ToUniTask(this AsyncOperationHandle handle, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); @@ -255,7 +276,7 @@ namespace Cysharp.Threading.Tasks return UniTask.FromResult(handle.Result); } - return new UniTask(AsyncOperationHandleConfiguredSource.Create(handle, timing, progress, cancellationToken, out var token), token); + return new UniTask(AsyncOperationHandleConfiguredSource.Create(handle, timing, progress, cancellationToken, cancelImmediately, out var token), token); } sealed class AsyncOperationHandleConfiguredSource : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode> @@ -272,6 +293,7 @@ namespace Cysharp.Threading.Tasks readonly Action> continuationAction; AsyncOperationHandle handle; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; IProgress progress; bool completed; @@ -282,7 +304,7 @@ namespace Cysharp.Threading.Tasks continuationAction = Continuation; } - public static IUniTaskSource Create(AsyncOperationHandle handle, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(AsyncOperationHandle handle, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, bool cancelImmediately, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -298,6 +320,15 @@ namespace Cysharp.Threading.Tasks result.cancellationToken = cancellationToken; result.completed = false; result.progress = progress; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (AsyncOperationHandleConfiguredSource)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } TaskTracker.TrackActiveTask(result, 3); @@ -390,6 +421,7 @@ namespace Cysharp.Threading.Tasks handle = default; progress = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } }