diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.AssetBundleRequestAllAssets.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.AssetBundleRequestAllAssets.cs index 95f00b2..f6fccc0 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.AssetBundleRequestAllAssets.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.AssetBundleRequestAllAssets.cs @@ -24,17 +24,24 @@ namespace Cysharp.Threading.Tasks return AwaitForAllAssets(asyncOperation, null, PlayerLoopTiming.Update, cancellationToken: cancellationToken); } - public static UniTask AwaitForAllAssets(this AssetBundleRequest asyncOperation, CancellationToken cancellationToken, bool cancelImmediately) + public static UniTask AwaitForAllAssets(this AssetBundleRequest asyncOperation, bool handleImmediately, CancellationToken cancellationToken) { - return AwaitForAllAssets(asyncOperation, null, PlayerLoopTiming.Update, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + if (handleImmediately) + { + Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); + if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); + if (asyncOperation.isDone) return UniTask.FromResult(asyncOperation.allAssets); + return new UniTask(AssetBundleRequestAllAssetsCallbackHandlerSource.Create(asyncOperation, cancellationToken, out var token), token); + } + return AwaitForAllAssets(asyncOperation, progress: null, cancellationToken: cancellationToken); } - public static UniTask AwaitForAllAssets(this AssetBundleRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) + public static UniTask AwaitForAllAssets(this AssetBundleRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); if (asyncOperation.isDone) return UniTask.FromResult(asyncOperation.allAssets); - return new UniTask(AssetBundleRequestAllAssetsConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, cancelImmediately, out var token), token); + return new UniTask(AssetBundleRequestAllAssetsConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); } public struct AssetBundleRequestAllAssetsAwaiter : ICriticalNotifyCompletion @@ -85,6 +92,108 @@ namespace Cysharp.Threading.Tasks asyncOperation.completed += continuationAction; } } + + sealed class AssetBundleRequestAllAssetsCallbackHandlerSource : IUniTaskSource, ITaskPoolNode + { + static TaskPool pool; + AssetBundleRequestAllAssetsCallbackHandlerSource nextNode; + public ref AssetBundleRequestAllAssetsCallbackHandlerSource NextNode => ref nextNode; + + static AssetBundleRequestAllAssetsCallbackHandlerSource() + { + TaskPool.RegisterSizeGetter(typeof(AssetBundleRequestConfiguredSource), () => pool.Size); + } + + AssetBundleRequest asyncOperation; + CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + + UniTaskCompletionSourceCore core; + + AssetBundleRequestAllAssetsCallbackHandlerSource() + { + } + + public static IUniTaskSource Create(AssetBundleRequest asyncOperation, CancellationToken cancellationToken, out short token) + { + if (cancellationToken.IsCancellationRequested) + { + return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token); + } + + if (!pool.TryPop(out var result)) + { + result = new AssetBundleRequestAllAssetsCallbackHandlerSource(); + } + + result.asyncOperation = asyncOperation; + result.cancellationToken = cancellationToken; + + asyncOperation.completed += result.AsyncOperationCompletedHandler; + + if (cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (AssetBundleRequestAllAssetsCallbackHandlerSource)state; + source.core.TrySetCanceled(source.cancellationToken); + }, result); + } + + TaskTracker.TrackActiveTask(result, 3); + + token = result.core.Version; + return result; + } + + public UnityEngine.Object[] GetResult(short token) + { + try + { + return core.GetResult(token); + } + finally + { + TryReturn(); + } + } + + void IUniTaskSource.GetResult(short token) + { + GetResult(token); + } + + public UniTaskStatus GetStatus(short token) + { + return core.GetStatus(token); + } + + public UniTaskStatus UnsafeGetStatus() + { + return core.UnsafeGetStatus(); + } + + public void OnCompleted(Action continuation, object state, short token) + { + core.OnCompleted(continuation, state, token); + } + + bool TryReturn() + { + TaskTracker.RemoveTracking(this); + core.Reset(); + asyncOperation.completed -= AsyncOperationCompletedHandler; + asyncOperation = default; + cancellationToken = default; + cancellationTokenRegistration.Dispose(); + return pool.TryPush(this); + } + + void AsyncOperationCompletedHandler(AsyncOperation _) + { + core.TrySetResult(asyncOperation.allAssets); + } + } sealed class AssetBundleRequestAllAssetsConfiguredSource : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode { @@ -100,7 +209,6 @@ namespace Cysharp.Threading.Tasks AssetBundleRequest asyncOperation; IProgress progress; CancellationToken cancellationToken; - CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; @@ -108,7 +216,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(AssetBundleRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, bool cancelImmediately, out short token) + public static IUniTaskSource Create(AssetBundleRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -123,15 +231,6 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.progress = progress; result.cancellationToken = cancellationToken; - - if (cancelImmediately && cancellationToken.CanBeCanceled) - { - result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => - { - var promise = (AssetBundleRequestAllAssetsConfiguredSource)state; - promise.core.TrySetCanceled(promise.cancellationToken); - }, result); - } TaskTracker.TrackActiveTask(result, 3); @@ -202,7 +301,6 @@ namespace Cysharp.Threading.Tasks asyncOperation = default; progress = default; cancellationToken = default; - cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.cs index 235faa0..069ee5a 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.cs @@ -15,32 +15,35 @@ namespace Cysharp.Threading.Tasks { #region AsyncOperation -#if !UNITY_2023_1_OR_NEWER - // from Unity2023.1.0a15, AsyncOperationAwaitableExtensions.GetAwaiter is defined in UnityEngine. - public static AsyncOperationAwaiter GetAwaiter(this AsyncOperation asyncOperation) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); return new AsyncOperationAwaiter(asyncOperation); } -#endif public static UniTask WithCancellation(this AsyncOperation asyncOperation, CancellationToken cancellationToken) { return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } - - public static UniTask WithCancellation(this AsyncOperation asyncOperation, CancellationToken cancellationToken, bool cancelImmediately) + + public static UniTask WithCancellation(this AsyncOperation asyncOperation, bool handleImmediately, CancellationToken cancellationToken) { - return ToUniTask(asyncOperation, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + if (handleImmediately) + { + Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); + if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); + if (asyncOperation.isDone) return UniTask.CompletedTask; + return new UniTask(AsyncOperationCallbackHandlerSource.Create(asyncOperation, cancellationToken, out var token), token); + } + return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this AsyncOperation asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) + public static UniTask ToUniTask(this AsyncOperation asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); if (asyncOperation.isDone) return UniTask.CompletedTask; - return new UniTask(AsyncOperationConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, cancelImmediately, out var token), token); + return new UniTask(AsyncOperationConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); } public struct AsyncOperationAwaiter : ICriticalNotifyCompletion @@ -83,6 +86,106 @@ namespace Cysharp.Threading.Tasks } } + sealed class AsyncOperationCallbackHandlerSource : IUniTaskSource, ITaskPoolNode + { + static TaskPool pool; + AsyncOperationCallbackHandlerSource nextNode; + public ref AsyncOperationCallbackHandlerSource NextNode => ref nextNode; + + static AsyncOperationCallbackHandlerSource() + { + TaskPool.RegisterSizeGetter(typeof(AsyncOperationCallbackHandlerSource), () => pool.Size); + } + + AsyncOperation asyncOperation; + IProgress progress; + CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + + UniTaskCompletionSourceCore core; + + AsyncOperationCallbackHandlerSource() + { + } + + public static IUniTaskSource Create(AsyncOperation asyncOperation, CancellationToken cancellationToken, out short token) + { + if (cancellationToken.IsCancellationRequested) + { + return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token); + } + + if (!pool.TryPop(out var result)) + { + result = new AsyncOperationCallbackHandlerSource(); + } + + result.asyncOperation = asyncOperation; + result.cancellationToken = cancellationToken; + + asyncOperation.completed += result.AsyncOperationCompletedHandler; + + if (cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (AsyncOperationCallbackHandlerSource)state; + source.core.TrySetCanceled(source.cancellationToken); + }, result); + } + + TaskTracker.TrackActiveTask(result, 3); + + token = result.core.Version; + return result; + } + + public void GetResult(short token) + { + try + { + core.GetResult(token); + } + finally + { + TryReturn(); + } + } + + + public UniTaskStatus GetStatus(short token) + { + return core.GetStatus(token); + } + + public UniTaskStatus UnsafeGetStatus() + { + return core.UnsafeGetStatus(); + } + + public void OnCompleted(Action continuation, object state, short token) + { + core.OnCompleted(continuation, state, token); + } + + bool TryReturn() + { + TaskTracker.RemoveTracking(this); + core.Reset(); + asyncOperation.completed -= AsyncOperationCompletedHandler; + asyncOperation = default; + progress = default; + cancellationToken = default; + cancellationTokenRegistration.Dispose(); + return pool.TryPush(this); + } + + void AsyncOperationCompletedHandler(AsyncOperation _) + { + core.TrySetResult(AsyncUnit.Default); + } + } + sealed class AsyncOperationConfiguredSource : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode { static TaskPool pool; @@ -97,7 +200,6 @@ namespace Cysharp.Threading.Tasks AsyncOperation asyncOperation; IProgress progress; CancellationToken cancellationToken; - CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; @@ -105,7 +207,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(AsyncOperation asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, bool cancelImmediately, out short token) + public static IUniTaskSource Create(AsyncOperation asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -116,16 +218,6 @@ namespace Cysharp.Threading.Tasks { result = new AsyncOperationConfiguredSource(); } - - - if (cancelImmediately && cancellationToken.CanBeCanceled) - { - result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => - { - var promise = (AsyncOperationConfiguredSource)state; - promise.core.TrySetCanceled(promise.cancellationToken); - }, result); - } result.asyncOperation = asyncOperation; result.progress = progress; @@ -151,6 +243,7 @@ namespace Cysharp.Threading.Tasks } } + public UniTaskStatus GetStatus(short token) { return core.GetStatus(token); @@ -195,7 +288,6 @@ namespace Cysharp.Threading.Tasks asyncOperation = default; progress = default; cancellationToken = default; - cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -215,17 +307,24 @@ namespace Cysharp.Threading.Tasks return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } - public static UniTask WithCancellation(this ResourceRequest asyncOperation, CancellationToken cancellationToken, bool cancelImmediately) + public static UniTask WithCancellation(this ResourceRequest asyncOperation, bool handleImmediately, CancellationToken cancellationToken) { - return ToUniTask(asyncOperation, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + if (handleImmediately) + { + Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); + if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); + if (asyncOperation.isDone) return UniTask.FromResult(asyncOperation.asset); + return new UniTask(ResourceRequestCallbackHandlerSource.Create(asyncOperation, cancellationToken, out var token), token); + } + return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this ResourceRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) + public static UniTask ToUniTask(this ResourceRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); if (asyncOperation.isDone) return UniTask.FromResult(asyncOperation.asset); - return new UniTask(ResourceRequestConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, cancelImmediately, out var token), token); + return new UniTask(ResourceRequestConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); } public struct ResourceRequestAwaiter : ICriticalNotifyCompletion @@ -272,6 +371,110 @@ namespace Cysharp.Threading.Tasks } } + sealed class ResourceRequestCallbackHandlerSource : IUniTaskSource, ITaskPoolNode + { + static TaskPool pool; + ResourceRequestCallbackHandlerSource nextNode; + public ref ResourceRequestCallbackHandlerSource NextNode => ref nextNode; + + static ResourceRequestCallbackHandlerSource() + { + TaskPool.RegisterSizeGetter(typeof(ResourceRequestCallbackHandlerSource), () => pool.Size); + } + + ResourceRequest asyncOperation; + IProgress progress; + CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + + UniTaskCompletionSourceCore core; + + ResourceRequestCallbackHandlerSource() + { + } + + public static IUniTaskSource Create(ResourceRequest asyncOperation, CancellationToken cancellationToken, out short token) + { + if (cancellationToken.IsCancellationRequested) + { + return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token); + } + + if (!pool.TryPop(out var result)) + { + result = new ResourceRequestCallbackHandlerSource(); + } + + result.asyncOperation = asyncOperation; + result.cancellationToken = cancellationToken; + + asyncOperation.completed += result.AsyncOperationCompletedHandler; + + if (cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (ResourceRequestCallbackHandlerSource)state; + source.core.TrySetCanceled(source.cancellationToken); + }, result); + } + + TaskTracker.TrackActiveTask(result, 3); + + token = result.core.Version; + return result; + } + + public UnityEngine.Object GetResult(short token) + { + try + { + return core.GetResult(token); + } + finally + { + TryReturn(); + } + } + + void IUniTaskSource.GetResult(short token) + { + GetResult(token); + } + + public UniTaskStatus GetStatus(short token) + { + return core.GetStatus(token); + } + + public UniTaskStatus UnsafeGetStatus() + { + return core.UnsafeGetStatus(); + } + + public void OnCompleted(Action continuation, object state, short token) + { + core.OnCompleted(continuation, state, token); + } + + bool TryReturn() + { + TaskTracker.RemoveTracking(this); + core.Reset(); + asyncOperation.completed -= AsyncOperationCompletedHandler; + asyncOperation = default; + progress = default; + cancellationToken = default; + cancellationTokenRegistration.Dispose(); + return pool.TryPush(this); + } + + void AsyncOperationCompletedHandler(AsyncOperation _) + { + core.TrySetResult(asyncOperation.asset); + } + } + sealed class ResourceRequestConfiguredSource : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode { static TaskPool pool; @@ -286,7 +489,6 @@ namespace Cysharp.Threading.Tasks ResourceRequest asyncOperation; IProgress progress; CancellationToken cancellationToken; - CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; @@ -294,7 +496,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(ResourceRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, bool cancelImmediately, out short token) + public static IUniTaskSource Create(ResourceRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -309,15 +511,6 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.progress = progress; result.cancellationToken = cancellationToken; - - if (cancelImmediately && cancellationToken.CanBeCanceled) - { - result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => - { - var source = (ResourceRequestConfiguredSource)state; - source.core.TrySetCanceled(source.cancellationToken); - }, result); - } TaskTracker.TrackActiveTask(result, 3); @@ -388,7 +581,6 @@ namespace Cysharp.Threading.Tasks asyncOperation = default; progress = default; cancellationToken = default; - cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -409,17 +601,24 @@ namespace Cysharp.Threading.Tasks return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } - public static UniTask WithCancellation(this AssetBundleRequest asyncOperation, CancellationToken cancellationToken, bool cancelImmediately) + public static UniTask WithCancellation(this AssetBundleRequest asyncOperation, bool handleImmediately, CancellationToken cancellationToken) { - return ToUniTask(asyncOperation, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + if (handleImmediately) + { + Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); + if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); + if (asyncOperation.isDone) return UniTask.FromResult(asyncOperation.asset); + return new UniTask(AssetBundleRequestCallbackHandlerSource.Create(asyncOperation, cancellationToken, out var token), token); + } + return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this AssetBundleRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) + public static UniTask ToUniTask(this AssetBundleRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); if (asyncOperation.isDone) return UniTask.FromResult(asyncOperation.asset); - return new UniTask(AssetBundleRequestConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, cancelImmediately, out var token), token); + return new UniTask(AssetBundleRequestConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); } public struct AssetBundleRequestAwaiter : ICriticalNotifyCompletion @@ -466,6 +665,110 @@ namespace Cysharp.Threading.Tasks } } + sealed class AssetBundleRequestCallbackHandlerSource : IUniTaskSource, ITaskPoolNode + { + static TaskPool pool; + AssetBundleRequestCallbackHandlerSource nextNode; + public ref AssetBundleRequestCallbackHandlerSource NextNode => ref nextNode; + + static AssetBundleRequestCallbackHandlerSource() + { + TaskPool.RegisterSizeGetter(typeof(AssetBundleRequestCallbackHandlerSource), () => pool.Size); + } + + AssetBundleRequest asyncOperation; + IProgress progress; + CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + + UniTaskCompletionSourceCore core; + + AssetBundleRequestCallbackHandlerSource() + { + } + + public static IUniTaskSource Create(AssetBundleRequest asyncOperation, CancellationToken cancellationToken, out short token) + { + if (cancellationToken.IsCancellationRequested) + { + return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token); + } + + if (!pool.TryPop(out var result)) + { + result = new AssetBundleRequestCallbackHandlerSource(); + } + + result.asyncOperation = asyncOperation; + result.cancellationToken = cancellationToken; + + asyncOperation.completed += result.AsyncOperationCompletedHandler; + + if (cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (AssetBundleRequestCallbackHandlerSource)state; + source.core.TrySetCanceled(source.cancellationToken); + }, result); + } + + TaskTracker.TrackActiveTask(result, 3); + + token = result.core.Version; + return result; + } + + public UnityEngine.Object GetResult(short token) + { + try + { + return core.GetResult(token); + } + finally + { + TryReturn(); + } + } + + void IUniTaskSource.GetResult(short token) + { + GetResult(token); + } + + public UniTaskStatus GetStatus(short token) + { + return core.GetStatus(token); + } + + public UniTaskStatus UnsafeGetStatus() + { + return core.UnsafeGetStatus(); + } + + public void OnCompleted(Action continuation, object state, short token) + { + core.OnCompleted(continuation, state, token); + } + + bool TryReturn() + { + TaskTracker.RemoveTracking(this); + core.Reset(); + asyncOperation.completed -= AsyncOperationCompletedHandler; + asyncOperation = default; + progress = default; + cancellationToken = default; + cancellationTokenRegistration.Dispose(); + return pool.TryPush(this); + } + + void AsyncOperationCompletedHandler(AsyncOperation _) + { + core.TrySetResult(asyncOperation.asset); + } + } + sealed class AssetBundleRequestConfiguredSource : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode { static TaskPool pool; @@ -480,7 +783,6 @@ namespace Cysharp.Threading.Tasks AssetBundleRequest asyncOperation; IProgress progress; CancellationToken cancellationToken; - CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; @@ -488,7 +790,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(AssetBundleRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, bool cancelImmediately, out short token) + public static IUniTaskSource Create(AssetBundleRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -503,15 +805,6 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.progress = progress; result.cancellationToken = cancellationToken; - - if (cancelImmediately && cancellationToken.CanBeCanceled) - { - result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => - { - var source = (AssetBundleRequestConfiguredSource)state; - source.core.TrySetCanceled(source.cancellationToken); - }, result); - } TaskTracker.TrackActiveTask(result, 3); @@ -582,7 +875,6 @@ namespace Cysharp.Threading.Tasks asyncOperation = default; progress = default; cancellationToken = default; - cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -604,17 +896,24 @@ namespace Cysharp.Threading.Tasks return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } - public static UniTask WithCancellation(this AssetBundleCreateRequest asyncOperation, CancellationToken cancellationToken, bool cancelImmediately) + public static UniTask WithCancellation(this AssetBundleCreateRequest asyncOperation, bool handleImmediately, CancellationToken cancellationToken) { - return ToUniTask(asyncOperation, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + if (handleImmediately) + { + Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); + if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); + if (asyncOperation.isDone) return UniTask.FromResult(asyncOperation.assetBundle); + return new UniTask(AssetBundleCreateRequestCallbackHandlerSource.Create(asyncOperation, cancellationToken, out var token), token); + } + return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this AssetBundleCreateRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) + public static UniTask ToUniTask(this AssetBundleCreateRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); if (asyncOperation.isDone) return UniTask.FromResult(asyncOperation.assetBundle); - return new UniTask(AssetBundleCreateRequestConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, cancelImmediately, out var token), token); + return new UniTask(AssetBundleCreateRequestConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); } public struct AssetBundleCreateRequestAwaiter : ICriticalNotifyCompletion @@ -661,6 +960,110 @@ namespace Cysharp.Threading.Tasks } } + sealed class AssetBundleCreateRequestCallbackHandlerSource : IUniTaskSource, ITaskPoolNode + { + static TaskPool pool; + AssetBundleCreateRequestCallbackHandlerSource nextNode; + public ref AssetBundleCreateRequestCallbackHandlerSource NextNode => ref nextNode; + + static AssetBundleCreateRequestCallbackHandlerSource() + { + TaskPool.RegisterSizeGetter(typeof(AssetBundleCreateRequestCallbackHandlerSource), () => pool.Size); + } + + AssetBundleCreateRequest asyncOperation; + IProgress progress; + CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + + UniTaskCompletionSourceCore core; + + AssetBundleCreateRequestCallbackHandlerSource() + { + } + + public static IUniTaskSource Create(AssetBundleCreateRequest asyncOperation, CancellationToken cancellationToken, out short token) + { + if (cancellationToken.IsCancellationRequested) + { + return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token); + } + + if (!pool.TryPop(out var result)) + { + result = new AssetBundleCreateRequestCallbackHandlerSource(); + } + + result.asyncOperation = asyncOperation; + result.cancellationToken = cancellationToken; + + asyncOperation.completed += result.AsyncOperationCompletedHandler; + + if (cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (AssetBundleCreateRequestCallbackHandlerSource)state; + source.core.TrySetCanceled(source.cancellationToken); + }, result); + } + + TaskTracker.TrackActiveTask(result, 3); + + token = result.core.Version; + return result; + } + + public AssetBundle GetResult(short token) + { + try + { + return core.GetResult(token); + } + finally + { + TryReturn(); + } + } + + void IUniTaskSource.GetResult(short token) + { + GetResult(token); + } + + public UniTaskStatus GetStatus(short token) + { + return core.GetStatus(token); + } + + public UniTaskStatus UnsafeGetStatus() + { + return core.UnsafeGetStatus(); + } + + public void OnCompleted(Action continuation, object state, short token) + { + core.OnCompleted(continuation, state, token); + } + + bool TryReturn() + { + TaskTracker.RemoveTracking(this); + core.Reset(); + asyncOperation.completed -= AsyncOperationCompletedHandler; + asyncOperation = default; + progress = default; + cancellationToken = default; + cancellationTokenRegistration.Dispose(); + return pool.TryPush(this); + } + + void AsyncOperationCompletedHandler(AsyncOperation _) + { + core.TrySetResult(asyncOperation.assetBundle); + } + } + sealed class AssetBundleCreateRequestConfiguredSource : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode { static TaskPool pool; @@ -675,7 +1078,6 @@ namespace Cysharp.Threading.Tasks AssetBundleCreateRequest asyncOperation; IProgress progress; CancellationToken cancellationToken; - CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; @@ -683,7 +1085,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(AssetBundleCreateRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, bool cancelImmediately, out short token) + public static IUniTaskSource Create(AssetBundleCreateRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -698,15 +1100,6 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.progress = progress; result.cancellationToken = cancellationToken; - - if (cancelImmediately && cancellationToken.CanBeCanceled) - { - result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => - { - var source = (AssetBundleCreateRequestConfiguredSource)state; - source.core.TrySetCanceled(source.cancellationToken); - }, result); - } TaskTracker.TrackActiveTask(result, 3); @@ -777,7 +1170,6 @@ namespace Cysharp.Threading.Tasks asyncOperation = default; progress = default; cancellationToken = default; - cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } @@ -799,12 +1191,26 @@ namespace Cysharp.Threading.Tasks return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } - public static UniTask WithCancellation(this UnityWebRequestAsyncOperation asyncOperation, CancellationToken cancellationToken, bool cancelImmediately) + public static UniTask WithCancellation(this UnityWebRequestAsyncOperation asyncOperation, bool handleImmediately, CancellationToken cancellationToken) { - return ToUniTask(asyncOperation, cancellationToken: cancellationToken, cancelImmediately: cancelImmediately); + if (handleImmediately) + { + Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); + if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); + if (asyncOperation.isDone) + { + if (asyncOperation.webRequest.IsError()) + { + return UniTask.FromException(new UnityWebRequestException(asyncOperation.webRequest)); + } + return UniTask.FromResult(asyncOperation.webRequest); + } + return new UniTask(UnityWebRequestAsyncOperationCallbackHandlerSource.Create(asyncOperation, cancellationToken, out var token), token); + } + return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this UnityWebRequestAsyncOperation asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) + public static UniTask ToUniTask(this UnityWebRequestAsyncOperation asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); @@ -816,7 +1222,7 @@ namespace Cysharp.Threading.Tasks } return UniTask.FromResult(asyncOperation.webRequest); } - return new UniTask(UnityWebRequestAsyncOperationConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, cancelImmediately, out var token), token); + return new UniTask(UnityWebRequestAsyncOperationConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); } public struct UnityWebRequestAsyncOperationAwaiter : ICriticalNotifyCompletion @@ -871,6 +1277,118 @@ namespace Cysharp.Threading.Tasks } } + sealed class UnityWebRequestAsyncOperationCallbackHandlerSource : IUniTaskSource, ITaskPoolNode + { + static TaskPool pool; + UnityWebRequestAsyncOperationCallbackHandlerSource nextNode; + public ref UnityWebRequestAsyncOperationCallbackHandlerSource NextNode => ref nextNode; + + static UnityWebRequestAsyncOperationCallbackHandlerSource() + { + TaskPool.RegisterSizeGetter(typeof(UnityWebRequestAsyncOperationCallbackHandlerSource), () => pool.Size); + } + + UnityWebRequestAsyncOperation asyncOperation; + IProgress progress; + CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + + UniTaskCompletionSourceCore core; + + UnityWebRequestAsyncOperationCallbackHandlerSource() + { + } + + public static IUniTaskSource Create(UnityWebRequestAsyncOperation asyncOperation, CancellationToken cancellationToken, out short token) + { + if (cancellationToken.IsCancellationRequested) + { + return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token); + } + + if (!pool.TryPop(out var result)) + { + result = new UnityWebRequestAsyncOperationCallbackHandlerSource(); + } + + result.asyncOperation = asyncOperation; + result.cancellationToken = cancellationToken; + + asyncOperation.completed += result.AsyncOperationCompletedHandler; + + if (cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (UnityWebRequestAsyncOperationCallbackHandlerSource)state; + source.asyncOperation.webRequest.Abort(); + source.core.TrySetCanceled(source.cancellationToken); + }, result); + } + + TaskTracker.TrackActiveTask(result, 3); + + token = result.core.Version; + return result; + } + + public UnityWebRequest GetResult(short token) + { + try + { + return core.GetResult(token); + } + finally + { + TryReturn(); + } + } + + void IUniTaskSource.GetResult(short token) + { + GetResult(token); + } + + public UniTaskStatus GetStatus(short token) + { + return core.GetStatus(token); + } + + public UniTaskStatus UnsafeGetStatus() + { + return core.UnsafeGetStatus(); + } + + public void OnCompleted(Action continuation, object state, short token) + { + core.OnCompleted(continuation, state, token); + } + + bool TryReturn() + { + TaskTracker.RemoveTracking(this); + core.Reset(); + asyncOperation.completed -= AsyncOperationCompletedHandler; + asyncOperation = default; + progress = default; + cancellationToken = default; + cancellationTokenRegistration.Dispose(); + return pool.TryPush(this); + } + + void AsyncOperationCompletedHandler(AsyncOperation _) + { + if (asyncOperation.webRequest.IsError()) + { + core.TrySetException(new UnityWebRequestException(asyncOperation.webRequest)); + } + else + { + core.TrySetResult(asyncOperation.webRequest); + } + } + } + sealed class UnityWebRequestAsyncOperationConfiguredSource : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode { static TaskPool pool; @@ -885,7 +1403,6 @@ namespace Cysharp.Threading.Tasks UnityWebRequestAsyncOperation asyncOperation; IProgress progress; CancellationToken cancellationToken; - CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; @@ -893,7 +1410,7 @@ namespace Cysharp.Threading.Tasks { } - public static IUniTaskSource Create(UnityWebRequestAsyncOperation asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, bool cancelImmediately, out short token) + public static IUniTaskSource Create(UnityWebRequestAsyncOperation asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -908,15 +1425,6 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.progress = progress; result.cancellationToken = cancellationToken; - - if (cancelImmediately && cancellationToken.CanBeCanceled) - { - result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => - { - var promise = (UnityWebRequestAsyncOperationConfiguredSource)state; - promise.core.TrySetCanceled(promise.cancellationToken); - }, result); - } TaskTracker.TrackActiveTask(result, 3); @@ -974,11 +1482,7 @@ namespace Cysharp.Threading.Tasks if (asyncOperation.isDone) { - if (asyncOperation.webRequest == null) - { - core.TrySetException(new ObjectDisposedException("The webRequest has been destroyed.")); - } - else if (asyncOperation.webRequest.IsError()) + if (asyncOperation.webRequest.IsError()) { core.TrySetException(new UnityWebRequestException(asyncOperation.webRequest)); } @@ -999,7 +1503,6 @@ namespace Cysharp.Threading.Tasks asyncOperation = default; progress = default; cancellationToken = default; - cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.tt b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.tt index 65dac9e..f74fa0b 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.tt +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.tt @@ -54,6 +54,29 @@ namespace Cysharp.Threading.Tasks return ToUniTask(asyncOperation, cancellationToken: cancellationToken); } + public static <#= ToUniTaskReturnType(t.returnType) #> WithCancellation(this <#= t.typeName #> asyncOperation, bool handleImmediately, CancellationToken cancellationToken) + { + if (handleImmediately) + { + Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); + if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled<#= IsVoid(t) ? "" : "<" + t.returnType + ">" #>(cancellationToken); +<# if(IsUnityWebRequest(t)) { #> + if (asyncOperation.isDone) + { + if (asyncOperation.webRequest.IsError()) + { + return UniTask.FromException(new UnityWebRequestException(asyncOperation.webRequest)); + } + return UniTask.FromResult(asyncOperation.webRequest); + } +<# } else { #> + if (asyncOperation.isDone) return <#= IsVoid(t) ? "UniTask.CompletedTask" : $"UniTask.FromResult(asyncOperation.{t.returnField})" #>; +<# } #> + return new <#= ToUniTaskReturnType(t.returnType) #>(<#= t.typeName #>CallbackHandlerSource.Create(asyncOperation, cancellationToken, out var token), token); + } + return ToUniTask(asyncOperation, cancellationToken: cancellationToken); + } + public static <#= ToUniTaskReturnType(t.returnType) #> ToUniTask(this <#= t.typeName #> asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); @@ -137,6 +160,130 @@ namespace Cysharp.Threading.Tasks } } + sealed class <#= t.typeName #>CallbackHandlerSource : <#= ToIUniTaskSourceReturnType(t.returnType) #>, ITaskPoolNode<<#= t.typeName #>CallbackHandlerSource> + { + static TaskPool<<#= t.typeName #>CallbackHandlerSource> pool; + <#= t.typeName #>CallbackHandlerSource nextNode; + public ref <#= t.typeName #>CallbackHandlerSource NextNode => ref nextNode; + + static <#= t.typeName #>CallbackHandlerSource() + { + TaskPool.RegisterSizeGetter(typeof(<#= t.typeName #>CallbackHandlerSource), () => pool.Size); + } + + <#= t.typeName #> asyncOperation; + IProgress progress; + CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + + UniTaskCompletionSourceCore<<#= IsVoid(t) ? "AsyncUnit" : t.returnType #>> core; + + <#= t.typeName #>CallbackHandlerSource() + { + } + + public static <#= ToIUniTaskSourceReturnType(t.returnType) #> Create(<#= t.typeName #> asyncOperation, CancellationToken cancellationToken, out short token) + { + if (cancellationToken.IsCancellationRequested) + { + return AutoResetUniTaskCompletionSource<#= IsVoid(t) ? "" : $"<{t.returnType}>" #>.CreateFromCanceled(cancellationToken, out token); + } + + if (!pool.TryPop(out var result)) + { + result = new <#= t.typeName #>CallbackHandlerSource(); + } + + result.asyncOperation = asyncOperation; + result.cancellationToken = cancellationToken; + + asyncOperation.completed += result.AsyncOperationCompletedHandler; + + if (cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (<#= t.typeName #>CallbackHandlerSource)state; +<# if(IsUnityWebRequest(t)) { #> + source.asyncOperation.webRequest.Abort(); +<# } #> + source.core.TrySetCanceled(source.cancellationToken); + }, result); + } + + TaskTracker.TrackActiveTask(result, 3); + + token = result.core.Version; + return result; + } + + public <#= t.returnType #> GetResult(short token) + { + try + { +<# if (!IsVoid(t)) { #> + return core.GetResult(token); +<# } else { #> + core.GetResult(token); +<# } #> + } + finally + { + TryReturn(); + } + } + +<# if (!IsVoid(t)) { #> + void IUniTaskSource.GetResult(short token) + { + GetResult(token); + } +<# } #> + + public UniTaskStatus GetStatus(short token) + { + return core.GetStatus(token); + } + + public UniTaskStatus UnsafeGetStatus() + { + return core.UnsafeGetStatus(); + } + + public void OnCompleted(Action continuation, object state, short token) + { + core.OnCompleted(continuation, state, token); + } + + bool TryReturn() + { + TaskTracker.RemoveTracking(this); + core.Reset(); + asyncOperation.completed -= AsyncOperationCompletedHandler; + asyncOperation = default; + progress = default; + cancellationToken = default; + cancellationTokenRegistration.Dispose(); + return pool.TryPush(this); + } + + void AsyncOperationCompletedHandler(AsyncOperation _) + { +<# if(IsUnityWebRequest(t)) { #> + if (asyncOperation.webRequest.IsError()) + { + core.TrySetException(new UnityWebRequestException(asyncOperation.webRequest)); + } + else + { + core.TrySetResult(asyncOperation.webRequest); + } +<# } else { #> + core.TrySetResult(<#= IsVoid(t) ? "AsyncUnit.Default" : $"asyncOperation.{t.returnField}" #>); +<# } #> + } + } + sealed class <#= t.typeName #>ConfiguredSource : <#= ToIUniTaskSourceReturnType(t.returnType) #>, IPlayerLoopItem, ITaskPoolNode<<#= t.typeName #>ConfiguredSource> { static TaskPool<<#= t.typeName #>ConfiguredSource> pool; @@ -156,7 +303,6 @@ namespace Cysharp.Threading.Tasks <#= t.typeName #>ConfiguredSource() { - } public static <#= ToIUniTaskSourceReturnType(t.returnType) #> Create(<#= t.typeName #> asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token)