Fix race condition
parent
3bac16229f
commit
937d3adf66
|
@ -65,7 +65,7 @@ namespace Cysharp.Threading.Tasks.Linq
|
|||
|
||||
readonly int length;
|
||||
readonly IUniTaskAsyncEnumerator<T>[] enumerators;
|
||||
readonly int[] states;
|
||||
readonly MergeSourceState[] states;
|
||||
readonly Queue<(T, Exception)> queuedResult = new Queue<(T, Exception)>();
|
||||
readonly CancellationToken cancellationToken;
|
||||
|
||||
|
@ -75,7 +75,7 @@ namespace Cysharp.Threading.Tasks.Linq
|
|||
{
|
||||
this.cancellationToken = cancellationToken;
|
||||
length = sources.Length;
|
||||
states = ArrayPool<int>.Shared.Rent(length);
|
||||
states = ArrayPool<MergeSourceState>.Shared.Rent(length);
|
||||
enumerators = ArrayPool<IUniTaskAsyncEnumerator<T>>.Shared.Rent(length);
|
||||
for (var i = 0; i < length; i++)
|
||||
{
|
||||
|
@ -112,18 +112,26 @@ namespace Cysharp.Threading.Tasks.Linq
|
|||
|
||||
for (var i = 0; i < length; i++)
|
||||
{
|
||||
if (Interlocked.CompareExchange(ref states[i], (int)MergeSourceState.Running, (int)MergeSourceState.Pending) == (int)MergeSourceState.Pending)
|
||||
lock (queuedResult)
|
||||
{
|
||||
var awaiter = enumerators[i].MoveNextAsync().GetAwaiter();
|
||||
if (awaiter.IsCompleted)
|
||||
if (states[i] == (int)MergeSourceState.Pending)
|
||||
{
|
||||
GetResultAt(i, awaiter);
|
||||
states[i] = MergeSourceState.Running;
|
||||
}
|
||||
else
|
||||
{
|
||||
awaiter.SourceOnCompleted(GetResultAtAction, StateTuple.Create(this, i, awaiter));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
var awaiter = enumerators[i].MoveNextAsync().GetAwaiter();
|
||||
if (awaiter.IsCompleted)
|
||||
{
|
||||
GetResultAt(i, awaiter);
|
||||
}
|
||||
else
|
||||
{
|
||||
awaiter.SourceOnCompleted(GetResultAtAction, StateTuple.Create(this, i, awaiter));
|
||||
}
|
||||
}
|
||||
return new UniTask<bool>(this, completionSource.Version);
|
||||
}
|
||||
|
@ -135,7 +143,7 @@ namespace Cysharp.Threading.Tasks.Linq
|
|||
await enumerators[i].DisposeAsync();
|
||||
}
|
||||
|
||||
ArrayPool<int>.Shared.Return(states, true);
|
||||
ArrayPool<MergeSourceState>.Shared.Return(states, true);
|
||||
ArrayPool<IUniTaskAsyncEnumerator<T>>.Shared.Return(enumerators, true);
|
||||
}
|
||||
|
||||
|
@ -153,7 +161,6 @@ namespace Cysharp.Threading.Tasks.Linq
|
|||
try
|
||||
{
|
||||
hasNext = awaiter.GetResult();
|
||||
Interlocked.Exchange(ref states[index], (int)(hasNext ? MergeSourceState.Pending : MergeSourceState.Completed));
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
|
@ -167,10 +174,11 @@ namespace Cysharp.Threading.Tasks.Linq
|
|||
return;
|
||||
}
|
||||
|
||||
var completedAll = IsCompletedAll();
|
||||
if (hasNext || completedAll)
|
||||
lock (queuedResult)
|
||||
{
|
||||
lock (queuedResult)
|
||||
states[index] = hasNext ? MergeSourceState.Pending : MergeSourceState.Completed;
|
||||
var completedAll = !hasNext && IsCompletedAll();
|
||||
if (hasNext || completedAll)
|
||||
{
|
||||
if (completionSource.GetStatus(completionSource.Version).IsCompleted())
|
||||
{
|
||||
|
@ -189,15 +197,12 @@ namespace Cysharp.Threading.Tasks.Linq
|
|||
{
|
||||
for (var i = 0; i < length; i++)
|
||||
{
|
||||
if (states[i] != (int)MergeSourceState.Completed)
|
||||
if (states[i] != MergeSourceState.Completed)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
lock (queuedResult)
|
||||
{
|
||||
return queuedResult.Count <= 0;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue