diff --git a/src/UniTask.NetCore/Linq/AsyncEnumeratorBase.cs b/src/UniTask.NetCore/Linq/AsyncEnumeratorBase.cs index b4b4493..3f3ba2f 100644 --- a/src/UniTask.NetCore/Linq/AsyncEnumeratorBase.cs +++ b/src/UniTask.NetCore/Linq/AsyncEnumeratorBase.cs @@ -212,7 +212,7 @@ namespace Cysharp.Threading.Tasks.Linq protected abstract bool TrySetCurrentCore(TAwait awaitResult, out bool terminateIteration); // Util - protected TSource SourceCurrent => enumerator.Current; + protected TSource SourceCurrent { get; private set; } protected (bool waitCallback, bool requireNextIteration) ActionCompleted(bool trySetCurrentResult, out bool moveNextResult) { @@ -287,7 +287,8 @@ namespace Cysharp.Threading.Tasks.Linq { if (sourceHasCurrent) { - var task = TransformAsync(enumerator.Current); + SourceCurrent = enumerator.Current; + var task = TransformAsync(SourceCurrent); if (UnwarapTask(task, out var taskResult)) { var currentResult = TrySetCurrentCore(taskResult, out var terminateIteration); diff --git a/src/UniTask.NetCore/Linq/Distinct.cs b/src/UniTask.NetCore/Linq/Distinct.cs index fc4aa48..2868fcd 100644 --- a/src/UniTask.NetCore/Linq/Distinct.cs +++ b/src/UniTask.NetCore/Linq/Distinct.cs @@ -9,8 +9,6 @@ namespace Cysharp.Threading.Tasks.Linq { public static IUniTaskAsyncEnumerable Distinct(this IUniTaskAsyncEnumerable source) { - Error.ThrowArgumentNullException(source, nameof(source)); - return Distinct(source, EqualityComparer.Default); } @@ -21,6 +19,48 @@ namespace Cysharp.Threading.Tasks.Linq return new Distinct(source, comparer); } + + public static IUniTaskAsyncEnumerable Distinct(this IUniTaskAsyncEnumerable source, Func keySelector) + { + return Distinct(source, keySelector, EqualityComparer.Default); + } + + public static IUniTaskAsyncEnumerable Distinct(this IUniTaskAsyncEnumerable source, Func keySelector, IEqualityComparer comparer) + { + Error.ThrowArgumentNullException(source, nameof(source)); + Error.ThrowArgumentNullException(keySelector, nameof(keySelector)); + Error.ThrowArgumentNullException(comparer, nameof(comparer)); + + return new Distinct(source, keySelector, comparer); + } + + public static IUniTaskAsyncEnumerable DistinctAwait(this IUniTaskAsyncEnumerable source, Func> keySelector) + { + return DistinctAwait(source, keySelector, EqualityComparer.Default); + } + + public static IUniTaskAsyncEnumerable DistinctAwait(this IUniTaskAsyncEnumerable source, Func> keySelector, IEqualityComparer comparer) + { + Error.ThrowArgumentNullException(source, nameof(source)); + Error.ThrowArgumentNullException(keySelector, nameof(keySelector)); + Error.ThrowArgumentNullException(comparer, nameof(comparer)); + + return new DistinctAwait(source, keySelector, comparer); + } + + public static IUniTaskAsyncEnumerable DistinctAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func> keySelector) + { + return DistinctAwaitWithCancellation(source, keySelector, EqualityComparer.Default); + } + + public static IUniTaskAsyncEnumerable DistinctAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func> keySelector, IEqualityComparer comparer) + { + Error.ThrowArgumentNullException(source, nameof(source)); + Error.ThrowArgumentNullException(keySelector, nameof(keySelector)); + Error.ThrowArgumentNullException(comparer, nameof(comparer)); + + return new DistinctAwaitCancellation(source, keySelector, comparer); + } } internal sealed class Distinct : IUniTaskAsyncEnumerable @@ -73,4 +113,165 @@ namespace Cysharp.Threading.Tasks.Linq } } } + + internal sealed class Distinct : IUniTaskAsyncEnumerable + { + readonly IUniTaskAsyncEnumerable source; + readonly Func keySelector; + readonly IEqualityComparer comparer; + + public Distinct(IUniTaskAsyncEnumerable source, Func keySelector, IEqualityComparer comparer) + { + this.source = source; + this.keySelector = keySelector; + this.comparer = comparer; + } + + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new Enumerator(source, keySelector, comparer, cancellationToken); + } + + class Enumerator : AsyncEnumeratorBase + { + readonly HashSet set; + readonly Func keySelector; + + public Enumerator(IUniTaskAsyncEnumerable source, Func keySelector, IEqualityComparer comparer, CancellationToken cancellationToken) + + : base(source, cancellationToken) + { + this.set = new HashSet(comparer); + this.keySelector = keySelector; + } + + protected override bool TryMoveNextCore(bool sourceHasCurrent, out bool result) + { + if (sourceHasCurrent) + { + var v = SourceCurrent; + if (set.Add(keySelector(v))) + { + Current = v; + result = true; + return true; + } + else + { + result = default; + return false; + } + } + + result = false; + return true; + } + } + } + + internal sealed class DistinctAwait : IUniTaskAsyncEnumerable + { + readonly IUniTaskAsyncEnumerable source; + readonly Func> keySelector; + readonly IEqualityComparer comparer; + + public DistinctAwait(IUniTaskAsyncEnumerable source, Func> keySelector, IEqualityComparer comparer) + { + this.source = source; + this.keySelector = keySelector; + this.comparer = comparer; + } + + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new Enumerator(source, keySelector, comparer, cancellationToken); + } + + class Enumerator : AsyncEnumeratorAwaitSelectorBase + { + readonly HashSet set; + readonly Func> keySelector; + + public Enumerator(IUniTaskAsyncEnumerable source, Func> keySelector, IEqualityComparer comparer, CancellationToken cancellationToken) + + : base(source, cancellationToken) + { + this.set = new HashSet(comparer); + this.keySelector = keySelector; + } + + protected override UniTask TransformAsync(TSource sourceCurrent) + { + return keySelector(sourceCurrent); + } + + protected override bool TrySetCurrentCore(TKey awaitResult, out bool terminateIteration) + { + if (set.Add(awaitResult)) + { + Current = SourceCurrent; + terminateIteration = false; + return true; + } + else + { + terminateIteration = false; + return false; + } + } + } + } + + internal sealed class DistinctAwaitCancellation : IUniTaskAsyncEnumerable + { + readonly IUniTaskAsyncEnumerable source; + readonly Func> keySelector; + readonly IEqualityComparer comparer; + + public DistinctAwaitCancellation(IUniTaskAsyncEnumerable source, Func> keySelector, IEqualityComparer comparer) + { + this.source = source; + this.keySelector = keySelector; + this.comparer = comparer; + } + + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new Enumerator(source, keySelector, comparer, cancellationToken); + } + + class Enumerator : AsyncEnumeratorAwaitSelectorBase + { + readonly HashSet set; + readonly Func> keySelector; + + public Enumerator(IUniTaskAsyncEnumerable source, Func> keySelector, IEqualityComparer comparer, CancellationToken cancellationToken) + + : base(source, cancellationToken) + { + this.set = new HashSet(comparer); + this.keySelector = keySelector; + } + + protected override UniTask TransformAsync(TSource sourceCurrent) + { + return keySelector(sourceCurrent, cancellationToken); + } + + protected override bool TrySetCurrentCore(TKey awaitResult, out bool terminateIteration) + { + if (set.Add(awaitResult)) + { + Current = SourceCurrent; + terminateIteration = false; + return true; + } + else + { + terminateIteration = false; + return false; + } + } + } + } } \ No newline at end of file diff --git a/src/UniTask.NetCore/Linq/_FileMaker.cs b/src/UniTask.NetCore/Linq/_FileMaker.cs deleted file mode 100644 index 4f1b224..0000000 --- a/src/UniTask.NetCore/Linq/_FileMaker.cs +++ /dev/null @@ -1,51 +0,0 @@ -using Cysharp.Threading.Tasks; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; - -namespace ___Dummy -{ - - public interface IAsyncGrouping - { - } - - - public interface IOrderedAsyncEnumerable - { - - } - public static partial class _FileMaker - { - // Buffer,Distinct, DistinctUntilChanged, Do, MaxBy, MinBy, Never,Return, Throw - - - - - - - - - - - - - - - - - - - - - - - - - } - - -} - diff --git a/src/UniTask.NetCoreTests/Linq/Sets.cs b/src/UniTask.NetCoreTests/Linq/Sets.cs index 627c578..e730f61 100644 --- a/src/UniTask.NetCoreTests/Linq/Sets.cs +++ b/src/UniTask.NetCoreTests/Linq/Sets.cs @@ -32,10 +32,13 @@ namespace NetCoreTests.Linq [MemberData(nameof(array1))] public async Task Distinct(int[] array) { - var xs = await array.ToUniTaskAsyncEnumerable().Distinct().ToArrayAsync(); var ys = array.Distinct().ToArray(); - - xs.Should().BeEquivalentTo(ys); + { + (await array.ToUniTaskAsyncEnumerable().Distinct().ToArrayAsync()).Should().BeEquivalentTo(ys); + (await array.ToUniTaskAsyncEnumerable().Distinct(x => x).ToArrayAsync()).Should().BeEquivalentTo(ys); + (await array.ToUniTaskAsyncEnumerable().DistinctAwait(x => UniTask.Run(() => x)).ToArrayAsync()).Should().BeEquivalentTo(ys); + (await array.ToUniTaskAsyncEnumerable().DistinctAwaitWithCancellation((x, _) => UniTask.Run(() => x)).ToArrayAsync()).Should().BeEquivalentTo(ys); + } } [Fact] @@ -43,8 +46,22 @@ namespace NetCoreTests.Linq { foreach (var item in UniTaskTestException.Throws()) { - var xs = item.Distinct().ToArrayAsync(); - await Assert.ThrowsAsync(async () => await xs); + { + var xs = item.Distinct().ToArrayAsync(); + await Assert.ThrowsAsync(async () => await xs); + } + { + var xs = item.Distinct(x => x).ToArrayAsync(); + await Assert.ThrowsAsync(async () => await xs); + } + { + var xs = item.DistinctAwait(x => UniTask.Run(() => x)).ToArrayAsync(); + await Assert.ThrowsAsync(async () => await xs); + } + { + var xs = item.DistinctAwaitWithCancellation((x, _) => UniTask.Run(() => x)).ToArrayAsync(); + await Assert.ThrowsAsync(async () => await xs); + } } }