diff --git a/src/UniTask.NetCoreTests/Linq/Merge.cs b/src/UniTask.NetCoreTests/Linq/Merge.cs new file mode 100644 index 0000000..049ae5a --- /dev/null +++ b/src/UniTask.NetCoreTests/Linq/Merge.cs @@ -0,0 +1,137 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Cysharp.Threading.Tasks; +using Cysharp.Threading.Tasks.Linq; +using FluentAssertions; +using Xunit; + +namespace NetCoreTests.Linq +{ + public class MergeTest + { + [Fact] + public async Task TwoSource() + { + var semaphore = new SemaphoreSlim(1, 1); + + var a = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await UniTask.SwitchToThreadPool(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("A1"); + semaphore.Release(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("A2"); + semaphore.Release(); + }); + + var b = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await UniTask.SwitchToThreadPool(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("B1"); + await writer.YieldAsync("B2"); + semaphore.Release(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("B3"); + semaphore.Release(); + }); + + var result = await a.Merge(b).ToArrayAsync(); + result.Should().Equal("A1", "B1", "B2", "A2", "B3"); + } + + [Fact] + public async Task ThreeSource() + { + var semaphore = new SemaphoreSlim(0, 1); + + var a = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await UniTask.SwitchToThreadPool(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("A1"); + semaphore.Release(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("A2"); + semaphore.Release(); + }); + + var b = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await UniTask.SwitchToThreadPool(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("B1"); + await writer.YieldAsync("B2"); + semaphore.Release(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("B3"); + semaphore.Release(); + }); + + var c = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await UniTask.SwitchToThreadPool(); + + await writer.YieldAsync("C1"); + semaphore.Release(); + }); + + var result = await a.Merge(b, c).ToArrayAsync(); + result.Should().Equal("C1", "A1", "B1", "B2", "A2", "B3"); + } + + [Fact] + public async Task Throw() + { + var a = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await writer.YieldAsync("A1"); + + }); + + var b = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + throw new UniTaskTestException(); + }); + + var enumerator = a.Merge(b).GetAsyncEnumerator(); + (await enumerator.MoveNextAsync()).Should().Be(true); + enumerator.Current.Should().Be("A1"); + + await Assert.ThrowsAsync(async () => await enumerator.MoveNextAsync()); + } + + [Fact] + public async Task Cancel() + { + var cts = new CancellationTokenSource(); + + var a = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await writer.YieldAsync("A1"); + }); + + var b = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await writer.YieldAsync("B1"); + }); + + var enumerator = a.Merge(b).GetAsyncEnumerator(cts.Token); + (await enumerator.MoveNextAsync()).Should().Be(true); + enumerator.Current.Should().Be("A1"); + + cts.Cancel(); + await Assert.ThrowsAsync(async () => await enumerator.MoveNextAsync()); + } + } +} \ No newline at end of file diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Internal/Error.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Internal/Error.cs index 5c7bc93..9664491 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Internal/Error.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Internal/Error.cs @@ -39,7 +39,7 @@ namespace Cysharp.Threading.Tasks.Internal } [MethodImpl(MethodImplOptions.NoInlining)] - public static void ThrowArgumentException(string message) + public static void ThrowArgumentException(string message) { throw new ArgumentException(message); } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs new file mode 100644 index 0000000..5bc7649 --- /dev/null +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs @@ -0,0 +1,221 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using Cysharp.Threading.Tasks.Internal; + +namespace Cysharp.Threading.Tasks.Linq +{ + public static partial class UniTaskAsyncEnumerable + { + public static IUniTaskAsyncEnumerable Merge(this IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second) + { + Error.ThrowArgumentNullException(first, nameof(first)); + Error.ThrowArgumentNullException(second, nameof(second)); + + return new Merge(new [] { first, second }); + } + + public static IUniTaskAsyncEnumerable Merge(this IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second, IUniTaskAsyncEnumerable third) + { + Error.ThrowArgumentNullException(first, nameof(first)); + Error.ThrowArgumentNullException(second, nameof(second)); + Error.ThrowArgumentNullException(third, nameof(third)); + + return new Merge(new[] { first, second, third }); + } + + public static IUniTaskAsyncEnumerable Merge(this IEnumerable> sources) + { + return new Merge(sources.ToArray()); + } + + public static IUniTaskAsyncEnumerable Merge(params IUniTaskAsyncEnumerable[] sources) + { + return new Merge(sources); + } + } + + internal sealed class Merge : IUniTaskAsyncEnumerable + { + readonly IUniTaskAsyncEnumerable[] sources; + + public Merge(IUniTaskAsyncEnumerable[] sources) + { + if (sources.Length <= 0) + { + Error.ThrowArgumentException("No source async enumerable to merge"); + } + this.sources = sources; + } + + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + => new _Merge(sources, cancellationToken); + + enum MergeSourceState + { + Pending, + Running, + Completed, + } + + sealed class _Merge : MoveNextSource, IUniTaskAsyncEnumerator + { + static readonly Action GetResultAtAction = GetResultAt; + + readonly int length; + readonly IUniTaskAsyncEnumerator[] enumerators; + readonly MergeSourceState[] states; + readonly Queue<(T, Exception)> resultQueue = new Queue<(T, Exception)>(); + readonly CancellationToken cancellationToken; + + public T Current { get; private set; } + + public _Merge(IUniTaskAsyncEnumerable[] sources, CancellationToken cancellationToken) + { + this.cancellationToken = cancellationToken; + length = sources.Length; + states = ArrayPool.Shared.Rent(length); + enumerators = ArrayPool>.Shared.Rent(length); + for (var i = 0; i < length; i++) + { + enumerators[i] = sources[i].GetAsyncEnumerator(cancellationToken); + states[i] = MergeSourceState.Pending; + } + } + + public UniTask MoveNextAsync() + { + cancellationToken.ThrowIfCancellationRequested(); + completionSource.Reset(); + + if (TryDequeue(out var queuedValue, out var queuedException)) + { + if (queuedException != null) + { + completionSource.TrySetException(queuedException); + } + else + { + Current = queuedValue; + completionSource.TrySetResult(!IsCompletedAll()); + } + return new UniTask(this, completionSource.Version); + } + + for (var i = 0; i < length; i++) + { + lock (states) + { + if (states[i] != MergeSourceState.Pending) + { + continue; + } + states[i] = MergeSourceState.Running; + } + + var awaiter = enumerators[i].MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + GetResultAt(i, awaiter); + } + else + { + awaiter.SourceOnCompleted(GetResultAtAction, StateTuple.Create(this, i, awaiter)); + } + } + return new UniTask(this, completionSource.Version); + } + + public async UniTask DisposeAsync() + { + for (var i = 0; i < length; i++) + { + await enumerators[i].DisposeAsync(); + } + + ArrayPool.Shared.Return(states, true); + ArrayPool>.Shared.Return(enumerators, true); + } + + static void GetResultAt(object state) + { + var tuple = (StateTuple<_Merge, int, UniTask.Awaiter>)state; + tuple.Item1.GetResultAt(tuple.Item2, tuple.Item3); + } + + void GetResultAt(int index, UniTask.Awaiter awaiter) + { + bool hasNext; + try + { + hasNext = awaiter.GetResult(); + lock (states) + { + states[index] = hasNext ? MergeSourceState.Pending : MergeSourceState.Completed; + } + } + catch (Exception ex) + { + if (!completionSource.TrySetException(ex)) + { + lock (resultQueue) + { + resultQueue.Enqueue((default, ex)); + } + } + return; + } + + var completed = IsCompletedAll(); + if (hasNext || completed) + { + if (completionSource.GetStatus(completionSource.Version).IsCompleted()) + { + lock (resultQueue) + { + resultQueue.Enqueue((enumerators[index].Current, null)); + } + } + else + { + Current = enumerators[index].Current; + completionSource.TrySetResult(!completed); + } + } + } + + bool TryDequeue(out T value, out Exception ex) + { + lock (resultQueue) + { + if (resultQueue.Count > 0) + { + var result = resultQueue.Dequeue(); + value = result.Item1; + ex = result.Item2; + return true; + } + } + value = default; + ex = default; + return false; + } + + bool IsCompletedAll() + { + lock (states) + { + for (var i = 0; i < length; i++) + { + if (states[i] != MergeSourceState.Completed) + { + return false; + } + } + return true; + } + } + } + } +} \ No newline at end of file diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs.meta b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs.meta new file mode 100644 index 0000000..2f671f4 --- /dev/null +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: ca56812f160c45d0bacb4339819edf1a +timeCreated: 1694133666 \ No newline at end of file