diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs index 86240a6..f129082 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs @@ -65,7 +65,7 @@ namespace Cysharp.Threading.Tasks.Linq readonly int length; readonly IUniTaskAsyncEnumerator[] enumerators; - readonly int[] states; + readonly MergeSourceState[] states; readonly Queue<(T, Exception)> queuedResult = new Queue<(T, Exception)>(); readonly CancellationToken cancellationToken; @@ -75,7 +75,7 @@ namespace Cysharp.Threading.Tasks.Linq { this.cancellationToken = cancellationToken; length = sources.Length; - states = ArrayPool.Shared.Rent(length); + states = ArrayPool.Shared.Rent(length); enumerators = ArrayPool>.Shared.Rent(length); for (var i = 0; i < length; i++) { @@ -112,18 +112,26 @@ namespace Cysharp.Threading.Tasks.Linq for (var i = 0; i < length; i++) { - if (Interlocked.CompareExchange(ref states[i], (int)MergeSourceState.Running, (int)MergeSourceState.Pending) == (int)MergeSourceState.Pending) + lock (queuedResult) { - var awaiter = enumerators[i].MoveNextAsync().GetAwaiter(); - if (awaiter.IsCompleted) + if (states[i] == (int)MergeSourceState.Pending) { - GetResultAt(i, awaiter); + states[i] = MergeSourceState.Running; } else { - awaiter.SourceOnCompleted(GetResultAtAction, StateTuple.Create(this, i, awaiter)); + continue; } } + var awaiter = enumerators[i].MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + GetResultAt(i, awaiter); + } + else + { + awaiter.SourceOnCompleted(GetResultAtAction, StateTuple.Create(this, i, awaiter)); + } } return new UniTask(this, completionSource.Version); } @@ -135,7 +143,7 @@ namespace Cysharp.Threading.Tasks.Linq await enumerators[i].DisposeAsync(); } - ArrayPool.Shared.Return(states, true); + ArrayPool.Shared.Return(states, true); ArrayPool>.Shared.Return(enumerators, true); } @@ -153,7 +161,6 @@ namespace Cysharp.Threading.Tasks.Linq try { hasNext = awaiter.GetResult(); - Interlocked.Exchange(ref states[index], (int)(hasNext ? MergeSourceState.Pending : MergeSourceState.Completed)); } catch (Exception ex) { @@ -167,10 +174,11 @@ namespace Cysharp.Threading.Tasks.Linq return; } - var completedAll = IsCompletedAll(); - if (hasNext || completedAll) + lock (queuedResult) { - lock (queuedResult) + states[index] = hasNext ? MergeSourceState.Pending : MergeSourceState.Completed; + var completedAll = !hasNext && IsCompletedAll(); + if (hasNext || completedAll) { if (completionSource.GetStatus(completionSource.Version).IsCompleted()) { @@ -189,15 +197,12 @@ namespace Cysharp.Threading.Tasks.Linq { for (var i = 0; i < length; i++) { - if (states[i] != (int)MergeSourceState.Completed) + if (states[i] != MergeSourceState.Completed) { return false; } } - lock (queuedResult) - { - return queuedResult.Count <= 0; - } + return true; } } }