From d36e7987b342252b49af6873f794804f60e6ad2d Mon Sep 17 00:00:00 2001 From: neuecc Date: Tue, 19 May 2020 02:41:45 +0900 Subject: [PATCH] Add SkipUntilCanceled, TakeUntilCanceled --- .../Linq/TakeInfinityTest.cs | 62 +++++++ .../UniTask/Runtime/Linq/SkipUntilCanceled.cs | 142 +++++++++++++++ .../Runtime/Linq/SkipUntilCanceled.cs.meta | 11 ++ .../UniTask/Runtime/Linq/TakeUntilCanceled.cs | 162 ++++++++++++++++++ .../Runtime/Linq/TakeUntilCanceled.cs.meta | 11 ++ 5 files changed, 388 insertions(+) create mode 100644 src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/SkipUntilCanceled.cs create mode 100644 src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/SkipUntilCanceled.cs.meta create mode 100644 src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/TakeUntilCanceled.cs create mode 100644 src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/TakeUntilCanceled.cs.meta diff --git a/src/UniTask.NetCoreTests/Linq/TakeInfinityTest.cs b/src/UniTask.NetCoreTests/Linq/TakeInfinityTest.cs index 1c1ea48..1c2ab66 100644 --- a/src/UniTask.NetCoreTests/Linq/TakeInfinityTest.cs +++ b/src/UniTask.NetCoreTests/Linq/TakeInfinityTest.cs @@ -4,6 +4,7 @@ using FluentAssertions; using System; using System.Collections.Generic; using System.Text; +using System.Threading; using System.Threading.Tasks; using Xunit; @@ -40,5 +41,66 @@ namespace NetCoreTests.Linq (await xs).Should().BeEquivalentTo(1, 2, 3, 4); } + + [Fact] + public async Task TakeUntil() + { + var cts = new CancellationTokenSource(); + + var rp = new AsyncReactiveProperty(1); + + var xs = rp.TakeUntilCanceled(cts.Token).ToArrayAsync(); + + var c = CancelAsync(); + + await c; + var foo = await xs; + + foo.Should().BeEquivalentTo(new[] { 1, 10, 20 }); + + async Task CancelAsync() + { + rp.Value = 10; + await Task.Yield(); + rp.Value = 20; + await Task.Yield(); + cts.Cancel(); + rp.Value = 30; + await Task.Yield(); + rp.Value = 40; + } + } + + [Fact] + public async Task SkipUntil() + { + var cts = new CancellationTokenSource(); + + var rp = new AsyncReactiveProperty(1); + + var xs = rp.SkipUntilCanceled(cts.Token).ToArrayAsync(); + + var c = CancelAsync(); + + await c; + var foo = await xs; + + foo.Should().BeEquivalentTo(new[] { 30, 40 }); + + async Task CancelAsync() + { + rp.Value = 10; + await Task.Yield(); + rp.Value = 20; + await Task.Yield(); + cts.Cancel(); + rp.Value = 30; + await Task.Yield(); + rp.Value = 40; + + rp.Dispose(); // complete. + } + } + } } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/SkipUntilCanceled.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/SkipUntilCanceled.cs new file mode 100644 index 0000000..2c7f653 --- /dev/null +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/SkipUntilCanceled.cs @@ -0,0 +1,142 @@ +using Cysharp.Threading.Tasks.Internal; +using System; +using System.Threading; + +namespace Cysharp.Threading.Tasks.Linq +{ + public static partial class UniTaskAsyncEnumerable + { + public static IUniTaskAsyncEnumerable SkipUntilCanceled(this IUniTaskAsyncEnumerable source, CancellationToken cancellationToken) + { + Error.ThrowArgumentNullException(source, nameof(source)); + + return new SkipUntilCanceled(source, cancellationToken); + } + } + + internal sealed class SkipUntilCanceled : IUniTaskAsyncEnumerable + { + readonly IUniTaskAsyncEnumerable source; + readonly CancellationToken cancellationToken; + + public SkipUntilCanceled(IUniTaskAsyncEnumerable source, CancellationToken cancellationToken) + { + this.source = source; + this.cancellationToken = cancellationToken; + } + + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new _SkipUntilCanceled(source, this.cancellationToken, cancellationToken); + } + + sealed class _SkipUntilCanceled : MoveNextSource, IUniTaskAsyncEnumerator + { + static readonly Action MoveNextCoreDelegate = MoveNextCore; + + readonly IUniTaskAsyncEnumerable source; + CancellationToken cancellationToken1; + CancellationToken cancellationToken2; + + bool isCanceled; + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; + bool continueNext; + + public _SkipUntilCanceled(IUniTaskAsyncEnumerable source, CancellationToken cancellationToken1, CancellationToken cancellationToken2) + { + this.source = source; + this.cancellationToken1 = cancellationToken1; + this.cancellationToken2 = cancellationToken2; + } + + public TSource Current { get; private set; } + + public UniTask MoveNextAsync() + { + if (cancellationToken1.IsCancellationRequested) isCanceled = true; + if (cancellationToken2.IsCancellationRequested) isCanceled = true; + + if (enumerator == null) + { + enumerator = source.GetAsyncEnumerator(cancellationToken2); // use only AsyncEnumerator provided token. + } + completionSource.Reset(); + SourceMoveNext(); + return new UniTask(this, completionSource.Version); + } + + void SourceMoveNext() + { + try + { + LOOP: + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + continueNext = true; + MoveNextCore(this); + if (continueNext) + { + continueNext = false; + goto LOOP; + } + } + else + { + awaiter.SourceOnCompleted(MoveNextCoreDelegate, this); + } + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + } + } + + static void MoveNextCore(object state) + { + var self = (_SkipUntilCanceled)state; + + if (self.TryGetResult(self.awaiter, out var result)) + { + if (result) + { + AGAIN: + + if (self.isCanceled) + { + self.continueNext = false; + self.Current = self.enumerator.Current; + self.completionSource.TrySetResult(true); + } + else + { + if (self.cancellationToken1.IsCancellationRequested) self.isCanceled = true; + if (self.cancellationToken2.IsCancellationRequested) self.isCanceled = true; + + if (self.isCanceled) goto AGAIN; + + if (!self.continueNext) + { + self.SourceMoveNext(); + } + } + } + else + { + self.completionSource.TrySetResult(false); + } + } + } + + public UniTask DisposeAsync() + { + if (enumerator != null) + { + return enumerator.DisposeAsync(); + } + return default; + } + } + } +} \ No newline at end of file diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/SkipUntilCanceled.cs.meta b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/SkipUntilCanceled.cs.meta new file mode 100644 index 0000000..9f67181 --- /dev/null +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/SkipUntilCanceled.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 4b1a778aef7150d47b93a49aa1bc34ae +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/TakeUntilCanceled.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/TakeUntilCanceled.cs new file mode 100644 index 0000000..8604bf7 --- /dev/null +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/TakeUntilCanceled.cs @@ -0,0 +1,162 @@ +using Cysharp.Threading.Tasks.Internal; +using System; +using System.Threading; + +namespace Cysharp.Threading.Tasks.Linq +{ + public static partial class UniTaskAsyncEnumerable + { + public static IUniTaskAsyncEnumerable TakeUntilCanceled(this IUniTaskAsyncEnumerable source, CancellationToken cancellationToken) + { + Error.ThrowArgumentNullException(source, nameof(source)); + + return new TakeUntilCanceled(source, cancellationToken); + } + } + + internal sealed class TakeUntilCanceled : IUniTaskAsyncEnumerable + { + readonly IUniTaskAsyncEnumerable source; + readonly CancellationToken cancellationToken; + + public TakeUntilCanceled(IUniTaskAsyncEnumerable source, CancellationToken cancellationToken) + { + this.source = source; + this.cancellationToken = cancellationToken; + } + + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new _TakeUntilCanceled(source, this.cancellationToken, cancellationToken); + } + + sealed class _TakeUntilCanceled : MoveNextSource, IUniTaskAsyncEnumerator + { + static readonly Action CancelDelegate1 = OnCanceled1; + static readonly Action CancelDelegate2 = OnCanceled2; + static readonly Action MoveNextCoreDelegate = MoveNextCore; + + readonly IUniTaskAsyncEnumerable source; + CancellationToken cancellationToken1; + CancellationToken cancellationToken2; + CancellationTokenRegistration cancellationTokenRegistration1; + CancellationTokenRegistration cancellationTokenRegistration2; + + bool isCanceled; + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; + + public _TakeUntilCanceled(IUniTaskAsyncEnumerable source, CancellationToken cancellationToken1, CancellationToken cancellationToken2) + { + this.source = source; + this.cancellationToken1 = cancellationToken1; + this.cancellationToken2 = cancellationToken2; + + if (cancellationToken1.CanBeCanceled) + { + this.cancellationTokenRegistration1 = cancellationToken1.RegisterWithoutCaptureExecutionContext(CancelDelegate1, this); + } + + if (cancellationToken1 != cancellationToken2 && cancellationToken2.CanBeCanceled) + { + this.cancellationTokenRegistration2 = cancellationToken2.RegisterWithoutCaptureExecutionContext(CancelDelegate2, this); + } + } + + public TSource Current { get; private set; } + + public UniTask MoveNextAsync() + { + if (cancellationToken1.IsCancellationRequested) isCanceled = true; + if (cancellationToken2.IsCancellationRequested) isCanceled = true; + + if (enumerator == null) + { + enumerator = source.GetAsyncEnumerator(cancellationToken2); // use only AsyncEnumerator provided token. + } + + if (isCanceled) return CompletedTasks.False; + + completionSource.Reset(); + SourceMoveNext(); + return new UniTask(this, completionSource.Version); + } + + void SourceMoveNext() + { + try + { + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + MoveNextCore(this); + } + else + { + awaiter.SourceOnCompleted(MoveNextCoreDelegate, this); + } + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + } + } + + static void MoveNextCore(object state) + { + var self = (_TakeUntilCanceled)state; + + if (self.TryGetResult(self.awaiter, out var result)) + { + if (result) + { + if (self.isCanceled) + { + self.completionSource.TrySetResult(false); + } + else + { + self.Current = self.enumerator.Current; + self.completionSource.TrySetResult(true); + } + } + else + { + self.completionSource.TrySetResult(false); + } + } + } + + static void OnCanceled1(object state) + { + var self = (_TakeUntilCanceled)state; + if (!self.isCanceled) + { + self.cancellationTokenRegistration2.Dispose(); + self.completionSource.TrySetResult(false); + } + } + + static void OnCanceled2(object state) + { + var self = (_TakeUntilCanceled)state; + if (!self.isCanceled) + { + self.cancellationTokenRegistration1.Dispose(); + self.completionSource.TrySetResult(false); + } + } + + public UniTask DisposeAsync() + { + cancellationTokenRegistration1.Dispose(); + cancellationTokenRegistration2.Dispose(); + if (enumerator != null) + { + return enumerator.DisposeAsync(); + } + return default; + } + } + } +} \ No newline at end of file diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/TakeUntilCanceled.cs.meta b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/TakeUntilCanceled.cs.meta new file mode 100644 index 0000000..4a89be5 --- /dev/null +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/TakeUntilCanceled.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: e82f498cf3a1df04cbf646773fc11319 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: