Fix race condition

master
hadashiA 2023-09-12 14:34:53 +09:00
parent 3bac16229f
commit 937d3adf66
1 changed files with 22 additions and 17 deletions

View File

@ -65,7 +65,7 @@ namespace Cysharp.Threading.Tasks.Linq
readonly int length; readonly int length;
readonly IUniTaskAsyncEnumerator<T>[] enumerators; readonly IUniTaskAsyncEnumerator<T>[] enumerators;
readonly int[] states; readonly MergeSourceState[] states;
readonly Queue<(T, Exception)> queuedResult = new Queue<(T, Exception)>(); readonly Queue<(T, Exception)> queuedResult = new Queue<(T, Exception)>();
readonly CancellationToken cancellationToken; readonly CancellationToken cancellationToken;
@ -75,7 +75,7 @@ namespace Cysharp.Threading.Tasks.Linq
{ {
this.cancellationToken = cancellationToken; this.cancellationToken = cancellationToken;
length = sources.Length; length = sources.Length;
states = ArrayPool<int>.Shared.Rent(length); states = ArrayPool<MergeSourceState>.Shared.Rent(length);
enumerators = ArrayPool<IUniTaskAsyncEnumerator<T>>.Shared.Rent(length); enumerators = ArrayPool<IUniTaskAsyncEnumerator<T>>.Shared.Rent(length);
for (var i = 0; i < length; i++) for (var i = 0; i < length; i++)
{ {
@ -112,18 +112,26 @@ namespace Cysharp.Threading.Tasks.Linq
for (var i = 0; i < length; i++) for (var i = 0; i < length; i++)
{ {
if (Interlocked.CompareExchange(ref states[i], (int)MergeSourceState.Running, (int)MergeSourceState.Pending) == (int)MergeSourceState.Pending) lock (queuedResult)
{ {
var awaiter = enumerators[i].MoveNextAsync().GetAwaiter(); if (states[i] == (int)MergeSourceState.Pending)
if (awaiter.IsCompleted)
{ {
GetResultAt(i, awaiter); states[i] = MergeSourceState.Running;
} }
else else
{ {
awaiter.SourceOnCompleted(GetResultAtAction, StateTuple.Create(this, i, awaiter)); continue;
} }
} }
var awaiter = enumerators[i].MoveNextAsync().GetAwaiter();
if (awaiter.IsCompleted)
{
GetResultAt(i, awaiter);
}
else
{
awaiter.SourceOnCompleted(GetResultAtAction, StateTuple.Create(this, i, awaiter));
}
} }
return new UniTask<bool>(this, completionSource.Version); return new UniTask<bool>(this, completionSource.Version);
} }
@ -135,7 +143,7 @@ namespace Cysharp.Threading.Tasks.Linq
await enumerators[i].DisposeAsync(); await enumerators[i].DisposeAsync();
} }
ArrayPool<int>.Shared.Return(states, true); ArrayPool<MergeSourceState>.Shared.Return(states, true);
ArrayPool<IUniTaskAsyncEnumerator<T>>.Shared.Return(enumerators, true); ArrayPool<IUniTaskAsyncEnumerator<T>>.Shared.Return(enumerators, true);
} }
@ -153,7 +161,6 @@ namespace Cysharp.Threading.Tasks.Linq
try try
{ {
hasNext = awaiter.GetResult(); hasNext = awaiter.GetResult();
Interlocked.Exchange(ref states[index], (int)(hasNext ? MergeSourceState.Pending : MergeSourceState.Completed));
} }
catch (Exception ex) catch (Exception ex)
{ {
@ -167,10 +174,11 @@ namespace Cysharp.Threading.Tasks.Linq
return; return;
} }
var completedAll = IsCompletedAll(); lock (queuedResult)
if (hasNext || completedAll)
{ {
lock (queuedResult) states[index] = hasNext ? MergeSourceState.Pending : MergeSourceState.Completed;
var completedAll = !hasNext && IsCompletedAll();
if (hasNext || completedAll)
{ {
if (completionSource.GetStatus(completionSource.Version).IsCompleted()) if (completionSource.GetStatus(completionSource.Version).IsCompleted())
{ {
@ -189,15 +197,12 @@ namespace Cysharp.Threading.Tasks.Linq
{ {
for (var i = 0; i < length; i++) for (var i = 0; i < length; i++)
{ {
if (states[i] != (int)MergeSourceState.Completed) if (states[i] != MergeSourceState.Completed)
{ {
return false; return false;
} }
} }
lock (queuedResult) return true;
{
return queuedResult.Count <= 0;
}
} }
} }
} }