From 12c507574e21a16abd9781e22fc4bcdc2c0d4c91 Mon Sep 17 00:00:00 2001 From: neuecc Date: Mon, 11 May 2020 17:33:41 +0900 Subject: [PATCH] Join --- src/UniTask.NetCore/Linq/Join.cs | 717 ++++++++++++++++++++++++- src/UniTask.NetCore/Linq/ToLookup.cs | 2 +- src/UniTask.NetCore/Linq/_FileMaker.cs | 29 - src/UniTask.NetCoreTests/Linq/Joins.cs | 116 ++++ 4 files changed, 833 insertions(+), 31 deletions(-) create mode 100644 src/UniTask.NetCoreTests/Linq/Joins.cs diff --git a/src/UniTask.NetCore/Linq/Join.cs b/src/UniTask.NetCore/Linq/Join.cs index 5f28270..1b42e2d 100644 --- a/src/UniTask.NetCore/Linq/Join.cs +++ b/src/UniTask.NetCore/Linq/Join.cs @@ -1 +1,716 @@ - \ No newline at end of file +using Cysharp.Threading.Tasks.Internal; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; + +namespace Cysharp.Threading.Tasks.Linq +{ + public static partial class UniTaskAsyncEnumerable + { + public static IUniTaskAsyncEnumerable Join(this IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector) + { + Error.ThrowArgumentNullException(outer, nameof(outer)); + Error.ThrowArgumentNullException(inner, nameof(inner)); + Error.ThrowArgumentNullException(outerKeySelector, nameof(outerKeySelector)); + Error.ThrowArgumentNullException(innerKeySelector, nameof(innerKeySelector)); + Error.ThrowArgumentNullException(resultSelector, nameof(resultSelector)); + + return new Join(outer, inner, outerKeySelector, innerKeySelector, resultSelector, EqualityComparer.Default); + } + + public static IUniTaskAsyncEnumerable Join(this IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector, IEqualityComparer comparer) + { + Error.ThrowArgumentNullException(outer, nameof(outer)); + Error.ThrowArgumentNullException(inner, nameof(inner)); + Error.ThrowArgumentNullException(outerKeySelector, nameof(outerKeySelector)); + Error.ThrowArgumentNullException(innerKeySelector, nameof(innerKeySelector)); + Error.ThrowArgumentNullException(resultSelector, nameof(resultSelector)); + Error.ThrowArgumentNullException(comparer, nameof(comparer)); + + return new Join(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer); + } + + public static IUniTaskAsyncEnumerable JoinAwait(this IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func> outerKeySelector, Func> innerKeySelector, Func> resultSelector) + { + Error.ThrowArgumentNullException(outer, nameof(outer)); + Error.ThrowArgumentNullException(inner, nameof(inner)); + Error.ThrowArgumentNullException(outerKeySelector, nameof(outerKeySelector)); + Error.ThrowArgumentNullException(innerKeySelector, nameof(innerKeySelector)); + Error.ThrowArgumentNullException(resultSelector, nameof(resultSelector)); + + return new JoinAwait(outer, inner, outerKeySelector, innerKeySelector, resultSelector, EqualityComparer.Default); + } + + public static IUniTaskAsyncEnumerable JoinAwait(this IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func> outerKeySelector, Func> innerKeySelector, Func> resultSelector, IEqualityComparer comparer) + { + Error.ThrowArgumentNullException(outer, nameof(outer)); + Error.ThrowArgumentNullException(inner, nameof(inner)); + Error.ThrowArgumentNullException(outerKeySelector, nameof(outerKeySelector)); + Error.ThrowArgumentNullException(innerKeySelector, nameof(innerKeySelector)); + Error.ThrowArgumentNullException(resultSelector, nameof(resultSelector)); + Error.ThrowArgumentNullException(comparer, nameof(comparer)); + + return new JoinAwait(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer); + } + + public static IUniTaskAsyncEnumerable JoinAwaitWithCancellation(this IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func> outerKeySelector, Func> innerKeySelector, Func> resultSelector) + { + Error.ThrowArgumentNullException(outer, nameof(outer)); + Error.ThrowArgumentNullException(inner, nameof(inner)); + Error.ThrowArgumentNullException(outerKeySelector, nameof(outerKeySelector)); + Error.ThrowArgumentNullException(innerKeySelector, nameof(innerKeySelector)); + Error.ThrowArgumentNullException(resultSelector, nameof(resultSelector)); + + return new JoinAwaitWithCancellation(outer, inner, outerKeySelector, innerKeySelector, resultSelector, EqualityComparer.Default); + } + + public static IUniTaskAsyncEnumerable JoinAwaitWithCancellation(this IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func> outerKeySelector, Func> innerKeySelector, Func> resultSelector, IEqualityComparer comparer) + { + Error.ThrowArgumentNullException(outer, nameof(outer)); + Error.ThrowArgumentNullException(inner, nameof(inner)); + Error.ThrowArgumentNullException(outerKeySelector, nameof(outerKeySelector)); + Error.ThrowArgumentNullException(innerKeySelector, nameof(innerKeySelector)); + Error.ThrowArgumentNullException(resultSelector, nameof(resultSelector)); + Error.ThrowArgumentNullException(comparer, nameof(comparer)); + + return new JoinAwaitWithCancellation(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer); + } + } + + internal sealed class Join : IUniTaskAsyncEnumerable + { + readonly IUniTaskAsyncEnumerable outer; + readonly IUniTaskAsyncEnumerable inner; + readonly Func outerKeySelector; + readonly Func innerKeySelector; + readonly Func resultSelector; + readonly IEqualityComparer comparer; + + public Join(IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector, IEqualityComparer comparer) + { + this.outer = outer; + this.inner = inner; + this.outerKeySelector = outerKeySelector; + this.innerKeySelector = innerKeySelector; + this.resultSelector = resultSelector; + this.comparer = comparer; + } + + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new Enumerator(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer, cancellationToken); + } + + sealed class Enumerator : MoveNextSource, IUniTaskAsyncEnumerator + { + static readonly Action MoveNextCoreDelegate = MoveNextCore; + + readonly IUniTaskAsyncEnumerable outer; + readonly IUniTaskAsyncEnumerable inner; + readonly Func outerKeySelector; + readonly Func innerKeySelector; + readonly Func resultSelector; + readonly IEqualityComparer comparer; + CancellationToken cancellationToken; + + ILookup lookup; + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; + TOuter currentOuterValue; + IEnumerator valueEnumerator; + + bool continueNext; + + public Enumerator(IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector, IEqualityComparer comparer, CancellationToken cancellationToken) + { + this.outer = outer; + this.inner = inner; + this.outerKeySelector = outerKeySelector; + this.innerKeySelector = innerKeySelector; + this.resultSelector = resultSelector; + this.cancellationToken = cancellationToken; + } + + public TResult Current { get; private set; } + + public UniTask MoveNextAsync() + { + cancellationToken.ThrowIfCancellationRequested(); + completionSource.Reset(); + + if (lookup == null) + { + CreateInnerHashSet().Forget(); + } + else + { + SourceMoveNext(); + } + return new UniTask(this, completionSource.Version); + } + + async UniTaskVoid CreateInnerHashSet() + { + try + { + lookup = await inner.ToLookupAsync(innerKeySelector, cancellationToken); + enumerator = outer.GetAsyncEnumerator(cancellationToken); + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + return; + } + SourceMoveNext(); + } + + void SourceMoveNext() + { + try + { + LOOP: + if (valueEnumerator != null) + { + if (valueEnumerator.MoveNext()) + { + Current = resultSelector(currentOuterValue, valueEnumerator.Current); + goto TRY_SET_RESULT_TRUE; + } + else + { + valueEnumerator = null; + } + } + + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + continueNext = true; + MoveNextCore(this); + if (continueNext) + { + continueNext = false; + goto LOOP; // avoid recursive + } + } + else + { + awaiter.SourceOnCompleted(MoveNextCoreDelegate, this); + } + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + } + + return; + + TRY_SET_RESULT_TRUE: + completionSource.TrySetResult(true); + } + + + static void MoveNextCore(object state) + { + var self = (Enumerator)state; + + if (self.TryGetResult(self.awaiter, out var result)) + { + if (result) + { + self.currentOuterValue = self.enumerator.Current; + var key = self.outerKeySelector(self.currentOuterValue); + self.valueEnumerator = self.lookup[key].GetEnumerator(); + + if (self.continueNext) + { + return; + } + else + { + self.SourceMoveNext(); + } + } + else + { + self.continueNext = false; + self.completionSource.TrySetResult(false); + } + } + else + { + self.continueNext = false; + } + } + + public UniTask DisposeAsync() + { + if (valueEnumerator != null) + { + valueEnumerator.Dispose(); + } + + if (enumerator != null) + { + return enumerator.DisposeAsync(); + } + + return default; + } + } + } + + internal sealed class JoinAwait : IUniTaskAsyncEnumerable + { + readonly IUniTaskAsyncEnumerable outer; + readonly IUniTaskAsyncEnumerable inner; + readonly Func> outerKeySelector; + readonly Func> innerKeySelector; + readonly Func> resultSelector; + readonly IEqualityComparer comparer; + + public JoinAwait(IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func> outerKeySelector, Func> innerKeySelector, Func> resultSelector, IEqualityComparer comparer) + { + this.outer = outer; + this.inner = inner; + this.outerKeySelector = outerKeySelector; + this.innerKeySelector = innerKeySelector; + this.resultSelector = resultSelector; + this.comparer = comparer; + } + + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new Enumerator(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer, cancellationToken); + } + + sealed class Enumerator : MoveNextSource, IUniTaskAsyncEnumerator + { + static readonly Action MoveNextCoreDelegate = MoveNextCore; + static readonly Action OuterSelectCoreDelegate = OuterSelectCore; + static readonly Action ResultSelectCoreDelegate = ResultSelectCore; + + readonly IUniTaskAsyncEnumerable outer; + readonly IUniTaskAsyncEnumerable inner; + readonly Func> outerKeySelector; + readonly Func> innerKeySelector; + readonly Func> resultSelector; + readonly IEqualityComparer comparer; + CancellationToken cancellationToken; + + ILookup lookup; + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; + TOuter currentOuterValue; + IEnumerator valueEnumerator; + + UniTask.Awaiter resultAwaiter; + UniTask.Awaiter outerKeyAwaiter; + + bool continueNext; + + public Enumerator(IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func> outerKeySelector, Func> innerKeySelector, Func> resultSelector, IEqualityComparer comparer, CancellationToken cancellationToken) + { + this.outer = outer; + this.inner = inner; + this.outerKeySelector = outerKeySelector; + this.innerKeySelector = innerKeySelector; + this.resultSelector = resultSelector; + this.cancellationToken = cancellationToken; + } + + public TResult Current { get; private set; } + + public UniTask MoveNextAsync() + { + cancellationToken.ThrowIfCancellationRequested(); + completionSource.Reset(); + + if (lookup == null) + { + CreateInnerHashSet().Forget(); + } + else + { + SourceMoveNext(); + } + return new UniTask(this, completionSource.Version); + } + + async UniTaskVoid CreateInnerHashSet() + { + try + { + lookup = await inner.ToLookupAwaitAsync(innerKeySelector, cancellationToken); + enumerator = outer.GetAsyncEnumerator(cancellationToken); + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + return; + } + SourceMoveNext(); + } + + void SourceMoveNext() + { + try + { + LOOP: + if (valueEnumerator != null) + { + if (valueEnumerator.MoveNext()) + { + resultAwaiter = resultSelector(currentOuterValue, valueEnumerator.Current).GetAwaiter(); + if (resultAwaiter.IsCompleted) + { + ResultSelectCore(this); + } + else + { + resultAwaiter.SourceOnCompleted(ResultSelectCoreDelegate, this); + } + return; + } + else + { + valueEnumerator = null; + } + } + + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + continueNext = true; + MoveNextCore(this); + if (continueNext) + { + continueNext = false; + goto LOOP; // avoid recursive + } + } + else + { + awaiter.SourceOnCompleted(MoveNextCoreDelegate, this); + } + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + } + } + + + static void MoveNextCore(object state) + { + var self = (Enumerator)state; + + if (self.TryGetResult(self.awaiter, out var result)) + { + if (result) + { + self.currentOuterValue = self.enumerator.Current; + + self.outerKeyAwaiter = self.outerKeySelector(self.currentOuterValue).GetAwaiter(); + + if (self.outerKeyAwaiter.IsCompleted) + { + OuterSelectCore(self); + } + else + { + self.continueNext = false; + self.outerKeyAwaiter.SourceOnCompleted(OuterSelectCoreDelegate, self); + } + } + else + { + self.continueNext = false; + self.completionSource.TrySetResult(false); + } + } + else + { + self.continueNext = false; + } + } + + static void OuterSelectCore(object state) + { + var self = (Enumerator)state; + + if (self.TryGetResult(self.outerKeyAwaiter, out var key)) + { + self.valueEnumerator = self.lookup[key].GetEnumerator(); + + if (self.continueNext) + { + return; + } + else + { + self.SourceMoveNext(); + } + } + else + { + self.continueNext = false; + } + } + + static void ResultSelectCore(object state) + { + var self = (Enumerator)state; + + if (self.TryGetResult(self.resultAwaiter, out var result)) + { + self.Current = result; + self.completionSource.TrySetResult(true); + } + } + + public UniTask DisposeAsync() + { + if (valueEnumerator != null) + { + valueEnumerator.Dispose(); + } + + if (enumerator != null) + { + return enumerator.DisposeAsync(); + } + + return default; + } + } + } + + internal sealed class JoinAwaitWithCancellation : IUniTaskAsyncEnumerable + { + readonly IUniTaskAsyncEnumerable outer; + readonly IUniTaskAsyncEnumerable inner; + readonly Func> outerKeySelector; + readonly Func> innerKeySelector; + readonly Func> resultSelector; + readonly IEqualityComparer comparer; + + public JoinAwaitWithCancellation(IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func> outerKeySelector, Func> innerKeySelector, Func> resultSelector, IEqualityComparer comparer) + { + this.outer = outer; + this.inner = inner; + this.outerKeySelector = outerKeySelector; + this.innerKeySelector = innerKeySelector; + this.resultSelector = resultSelector; + this.comparer = comparer; + } + + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new Enumerator(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer, cancellationToken); + } + + sealed class Enumerator : MoveNextSource, IUniTaskAsyncEnumerator + { + static readonly Action MoveNextCoreDelegate = MoveNextCore; + static readonly Action OuterSelectCoreDelegate = OuterSelectCore; + static readonly Action ResultSelectCoreDelegate = ResultSelectCore; + + readonly IUniTaskAsyncEnumerable outer; + readonly IUniTaskAsyncEnumerable inner; + readonly Func> outerKeySelector; + readonly Func> innerKeySelector; + readonly Func> resultSelector; + readonly IEqualityComparer comparer; + CancellationToken cancellationToken; + + ILookup lookup; + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; + TOuter currentOuterValue; + IEnumerator valueEnumerator; + + UniTask.Awaiter resultAwaiter; + UniTask.Awaiter outerKeyAwaiter; + + bool continueNext; + + public Enumerator(IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func> outerKeySelector, Func> innerKeySelector, Func> resultSelector, IEqualityComparer comparer, CancellationToken cancellationToken) + { + this.outer = outer; + this.inner = inner; + this.outerKeySelector = outerKeySelector; + this.innerKeySelector = innerKeySelector; + this.resultSelector = resultSelector; + this.cancellationToken = cancellationToken; + } + + public TResult Current { get; private set; } + + public UniTask MoveNextAsync() + { + cancellationToken.ThrowIfCancellationRequested(); + completionSource.Reset(); + + if (lookup == null) + { + CreateInnerHashSet().Forget(); + } + else + { + SourceMoveNext(); + } + return new UniTask(this, completionSource.Version); + } + + async UniTaskVoid CreateInnerHashSet() + { + try + { + lookup = await inner.ToLookupAwaitWithCancellationAsync(innerKeySelector, cancellationToken: cancellationToken); + enumerator = outer.GetAsyncEnumerator(cancellationToken); + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + return; + } + SourceMoveNext(); + } + + void SourceMoveNext() + { + try + { + LOOP: + if (valueEnumerator != null) + { + if (valueEnumerator.MoveNext()) + { + resultAwaiter = resultSelector(currentOuterValue, valueEnumerator.Current, cancellationToken).GetAwaiter(); + if (resultAwaiter.IsCompleted) + { + ResultSelectCore(this); + } + else + { + resultAwaiter.SourceOnCompleted(ResultSelectCoreDelegate, this); + } + return; + } + else + { + valueEnumerator = null; + } + } + + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + continueNext = true; + MoveNextCore(this); + if (continueNext) + { + continueNext = false; + goto LOOP; // avoid recursive + } + } + else + { + awaiter.SourceOnCompleted(MoveNextCoreDelegate, this); + } + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + } + } + + + static void MoveNextCore(object state) + { + var self = (Enumerator)state; + + if (self.TryGetResult(self.awaiter, out var result)) + { + if (result) + { + self.currentOuterValue = self.enumerator.Current; + + self.outerKeyAwaiter = self.outerKeySelector(self.currentOuterValue, self.cancellationToken).GetAwaiter(); + + if (self.outerKeyAwaiter.IsCompleted) + { + OuterSelectCore(self); + } + else + { + self.continueNext = false; + self.outerKeyAwaiter.SourceOnCompleted(OuterSelectCoreDelegate, self); + } + } + else + { + self.continueNext = false; + self.completionSource.TrySetResult(false); + } + } + else + { + self.continueNext = false; + } + } + + static void OuterSelectCore(object state) + { + var self = (Enumerator)state; + + if (self.TryGetResult(self.outerKeyAwaiter, out var key)) + { + self.valueEnumerator = self.lookup[key].GetEnumerator(); + + if (self.continueNext) + { + return; + } + else + { + self.SourceMoveNext(); + } + } + else + { + self.continueNext = false; + } + } + + static void ResultSelectCore(object state) + { + var self = (Enumerator)state; + + if (self.TryGetResult(self.resultAwaiter, out var result)) + { + self.Current = result; + self.completionSource.TrySetResult(true); + } + } + + public UniTask DisposeAsync() + { + if (valueEnumerator != null) + { + valueEnumerator.Dispose(); + } + + if (enumerator != null) + { + return enumerator.DisposeAsync(); + } + + return default; + } + } + } + +} \ No newline at end of file diff --git a/src/UniTask.NetCore/Linq/ToLookup.cs b/src/UniTask.NetCore/Linq/ToLookup.cs index d0ccbaa..dfd9faa 100644 --- a/src/UniTask.NetCore/Linq/ToLookup.cs +++ b/src/UniTask.NetCore/Linq/ToLookup.cs @@ -494,7 +494,7 @@ namespace Cysharp.Threading.Tasks.Linq return new Lookup(dict); } - public IEnumerable this[TKey key] => dict[key]; + public IEnumerable this[TKey key] => dict.TryGetValue(key, out var g) ? g : Enumerable.Empty(); public int Count => dict.Count; diff --git a/src/UniTask.NetCore/Linq/_FileMaker.cs b/src/UniTask.NetCore/Linq/_FileMaker.cs index 96f13ab..840e767 100644 --- a/src/UniTask.NetCore/Linq/_FileMaker.cs +++ b/src/UniTask.NetCore/Linq/_FileMaker.cs @@ -180,35 +180,6 @@ namespace ___Dummy throw new NotImplementedException(); } - public static IUniTaskAsyncEnumerable Join(this IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable Join(this IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector, IEqualityComparer comparer) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable JoinAwait(this IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func> outerKeySelector, Func> innerKeySelector, Func> resultSelector) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable JoinAwait(this IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func> outerKeySelector, Func> innerKeySelector, Func> resultSelector, IEqualityComparer comparer) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable JoinAwaitWithCancellation(this IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func> outerKeySelector, Func> innerKeySelector, Func> resultSelector) - { - throw new NotImplementedException(); - } - - public static IUniTaskAsyncEnumerable JoinAwaitWithCancellation(this IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func> outerKeySelector, Func> innerKeySelector, Func> resultSelector, IEqualityComparer comparer) - { - throw new NotImplementedException(); - } diff --git a/src/UniTask.NetCoreTests/Linq/Joins.cs b/src/UniTask.NetCoreTests/Linq/Joins.cs new file mode 100644 index 0000000..bfd1099 --- /dev/null +++ b/src/UniTask.NetCoreTests/Linq/Joins.cs @@ -0,0 +1,116 @@ +using Cysharp.Threading.Tasks; +using Cysharp.Threading.Tasks.Linq; +using FluentAssertions; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reactive.Concurrency; +using System.Reactive.Linq; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + + +namespace NetCoreTests.Linq +{ + public class Joins + { + static int rd; + + static UniTask RandomRun(T value) + { + if (Interlocked.Increment(ref rd) % 2 == 0) + { + return UniTask.Run(() => value); + } + else + { + return UniTask.FromResult(value); + } + } + + [Fact] + public async Task Join() + { + var outer = new[] { 1, 2, 4, 5, 8, 10, 14, 4, 8, 1, 2, 10 }; + var inner = new[] { 1, 2, 1, 2, 1, 14, 2 }; + + var xs = await outer.ToUniTaskAsyncEnumerable().Join(inner.ToUniTaskAsyncEnumerable(), x => x, x => x, (x, y) => (x, y)).ToArrayAsync(); + var ys = outer.Join(inner, x => x, x => x, (x, y) => (x, y)).ToArray(); + + xs.Should().BeEquivalentTo(ys); + } + + + [Fact] + public async Task JoinThrow() + { + var outer = new[] { 1, 2, 4, 5, 8, 10, 14, 4, 8, 1, 2, 10 }; + var inner = new[] { 1, 2, 1, 2, 1, 14, 2 }; + foreach (var item in UniTaskTestException.Throws()) + { + var xs = outer.ToUniTaskAsyncEnumerable().Join(item, x => x, x => x, (x, y) => x + y).ToArrayAsync(); + var ys = item.Join(inner.ToUniTaskAsyncEnumerable(), x => x, x => x, (x, y) => x + y).ToArrayAsync(); + + await Assert.ThrowsAsync(async () => await xs); + await Assert.ThrowsAsync(async () => await ys); + } + } + + [Fact] + public async Task JoinAwait() + { + var outer = new[] { 1, 2, 4, 5, 8, 10, 14, 4, 8, 1, 2, 10 }; + var inner = new[] { 1, 2, 1, 2, 1, 14, 2 }; + + var xs = await outer.ToUniTaskAsyncEnumerable().JoinAwait(inner.ToUniTaskAsyncEnumerable(), x => RandomRun(x), x => RandomRun(x), (x, y) => RandomRun((x, y))).ToArrayAsync(); + var ys = outer.Join(inner, x => x, x => x, (x, y) => (x, y)).ToArray(); + + xs.Should().BeEquivalentTo(ys); + } + + + [Fact] + public async Task JoinAwaitThrow() + { + var outer = new[] { 1, 2, 4, 5, 8, 10, 14, 4, 8, 1, 2, 10 }; + var inner = new[] { 1, 2, 1, 2, 1, 14, 2 }; + foreach (var item in UniTaskTestException.Throws()) + { + var xs = outer.ToUniTaskAsyncEnumerable().JoinAwait(item, x => RandomRun(x), x => RandomRun(x), (x, y) => RandomRun(x + y)).ToArrayAsync(); + var ys = item.Join(inner.ToUniTaskAsyncEnumerable(), x => x, x => x, (x, y) => x + y).ToArrayAsync(); + + await Assert.ThrowsAsync(async () => await xs); + await Assert.ThrowsAsync(async () => await ys); + } + } + + [Fact] + public async Task JoinAwaitCt() + { + var outer = new[] { 1, 2, 4, 5, 8, 10, 14, 4, 8, 1, 2, 10 }; + var inner = new[] { 1, 2, 1, 2, 1, 14, 2 }; + + var xs = await outer.ToUniTaskAsyncEnumerable().JoinAwaitWithCancellation(inner.ToUniTaskAsyncEnumerable(), (x, _) => RandomRun(x), (x, _) => RandomRun(x), (x, y, _) => RandomRun((x, y))).ToArrayAsync(); + var ys = outer.Join(inner, x => x, x => x, (x, y) => (x, y)).ToArray(); + + xs.Should().BeEquivalentTo(ys); + } + + + [Fact] + public async Task JoinAwaitCtThrow() + { + var outer = new[] { 1, 2, 4, 5, 8, 10, 14, 4, 8, 1, 2, 10 }; + var inner = new[] { 1, 2, 1, 2, 1, 14, 2 }; + foreach (var item in UniTaskTestException.Throws()) + { + var xs = outer.ToUniTaskAsyncEnumerable().JoinAwaitWithCancellation(item, (x, _) => RandomRun(x), (x, _) => RandomRun(x), (x, y, _) => RandomRun(x + y)).ToArrayAsync(); + var ys = item.JoinAwaitWithCancellation(inner.ToUniTaskAsyncEnumerable(), (x, _) => RandomRun(x), (x, _) => RandomRun(x), (x, y, _) => RandomRun(x + y)).ToArrayAsync(); + + await Assert.ThrowsAsync(async () => await xs); + await Assert.ThrowsAsync(async () => await ys); + } + } + } +}