diff --git a/src/UniTask.NetCore/Linq/SkipLast.cs b/src/UniTask.NetCore/Linq/SkipLast.cs index 5f79cf2..665e15d 100644 --- a/src/UniTask.NetCore/Linq/SkipLast.cs +++ b/src/UniTask.NetCore/Linq/SkipLast.cs @@ -1,775 +1,153 @@ -namespace Cysharp.Threading.Tasks.Linq +using Cysharp.Threading.Tasks.Internal; +using System; +using System.Collections.Generic; +using System.Threading; + +namespace Cysharp.Threading.Tasks.Linq { - internal sealed class SkipLast + public static partial class UniTaskAsyncEnumerable { + public static IUniTaskAsyncEnumerable SkipLast(this IUniTaskAsyncEnumerable source, Int32 count) + { + Error.ThrowArgumentNullException(source, nameof(source)); + + // non skip. + if (count <= 0) + { + return source; + } + + return new SkipLast(source, count); + } } - -} - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + internal sealed class SkipLast : IUniTaskAsyncEnumerable + { + readonly IUniTaskAsyncEnumerable source; + readonly int count; + + public SkipLast(IUniTaskAsyncEnumerable source, int count) + { + this.source = source; + this.count = count; + } + + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new Enumerator(source, count, cancellationToken); + } + + sealed class Enumerator : MoveNextSource, IUniTaskAsyncEnumerator + { + static readonly Action MoveNextCoreDelegate = MoveNextCore; + + readonly IUniTaskAsyncEnumerable source; + readonly int count; + CancellationToken cancellationToken; + + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; + Queue queue; + + bool continueNext; + + public Enumerator(IUniTaskAsyncEnumerable source, int count, CancellationToken cancellationToken) + { + this.source = source; + this.count = count; + this.cancellationToken = cancellationToken; + } + + public TSource Current { get; private set; } + + public UniTask MoveNextAsync() + { + cancellationToken.ThrowIfCancellationRequested(); + + if (enumerator == null) + { + enumerator = source.GetAsyncEnumerator(cancellationToken); + queue = new Queue(); + } + + completionSource.Reset(); + SourceMoveNext(); + return new UniTask(this, completionSource.Version); + } + + void SourceMoveNext() + { + try + { + + LOOP: + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + continueNext = true; + MoveNextCore(this); + if (continueNext) + { + continueNext = false; + goto LOOP; // avoid recursive + } + } + else + { + awaiter.SourceOnCompleted(MoveNextCoreDelegate, this); + } + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + } + } + + + static void MoveNextCore(object state) + { + var self = (Enumerator)state; + + if (self.TryGetResult(self.awaiter, out var result)) + { + if (result) + { + if (self.queue.Count == self.count) + { + self.continueNext = false; + + var deq = self.queue.Dequeue(); + self.Current = deq; + self.queue.Enqueue(self.enumerator.Current); + + self.completionSource.TrySetResult(true); + } + else + { + self.queue.Enqueue(self.enumerator.Current); + + if (!self.continueNext) + { + self.SourceMoveNext(); + } + } + } + else + { + self.continueNext = false; + self.completionSource.TrySetResult(false); + } + } + } + + public UniTask DisposeAsync() + { + if (enumerator != null) + { + return enumerator.DisposeAsync(); + } + return default; + } + } + } +} \ No newline at end of file diff --git a/src/UniTask.NetCore/Linq/_FileMaker.cs b/src/UniTask.NetCore/Linq/_FileMaker.cs index 11e8d7a..7fca4f5 100644 --- a/src/UniTask.NetCore/Linq/_FileMaker.cs +++ b/src/UniTask.NetCore/Linq/_FileMaker.cs @@ -307,12 +307,6 @@ namespace ___Dummy - - public static IUniTaskAsyncEnumerable SkipLast(this IUniTaskAsyncEnumerable source, Int32 count) - { - throw new NotImplementedException(); - } - public static IUniTaskAsyncEnumerable TakeLast(this IUniTaskAsyncEnumerable source, Int32 count) diff --git a/src/UniTask.NetCoreSandbox/Program.cs b/src/UniTask.NetCoreSandbox/Program.cs index 024c8da..dc26574 100644 --- a/src/UniTask.NetCoreSandbox/Program.cs +++ b/src/UniTask.NetCoreSandbox/Program.cs @@ -39,9 +39,11 @@ namespace NetCoreSandbox static async Task Main(string[] args) { - await foreach (var item in UniTaskAsyncEnumerable.Range(1, 10).Do(x => Console.WriteLine("DO:" + x)) - //.TakeWhileAwait(x => UniTask.FromResult(x < 5)) - .Take(5) + await foreach (var item in UniTaskAsyncEnumerable.Range(1, 10) + .SelectAwait(x => UniTask.Run(() => x)) + .SkipLast(6) + + ) { @@ -53,6 +55,8 @@ namespace NetCoreSandbox + + } diff --git a/src/UniTask.NetCoreTests/Linq/Paging.cs b/src/UniTask.NetCoreTests/Linq/Paging.cs index d0e2b55..aef0600 100644 --- a/src/UniTask.NetCoreTests/Linq/Paging.cs +++ b/src/UniTask.NetCoreTests/Linq/Paging.cs @@ -39,6 +39,31 @@ namespace NetCoreTests.Linq await Assert.ThrowsAsync(async () => await xs); } } + [Theory] + [InlineData(0, 0)] + [InlineData(0, 1)] + [InlineData(9, 0)] + [InlineData(9, 1)] + [InlineData(9, 5)] + [InlineData(9, 9)] + [InlineData(9, 15)] + public async Task SkipLast(int collection, int skipCount) + { + var xs = await UniTaskAsyncEnumerable.Range(1, collection).SkipLast(skipCount).ToArrayAsync(); + var ys = Enumerable.Range(1, collection).SkipLast(skipCount).ToArray(); + + xs.Should().BeEquivalentTo(ys); + } + + [Fact] + public async Task SkipLastException() + { + foreach (var item in UniTaskTestException.Throws()) + { + var xs = item.SkipLast(5).ToArrayAsync(); + await Assert.ThrowsAsync(async () => await xs); + } + } [Theory] [InlineData(0, 0)]