Fix race condition (todo: too wide lock range?)

master
hadashiA 2023-09-08 23:38:34 +09:00
parent b195df9773
commit 6e99accf99
2 changed files with 54 additions and 66 deletions

View File

@ -13,77 +13,61 @@ namespace NetCoreTests.Linq
[Fact]
public async Task TwoSource()
{
var semaphore = new SemaphoreSlim(1, 1);
var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{
await UniTask.SwitchToThreadPool();
await semaphore.WaitAsync();
await writer.YieldAsync("A1");
semaphore.Release();
await semaphore.WaitAsync();
await Task.Delay(TimeSpan.FromMilliseconds(20));
await writer.YieldAsync("A2");
semaphore.Release();
});
var b = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{
await UniTask.SwitchToThreadPool();
await semaphore.WaitAsync();
await Task.Delay(TimeSpan.FromMilliseconds(10));
await writer.YieldAsync("B1");
await writer.YieldAsync("B2");
semaphore.Release();
await semaphore.WaitAsync();
await Task.Delay(TimeSpan.FromMilliseconds(30));
await writer.YieldAsync("B3");
semaphore.Release();
});
var result = await a.Merge(b).ToArrayAsync();
result.Should().Equal("A1", "B1", "B2", "A2", "B3");
}
[Fact]
public async Task ThreeSource()
{
var semaphore = new SemaphoreSlim(0, 1);
var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{
await UniTask.SwitchToThreadPool();
await semaphore.WaitAsync();
await Task.Delay(TimeSpan.FromMilliseconds(10));
await writer.YieldAsync("A1");
semaphore.Release();
await semaphore.WaitAsync();
await Task.Delay(TimeSpan.FromMilliseconds(30));
await writer.YieldAsync("A2");
semaphore.Release();
});
var b = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{
await UniTask.SwitchToThreadPool();
await semaphore.WaitAsync();
await Task.Delay(TimeSpan.FromMilliseconds(20));
await writer.YieldAsync("B1");
await writer.YieldAsync("B2");
semaphore.Release();
await semaphore.WaitAsync();
await Task.Delay(TimeSpan.FromMilliseconds(40));
await writer.YieldAsync("B3");
semaphore.Release();
});
var c = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{
await UniTask.SwitchToThreadPool();
await writer.YieldAsync("C1");
semaphore.Release();
});
var result = await a.Merge(b, c).ToArrayAsync();
@ -107,15 +91,15 @@ namespace NetCoreTests.Linq
var enumerator = a.Merge(b).GetAsyncEnumerator();
(await enumerator.MoveNextAsync()).Should().Be(true);
enumerator.Current.Should().Be("A1");
await Assert.ThrowsAsync<UniTaskTestException>(async () => await enumerator.MoveNextAsync());
}
[Fact]
public async Task Cancel()
{
var cts = new CancellationTokenSource();
var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{
await writer.YieldAsync("A1");
@ -129,7 +113,7 @@ namespace NetCoreTests.Linq
var enumerator = a.Merge(b).GetAsyncEnumerator(cts.Token);
(await enumerator.MoveNextAsync()).Should().Be(true);
enumerator.Current.Should().Be("A1");
cts.Cancel();
await Assert.ThrowsAsync<OperationCanceledException>(async () => await enumerator.MoveNextAsync());
}

View File

@ -15,7 +15,7 @@ namespace Cysharp.Threading.Tasks.Linq
return new Merge<T>(new [] { first, second });
}
public static IUniTaskAsyncEnumerable<T> Merge<T>(this IUniTaskAsyncEnumerable<T> first, IUniTaskAsyncEnumerable<T> second, IUniTaskAsyncEnumerable<T> third)
{
Error.ThrowArgumentNullException(first, nameof(first));
@ -24,7 +24,7 @@ namespace Cysharp.Threading.Tasks.Linq
return new Merge<T>(new[] { first, second, third });
}
public static IUniTaskAsyncEnumerable<T> Merge<T>(this IEnumerable<IUniTaskAsyncEnumerable<T>> sources)
{
return new Merge<T>(sources.ToArray());
@ -35,11 +35,11 @@ namespace Cysharp.Threading.Tasks.Linq
return new Merge<T>(sources);
}
}
internal sealed class Merge<T> : IUniTaskAsyncEnumerable<T>
{
readonly IUniTaskAsyncEnumerable<T>[] sources;
public Merge(IUniTaskAsyncEnumerable<T>[] sources)
{
if (sources.Length <= 0)
@ -49,7 +49,7 @@ namespace Cysharp.Threading.Tasks.Linq
this.sources = sources;
}
public IUniTaskAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
public IUniTaskAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
=> new _Merge(sources, cancellationToken);
enum MergeSourceState
@ -82,27 +82,30 @@ namespace Cysharp.Threading.Tasks.Linq
enumerators[i] = sources[i].GetAsyncEnumerator(cancellationToken);
states[i] = MergeSourceState.Pending;
}
}
}
public UniTask<bool> MoveNextAsync()
{
cancellationToken.ThrowIfCancellationRequested();
completionSource.Reset();
if (TryDequeue(out var queuedValue, out var queuedException))
lock (states)
{
if (queuedException != null)
if (TryDequeue(out var queuedValue, out var queuedException))
{
completionSource.TrySetException(queuedException);
if (queuedException != null)
{
completionSource.TrySetException(queuedException);
}
else
{
Current = queuedValue;
completionSource.TrySetResult(!IsCompletedAll());
}
return new UniTask<bool>(this, completionSource.Version);
}
else
{
Current = queuedValue;
completionSource.TrySetResult(!IsCompletedAll());
}
return new UniTask<bool>(this, completionSource.Version);
}
for (var i = 0; i < length; i++)
{
lock (states)
@ -113,7 +116,7 @@ namespace Cysharp.Threading.Tasks.Linq
}
states[i] = MergeSourceState.Running;
}
var awaiter = enumerators[i].MoveNextAsync().GetAwaiter();
if (awaiter.IsCompleted)
{
@ -159,7 +162,8 @@ namespace Cysharp.Threading.Tasks.Linq
{
if (!completionSource.TrySetException(ex))
{
lock (resultQueue)
//
lock (states)
{
resultQueue.Enqueue((default, ex));
}
@ -167,27 +171,27 @@ namespace Cysharp.Threading.Tasks.Linq
return;
}
var completed = IsCompletedAll();
if (hasNext || completed)
var completedAll = IsCompletedAll();
if (hasNext || completedAll)
{
if (completionSource.GetStatus(completionSource.Version).IsCompleted())
lock (states)
{
lock (resultQueue)
if (completionSource.GetStatus(completionSource.Version).IsCompleted())
{
resultQueue.Enqueue((enumerators[index].Current, null));
}
}
else
{
Current = enumerators[index].Current;
completionSource.TrySetResult(!completed);
else
{
Current = enumerators[index].Current;
completionSource.TrySetResult(!completedAll);
}
}
}
}
bool TryDequeue(out T value, out Exception ex)
{
lock (resultQueue)
lock (states)
{
if (resultQueue.Count > 0)
{