From 1999d94b33e53eac175856401cc75cfac98ca6ec Mon Sep 17 00:00:00 2001 From: neuecc Date: Thu, 30 Jul 2020 08:10:16 +0900 Subject: [PATCH] Add UniTask.WithCancellation --- src/UniTask.NetCore/UniTask.NetCore.csproj | 1 + .../WithCancellationTest.cs | 40 +++++ .../UniTask/Runtime/UniTaskExtensions.cs | 168 ++++++++++++++++++ 3 files changed, 209 insertions(+) create mode 100644 src/UniTask.NetCoreTests/WithCancellationTest.cs diff --git a/src/UniTask.NetCore/UniTask.NetCore.csproj b/src/UniTask.NetCore/UniTask.NetCore.csproj index bc5be56..1bb7900 100644 --- a/src/UniTask.NetCore/UniTask.NetCore.csproj +++ b/src/UniTask.NetCore/UniTask.NetCore.csproj @@ -38,6 +38,7 @@ ..\UniTask\Assets\Plugins\UniTask\Runtime\Internal\ContinuationQueue.cs; ..\UniTask\Assets\Plugins\UniTask\Runtime\Internal\UnityWebRequestExtensions.cs; +..\UniTask\Assets\Plugins\UniTask\Runtime\UniTaskSynchronizationContext.cs; ..\UniTask\Assets\Plugins\UniTask\Runtime\CancellationTokenSourceExtensions.cs; ..\UniTask\Assets\Plugins\UniTask\Runtime\EnumeratorAsyncExtensions.cs; ..\UniTask\Assets\Plugins\UniTask\Runtime\PlayerLoopHelper.cs; diff --git a/src/UniTask.NetCoreTests/WithCancellationTest.cs b/src/UniTask.NetCoreTests/WithCancellationTest.cs new file mode 100644 index 0000000..90236d2 --- /dev/null +++ b/src/UniTask.NetCoreTests/WithCancellationTest.cs @@ -0,0 +1,40 @@ +using Cysharp.Threading.Tasks; +using FluentAssertions; +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace NetCoreTests +{ + public class WithCancellationTest + { + [Fact] + public async Task Standard() + { + CancellationTokenSource cts = new CancellationTokenSource(); + + var v = await UniTask.Run(() => 10).WithCancellation(cts.Token); + + v.Should().Be(10); + } + + [Fact] + public async Task Cancel() + { + CancellationTokenSource cts = new CancellationTokenSource(); + + var t = UniTask.Create(async () => + { + await Task.Delay(TimeSpan.FromSeconds(1)); + return 10; + }).WithCancellation(cts.Token); + + cts.Cancel(); + + (await Assert.ThrowsAsync(async () => await t)).CancellationToken.Should().Be(cts.Token); + } + } +} diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTaskExtensions.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTaskExtensions.cs index 627b231..880cb96 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTaskExtensions.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTaskExtensions.cs @@ -189,6 +189,174 @@ namespace Cysharp.Threading.Tasks return new AsyncLazy(task); } + /// + /// Ignore task result when cancel raised first. + /// + public static UniTask WithCancellation(this UniTask task, CancellationToken cancellationToken) + { + if (!cancellationToken.CanBeCanceled) + { + return task; + } + + if (cancellationToken.IsCancellationRequested) + { + return UniTask.FromCanceled(cancellationToken); + } + + if (task.Status.IsCompleted()) + { + return task; + } + + return new UniTask(new WithCancellationSource(task, cancellationToken), 0); + } + + /// + /// Ignore task result when cancel raised first. + /// + public static UniTask WithCancellation(this UniTask task, CancellationToken cancellationToken) + { + if (!cancellationToken.CanBeCanceled) + { + return task; + } + + if (cancellationToken.IsCancellationRequested) + { + return UniTask.FromCanceled(cancellationToken); + } + + if (task.Status.IsCompleted()) + { + return task; + } + + return new UniTask(new WithCancellationSource(task, cancellationToken), 0); + } + + sealed class WithCancellationSource : IUniTaskSource + { + static readonly Action cancellationCallbackDelegate = CancellationCallback; + + CancellationToken cancellationToken; + CancellationTokenRegistration tokenRegistration; + UniTaskCompletionSourceCore core; + + public WithCancellationSource(UniTask task, CancellationToken cancellationToken) + { + this.cancellationToken = cancellationToken; + this.tokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(cancellationCallbackDelegate, this); + RunTask(task).Forget(); + } + + async UniTaskVoid RunTask(UniTask task) + { + try + { + await task; + core.TrySetResult(AsyncUnit.Default); + } + catch (Exception ex) + { + core.TrySetException(ex); + } + finally + { + tokenRegistration.Dispose(); + } + } + + static void CancellationCallback(object state) + { + var self = (WithCancellationSource)state; + self.core.TrySetCanceled(self.cancellationToken); + } + + public void GetResult(short token) + { + core.GetResult(token); + } + + public UniTaskStatus GetStatus(short token) + { + return core.GetStatus(token); + } + + public void OnCompleted(Action continuation, object state, short token) + { + core.OnCompleted(continuation, state, token); + } + + public UniTaskStatus UnsafeGetStatus() + { + return core.UnsafeGetStatus(); + } + } + + sealed class WithCancellationSource : IUniTaskSource + { + static readonly Action cancellationCallbackDelegate = CancellationCallback; + + CancellationToken cancellationToken; + CancellationTokenRegistration tokenRegistration; + UniTaskCompletionSourceCore core; + + public WithCancellationSource(UniTask task, CancellationToken cancellationToken) + { + this.cancellationToken = cancellationToken; + this.tokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(cancellationCallbackDelegate, this); + RunTask(task).Forget(); + } + + async UniTaskVoid RunTask(UniTask task) + { + try + { + core.TrySetResult(await task); + } + catch (Exception ex) + { + core.TrySetException(ex); + } + finally + { + tokenRegistration.Dispose(); + } + } + + static void CancellationCallback(object state) + { + var self = (WithCancellationSource)state; + self.core.TrySetCanceled(self.cancellationToken); + } + + void IUniTaskSource.GetResult(short token) + { + core.GetResult(token); + } + + public T GetResult(short token) + { + return core.GetResult(token); + } + + public UniTaskStatus GetStatus(short token) + { + return core.GetStatus(token); + } + + public void OnCompleted(Action continuation, object state, short token) + { + core.OnCompleted(continuation, state, token); + } + + public UniTaskStatus UnsafeGetStatus() + { + return core.UnsafeGetStatus(); + } + } + #if UNITY_2018_3_OR_NEWER public static IEnumerator ToCoroutine(this UniTask task, Action resultHandler = null, Action exceptionHandler = null)