diff --git a/src/UniTask.NetCore/Linq/Buffer.cs b/src/UniTask.NetCore/Linq/Buffer.cs new file mode 100644 index 0000000..f29af41 --- /dev/null +++ b/src/UniTask.NetCore/Linq/Buffer.cs @@ -0,0 +1,340 @@ +using Cysharp.Threading.Tasks.Internal; +using System; +using System.Collections.Generic; +using System.Threading; + +namespace Cysharp.Threading.Tasks.Linq +{ + public static partial class UniTaskAsyncEnumerable + { + public static IUniTaskAsyncEnumerable> Buffer(this IUniTaskAsyncEnumerable source, Int32 count) + { + Error.ThrowArgumentNullException(source, nameof(source)); + if (count <= 0) throw Error.ArgumentOutOfRange(nameof(count)); + + return new Buffer(source, count); + } + + public static IUniTaskAsyncEnumerable> Buffer(this IUniTaskAsyncEnumerable source, Int32 count, Int32 skip) + { + Error.ThrowArgumentNullException(source, nameof(source)); + if (count <= 0) throw Error.ArgumentOutOfRange(nameof(count)); + if (skip <= 0) throw Error.ArgumentOutOfRange(nameof(skip)); + + return new BufferSkip(source, count, skip); + } + } + + internal sealed class Buffer : IUniTaskAsyncEnumerable> + { + readonly IUniTaskAsyncEnumerable source; + readonly int count; + + public Buffer(IUniTaskAsyncEnumerable source, int count) + { + this.source = source; + this.count = count; + } + + public IUniTaskAsyncEnumerator> GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new Enumerator(source, count, cancellationToken); + } + + sealed class Enumerator : MoveNextSource, IUniTaskAsyncEnumerator> + { + static readonly Action MoveNextCoreDelegate = MoveNextCore; + + readonly IUniTaskAsyncEnumerable source; + readonly int count; + CancellationToken cancellationToken; + + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; + bool continueNext; + + bool completed; + List buffer; + + public Enumerator(IUniTaskAsyncEnumerable source, int count, CancellationToken cancellationToken) + { + this.source = source; + this.count = count; + this.cancellationToken = cancellationToken; + } + + public IList Current { get; private set; } + + public UniTask MoveNextAsync() + { + cancellationToken.ThrowIfCancellationRequested(); + + if (enumerator == null) + { + enumerator = source.GetAsyncEnumerator(cancellationToken); + buffer = new List(count); + } + + completionSource.Reset(); + SourceMoveNext(); + return new UniTask(this, completionSource.Version); + } + + void SourceMoveNext() + { + if (completed) + { + if (buffer != null && buffer.Count > 0) + { + var ret = buffer; + buffer = null; + Current = ret; + completionSource.TrySetResult(true); + return; + } + else + { + completionSource.TrySetResult(false); + return; + } + } + + try + { + + LOOP: + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + continueNext = true; + MoveNextCore(this); + if (continueNext) + { + continueNext = false; + goto LOOP; // avoid recursive + } + } + else + { + awaiter.SourceOnCompleted(MoveNextCoreDelegate, this); + } + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + } + } + + + static void MoveNextCore(object state) + { + var self = (Enumerator)state; + + if (self.TryGetResult(self.awaiter, out var result)) + { + if (result) + { + self.buffer.Add(self.enumerator.Current); + + if (self.buffer.Count == self.count) + { + self.Current = self.buffer; + self.buffer = new List(self.count); + self.continueNext = false; + self.completionSource.TrySetResult(true); + return; + } + else + { + if (!self.continueNext) + { + self.SourceMoveNext(); + } + } + } + else + { + self.continueNext = false; + self.completed = true; + self.SourceMoveNext(); + } + } + else + { + self.continueNext = false; + } + } + + public UniTask DisposeAsync() + { + if (enumerator != null) + { + return enumerator.DisposeAsync(); + } + return default; + } + } + } + + internal sealed class BufferSkip : IUniTaskAsyncEnumerable> + { + readonly IUniTaskAsyncEnumerable source; + readonly int count; + readonly int skip; + + public BufferSkip(IUniTaskAsyncEnumerable source, int count, int skip) + { + this.source = source; + this.count = count; + this.skip = skip; + } + + public IUniTaskAsyncEnumerator> GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new Enumerator(source, count, skip, cancellationToken); + } + + sealed class Enumerator : MoveNextSource, IUniTaskAsyncEnumerator> + { + static readonly Action MoveNextCoreDelegate = MoveNextCore; + + readonly IUniTaskAsyncEnumerable source; + readonly int count; + readonly int skip; + CancellationToken cancellationToken; + + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; + bool continueNext; + + bool completed; + Queue> buffers; + int index = 0; + + public Enumerator(IUniTaskAsyncEnumerable source, int count, int skip, CancellationToken cancellationToken) + { + this.source = source; + this.count = count; + this.skip = skip; + this.cancellationToken = cancellationToken; + } + + public IList Current { get; private set; } + + public UniTask MoveNextAsync() + { + cancellationToken.ThrowIfCancellationRequested(); + + if (enumerator == null) + { + enumerator = source.GetAsyncEnumerator(cancellationToken); + buffers = new Queue>(); + } + + completionSource.Reset(); + SourceMoveNext(); + return new UniTask(this, completionSource.Version); + } + + void SourceMoveNext() + { + if (completed) + { + if (buffers.Count > 0) + { + Current = buffers.Dequeue(); + completionSource.TrySetResult(true); + return; + } + else + { + completionSource.TrySetResult(false); + return; + } + } + + try + { + + LOOP: + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + continueNext = true; + MoveNextCore(this); + if (continueNext) + { + continueNext = false; + goto LOOP; // avoid recursive + } + } + else + { + awaiter.SourceOnCompleted(MoveNextCoreDelegate, this); + } + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + } + } + + + static void MoveNextCore(object state) + { + var self = (Enumerator)state; + + if (self.TryGetResult(self.awaiter, out var result)) + { + if (result) + { + if (self.index++ % self.skip == 0) + { + self.buffers.Enqueue(new List(self.count)); + } + + var item = self.enumerator.Current; + foreach (var buffer in self.buffers) + { + buffer.Add(item); + } + + if (self.buffers.Count > 0 && self.buffers.Peek().Count == self.count) + { + self.Current = self.buffers.Dequeue(); + self.continueNext = false; + self.completionSource.TrySetResult(true); + return; + } + else + { + if (!self.continueNext) + { + self.SourceMoveNext(); + } + } + } + else + { + self.continueNext = false; + self.completed = true; + self.SourceMoveNext(); + } + } + else + { + self.continueNext = false; + } + } + + public UniTask DisposeAsync() + { + if (enumerator != null) + { + return enumerator.DisposeAsync(); + } + return default; + } + } + } +} \ No newline at end of file diff --git a/src/UniTask.NetCoreTests/Linq/Projection.cs b/src/UniTask.NetCoreTests/Linq/Projection.cs index 2ec717a..ec652ac 100644 --- a/src/UniTask.NetCoreTests/Linq/Projection.cs +++ b/src/UniTask.NetCoreTests/Linq/Projection.cs @@ -272,5 +272,52 @@ namespace NetCoreTests.Linq await Assert.ThrowsAsync(async () => await xs); } } + + [Theory] + // [InlineData(0, 0)] + [InlineData(0, 3)] + [InlineData(9, 1)] + [InlineData(9, 2)] + [InlineData(9, 3)] + [InlineData(17, 3)] + [InlineData(17, 16)] + [InlineData(17, 17)] + [InlineData(17, 27)] + public async Task Buffer(int rangeCount, int bufferCount) + { + var xs = await UniTaskAsyncEnumerable.Range(0, rangeCount).Buffer(bufferCount).Select(x => string.Join(",", x)).ToArrayAsync(); + var ys = await AsyncEnumerable.Range(0, rangeCount).Buffer(bufferCount).Select(x => string.Join(",", x)).ToArrayAsync(); + + xs.Should().BeEquivalentTo(ys); + } + + [Theory] + // [InlineData(0, 0)] + [InlineData(0, 3, 2)] + [InlineData(9, 1, 1)] + [InlineData(9, 2, 3)] + [InlineData(9, 3, 4)] + [InlineData(17, 3, 3)] + [InlineData(17, 16, 5)] + [InlineData(17, 17, 19)] + public async Task BufferSkip(int rangeCount, int bufferCount, int skipCount) + { + var xs = await UniTaskAsyncEnumerable.Range(0, rangeCount).Buffer(bufferCount, skipCount).Select(x => string.Join(",", x)).ToArrayAsync(); + var ys = await AsyncEnumerable.Range(0, rangeCount).Buffer(bufferCount, skipCount).Select(x => string.Join(",", x)).ToArrayAsync(); + + xs.Should().BeEquivalentTo(ys); + } + + [Fact] + public async Task BufferError() + { + foreach (var item in UniTaskTestException.Throws()) + { + var xs = item.Buffer(3).ToArrayAsync(); + var ys = item.Buffer(3, 2).ToArrayAsync(); + await Assert.ThrowsAsync(async () => await xs); + await Assert.ThrowsAsync(async () => await ys); + } + } } }