diff --git a/README.md b/README.md index 1dd27d0..ce50fae 100644 --- a/README.md +++ b/README.md @@ -723,7 +723,7 @@ Async LINQ is enabled when `using Cysharp.Threading.Tasks.Linq;`, and `UniTaskAs It's closer to UniRx (Reactive Extensions), but UniTaskAsyncEnumerable is a pull-based asynchronous stream, whereas Rx was a push-based asynchronous stream. Note that although similar, the characteristics are different and the details behave differently along with them. -`UniTaskAsyncEnumerable` is the entry point like `Enumerable`. In addition to the standard query operators, there are other generators for Unity such as `EveryUpdate`, `Timer`, `TimerFrame`, `Interval`, `IntervalFrame`, and `EveryValueChanged`. And also added additional UniTask original query operators like `Append`, `Prepend`, `DistinctUntilChanged`, `ToHashSet`, `Buffer`, `CombineLatest`, `Do`, `Never`, `ForEachAsync`, `Pairwise`, `Publish`, `Queue`, `Return`, `SkipUntil`, `TakeUntil`, `SkipUntilCanceled`, `TakeUntilCanceled`, `TakeLast`, `Subscribe`. +`UniTaskAsyncEnumerable` is the entry point like `Enumerable`. In addition to the standard query operators, there are other generators for Unity such as `EveryUpdate`, `Timer`, `TimerFrame`, `Interval`, `IntervalFrame`, and `EveryValueChanged`. And also added additional UniTask original query operators like `Append`, `Prepend`, `DistinctUntilChanged`, `ToHashSet`, `Buffer`, `CombineLatest`,`Merge` `Do`, `Never`, `ForEachAsync`, `Pairwise`, `Publish`, `Queue`, `Return`, `SkipUntil`, `TakeUntil`, `SkipUntilCanceled`, `TakeUntilCanceled`, `TakeLast`, `Subscribe`. The method with Func as an argument has three additional overloads, `***Await`, `***AwaitWithCancellation`. diff --git a/src/UniTask.NetCoreTests/Linq/Merge.cs b/src/UniTask.NetCoreTests/Linq/Merge.cs new file mode 100644 index 0000000..7021d1d --- /dev/null +++ b/src/UniTask.NetCoreTests/Linq/Merge.cs @@ -0,0 +1,137 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Cysharp.Threading.Tasks; +using Cysharp.Threading.Tasks.Linq; +using FluentAssertions; +using Xunit; + +namespace NetCoreTests.Linq +{ + public class MergeTest + { + [Fact] + public async Task TwoSource() + { + var semaphore = new SemaphoreSlim(1, 1); + + var a = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await UniTask.SwitchToThreadPool(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("A1"); + semaphore.Release(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("A2"); + semaphore.Release(); + }); + + var b = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await UniTask.SwitchToThreadPool(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("B1"); + await writer.YieldAsync("B2"); + semaphore.Release(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("B3"); + semaphore.Release(); + }); + + var result = await a.Merge(b).ToArrayAsync(); + result.Should().Equal("A1", "B1", "B2", "A2", "B3"); + } + + [Fact] + public async Task ThreeSource() + { + var semaphore = new SemaphoreSlim(0, 1); + + var a = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await UniTask.SwitchToThreadPool(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("A1"); + semaphore.Release(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("A2"); + semaphore.Release(); + }); + + var b = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await UniTask.SwitchToThreadPool(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("B1"); + await writer.YieldAsync("B2"); + semaphore.Release(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("B3"); + semaphore.Release(); + }); + + var c = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await UniTask.SwitchToThreadPool(); + + await writer.YieldAsync("C1"); + semaphore.Release(); + }); + + var result = await a.Merge(b, c).ToArrayAsync(); + result.Should().Equal("C1", "A1", "B1", "B2", "A2", "B3"); + } + + [Fact] + public async Task Throw() + { + var a = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await writer.YieldAsync("A1"); + + }); + + var b = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + throw new UniTaskTestException(); + }); + + var enumerator = a.Merge(b).GetAsyncEnumerator(); + (await enumerator.MoveNextAsync()).Should().Be(true); + enumerator.Current.Should().Be("A1"); + + await Assert.ThrowsAsync(async () => await enumerator.MoveNextAsync()); + } + + [Fact] + public async Task Cancel() + { + var cts = new CancellationTokenSource(); + + var a = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await writer.YieldAsync("A1"); + }); + + var b = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await writer.YieldAsync("B1"); + }); + + var enumerator = a.Merge(b).GetAsyncEnumerator(cts.Token); + (await enumerator.MoveNextAsync()).Should().Be(true); + enumerator.Current.Should().Be("A1"); + + cts.Cancel(); + await Assert.ThrowsAsync(async () => await enumerator.MoveNextAsync()); + } + } +} \ No newline at end of file diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Internal/Error.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Internal/Error.cs index 5c7bc93..9664491 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Internal/Error.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Internal/Error.cs @@ -39,7 +39,7 @@ namespace Cysharp.Threading.Tasks.Internal } [MethodImpl(MethodImplOptions.NoInlining)] - public static void ThrowArgumentException(string message) + public static void ThrowArgumentException(string message) { throw new ArgumentException(message); } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs new file mode 100644 index 0000000..d4ea969 --- /dev/null +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs @@ -0,0 +1,232 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using Cysharp.Threading.Tasks.Internal; + +namespace Cysharp.Threading.Tasks.Linq +{ + public static partial class UniTaskAsyncEnumerable + { + public static IUniTaskAsyncEnumerable Merge(this IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second) + { + Error.ThrowArgumentNullException(first, nameof(first)); + Error.ThrowArgumentNullException(second, nameof(second)); + + return new Merge(new [] { first, second }); + } + + public static IUniTaskAsyncEnumerable Merge(this IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second, IUniTaskAsyncEnumerable third) + { + Error.ThrowArgumentNullException(first, nameof(first)); + Error.ThrowArgumentNullException(second, nameof(second)); + Error.ThrowArgumentNullException(third, nameof(third)); + + return new Merge(new[] { first, second, third }); + } + + public static IUniTaskAsyncEnumerable Merge(this IEnumerable> sources) + { + return new Merge(sources.ToArray()); + } + + public static IUniTaskAsyncEnumerable Merge(params IUniTaskAsyncEnumerable[] sources) + { + return new Merge(sources); + } + } + + internal sealed class Merge : IUniTaskAsyncEnumerable + { + readonly IUniTaskAsyncEnumerable[] sources; + + public Merge(IUniTaskAsyncEnumerable[] sources) + { + if (sources.Length <= 0) + { + Error.ThrowArgumentException("No source async enumerable to merge"); + } + this.sources = sources; + } + + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + => new _Merge(sources, cancellationToken); + + enum MergeSourceState + { + Pending, + Running, + Completed, + } + + sealed class _Merge : MoveNextSource, IUniTaskAsyncEnumerator + { + static readonly Action GetResultAtAction = GetResultAt; + + readonly int length; + readonly IUniTaskAsyncEnumerator[] enumerators; + readonly MergeSourceState[] states; + readonly Queue<(T, Exception, bool)> queuedResult = new Queue<(T, Exception, bool)>(); + readonly CancellationToken cancellationToken; + + int moveNextCompleted; + + public T Current { get; private set; } + + public _Merge(IUniTaskAsyncEnumerable[] sources, CancellationToken cancellationToken) + { + this.cancellationToken = cancellationToken; + length = sources.Length; + states = ArrayPool.Shared.Rent(length); + enumerators = ArrayPool>.Shared.Rent(length); + for (var i = 0; i < length; i++) + { + enumerators[i] = sources[i].GetAsyncEnumerator(cancellationToken); + states[i] = (int)MergeSourceState.Pending;; + } + } + + public UniTask MoveNextAsync() + { + cancellationToken.ThrowIfCancellationRequested(); + completionSource.Reset(); + Interlocked.Exchange(ref moveNextCompleted, 0); + + if (HasQueuedResult() && Interlocked.CompareExchange(ref moveNextCompleted, 1, 0) == 0) + { + (T, Exception, bool) value; + lock (states) + { + value = queuedResult.Dequeue(); + } + var resultValue = value.Item1; + var exception = value.Item2; + var hasNext = value.Item3; + if (exception != null) + { + completionSource.TrySetException(exception); + } + else + { + Current = resultValue; + completionSource.TrySetResult(hasNext); + } + return new UniTask(this, completionSource.Version); + } + + for (var i = 0; i < length; i++) + { + lock (states) + { + if (states[i] == MergeSourceState.Pending) + { + states[i] = MergeSourceState.Running; + } + else + { + continue; + } + } + var awaiter = enumerators[i].MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + GetResultAt(i, awaiter); + } + else + { + awaiter.SourceOnCompleted(GetResultAtAction, StateTuple.Create(this, i, awaiter)); + } + } + return new UniTask(this, completionSource.Version); + } + + public async UniTask DisposeAsync() + { + for (var i = 0; i < length; i++) + { + await enumerators[i].DisposeAsync(); + } + + ArrayPool.Shared.Return(states, true); + ArrayPool>.Shared.Return(enumerators, true); + } + + static void GetResultAt(object state) + { + using (var tuple = (StateTuple<_Merge, int, UniTask.Awaiter>)state) + { + tuple.Item1.GetResultAt(tuple.Item2, tuple.Item3); + } + } + + void GetResultAt(int index, UniTask.Awaiter awaiter) + { + bool hasNext; + bool completedAll; + try + { + hasNext = awaiter.GetResult(); + } + catch (Exception ex) + { + if (Interlocked.CompareExchange(ref moveNextCompleted, 1, 0) == 0) + { + completionSource.TrySetException(ex); + } + else + { + lock (states) + { + queuedResult.Enqueue((default, ex, default)); + } + } + return; + } + + lock (states) + { + states[index] = hasNext ? MergeSourceState.Pending : MergeSourceState.Completed; + completedAll = !hasNext && IsCompletedAll(); + } + if (hasNext || completedAll) + { + if (Interlocked.CompareExchange(ref moveNextCompleted, 1, 0) == 0) + { + Current = enumerators[index].Current; + completionSource.TrySetResult(!completedAll); + } + else + { + lock (states) + { + queuedResult.Enqueue((enumerators[index].Current, null, !completedAll)); + } + } + } + } + + bool HasQueuedResult() + { + lock (states) + { + return queuedResult.Count > 0; + } + } + + bool IsCompletedAll() + { + lock (states) + { + for (var i = 0; i < length; i++) + { + if (states[i] != MergeSourceState.Completed) + { + return false; + } + } + } + return true; + } + } + } +} \ No newline at end of file diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs.meta b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs.meta new file mode 100644 index 0000000..2f671f4 --- /dev/null +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: ca56812f160c45d0bacb4339819edf1a +timeCreated: 1694133666 \ No newline at end of file