From d9e20de8a59a75a0d16e84ba769458f13f778507 Mon Sep 17 00:00:00 2001 From: neuecc Date: Mon, 29 Jun 2020 01:10:18 +0900 Subject: [PATCH] Add UniTaskAsyncEnumerable.SkipUntil, TakeUntil. Fix SkipUntilCanceled behaviour. --- .../Linq/TakeInfinityTest.cs | 65 +++++- .../Plugins/UniTask/Runtime/Linq/SkipUntil.cs | 187 +++++++++++++++++ .../UniTask/Runtime/Linq/SkipUntilCanceled.cs | 75 ++++--- .../Plugins/UniTask/Runtime/Linq/TakeUntil.cs | 190 ++++++++++++++++++ src/UniTask/Assets/Scenes/SandboxMain.cs | 8 +- 5 files changed, 495 insertions(+), 30 deletions(-) create mode 100644 src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/SkipUntil.cs create mode 100644 src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/TakeUntil.cs diff --git a/src/UniTask.NetCoreTests/Linq/TakeInfinityTest.cs b/src/UniTask.NetCoreTests/Linq/TakeInfinityTest.cs index 1c2ab66..dd92445 100644 --- a/src/UniTask.NetCoreTests/Linq/TakeInfinityTest.cs +++ b/src/UniTask.NetCoreTests/Linq/TakeInfinityTest.cs @@ -43,7 +43,7 @@ namespace NetCoreTests.Linq } [Fact] - public async Task TakeUntil() + public async Task TakeUntilCanceled() { var cts = new CancellationTokenSource(); @@ -72,7 +72,7 @@ namespace NetCoreTests.Linq } [Fact] - public async Task SkipUntil() + public async Task SkipUntilCanceled() { var cts = new CancellationTokenSource(); @@ -85,7 +85,7 @@ namespace NetCoreTests.Linq await c; var foo = await xs; - foo.Should().BeEquivalentTo(new[] { 30, 40 }); + foo.Should().BeEquivalentTo(new[] { 20, 30, 40 }); async Task CancelAsync() { @@ -102,5 +102,64 @@ namespace NetCoreTests.Linq } } + [Fact] + public async Task TakeUntil() + { + var cts = new AsyncReactiveProperty(0); + + var rp = new AsyncReactiveProperty(1); + + var xs = rp.TakeUntil(cts.WaitAsync()).ToArrayAsync(); + + var c = CancelAsync(); + + await c; + var foo = await xs; + + foo.Should().BeEquivalentTo(new[] { 1, 10, 20 }); + + async Task CancelAsync() + { + rp.Value = 10; + await Task.Yield(); + rp.Value = 20; + await Task.Yield(); + cts.Value = 9999; + rp.Value = 30; + await Task.Yield(); + rp.Value = 40; + } + } + + [Fact] + public async Task SkipUntil() + { + var cts = new AsyncReactiveProperty(0); + + var rp = new AsyncReactiveProperty(1); + + var xs = rp.SkipUntil(cts.WaitAsync()).ToArrayAsync(); + + var c = CancelAsync(); + + await c; + var foo = await xs; + + foo.Should().BeEquivalentTo(new[] { 20, 30, 40 }); + + async Task CancelAsync() + { + rp.Value = 10; + await Task.Yield(); + rp.Value = 20; + await Task.Yield(); + cts.Value = 9999; + rp.Value = 30; + await Task.Yield(); + rp.Value = 40; + + rp.Dispose(); // complete. + } + } } } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/SkipUntil.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/SkipUntil.cs new file mode 100644 index 0000000..5a707bb --- /dev/null +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/SkipUntil.cs @@ -0,0 +1,187 @@ +using Cysharp.Threading.Tasks.Internal; +using System; +using System.Threading; + +namespace Cysharp.Threading.Tasks.Linq +{ + public static partial class UniTaskAsyncEnumerable + { + public static IUniTaskAsyncEnumerable SkipUntil(this IUniTaskAsyncEnumerable source, UniTask other) + { + Error.ThrowArgumentNullException(source, nameof(source)); + + return new SkipUntil(source, other, null); + } + + public static IUniTaskAsyncEnumerable SkipUntil(this IUniTaskAsyncEnumerable source, Func other) + { + Error.ThrowArgumentNullException(source, nameof(source)); + Error.ThrowArgumentNullException(source, nameof(other)); + + return new SkipUntil(source, default, other); + } + } + + internal sealed class SkipUntil : IUniTaskAsyncEnumerable + { + readonly IUniTaskAsyncEnumerable source; + readonly UniTask other; + readonly Func other2; + + public SkipUntil(IUniTaskAsyncEnumerable source, UniTask other, Func other2) + { + this.source = source; + this.other = other; + this.other2 = other2; + } + + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + if (other2 != null) + { + return new _SkipUntil(source, this.other2(cancellationToken), cancellationToken); + } + else + { + return new _SkipUntil(source, this.other, cancellationToken); + } + } + + sealed class _SkipUntil : MoveNextSource, IUniTaskAsyncEnumerator + { + static readonly Action CancelDelegate1 = OnCanceled1; + static readonly Action MoveNextCoreDelegate = MoveNextCore; + + readonly IUniTaskAsyncEnumerable source; + CancellationToken cancellationToken1; + + bool completed; + CancellationTokenRegistration cancellationTokenRegistration1; + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; + bool continueNext; + Exception exception; + + public _SkipUntil(IUniTaskAsyncEnumerable source, UniTask other, CancellationToken cancellationToken1) + { + this.source = source; + this.cancellationToken1 = cancellationToken1; + if (cancellationToken1.CanBeCanceled) + { + this.cancellationTokenRegistration1 = cancellationToken1.RegisterWithoutCaptureExecutionContext(CancelDelegate1, this); + } + + TaskTracker.TrackActiveTask(this, 3); + RunOther(other).Forget(); + } + + public TSource Current { get; private set; } + + public UniTask MoveNextAsync() + { + if (exception != null) + { + return UniTask.FromException(exception); + } + + if (cancellationToken1.IsCancellationRequested) + { + return UniTask.FromCanceled(cancellationToken1); + } + + if (enumerator == null) + { + enumerator = source.GetAsyncEnumerator(cancellationToken1); + } + completionSource.Reset(); + + if (completed) + { + SourceMoveNext(); + } + return new UniTask(this, completionSource.Version); + } + + void SourceMoveNext() + { + try + { + LOOP: + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + continueNext = true; + MoveNextCore(this); + if (continueNext) + { + continueNext = false; + goto LOOP; + } + } + else + { + awaiter.SourceOnCompleted(MoveNextCoreDelegate, this); + } + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + } + } + + static void MoveNextCore(object state) + { + var self = (_SkipUntil)state; + + if (self.TryGetResult(self.awaiter, out var result)) + { + if (result) + { + self.Current = self.enumerator.Current; + self.completionSource.TrySetResult(true); + if (self.continueNext) + { + self.SourceMoveNext(); + } + } + else + { + self.completionSource.TrySetResult(false); + } + } + } + + async UniTaskVoid RunOther(UniTask other) + { + try + { + await other; + completed = true; + SourceMoveNext(); + } + catch (Exception ex) + { + exception = ex; + completionSource.TrySetException(ex); + } + } + + static void OnCanceled1(object state) + { + var self = (_SkipUntil)state; + self.completionSource.TrySetCanceled(self.cancellationToken1); + } + + public UniTask DisposeAsync() + { + TaskTracker.RemoveTracking(this); + cancellationTokenRegistration1.Dispose(); + if (enumerator != null) + { + return enumerator.DisposeAsync(); + } + return default; + } + } + } +} \ No newline at end of file diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/SkipUntilCanceled.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/SkipUntilCanceled.cs index f8c2b30..f4c9679 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/SkipUntilCanceled.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/SkipUntilCanceled.cs @@ -32,13 +32,17 @@ namespace Cysharp.Threading.Tasks.Linq sealed class _SkipUntilCanceled : MoveNextSource, IUniTaskAsyncEnumerator { + static readonly Action CancelDelegate1 = OnCanceled1; + static readonly Action CancelDelegate2 = OnCanceled2; static readonly Action MoveNextCoreDelegate = MoveNextCore; readonly IUniTaskAsyncEnumerable source; CancellationToken cancellationToken1; CancellationToken cancellationToken2; + CancellationTokenRegistration cancellationTokenRegistration1; + CancellationTokenRegistration cancellationTokenRegistration2; - bool isCanceled; + int isCanceled; IUniTaskAsyncEnumerator enumerator; UniTask.Awaiter awaiter; bool continueNext; @@ -48,6 +52,14 @@ namespace Cysharp.Threading.Tasks.Linq this.source = source; this.cancellationToken1 = cancellationToken1; this.cancellationToken2 = cancellationToken2; + if (cancellationToken1.CanBeCanceled) + { + this.cancellationTokenRegistration1 = cancellationToken1.RegisterWithoutCaptureExecutionContext(CancelDelegate1, this); + } + if (cancellationToken1 != cancellationToken2 && cancellationToken2.CanBeCanceled) + { + this.cancellationTokenRegistration2 = cancellationToken2.RegisterWithoutCaptureExecutionContext(CancelDelegate2, this); + } TaskTracker.TrackActiveTask(this, 3); } @@ -55,15 +67,18 @@ namespace Cysharp.Threading.Tasks.Linq public UniTask MoveNextAsync() { - if (cancellationToken1.IsCancellationRequested) isCanceled = true; - if (cancellationToken2.IsCancellationRequested) isCanceled = true; - if (enumerator == null) { + if (cancellationToken1.IsCancellationRequested) isCanceled = 1; + if (cancellationToken2.IsCancellationRequested) isCanceled = 1; enumerator = source.GetAsyncEnumerator(cancellationToken2); // use only AsyncEnumerator provided token. } completionSource.Reset(); - SourceMoveNext(); + + if (isCanceled != 0) + { + SourceMoveNext(); + } return new UniTask(this, completionSource.Version); } @@ -102,25 +117,11 @@ namespace Cysharp.Threading.Tasks.Linq { if (result) { - AGAIN: - - if (self.isCanceled) + self.Current = self.enumerator.Current; + self.completionSource.TrySetResult(true); + if (self.continueNext) { - self.continueNext = false; - self.Current = self.enumerator.Current; - self.completionSource.TrySetResult(true); - } - else - { - if (self.cancellationToken1.IsCancellationRequested) self.isCanceled = true; - if (self.cancellationToken2.IsCancellationRequested) self.isCanceled = true; - - if (self.isCanceled) goto AGAIN; - - if (!self.continueNext) - { - self.SourceMoveNext(); - } + self.SourceMoveNext(); } } else @@ -130,9 +131,37 @@ namespace Cysharp.Threading.Tasks.Linq } } + static void OnCanceled1(object state) + { + var self = (_SkipUntilCanceled)state; + if (self.isCanceled == 0) + { + if (Interlocked.Increment(ref self.isCanceled) == 1) + { + self.cancellationTokenRegistration2.Dispose(); + self.SourceMoveNext(); + } + } + } + + static void OnCanceled2(object state) + { + var self = (_SkipUntilCanceled)state; + if (self.isCanceled == 0) + { + if (Interlocked.Increment(ref self.isCanceled) == 1) + { + self.cancellationTokenRegistration2.Dispose(); + self.SourceMoveNext(); + } + } + } + public UniTask DisposeAsync() { TaskTracker.RemoveTracking(this); + cancellationTokenRegistration1.Dispose(); + cancellationTokenRegistration2.Dispose(); if (enumerator != null) { return enumerator.DisposeAsync(); diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/TakeUntil.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/TakeUntil.cs new file mode 100644 index 0000000..25371ad --- /dev/null +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/TakeUntil.cs @@ -0,0 +1,190 @@ +using Cysharp.Threading.Tasks.Internal; +using System; +using System.Threading; + +namespace Cysharp.Threading.Tasks.Linq +{ + public static partial class UniTaskAsyncEnumerable + { + public static IUniTaskAsyncEnumerable TakeUntil(this IUniTaskAsyncEnumerable source, UniTask other) + { + Error.ThrowArgumentNullException(source, nameof(source)); + + return new TakeUntil(source, other, null); + } + + public static IUniTaskAsyncEnumerable TakeUntil(this IUniTaskAsyncEnumerable source, Func other) + { + Error.ThrowArgumentNullException(source, nameof(source)); + Error.ThrowArgumentNullException(source, nameof(other)); + + return new TakeUntil(source, default, other); + } + } + + internal sealed class TakeUntil : IUniTaskAsyncEnumerable + { + readonly IUniTaskAsyncEnumerable source; + readonly UniTask other; + readonly Func other2; + + public TakeUntil(IUniTaskAsyncEnumerable source, UniTask other, Func other2) + { + this.source = source; + this.other = other; + this.other2 = other2; + } + + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + if (other2 != null) + { + return new _TakeUntil(source, this.other2(cancellationToken), cancellationToken); + } + else + { + return new _TakeUntil(source, this.other, cancellationToken); + } + } + + sealed class _TakeUntil : MoveNextSource, IUniTaskAsyncEnumerator + { + static readonly Action CancelDelegate1 = OnCanceled1; + static readonly Action MoveNextCoreDelegate = MoveNextCore; + + readonly IUniTaskAsyncEnumerable source; + CancellationToken cancellationToken1; + CancellationTokenRegistration cancellationTokenRegistration1; + + bool completed; + Exception exception; + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; + + public _TakeUntil(IUniTaskAsyncEnumerable source, UniTask other, CancellationToken cancellationToken1) + { + this.source = source; + this.cancellationToken1 = cancellationToken1; + + if (cancellationToken1.CanBeCanceled) + { + this.cancellationTokenRegistration1 = cancellationToken1.RegisterWithoutCaptureExecutionContext(CancelDelegate1, this); + } + + TaskTracker.TrackActiveTask(this, 3); + + RunOther(other).Forget(); + } + + public TSource Current { get; private set; } + + public UniTask MoveNextAsync() + { + if (completed) + { + return CompletedTasks.False; + } + + if (exception != null) + { + return UniTask.FromException(exception); + } + + if (cancellationToken1.IsCancellationRequested) + { + return UniTask.FromCanceled(cancellationToken1); + } + + if (enumerator == null) + { + enumerator = source.GetAsyncEnumerator(cancellationToken1); + } + + completionSource.Reset(); + SourceMoveNext(); + return new UniTask(this, completionSource.Version); + } + + void SourceMoveNext() + { + try + { + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + MoveNextCore(this); + } + else + { + awaiter.SourceOnCompleted(MoveNextCoreDelegate, this); + } + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + } + } + + static void MoveNextCore(object state) + { + var self = (_TakeUntil)state; + + if (self.TryGetResult(self.awaiter, out var result)) + { + if (result) + { + if (self.exception != null) + { + self.completionSource.TrySetException(self.exception); + } + else if (self.cancellationToken1.IsCancellationRequested) + { + self.completionSource.TrySetCanceled(self.cancellationToken1); + } + else + { + self.Current = self.enumerator.Current; + self.completionSource.TrySetResult(true); + } + } + else + { + self.completionSource.TrySetResult(false); + } + } + } + + async UniTaskVoid RunOther(UniTask other) + { + try + { + await other; + completed = true; + completionSource.TrySetResult(false); + } + catch (Exception ex) + { + exception = ex; + completionSource.TrySetException(ex); + } + } + + static void OnCanceled1(object state) + { + var self = (_TakeUntil)state; + self.completionSource.TrySetCanceled(self.cancellationToken1); + } + + public UniTask DisposeAsync() + { + TaskTracker.RemoveTracking(this); + cancellationTokenRegistration1.Dispose(); + if (enumerator != null) + { + return enumerator.DisposeAsync(); + } + return default; + } + } + } +} \ No newline at end of file diff --git a/src/UniTask/Assets/Scenes/SandboxMain.cs b/src/UniTask/Assets/Scenes/SandboxMain.cs index 8a23443..beb16ab 100644 --- a/src/UniTask/Assets/Scenes/SandboxMain.cs +++ b/src/UniTask/Assets/Scenes/SandboxMain.cs @@ -119,7 +119,7 @@ public class AsyncMessageBroker : IDisposable public class SandboxMain : MonoBehaviour { - public Camera camera; + public Camera mycamera; public Button okButton; public Button cancelButton; @@ -998,11 +998,11 @@ public class SandboxMain : MonoBehaviour var height = 240; var depth = 24; - camera.targetTexture = new RenderTexture(width, height, depth, RenderTextureFormat.ARGB32, RenderTextureReadWrite.Default) + mycamera.targetTexture = new RenderTexture(width, height, depth, RenderTextureFormat.ARGB32, RenderTextureReadWrite.Default) { antiAliasing = 8 }; - camera.enabled = true; + mycamera.enabled = true; //myRenderTexture = new RenderTexture(width, height, depth, RenderTextureFormat.ARGB32, RenderTextureReadWrite.Default) //{ @@ -1014,7 +1014,7 @@ public class SandboxMain : MonoBehaviour async UniTask ShootAsync() { - var rt = camera.targetTexture; + var rt = mycamera.targetTexture;