diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.cs index 069ee5a..dc24ad0 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.cs @@ -28,22 +28,15 @@ namespace Cysharp.Threading.Tasks public static UniTask WithCancellation(this AsyncOperation asyncOperation, bool handleImmediately, CancellationToken cancellationToken) { - if (handleImmediately) - { - Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); - if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); - if (asyncOperation.isDone) return UniTask.CompletedTask; - return new UniTask(AsyncOperationCallbackHandlerSource.Create(asyncOperation, cancellationToken, out var token), token); - } - return ToUniTask(asyncOperation, cancellationToken: cancellationToken); + return ToUniTask(asyncOperation, handleImmediately: handleImmediately, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this AsyncOperation asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask ToUniTask(this AsyncOperation asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, bool handleImmediately = false, CancellationToken cancellationToken = default(CancellationToken)) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); if (asyncOperation.isDone) return UniTask.CompletedTask; - return new UniTask(AsyncOperationConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); + return new UniTask(AsyncOperationConfiguredSource.Create(asyncOperation, timing, progress, handleImmediately, cancellationToken, out var token), token); } public struct AsyncOperationAwaiter : ICriticalNotifyCompletion @@ -86,106 +79,6 @@ namespace Cysharp.Threading.Tasks } } - sealed class AsyncOperationCallbackHandlerSource : IUniTaskSource, ITaskPoolNode - { - static TaskPool pool; - AsyncOperationCallbackHandlerSource nextNode; - public ref AsyncOperationCallbackHandlerSource NextNode => ref nextNode; - - static AsyncOperationCallbackHandlerSource() - { - TaskPool.RegisterSizeGetter(typeof(AsyncOperationCallbackHandlerSource), () => pool.Size); - } - - AsyncOperation asyncOperation; - IProgress progress; - CancellationToken cancellationToken; - CancellationTokenRegistration cancellationTokenRegistration; - - UniTaskCompletionSourceCore core; - - AsyncOperationCallbackHandlerSource() - { - } - - public static IUniTaskSource Create(AsyncOperation asyncOperation, CancellationToken cancellationToken, out short token) - { - if (cancellationToken.IsCancellationRequested) - { - return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token); - } - - if (!pool.TryPop(out var result)) - { - result = new AsyncOperationCallbackHandlerSource(); - } - - result.asyncOperation = asyncOperation; - result.cancellationToken = cancellationToken; - - asyncOperation.completed += result.AsyncOperationCompletedHandler; - - if (cancellationToken.CanBeCanceled) - { - result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => - { - var source = (AsyncOperationCallbackHandlerSource)state; - source.core.TrySetCanceled(source.cancellationToken); - }, result); - } - - TaskTracker.TrackActiveTask(result, 3); - - token = result.core.Version; - return result; - } - - public void GetResult(short token) - { - try - { - core.GetResult(token); - } - finally - { - TryReturn(); - } - } - - - public UniTaskStatus GetStatus(short token) - { - return core.GetStatus(token); - } - - public UniTaskStatus UnsafeGetStatus() - { - return core.UnsafeGetStatus(); - } - - public void OnCompleted(Action continuation, object state, short token) - { - core.OnCompleted(continuation, state, token); - } - - bool TryReturn() - { - TaskTracker.RemoveTracking(this); - core.Reset(); - asyncOperation.completed -= AsyncOperationCompletedHandler; - asyncOperation = default; - progress = default; - cancellationToken = default; - cancellationTokenRegistration.Dispose(); - return pool.TryPush(this); - } - - void AsyncOperationCompletedHandler(AsyncOperation _) - { - core.TrySetResult(AsyncUnit.Default); - } - } - sealed class AsyncOperationConfiguredSource : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode { static TaskPool pool; @@ -200,14 +93,18 @@ namespace Cysharp.Threading.Tasks AsyncOperation asyncOperation; IProgress progress; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + bool completed; UniTaskCompletionSourceCore core; + Action continuationAction; + AsyncOperationConfiguredSource() { } - public static IUniTaskSource Create(AsyncOperation asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(AsyncOperation asyncOperation, PlayerLoopTiming timing, IProgress progress, bool handleImmediately, CancellationToken cancellationToken, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -222,6 +119,22 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.progress = progress; result.cancellationToken = cancellationToken; + result.completed = false; + + if (handleImmediately) + { + result.continuationAction = result.Continuation; + asyncOperation.completed += result.continuationAction; + + if (cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (AsyncOperationConfiguredSource)state; + source.core.TrySetCanceled(source.cancellationToken); + }, result); + } + } TaskTracker.TrackActiveTask(result, 3); @@ -285,11 +198,33 @@ namespace Cysharp.Threading.Tasks { TaskTracker.RemoveTracking(this); core.Reset(); + asyncOperation.completed -= continuationAction; asyncOperation = default; progress = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } + + void Continuation(AsyncOperation _) + { + if (completed) + { + TryReturn(); + } + else + { + completed = true; + if (cancellationToken.IsCancellationRequested) + { + core.TrySetCanceled(cancellationToken); + } + else + { + core.TrySetResult(AsyncUnit.Default); + } + } + } } #endregion @@ -309,22 +244,15 @@ namespace Cysharp.Threading.Tasks public static UniTask WithCancellation(this ResourceRequest asyncOperation, bool handleImmediately, CancellationToken cancellationToken) { - if (handleImmediately) - { - Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); - if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); - if (asyncOperation.isDone) return UniTask.FromResult(asyncOperation.asset); - return new UniTask(ResourceRequestCallbackHandlerSource.Create(asyncOperation, cancellationToken, out var token), token); - } - return ToUniTask(asyncOperation, cancellationToken: cancellationToken); + return ToUniTask(asyncOperation, handleImmediately: handleImmediately, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this ResourceRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask ToUniTask(this ResourceRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, bool handleImmediately = false, CancellationToken cancellationToken = default(CancellationToken)) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); if (asyncOperation.isDone) return UniTask.FromResult(asyncOperation.asset); - return new UniTask(ResourceRequestConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); + return new UniTask(ResourceRequestConfiguredSource.Create(asyncOperation, timing, progress, handleImmediately, cancellationToken, out var token), token); } public struct ResourceRequestAwaiter : ICriticalNotifyCompletion @@ -371,110 +299,6 @@ namespace Cysharp.Threading.Tasks } } - sealed class ResourceRequestCallbackHandlerSource : IUniTaskSource, ITaskPoolNode - { - static TaskPool pool; - ResourceRequestCallbackHandlerSource nextNode; - public ref ResourceRequestCallbackHandlerSource NextNode => ref nextNode; - - static ResourceRequestCallbackHandlerSource() - { - TaskPool.RegisterSizeGetter(typeof(ResourceRequestCallbackHandlerSource), () => pool.Size); - } - - ResourceRequest asyncOperation; - IProgress progress; - CancellationToken cancellationToken; - CancellationTokenRegistration cancellationTokenRegistration; - - UniTaskCompletionSourceCore core; - - ResourceRequestCallbackHandlerSource() - { - } - - public static IUniTaskSource Create(ResourceRequest asyncOperation, CancellationToken cancellationToken, out short token) - { - if (cancellationToken.IsCancellationRequested) - { - return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token); - } - - if (!pool.TryPop(out var result)) - { - result = new ResourceRequestCallbackHandlerSource(); - } - - result.asyncOperation = asyncOperation; - result.cancellationToken = cancellationToken; - - asyncOperation.completed += result.AsyncOperationCompletedHandler; - - if (cancellationToken.CanBeCanceled) - { - result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => - { - var source = (ResourceRequestCallbackHandlerSource)state; - source.core.TrySetCanceled(source.cancellationToken); - }, result); - } - - TaskTracker.TrackActiveTask(result, 3); - - token = result.core.Version; - return result; - } - - public UnityEngine.Object GetResult(short token) - { - try - { - return core.GetResult(token); - } - finally - { - TryReturn(); - } - } - - void IUniTaskSource.GetResult(short token) - { - GetResult(token); - } - - public UniTaskStatus GetStatus(short token) - { - return core.GetStatus(token); - } - - public UniTaskStatus UnsafeGetStatus() - { - return core.UnsafeGetStatus(); - } - - public void OnCompleted(Action continuation, object state, short token) - { - core.OnCompleted(continuation, state, token); - } - - bool TryReturn() - { - TaskTracker.RemoveTracking(this); - core.Reset(); - asyncOperation.completed -= AsyncOperationCompletedHandler; - asyncOperation = default; - progress = default; - cancellationToken = default; - cancellationTokenRegistration.Dispose(); - return pool.TryPush(this); - } - - void AsyncOperationCompletedHandler(AsyncOperation _) - { - core.TrySetResult(asyncOperation.asset); - } - } - sealed class ResourceRequestConfiguredSource : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode { static TaskPool pool; @@ -489,14 +313,18 @@ namespace Cysharp.Threading.Tasks ResourceRequest asyncOperation; IProgress progress; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + bool completed; UniTaskCompletionSourceCore core; + Action continuationAction; + ResourceRequestConfiguredSource() { } - public static IUniTaskSource Create(ResourceRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(ResourceRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, bool handleImmediately, CancellationToken cancellationToken, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -511,6 +339,22 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.progress = progress; result.cancellationToken = cancellationToken; + result.completed = false; + + if (handleImmediately) + { + result.continuationAction = result.Continuation; + asyncOperation.completed += result.continuationAction; + + if (cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (ResourceRequestConfiguredSource)state; + source.core.TrySetCanceled(source.cancellationToken); + }, result); + } + } TaskTracker.TrackActiveTask(result, 3); @@ -578,11 +422,33 @@ namespace Cysharp.Threading.Tasks { TaskTracker.RemoveTracking(this); core.Reset(); + asyncOperation.completed -= continuationAction; asyncOperation = default; progress = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } + + void Continuation(AsyncOperation _) + { + if (completed) + { + TryReturn(); + } + else + { + completed = true; + if (cancellationToken.IsCancellationRequested) + { + core.TrySetCanceled(cancellationToken); + } + else + { + core.TrySetResult(asyncOperation.asset); + } + } + } } #endregion @@ -603,22 +469,15 @@ namespace Cysharp.Threading.Tasks public static UniTask WithCancellation(this AssetBundleRequest asyncOperation, bool handleImmediately, CancellationToken cancellationToken) { - if (handleImmediately) - { - Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); - if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); - if (asyncOperation.isDone) return UniTask.FromResult(asyncOperation.asset); - return new UniTask(AssetBundleRequestCallbackHandlerSource.Create(asyncOperation, cancellationToken, out var token), token); - } - return ToUniTask(asyncOperation, cancellationToken: cancellationToken); + return ToUniTask(asyncOperation, handleImmediately: handleImmediately, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this AssetBundleRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask ToUniTask(this AssetBundleRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, bool handleImmediately = false, CancellationToken cancellationToken = default(CancellationToken)) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); if (asyncOperation.isDone) return UniTask.FromResult(asyncOperation.asset); - return new UniTask(AssetBundleRequestConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); + return new UniTask(AssetBundleRequestConfiguredSource.Create(asyncOperation, timing, progress, handleImmediately, cancellationToken, out var token), token); } public struct AssetBundleRequestAwaiter : ICriticalNotifyCompletion @@ -665,110 +524,6 @@ namespace Cysharp.Threading.Tasks } } - sealed class AssetBundleRequestCallbackHandlerSource : IUniTaskSource, ITaskPoolNode - { - static TaskPool pool; - AssetBundleRequestCallbackHandlerSource nextNode; - public ref AssetBundleRequestCallbackHandlerSource NextNode => ref nextNode; - - static AssetBundleRequestCallbackHandlerSource() - { - TaskPool.RegisterSizeGetter(typeof(AssetBundleRequestCallbackHandlerSource), () => pool.Size); - } - - AssetBundleRequest asyncOperation; - IProgress progress; - CancellationToken cancellationToken; - CancellationTokenRegistration cancellationTokenRegistration; - - UniTaskCompletionSourceCore core; - - AssetBundleRequestCallbackHandlerSource() - { - } - - public static IUniTaskSource Create(AssetBundleRequest asyncOperation, CancellationToken cancellationToken, out short token) - { - if (cancellationToken.IsCancellationRequested) - { - return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token); - } - - if (!pool.TryPop(out var result)) - { - result = new AssetBundleRequestCallbackHandlerSource(); - } - - result.asyncOperation = asyncOperation; - result.cancellationToken = cancellationToken; - - asyncOperation.completed += result.AsyncOperationCompletedHandler; - - if (cancellationToken.CanBeCanceled) - { - result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => - { - var source = (AssetBundleRequestCallbackHandlerSource)state; - source.core.TrySetCanceled(source.cancellationToken); - }, result); - } - - TaskTracker.TrackActiveTask(result, 3); - - token = result.core.Version; - return result; - } - - public UnityEngine.Object GetResult(short token) - { - try - { - return core.GetResult(token); - } - finally - { - TryReturn(); - } - } - - void IUniTaskSource.GetResult(short token) - { - GetResult(token); - } - - public UniTaskStatus GetStatus(short token) - { - return core.GetStatus(token); - } - - public UniTaskStatus UnsafeGetStatus() - { - return core.UnsafeGetStatus(); - } - - public void OnCompleted(Action continuation, object state, short token) - { - core.OnCompleted(continuation, state, token); - } - - bool TryReturn() - { - TaskTracker.RemoveTracking(this); - core.Reset(); - asyncOperation.completed -= AsyncOperationCompletedHandler; - asyncOperation = default; - progress = default; - cancellationToken = default; - cancellationTokenRegistration.Dispose(); - return pool.TryPush(this); - } - - void AsyncOperationCompletedHandler(AsyncOperation _) - { - core.TrySetResult(asyncOperation.asset); - } - } - sealed class AssetBundleRequestConfiguredSource : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode { static TaskPool pool; @@ -783,14 +538,18 @@ namespace Cysharp.Threading.Tasks AssetBundleRequest asyncOperation; IProgress progress; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + bool completed; UniTaskCompletionSourceCore core; + Action continuationAction; + AssetBundleRequestConfiguredSource() { } - public static IUniTaskSource Create(AssetBundleRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(AssetBundleRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, bool handleImmediately, CancellationToken cancellationToken, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -805,6 +564,22 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.progress = progress; result.cancellationToken = cancellationToken; + result.completed = false; + + if (handleImmediately) + { + result.continuationAction = result.Continuation; + asyncOperation.completed += result.continuationAction; + + if (cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (AssetBundleRequestConfiguredSource)state; + source.core.TrySetCanceled(source.cancellationToken); + }, result); + } + } TaskTracker.TrackActiveTask(result, 3); @@ -872,11 +647,33 @@ namespace Cysharp.Threading.Tasks { TaskTracker.RemoveTracking(this); core.Reset(); + asyncOperation.completed -= continuationAction; asyncOperation = default; progress = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } + + void Continuation(AsyncOperation _) + { + if (completed) + { + TryReturn(); + } + else + { + completed = true; + if (cancellationToken.IsCancellationRequested) + { + core.TrySetCanceled(cancellationToken); + } + else + { + core.TrySetResult(asyncOperation.asset); + } + } + } } #endregion @@ -898,22 +695,15 @@ namespace Cysharp.Threading.Tasks public static UniTask WithCancellation(this AssetBundleCreateRequest asyncOperation, bool handleImmediately, CancellationToken cancellationToken) { - if (handleImmediately) - { - Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); - if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); - if (asyncOperation.isDone) return UniTask.FromResult(asyncOperation.assetBundle); - return new UniTask(AssetBundleCreateRequestCallbackHandlerSource.Create(asyncOperation, cancellationToken, out var token), token); - } - return ToUniTask(asyncOperation, cancellationToken: cancellationToken); + return ToUniTask(asyncOperation, handleImmediately: handleImmediately, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this AssetBundleCreateRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask ToUniTask(this AssetBundleCreateRequest asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, bool handleImmediately = false, CancellationToken cancellationToken = default(CancellationToken)) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); if (asyncOperation.isDone) return UniTask.FromResult(asyncOperation.assetBundle); - return new UniTask(AssetBundleCreateRequestConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); + return new UniTask(AssetBundleCreateRequestConfiguredSource.Create(asyncOperation, timing, progress, handleImmediately, cancellationToken, out var token), token); } public struct AssetBundleCreateRequestAwaiter : ICriticalNotifyCompletion @@ -960,110 +750,6 @@ namespace Cysharp.Threading.Tasks } } - sealed class AssetBundleCreateRequestCallbackHandlerSource : IUniTaskSource, ITaskPoolNode - { - static TaskPool pool; - AssetBundleCreateRequestCallbackHandlerSource nextNode; - public ref AssetBundleCreateRequestCallbackHandlerSource NextNode => ref nextNode; - - static AssetBundleCreateRequestCallbackHandlerSource() - { - TaskPool.RegisterSizeGetter(typeof(AssetBundleCreateRequestCallbackHandlerSource), () => pool.Size); - } - - AssetBundleCreateRequest asyncOperation; - IProgress progress; - CancellationToken cancellationToken; - CancellationTokenRegistration cancellationTokenRegistration; - - UniTaskCompletionSourceCore core; - - AssetBundleCreateRequestCallbackHandlerSource() - { - } - - public static IUniTaskSource Create(AssetBundleCreateRequest asyncOperation, CancellationToken cancellationToken, out short token) - { - if (cancellationToken.IsCancellationRequested) - { - return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token); - } - - if (!pool.TryPop(out var result)) - { - result = new AssetBundleCreateRequestCallbackHandlerSource(); - } - - result.asyncOperation = asyncOperation; - result.cancellationToken = cancellationToken; - - asyncOperation.completed += result.AsyncOperationCompletedHandler; - - if (cancellationToken.CanBeCanceled) - { - result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => - { - var source = (AssetBundleCreateRequestCallbackHandlerSource)state; - source.core.TrySetCanceled(source.cancellationToken); - }, result); - } - - TaskTracker.TrackActiveTask(result, 3); - - token = result.core.Version; - return result; - } - - public AssetBundle GetResult(short token) - { - try - { - return core.GetResult(token); - } - finally - { - TryReturn(); - } - } - - void IUniTaskSource.GetResult(short token) - { - GetResult(token); - } - - public UniTaskStatus GetStatus(short token) - { - return core.GetStatus(token); - } - - public UniTaskStatus UnsafeGetStatus() - { - return core.UnsafeGetStatus(); - } - - public void OnCompleted(Action continuation, object state, short token) - { - core.OnCompleted(continuation, state, token); - } - - bool TryReturn() - { - TaskTracker.RemoveTracking(this); - core.Reset(); - asyncOperation.completed -= AsyncOperationCompletedHandler; - asyncOperation = default; - progress = default; - cancellationToken = default; - cancellationTokenRegistration.Dispose(); - return pool.TryPush(this); - } - - void AsyncOperationCompletedHandler(AsyncOperation _) - { - core.TrySetResult(asyncOperation.assetBundle); - } - } - sealed class AssetBundleCreateRequestConfiguredSource : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode { static TaskPool pool; @@ -1078,14 +764,18 @@ namespace Cysharp.Threading.Tasks AssetBundleCreateRequest asyncOperation; IProgress progress; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + bool completed; UniTaskCompletionSourceCore core; + Action continuationAction; + AssetBundleCreateRequestConfiguredSource() { } - public static IUniTaskSource Create(AssetBundleCreateRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(AssetBundleCreateRequest asyncOperation, PlayerLoopTiming timing, IProgress progress, bool handleImmediately, CancellationToken cancellationToken, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -1100,6 +790,22 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.progress = progress; result.cancellationToken = cancellationToken; + result.completed = false; + + if (handleImmediately) + { + result.continuationAction = result.Continuation; + asyncOperation.completed += result.continuationAction; + + if (cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (AssetBundleCreateRequestConfiguredSource)state; + source.core.TrySetCanceled(source.cancellationToken); + }, result); + } + } TaskTracker.TrackActiveTask(result, 3); @@ -1167,11 +873,33 @@ namespace Cysharp.Threading.Tasks { TaskTracker.RemoveTracking(this); core.Reset(); + asyncOperation.completed -= continuationAction; asyncOperation = default; progress = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } + + void Continuation(AsyncOperation _) + { + if (completed) + { + TryReturn(); + } + else + { + completed = true; + if (cancellationToken.IsCancellationRequested) + { + core.TrySetCanceled(cancellationToken); + } + else + { + core.TrySetResult(asyncOperation.assetBundle); + } + } + } } #endregion @@ -1193,24 +921,10 @@ namespace Cysharp.Threading.Tasks public static UniTask WithCancellation(this UnityWebRequestAsyncOperation asyncOperation, bool handleImmediately, CancellationToken cancellationToken) { - if (handleImmediately) - { - Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); - if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); - if (asyncOperation.isDone) - { - if (asyncOperation.webRequest.IsError()) - { - return UniTask.FromException(new UnityWebRequestException(asyncOperation.webRequest)); - } - return UniTask.FromResult(asyncOperation.webRequest); - } - return new UniTask(UnityWebRequestAsyncOperationCallbackHandlerSource.Create(asyncOperation, cancellationToken, out var token), token); - } - return ToUniTask(asyncOperation, cancellationToken: cancellationToken); + return ToUniTask(asyncOperation, handleImmediately: handleImmediately, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this UnityWebRequestAsyncOperation asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask ToUniTask(this UnityWebRequestAsyncOperation asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, bool handleImmediately = false, CancellationToken cancellationToken = default(CancellationToken)) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled(cancellationToken); @@ -1222,7 +936,7 @@ namespace Cysharp.Threading.Tasks } return UniTask.FromResult(asyncOperation.webRequest); } - return new UniTask(UnityWebRequestAsyncOperationConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); + return new UniTask(UnityWebRequestAsyncOperationConfiguredSource.Create(asyncOperation, timing, progress, handleImmediately, cancellationToken, out var token), token); } public struct UnityWebRequestAsyncOperationAwaiter : ICriticalNotifyCompletion @@ -1277,118 +991,6 @@ namespace Cysharp.Threading.Tasks } } - sealed class UnityWebRequestAsyncOperationCallbackHandlerSource : IUniTaskSource, ITaskPoolNode - { - static TaskPool pool; - UnityWebRequestAsyncOperationCallbackHandlerSource nextNode; - public ref UnityWebRequestAsyncOperationCallbackHandlerSource NextNode => ref nextNode; - - static UnityWebRequestAsyncOperationCallbackHandlerSource() - { - TaskPool.RegisterSizeGetter(typeof(UnityWebRequestAsyncOperationCallbackHandlerSource), () => pool.Size); - } - - UnityWebRequestAsyncOperation asyncOperation; - IProgress progress; - CancellationToken cancellationToken; - CancellationTokenRegistration cancellationTokenRegistration; - - UniTaskCompletionSourceCore core; - - UnityWebRequestAsyncOperationCallbackHandlerSource() - { - } - - public static IUniTaskSource Create(UnityWebRequestAsyncOperation asyncOperation, CancellationToken cancellationToken, out short token) - { - if (cancellationToken.IsCancellationRequested) - { - return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token); - } - - if (!pool.TryPop(out var result)) - { - result = new UnityWebRequestAsyncOperationCallbackHandlerSource(); - } - - result.asyncOperation = asyncOperation; - result.cancellationToken = cancellationToken; - - asyncOperation.completed += result.AsyncOperationCompletedHandler; - - if (cancellationToken.CanBeCanceled) - { - result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => - { - var source = (UnityWebRequestAsyncOperationCallbackHandlerSource)state; - source.asyncOperation.webRequest.Abort(); - source.core.TrySetCanceled(source.cancellationToken); - }, result); - } - - TaskTracker.TrackActiveTask(result, 3); - - token = result.core.Version; - return result; - } - - public UnityWebRequest GetResult(short token) - { - try - { - return core.GetResult(token); - } - finally - { - TryReturn(); - } - } - - void IUniTaskSource.GetResult(short token) - { - GetResult(token); - } - - public UniTaskStatus GetStatus(short token) - { - return core.GetStatus(token); - } - - public UniTaskStatus UnsafeGetStatus() - { - return core.UnsafeGetStatus(); - } - - public void OnCompleted(Action continuation, object state, short token) - { - core.OnCompleted(continuation, state, token); - } - - bool TryReturn() - { - TaskTracker.RemoveTracking(this); - core.Reset(); - asyncOperation.completed -= AsyncOperationCompletedHandler; - asyncOperation = default; - progress = default; - cancellationToken = default; - cancellationTokenRegistration.Dispose(); - return pool.TryPush(this); - } - - void AsyncOperationCompletedHandler(AsyncOperation _) - { - if (asyncOperation.webRequest.IsError()) - { - core.TrySetException(new UnityWebRequestException(asyncOperation.webRequest)); - } - else - { - core.TrySetResult(asyncOperation.webRequest); - } - } - } - sealed class UnityWebRequestAsyncOperationConfiguredSource : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode { static TaskPool pool; @@ -1403,14 +1005,18 @@ namespace Cysharp.Threading.Tasks UnityWebRequestAsyncOperation asyncOperation; IProgress progress; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + bool completed; UniTaskCompletionSourceCore core; + Action continuationAction; + UnityWebRequestAsyncOperationConfiguredSource() { } - public static IUniTaskSource Create(UnityWebRequestAsyncOperation asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) + public static IUniTaskSource Create(UnityWebRequestAsyncOperation asyncOperation, PlayerLoopTiming timing, IProgress progress, bool handleImmediately, CancellationToken cancellationToken, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -1425,6 +1031,23 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.progress = progress; result.cancellationToken = cancellationToken; + result.completed = false; + + if (handleImmediately) + { + result.continuationAction = result.Continuation; + asyncOperation.completed += result.continuationAction; + + if (cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (UnityWebRequestAsyncOperationConfiguredSource)state; + source.asyncOperation.webRequest.Abort(); + source.core.TrySetCanceled(source.cancellationToken); + }, result); + } + } TaskTracker.TrackActiveTask(result, 3); @@ -1500,11 +1123,37 @@ namespace Cysharp.Threading.Tasks { TaskTracker.RemoveTracking(this); core.Reset(); + asyncOperation.completed -= continuationAction; asyncOperation = default; progress = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } + + void Continuation(AsyncOperation _) + { + if (completed) + { + TryReturn(); + } + else + { + completed = true; + if (cancellationToken.IsCancellationRequested) + { + core.TrySetCanceled(cancellationToken); + } + else if (asyncOperation.webRequest.IsError()) + { + core.TrySetException(new UnityWebRequestException(asyncOperation.webRequest)); + } + else + { + core.TrySetResult(asyncOperation.webRequest); + } + } + } } #endregion diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.tt b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.tt index f74fa0b..105893d 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.tt +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UnityAsyncExtensions.tt @@ -56,28 +56,10 @@ namespace Cysharp.Threading.Tasks public static <#= ToUniTaskReturnType(t.returnType) #> WithCancellation(this <#= t.typeName #> asyncOperation, bool handleImmediately, CancellationToken cancellationToken) { - if (handleImmediately) - { - Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); - if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled<#= IsVoid(t) ? "" : "<" + t.returnType + ">" #>(cancellationToken); -<# if(IsUnityWebRequest(t)) { #> - if (asyncOperation.isDone) - { - if (asyncOperation.webRequest.IsError()) - { - return UniTask.FromException(new UnityWebRequestException(asyncOperation.webRequest)); - } - return UniTask.FromResult(asyncOperation.webRequest); - } -<# } else { #> - if (asyncOperation.isDone) return <#= IsVoid(t) ? "UniTask.CompletedTask" : $"UniTask.FromResult(asyncOperation.{t.returnField})" #>; -<# } #> - return new <#= ToUniTaskReturnType(t.returnType) #>(<#= t.typeName #>CallbackHandlerSource.Create(asyncOperation, cancellationToken, out var token), token); - } - return ToUniTask(asyncOperation, cancellationToken: cancellationToken); + return ToUniTask(asyncOperation, handleImmediately: handleImmediately, cancellationToken: cancellationToken); } - public static <#= ToUniTaskReturnType(t.returnType) #> ToUniTask(this <#= t.typeName #> asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static <#= ToUniTaskReturnType(t.returnType) #> ToUniTask(this <#= t.typeName #> asyncOperation, IProgress progress = null, PlayerLoopTiming timing = PlayerLoopTiming.Update, bool handleImmediately = false, CancellationToken cancellationToken = default(CancellationToken)) { Error.ThrowArgumentNullException(asyncOperation, nameof(asyncOperation)); if (cancellationToken.IsCancellationRequested) return UniTask.FromCanceled<#= IsVoid(t) ? "" : "<" + t.returnType + ">" #>(cancellationToken); @@ -93,7 +75,7 @@ namespace Cysharp.Threading.Tasks <# } else { #> if (asyncOperation.isDone) return <#= IsVoid(t) ? "UniTask.CompletedTask" : $"UniTask.FromResult(asyncOperation.{t.returnField})" #>; <# } #> - return new <#= ToUniTaskReturnType(t.returnType) #>(<#= t.typeName #>ConfiguredSource.Create(asyncOperation, timing, progress, cancellationToken, out var token), token); + return new <#= ToUniTaskReturnType(t.returnType) #>(<#= t.typeName #>ConfiguredSource.Create(asyncOperation, timing, progress, handleImmediately, cancellationToken, out var token), token); } public struct <#= t.typeName #>Awaiter : ICriticalNotifyCompletion @@ -160,130 +142,6 @@ namespace Cysharp.Threading.Tasks } } - sealed class <#= t.typeName #>CallbackHandlerSource : <#= ToIUniTaskSourceReturnType(t.returnType) #>, ITaskPoolNode<<#= t.typeName #>CallbackHandlerSource> - { - static TaskPool<<#= t.typeName #>CallbackHandlerSource> pool; - <#= t.typeName #>CallbackHandlerSource nextNode; - public ref <#= t.typeName #>CallbackHandlerSource NextNode => ref nextNode; - - static <#= t.typeName #>CallbackHandlerSource() - { - TaskPool.RegisterSizeGetter(typeof(<#= t.typeName #>CallbackHandlerSource), () => pool.Size); - } - - <#= t.typeName #> asyncOperation; - IProgress progress; - CancellationToken cancellationToken; - CancellationTokenRegistration cancellationTokenRegistration; - - UniTaskCompletionSourceCore<<#= IsVoid(t) ? "AsyncUnit" : t.returnType #>> core; - - <#= t.typeName #>CallbackHandlerSource() - { - } - - public static <#= ToIUniTaskSourceReturnType(t.returnType) #> Create(<#= t.typeName #> asyncOperation, CancellationToken cancellationToken, out short token) - { - if (cancellationToken.IsCancellationRequested) - { - return AutoResetUniTaskCompletionSource<#= IsVoid(t) ? "" : $"<{t.returnType}>" #>.CreateFromCanceled(cancellationToken, out token); - } - - if (!pool.TryPop(out var result)) - { - result = new <#= t.typeName #>CallbackHandlerSource(); - } - - result.asyncOperation = asyncOperation; - result.cancellationToken = cancellationToken; - - asyncOperation.completed += result.AsyncOperationCompletedHandler; - - if (cancellationToken.CanBeCanceled) - { - result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => - { - var source = (<#= t.typeName #>CallbackHandlerSource)state; -<# if(IsUnityWebRequest(t)) { #> - source.asyncOperation.webRequest.Abort(); -<# } #> - source.core.TrySetCanceled(source.cancellationToken); - }, result); - } - - TaskTracker.TrackActiveTask(result, 3); - - token = result.core.Version; - return result; - } - - public <#= t.returnType #> GetResult(short token) - { - try - { -<# if (!IsVoid(t)) { #> - return core.GetResult(token); -<# } else { #> - core.GetResult(token); -<# } #> - } - finally - { - TryReturn(); - } - } - -<# if (!IsVoid(t)) { #> - void IUniTaskSource.GetResult(short token) - { - GetResult(token); - } -<# } #> - - public UniTaskStatus GetStatus(short token) - { - return core.GetStatus(token); - } - - public UniTaskStatus UnsafeGetStatus() - { - return core.UnsafeGetStatus(); - } - - public void OnCompleted(Action continuation, object state, short token) - { - core.OnCompleted(continuation, state, token); - } - - bool TryReturn() - { - TaskTracker.RemoveTracking(this); - core.Reset(); - asyncOperation.completed -= AsyncOperationCompletedHandler; - asyncOperation = default; - progress = default; - cancellationToken = default; - cancellationTokenRegistration.Dispose(); - return pool.TryPush(this); - } - - void AsyncOperationCompletedHandler(AsyncOperation _) - { -<# if(IsUnityWebRequest(t)) { #> - if (asyncOperation.webRequest.IsError()) - { - core.TrySetException(new UnityWebRequestException(asyncOperation.webRequest)); - } - else - { - core.TrySetResult(asyncOperation.webRequest); - } -<# } else { #> - core.TrySetResult(<#= IsVoid(t) ? "AsyncUnit.Default" : $"asyncOperation.{t.returnField}" #>); -<# } #> - } - } - sealed class <#= t.typeName #>ConfiguredSource : <#= ToIUniTaskSourceReturnType(t.returnType) #>, IPlayerLoopItem, ITaskPoolNode<<#= t.typeName #>ConfiguredSource> { static TaskPool<<#= t.typeName #>ConfiguredSource> pool; @@ -298,14 +156,18 @@ namespace Cysharp.Threading.Tasks <#= t.typeName #> asyncOperation; IProgress progress; CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + bool completed; UniTaskCompletionSourceCore<<#= IsVoid(t) ? "AsyncUnit" : t.returnType #>> core; + Action continuationAction; + <#= t.typeName #>ConfiguredSource() { } - public static <#= ToIUniTaskSourceReturnType(t.returnType) #> Create(<#= t.typeName #> asyncOperation, PlayerLoopTiming timing, IProgress progress, CancellationToken cancellationToken, out short token) + public static <#= ToIUniTaskSourceReturnType(t.returnType) #> Create(<#= t.typeName #> asyncOperation, PlayerLoopTiming timing, IProgress progress, bool handleImmediately, CancellationToken cancellationToken, out short token) { if (cancellationToken.IsCancellationRequested) { @@ -320,6 +182,25 @@ namespace Cysharp.Threading.Tasks result.asyncOperation = asyncOperation; result.progress = progress; result.cancellationToken = cancellationToken; + result.completed = false; + + if (handleImmediately) + { + result.continuationAction = result.Continuation; + asyncOperation.completed += result.continuationAction; + + if (cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var source = (<#= t.typeName #>ConfiguredSource)state; +<# if(IsUnityWebRequest(t)) { #> + source.asyncOperation.webRequest.Abort(); +<# } #> + source.core.TrySetCanceled(source.cancellationToken); + }, result); + } + } TaskTracker.TrackActiveTask(result, 3); @@ -407,11 +288,44 @@ namespace Cysharp.Threading.Tasks { TaskTracker.RemoveTracking(this); core.Reset(); + asyncOperation.completed -= continuationAction; asyncOperation = default; progress = default; cancellationToken = default; + cancellationTokenRegistration.Dispose(); return pool.TryPush(this); } + + void Continuation(AsyncOperation _) + { + if (completed) + { + TryReturn(); + } + else + { + completed = true; + if (cancellationToken.IsCancellationRequested) + { + core.TrySetCanceled(cancellationToken); + } +<# if(IsUnityWebRequest(t)) { #> + else if (asyncOperation.webRequest.IsError()) + { + core.TrySetException(new UnityWebRequestException(asyncOperation.webRequest)); + } + else + { + core.TrySetResult(asyncOperation.webRequest); + } +<# } else { #> + else + { + core.TrySetResult(<#= IsVoid(t) ? "AsyncUnit.Default" : $"asyncOperation.{t.returnField}" #>); + } +<# } #> + } + } } #endregion