diff --git a/src/UniTask.NetCore/IAsyncEnumerable.cs b/src/UniTask.NetCore/IAsyncEnumerable.cs index 34d0218..917c521 100644 --- a/src/UniTask.NetCore/IAsyncEnumerable.cs +++ b/src/UniTask.NetCore/IAsyncEnumerable.cs @@ -19,6 +19,11 @@ namespace Cysharp.Threading.Tasks UniTask DisposeAsync(); } + //public interface IUniTaskAsyncGrouping : IUniTaskAsyncEnumerable + //{ + // TKey Key { get; } + //} + public static class UniTaskAsyncEnumerableExtensions { public static UniTaskCancelableAsyncEnumerable WithCancellation(this IUniTaskAsyncEnumerable source, CancellationToken cancellationToken) diff --git a/src/UniTask.NetCore/Linq/GroupBy.cs b/src/UniTask.NetCore/Linq/GroupBy.cs index 46629e6..5da393b 100644 --- a/src/UniTask.NetCore/Linq/GroupBy.cs +++ b/src/UniTask.NetCore/Linq/GroupBy.cs @@ -1,775 +1,827 @@ -namespace Cysharp.Threading.Tasks.Linq +using Cysharp.Threading.Tasks.Internal; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; + +namespace Cysharp.Threading.Tasks.Linq { - internal sealed class GroupBy + public static partial class UniTaskAsyncEnumerable { + // Ix-Async returns IGrouping but it is competely waste, use standard IGrouping. + + public static IUniTaskAsyncEnumerable> GroupBy(this IUniTaskAsyncEnumerable source, Func keySelector) + { + return new GroupBy(source, keySelector, x => x, EqualityComparer.Default); + } + + public static IUniTaskAsyncEnumerable> GroupBy(this IUniTaskAsyncEnumerable source, Func keySelector, IEqualityComparer comparer) + { + return new GroupBy(source, keySelector, x => x, comparer); + } + + public static IUniTaskAsyncEnumerable> GroupBy(this IUniTaskAsyncEnumerable source, Func keySelector, Func elementSelector) + { + return new GroupBy(source, keySelector, elementSelector, EqualityComparer.Default); + } + + public static IUniTaskAsyncEnumerable> GroupBy(this IUniTaskAsyncEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) + { + return new GroupBy(source, keySelector, elementSelector, comparer); + } + + public static IUniTaskAsyncEnumerable GroupBy(this IUniTaskAsyncEnumerable source, Func keySelector, Func, TResult> resultSelector) + { + return new GroupBy(source, keySelector, x => x, resultSelector, EqualityComparer.Default); + } + + public static IUniTaskAsyncEnumerable GroupBy(this IUniTaskAsyncEnumerable source, Func keySelector, Func, TResult> resultSelector, IEqualityComparer comparer) + { + return new GroupBy(source, keySelector, x => x, resultSelector, comparer); + } + + public static IUniTaskAsyncEnumerable GroupBy(this IUniTaskAsyncEnumerable source, Func keySelector, Func elementSelector, Func, TResult> resultSelector) + { + return new GroupBy(source, keySelector, elementSelector, resultSelector, EqualityComparer.Default); + } + public static IUniTaskAsyncEnumerable GroupBy(this IUniTaskAsyncEnumerable source, Func keySelector, Func elementSelector, Func, TResult> resultSelector, IEqualityComparer comparer) + { + return new GroupBy(source, keySelector, elementSelector, resultSelector, comparer); + } + + // await + + public static IUniTaskAsyncEnumerable> GroupByAwait(this IUniTaskAsyncEnumerable source, Func> keySelector) + { + return new GroupByAwait(source, keySelector, x => UniTask.FromResult(x), EqualityComparer.Default); + } + + public static IUniTaskAsyncEnumerable> GroupByAwait(this IUniTaskAsyncEnumerable source, Func> keySelector, IEqualityComparer comparer) + { + return new GroupByAwait(source, keySelector, x => UniTask.FromResult(x), comparer); + } + + public static IUniTaskAsyncEnumerable> GroupByAwait(this IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector) + { + return new GroupByAwait(source, keySelector, elementSelector, EqualityComparer.Default); + } + + public static IUniTaskAsyncEnumerable> GroupByAwait(this IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector, IEqualityComparer comparer) + { + return new GroupByAwait(source, keySelector, elementSelector, comparer); + } + + public static IUniTaskAsyncEnumerable GroupByAwait(this IUniTaskAsyncEnumerable source, Func> keySelector, Func, UniTask> resultSelector) + { + return new GroupByAwait(source, keySelector, x => UniTask.FromResult(x), resultSelector, EqualityComparer.Default); + } + + public static IUniTaskAsyncEnumerable GroupByAwait(this IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector, Func, UniTask> resultSelector) + { + return new GroupByAwait(source, keySelector, elementSelector, resultSelector, EqualityComparer.Default); + } + + public static IUniTaskAsyncEnumerable GroupByAwait(this IUniTaskAsyncEnumerable source, Func> keySelector, Func, UniTask> resultSelector, IEqualityComparer comparer) + { + return new GroupByAwait(source, keySelector, x => UniTask.FromResult(x), resultSelector, comparer); + } + + public static IUniTaskAsyncEnumerable GroupByAwait(this IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector, Func, UniTask> resultSelector, IEqualityComparer comparer) + { + return new GroupByAwait(source, keySelector, elementSelector, resultSelector, comparer); + } + + // with ct + + public static IUniTaskAsyncEnumerable> GroupByAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func> keySelector) + { + return new GroupByAwaitWithCancellation(source, keySelector, (x, _) => UniTask.FromResult(x), EqualityComparer.Default); + } + + public static IUniTaskAsyncEnumerable> GroupByAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func> keySelector, IEqualityComparer comparer) + { + return new GroupByAwaitWithCancellation(source, keySelector, (x, _) => UniTask.FromResult(x), comparer); + } + + public static IUniTaskAsyncEnumerable> GroupByAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector) + { + return new GroupByAwaitWithCancellation(source, keySelector, elementSelector, EqualityComparer.Default); + } + + public static IUniTaskAsyncEnumerable> GroupByAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector, IEqualityComparer comparer) + { + return new GroupByAwaitWithCancellation(source, keySelector, elementSelector, comparer); + } + + public static IUniTaskAsyncEnumerable GroupByAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func> keySelector, Func, CancellationToken, UniTask> resultSelector) + { + return new GroupByAwaitWithCancellation(source, keySelector, (x, _) => UniTask.FromResult(x), resultSelector, EqualityComparer.Default); + } + + public static IUniTaskAsyncEnumerable GroupByAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector, Func, CancellationToken, UniTask> resultSelector) + { + return new GroupByAwaitWithCancellation(source, keySelector, elementSelector, resultSelector, EqualityComparer.Default); + } + + public static IUniTaskAsyncEnumerable GroupByAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func> keySelector, Func, CancellationToken, UniTask> resultSelector, IEqualityComparer comparer) + { + return new GroupByAwaitWithCancellation(source, keySelector, (x, _) => UniTask.FromResult(x), resultSelector, comparer); + } + + public static IUniTaskAsyncEnumerable GroupByAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector, Func, CancellationToken, UniTask> resultSelector, IEqualityComparer comparer) + { + return new GroupByAwaitWithCancellation(source, keySelector, elementSelector, resultSelector, comparer); + } } - -} - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + internal sealed class GroupBy : IUniTaskAsyncEnumerable> + { + readonly IUniTaskAsyncEnumerable source; + readonly Func keySelector; + readonly Func elementSelector; + readonly IEqualityComparer comparer; + + public GroupBy(IUniTaskAsyncEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) + { + this.source = source; + this.keySelector = keySelector; + this.elementSelector = elementSelector; + this.comparer = comparer; + } + + public IUniTaskAsyncEnumerator> GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new Enumerator(source, keySelector, elementSelector, comparer, cancellationToken); + } + + sealed class Enumerator : MoveNextSource, IUniTaskAsyncEnumerator> + { + readonly IUniTaskAsyncEnumerable source; + readonly Func keySelector; + readonly Func elementSelector; + readonly IEqualityComparer comparer; + CancellationToken cancellationToken; + + IEnumerator> groupEnumerator; + + public Enumerator(IUniTaskAsyncEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer, CancellationToken cancellationToken) + { + this.source = source; + this.keySelector = keySelector; + this.elementSelector = elementSelector; + this.comparer = comparer; + this.cancellationToken = cancellationToken; + } + + public IGrouping Current { get; private set; } + + public UniTask MoveNextAsync() + { + cancellationToken.ThrowIfCancellationRequested(); + completionSource.Reset(); + + if (groupEnumerator == null) + { + CreateLookup().Forget(); + } + else + { + SourceMoveNext(); + } + return new UniTask(this, completionSource.Version); + } + + async UniTaskVoid CreateLookup() + { + try + { + var lookup = await source.ToLookupAsync(keySelector, elementSelector, comparer, cancellationToken); + groupEnumerator = lookup.GetEnumerator(); + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + return; + } + SourceMoveNext(); + } + + void SourceMoveNext() + { + try + { + if (groupEnumerator.MoveNext()) + { + Current = groupEnumerator.Current as IGrouping; + completionSource.TrySetResult(true); + } + else + { + completionSource.TrySetResult(false); + } + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + return; + } + } + + public UniTask DisposeAsync() + { + if (groupEnumerator != null) + { + groupEnumerator.Dispose(); + } + + return default; + } + } + } + + internal sealed class GroupBy : IUniTaskAsyncEnumerable + { + readonly IUniTaskAsyncEnumerable source; + readonly Func keySelector; + readonly Func elementSelector; + readonly Func, TResult> resultSelector; + readonly IEqualityComparer comparer; + + public GroupBy(IUniTaskAsyncEnumerable source, Func keySelector, Func elementSelector, Func, TResult> resultSelector, IEqualityComparer comparer) + { + this.source = source; + this.keySelector = keySelector; + this.elementSelector = elementSelector; + this.resultSelector = resultSelector; + this.comparer = comparer; + } + + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new Enumerator(source, keySelector, elementSelector, resultSelector, comparer, cancellationToken); + } + + sealed class Enumerator : MoveNextSource, IUniTaskAsyncEnumerator + { + readonly IUniTaskAsyncEnumerable source; + readonly Func keySelector; + readonly Func elementSelector; + readonly Func, TResult> resultSelector; + readonly IEqualityComparer comparer; + CancellationToken cancellationToken; + + IEnumerator> groupEnumerator; + + public Enumerator(IUniTaskAsyncEnumerable source, Func keySelector, Func elementSelector, Func, TResult> resultSelector, IEqualityComparer comparer, CancellationToken cancellationToken) + { + this.source = source; + this.keySelector = keySelector; + this.elementSelector = elementSelector; + this.resultSelector = resultSelector; + this.comparer = comparer; + this.cancellationToken = cancellationToken; + } + + public TResult Current { get; private set; } + + public UniTask MoveNextAsync() + { + cancellationToken.ThrowIfCancellationRequested(); + completionSource.Reset(); + + if (groupEnumerator == null) + { + CreateLookup().Forget(); + } + else + { + SourceMoveNext(); + } + return new UniTask(this, completionSource.Version); + } + + async UniTaskVoid CreateLookup() + { + try + { + var lookup = await source.ToLookupAsync(keySelector, elementSelector, comparer, cancellationToken); + groupEnumerator = lookup.GetEnumerator(); + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + return; + } + SourceMoveNext(); + } + + void SourceMoveNext() + { + try + { + if (groupEnumerator.MoveNext()) + { + var current = groupEnumerator.Current; + Current = resultSelector(current.Key, current); + completionSource.TrySetResult(true); + } + else + { + completionSource.TrySetResult(false); + } + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + return; + } + } + + public UniTask DisposeAsync() + { + if (groupEnumerator != null) + { + groupEnumerator.Dispose(); + } + + return default; + } + } + } + + internal sealed class GroupByAwait : IUniTaskAsyncEnumerable> + { + readonly IUniTaskAsyncEnumerable source; + readonly Func> keySelector; + readonly Func> elementSelector; + readonly IEqualityComparer comparer; + + public GroupByAwait(IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector, IEqualityComparer comparer) + { + this.source = source; + this.keySelector = keySelector; + this.elementSelector = elementSelector; + this.comparer = comparer; + } + + public IUniTaskAsyncEnumerator> GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new Enumerator(source, keySelector, elementSelector, comparer, cancellationToken); + } + + sealed class Enumerator : MoveNextSource, IUniTaskAsyncEnumerator> + { + readonly IUniTaskAsyncEnumerable source; + readonly Func> keySelector; + readonly Func> elementSelector; + readonly IEqualityComparer comparer; + CancellationToken cancellationToken; + + IEnumerator> groupEnumerator; + + public Enumerator(IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector, IEqualityComparer comparer, CancellationToken cancellationToken) + { + this.source = source; + this.keySelector = keySelector; + this.elementSelector = elementSelector; + this.comparer = comparer; + this.cancellationToken = cancellationToken; + } + + public IGrouping Current { get; private set; } + + public UniTask MoveNextAsync() + { + cancellationToken.ThrowIfCancellationRequested(); + completionSource.Reset(); + + if (groupEnumerator == null) + { + CreateLookup().Forget(); + } + else + { + SourceMoveNext(); + } + return new UniTask(this, completionSource.Version); + } + + async UniTaskVoid CreateLookup() + { + try + { + var lookup = await source.ToLookupAwaitAsync(keySelector, elementSelector, comparer, cancellationToken); + groupEnumerator = lookup.GetEnumerator(); + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + return; + } + SourceMoveNext(); + } + + void SourceMoveNext() + { + try + { + if (groupEnumerator.MoveNext()) + { + Current = groupEnumerator.Current as IGrouping; + completionSource.TrySetResult(true); + } + else + { + completionSource.TrySetResult(false); + } + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + return; + } + } + + public UniTask DisposeAsync() + { + if (groupEnumerator != null) + { + groupEnumerator.Dispose(); + } + + return default; + } + } + } + + internal sealed class GroupByAwait : IUniTaskAsyncEnumerable + { + readonly IUniTaskAsyncEnumerable source; + readonly Func> keySelector; + readonly Func> elementSelector; + readonly Func, UniTask> resultSelector; + readonly IEqualityComparer comparer; + + public GroupByAwait(IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector, Func, UniTask> resultSelector, IEqualityComparer comparer) + { + this.source = source; + this.keySelector = keySelector; + this.elementSelector = elementSelector; + this.resultSelector = resultSelector; + this.comparer = comparer; + } + + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new Enumerator(source, keySelector, elementSelector, resultSelector, comparer, cancellationToken); + } + + sealed class Enumerator : MoveNextSource, IUniTaskAsyncEnumerator + { + readonly static Action ResultSelectCoreDelegate = ResultSelectCore; + + readonly IUniTaskAsyncEnumerable source; + readonly Func> keySelector; + readonly Func> elementSelector; + readonly Func, UniTask> resultSelector; + readonly IEqualityComparer comparer; + CancellationToken cancellationToken; + + IEnumerator> groupEnumerator; + UniTask.Awaiter awaiter; + + public Enumerator(IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector, Func, UniTask> resultSelector, IEqualityComparer comparer, CancellationToken cancellationToken) + { + this.source = source; + this.keySelector = keySelector; + this.elementSelector = elementSelector; + this.resultSelector = resultSelector; + this.comparer = comparer; + this.cancellationToken = cancellationToken; + } + + public TResult Current { get; private set; } + + public UniTask MoveNextAsync() + { + cancellationToken.ThrowIfCancellationRequested(); + completionSource.Reset(); + + if (groupEnumerator == null) + { + CreateLookup().Forget(); + } + else + { + SourceMoveNext(); + } + return new UniTask(this, completionSource.Version); + } + + async UniTaskVoid CreateLookup() + { + try + { + var lookup = await source.ToLookupAwaitAsync(keySelector, elementSelector, comparer, cancellationToken); + groupEnumerator = lookup.GetEnumerator(); + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + return; + } + SourceMoveNext(); + } + + void SourceMoveNext() + { + try + { + if (groupEnumerator.MoveNext()) + { + var current = groupEnumerator.Current; + + awaiter = resultSelector(current.Key, current).GetAwaiter(); + if (awaiter.IsCompleted) + { + ResultSelectCore(this); + } + else + { + awaiter.SourceOnCompleted(ResultSelectCoreDelegate, this); + } + return; + } + else + { + completionSource.TrySetResult(false); + } + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + return; + } + } + + static void ResultSelectCore(object state) + { + var self = (Enumerator)state; + + if (self.TryGetResult(self.awaiter, out var result)) + { + self.Current = result; + self.completionSource.TrySetResult(true); + } + } + + public UniTask DisposeAsync() + { + if (groupEnumerator != null) + { + groupEnumerator.Dispose(); + } + + return default; + } + } + } + + internal sealed class GroupByAwaitWithCancellation : IUniTaskAsyncEnumerable> + { + readonly IUniTaskAsyncEnumerable source; + readonly Func> keySelector; + readonly Func> elementSelector; + readonly IEqualityComparer comparer; + + public GroupByAwaitWithCancellation(IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector, IEqualityComparer comparer) + { + this.source = source; + this.keySelector = keySelector; + this.elementSelector = elementSelector; + this.comparer = comparer; + } + + public IUniTaskAsyncEnumerator> GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new Enumerator(source, keySelector, elementSelector, comparer, cancellationToken); + } + + sealed class Enumerator : MoveNextSource, IUniTaskAsyncEnumerator> + { + readonly IUniTaskAsyncEnumerable source; + readonly Func> keySelector; + readonly Func> elementSelector; + readonly IEqualityComparer comparer; + CancellationToken cancellationToken; + + IEnumerator> groupEnumerator; + + public Enumerator(IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector, IEqualityComparer comparer, CancellationToken cancellationToken) + { + this.source = source; + this.keySelector = keySelector; + this.elementSelector = elementSelector; + this.comparer = comparer; + this.cancellationToken = cancellationToken; + } + + public IGrouping Current { get; private set; } + + public UniTask MoveNextAsync() + { + cancellationToken.ThrowIfCancellationRequested(); + completionSource.Reset(); + + if (groupEnumerator == null) + { + CreateLookup().Forget(); + } + else + { + SourceMoveNext(); + } + return new UniTask(this, completionSource.Version); + } + + async UniTaskVoid CreateLookup() + { + try + { + var lookup = await source.ToLookupAwaitWithCancellationAsync(keySelector, elementSelector, comparer, cancellationToken); + groupEnumerator = lookup.GetEnumerator(); + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + return; + } + SourceMoveNext(); + } + + void SourceMoveNext() + { + try + { + if (groupEnumerator.MoveNext()) + { + Current = groupEnumerator.Current as IGrouping; + completionSource.TrySetResult(true); + } + else + { + completionSource.TrySetResult(false); + } + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + return; + } + } + + public UniTask DisposeAsync() + { + if (groupEnumerator != null) + { + groupEnumerator.Dispose(); + } + + return default; + } + } + } + + internal sealed class GroupByAwaitWithCancellation : IUniTaskAsyncEnumerable + { + readonly IUniTaskAsyncEnumerable source; + readonly Func> keySelector; + readonly Func> elementSelector; + readonly Func, CancellationToken, UniTask> resultSelector; + readonly IEqualityComparer comparer; + + public GroupByAwaitWithCancellation(IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector, Func, CancellationToken, UniTask> resultSelector, IEqualityComparer comparer) + { + this.source = source; + this.keySelector = keySelector; + this.elementSelector = elementSelector; + this.resultSelector = resultSelector; + this.comparer = comparer; + } + + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new Enumerator(source, keySelector, elementSelector, resultSelector, comparer, cancellationToken); + } + + sealed class Enumerator : MoveNextSource, IUniTaskAsyncEnumerator + { + readonly static Action ResultSelectCoreDelegate = ResultSelectCore; + + readonly IUniTaskAsyncEnumerable source; + readonly Func> keySelector; + readonly Func> elementSelector; + readonly Func, CancellationToken, UniTask> resultSelector; + readonly IEqualityComparer comparer; + CancellationToken cancellationToken; + + IEnumerator> groupEnumerator; + UniTask.Awaiter awaiter; + + public Enumerator(IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector, Func, CancellationToken, UniTask> resultSelector, IEqualityComparer comparer, CancellationToken cancellationToken) + { + this.source = source; + this.keySelector = keySelector; + this.elementSelector = elementSelector; + this.resultSelector = resultSelector; + this.comparer = comparer; + this.cancellationToken = cancellationToken; + } + + public TResult Current { get; private set; } + + public UniTask MoveNextAsync() + { + cancellationToken.ThrowIfCancellationRequested(); + completionSource.Reset(); + + if (groupEnumerator == null) + { + CreateLookup().Forget(); + } + else + { + SourceMoveNext(); + } + return new UniTask(this, completionSource.Version); + } + + async UniTaskVoid CreateLookup() + { + try + { + var lookup = await source.ToLookupAwaitWithCancellationAsync(keySelector, elementSelector, comparer, cancellationToken); + groupEnumerator = lookup.GetEnumerator(); + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + return; + } + SourceMoveNext(); + } + + void SourceMoveNext() + { + try + { + if (groupEnumerator.MoveNext()) + { + var current = groupEnumerator.Current; + + awaiter = resultSelector(current.Key, current, cancellationToken).GetAwaiter(); + if (awaiter.IsCompleted) + { + ResultSelectCore(this); + } + else + { + awaiter.SourceOnCompleted(ResultSelectCoreDelegate, this); + } + return; + } + else + { + completionSource.TrySetResult(false); + } + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + return; + } + } + + static void ResultSelectCore(object state) + { + var self = (Enumerator)state; + + if (self.TryGetResult(self.awaiter, out var result)) + { + self.Current = result; + self.completionSource.TrySetResult(true); + } + } + + public UniTask DisposeAsync() + { + if (groupEnumerator != null) + { + groupEnumerator.Dispose(); + } + + return default; + } + } + } +} \ No newline at end of file diff --git a/src/UniTask.NetCore/Linq/Join.cs b/src/UniTask.NetCore/Linq/Join.cs index 1b42e2d..ceebbc8 100644 --- a/src/UniTask.NetCore/Linq/Join.cs +++ b/src/UniTask.NetCore/Linq/Join.cs @@ -154,7 +154,7 @@ namespace Cysharp.Threading.Tasks.Linq { try { - lookup = await inner.ToLookupAsync(innerKeySelector, cancellationToken); + lookup = await inner.ToLookupAsync(innerKeySelector, comparer, cancellationToken); enumerator = outer.GetAsyncEnumerator(cancellationToken); } catch (Exception ex) @@ -179,6 +179,7 @@ namespace Cysharp.Threading.Tasks.Linq } else { + valueEnumerator.Dispose(); valueEnumerator = null; } } @@ -342,7 +343,7 @@ namespace Cysharp.Threading.Tasks.Linq { try { - lookup = await inner.ToLookupAwaitAsync(innerKeySelector, cancellationToken); + lookup = await inner.ToLookupAwaitAsync(innerKeySelector, comparer, cancellationToken); enumerator = outer.GetAsyncEnumerator(cancellationToken); } catch (Exception ex) @@ -375,6 +376,7 @@ namespace Cysharp.Threading.Tasks.Linq } else { + valueEnumerator.Dispose(); valueEnumerator = null; } } @@ -568,7 +570,7 @@ namespace Cysharp.Threading.Tasks.Linq { try { - lookup = await inner.ToLookupAwaitWithCancellationAsync(innerKeySelector, cancellationToken: cancellationToken); + lookup = await inner.ToLookupAwaitWithCancellationAsync(innerKeySelector, comparer, cancellationToken: cancellationToken); enumerator = outer.GetAsyncEnumerator(cancellationToken); } catch (Exception ex) @@ -601,6 +603,7 @@ namespace Cysharp.Threading.Tasks.Linq } else { + valueEnumerator.Dispose(); valueEnumerator = null; } } diff --git a/src/UniTask.NetCore/Linq/ToLookup.cs b/src/UniTask.NetCore/Linq/ToLookup.cs index dfd9faa..c93944b 100644 --- a/src/UniTask.NetCore/Linq/ToLookup.cs +++ b/src/UniTask.NetCore/Linq/ToLookup.cs @@ -514,7 +514,7 @@ namespace Cysharp.Threading.Tasks.Linq } } - class Grouping : IGrouping + class Grouping : IGrouping // , IUniTaskAsyncGrouping { readonly List elements; @@ -539,6 +539,11 @@ namespace Cysharp.Threading.Tasks.Linq { return elements.GetEnumerator(); } + + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return this.ToUniTaskAsyncEnumerable().GetAsyncEnumerator(cancellationToken); + } } } } \ No newline at end of file diff --git a/src/UniTask.NetCore/Linq/_FileMaker.cs b/src/UniTask.NetCore/Linq/_FileMaker.cs index 840e767..df91ad8 100644 --- a/src/UniTask.NetCore/Linq/_FileMaker.cs +++ b/src/UniTask.NetCore/Linq/_FileMaker.cs @@ -30,125 +30,6 @@ namespace ___Dummy - public static IUniTaskAsyncEnumerable GroupBy(this IUniTaskAsyncEnumerable source, Func keySelector, Func elementSelector, Func, TResult> resultSelector, IEqualityComparer comparer) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable> GroupBy(this IUniTaskAsyncEnumerable source, Func keySelector) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable> GroupBy(this IUniTaskAsyncEnumerable source, Func keySelector, IEqualityComparer comparer) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable> GroupBy(this IUniTaskAsyncEnumerable source, Func keySelector, Func elementSelector) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable> GroupBy(this IUniTaskAsyncEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable GroupBy(this IUniTaskAsyncEnumerable source, Func keySelector, Func, TResult> resultSelector) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable GroupBy(this IUniTaskAsyncEnumerable source, Func keySelector, Func, TResult> resultSelector, IEqualityComparer comparer) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable GroupBy(this IUniTaskAsyncEnumerable source, Func keySelector, Func elementSelector, Func, TResult> resultSelector) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable> GroupByAwait(this IUniTaskAsyncEnumerable source, Func> keySelector) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable> GroupByAwait(this IUniTaskAsyncEnumerable source, Func> keySelector, IEqualityComparer comparer) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable> GroupByAwait(this IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable GroupByAwait(this IUniTaskAsyncEnumerable source, Func> keySelector, Func, UniTask> resultSelector) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable GroupByAwait(this IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector, Func, UniTask> resultSelector) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable> GroupByAwait(this IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector, IEqualityComparer comparer) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable GroupByAwait(this IUniTaskAsyncEnumerable source, Func> keySelector, Func, UniTask> resultSelector, IEqualityComparer comparer) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable GroupByAwait(this IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector, Func, UniTask> resultSelector, IEqualityComparer comparer) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable> GroupByAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func> keySelector) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable> GroupByAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func> keySelector, IEqualityComparer comparer) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable> GroupByAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable GroupByAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func> keySelector, Func, CancellationToken, UniTask> resultSelector) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable GroupByAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector, Func, CancellationToken, UniTask> resultSelector) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable> GroupByAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector, IEqualityComparer comparer) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable GroupByAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func> keySelector, Func, CancellationToken, UniTask> resultSelector, IEqualityComparer comparer) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable GroupByAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func> keySelector, Func> elementSelector, Func, CancellationToken, UniTask> resultSelector, IEqualityComparer comparer) - { - throw new NotImplementedException(); - } public static IUniTaskAsyncEnumerable GroupJoin(this IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func, TResult> resultSelector) { diff --git a/src/UniTask.NetCoreSandbox/Program.cs b/src/UniTask.NetCoreSandbox/Program.cs index f9deafb..24807f2 100644 --- a/src/UniTask.NetCoreSandbox/Program.cs +++ b/src/UniTask.NetCoreSandbox/Program.cs @@ -52,6 +52,8 @@ namespace NetCoreSandbox } + // AsyncEnumerable.Range(1, 10).GroupBy(x=>x).Select(x=>x.first + //Enumerable.Range(1,10).ToHashSet( diff --git a/src/UniTask.NetCoreTests/Linq/Joins.cs b/src/UniTask.NetCoreTests/Linq/Joins.cs index bfd1099..bc276e0 100644 --- a/src/UniTask.NetCoreTests/Linq/Joins.cs +++ b/src/UniTask.NetCoreTests/Linq/Joins.cs @@ -112,5 +112,78 @@ namespace NetCoreTests.Linq await Assert.ThrowsAsync(async () => await ys); } } + + + [Fact] + public async Task GroupBy() + { + var arr = new[] { 1, 4, 10, 10, 4, 5, 10, 9 }; + { + var xs = await arr.ToUniTaskAsyncEnumerable().GroupBy(x => x).ToArrayAsync(); + var ys = arr.GroupBy(x => x).ToArray(); + + xs.Length.Should().Be(ys.Length); + xs.OrderBy(x => x.Key).Should().BeEquivalentTo(ys.OrderBy(x => x.Key)); + } + + { + var xs = await arr.ToUniTaskAsyncEnumerable().GroupBy(x => x, (key, xs) => (key, xs.ToArray())).ToArrayAsync(); + var ys = arr.GroupBy(x => x, (key, xs) => (key, xs.ToArray())).ToArray(); + + xs.Length.Should().Be(ys.Length); + xs.OrderBy(x => x.key).SelectMany(x => x.Item2).Should().BeEquivalentTo(ys.OrderBy(x => x.key).SelectMany(x => x.Item2)); + } + + { + var xs = await arr.ToUniTaskAsyncEnumerable().GroupByAwait(x => RandomRun(x)).ToArrayAsync(); + var ys = arr.GroupBy(x => x).ToArray(); + + xs.Length.Should().Be(ys.Length); + xs.OrderBy(x => x.Key).Should().BeEquivalentTo(ys.OrderBy(x => x.Key)); + } + + { + var xs = await arr.ToUniTaskAsyncEnumerable().GroupByAwait(x => RandomRun(x), (key, xs) => RandomRun((key, xs.ToArray()))).ToArrayAsync(); + var ys = arr.GroupBy(x => x, (key, xs) => (key, xs.ToArray())).ToArray(); + + xs.Length.Should().Be(ys.Length); + xs.OrderBy(x => x.key).SelectMany(x => x.Item2).Should().BeEquivalentTo(ys.OrderBy(x => x.key).SelectMany(x => x.Item2)); + } + + { + var xs = await arr.ToUniTaskAsyncEnumerable().GroupByAwaitWithCancellation((x, _) => RandomRun(x)).ToArrayAsync(); + var ys = arr.GroupBy(x => x).ToArray(); + + xs.Length.Should().Be(ys.Length); + xs.OrderBy(x => x.Key).Should().BeEquivalentTo(ys.OrderBy(x => x.Key)); + } + + { + var xs = await arr.ToUniTaskAsyncEnumerable().GroupByAwaitWithCancellation((x, _) => RandomRun(x), (key, xs, _) => RandomRun((key, xs.ToArray()))).ToArrayAsync(); + var ys = arr.GroupBy(x => x, (key, xs) => (key, xs.ToArray())).ToArray(); + + xs.Length.Should().Be(ys.Length); + xs.OrderBy(x => x.key).SelectMany(x => x.Item2).Should().BeEquivalentTo(ys.OrderBy(x => x.key).SelectMany(x => x.Item2)); + } + } + + + + + [Fact] + public async Task GroupByThrow() + { + var arr = new[] { 1, 4, 10, 10, 4, 5, 10, 9 }; + foreach (var item in UniTaskTestException.Throws()) + { + var xs = item.GroupBy(x => x).ToArrayAsync(); + var ys = item.GroupByAwait(x => RandomRun(x)).ToArrayAsync(); + var zs = item.GroupByAwaitWithCancellation((x, _) => RandomRun(x)).ToArrayAsync(); + + await Assert.ThrowsAsync(async () => await xs); + await Assert.ThrowsAsync(async () => await ys); + await Assert.ThrowsAsync(async () => await zs); + } + } } }