From 6e99accf994e1a4ac67cc3f1059cc10ff5b6020f Mon Sep 17 00:00:00 2001 From: hadashiA Date: Fri, 8 Sep 2023 23:38:34 +0900 Subject: [PATCH] Fix race condition (todo: too wide lock range?) --- src/UniTask.NetCoreTests/Linq/Merge.cs | 60 +++++++------------ .../Plugins/UniTask/Runtime/Linq/Merge.cs | 60 ++++++++++--------- 2 files changed, 54 insertions(+), 66 deletions(-) diff --git a/src/UniTask.NetCoreTests/Linq/Merge.cs b/src/UniTask.NetCoreTests/Linq/Merge.cs index 049ae5a..e669580 100644 --- a/src/UniTask.NetCoreTests/Linq/Merge.cs +++ b/src/UniTask.NetCoreTests/Linq/Merge.cs @@ -13,77 +13,61 @@ namespace NetCoreTests.Linq [Fact] public async Task TwoSource() { - var semaphore = new SemaphoreSlim(1, 1); - var a = UniTaskAsyncEnumerable.Create(async (writer, _) => { await UniTask.SwitchToThreadPool(); - - await semaphore.WaitAsync(); + await writer.YieldAsync("A1"); - semaphore.Release(); - - await semaphore.WaitAsync(); + await Task.Delay(TimeSpan.FromMilliseconds(20)); await writer.YieldAsync("A2"); - semaphore.Release(); }); - + var b = UniTaskAsyncEnumerable.Create(async (writer, _) => { await UniTask.SwitchToThreadPool(); - - await semaphore.WaitAsync(); + + await Task.Delay(TimeSpan.FromMilliseconds(10)); await writer.YieldAsync("B1"); await writer.YieldAsync("B2"); - semaphore.Release(); - - await semaphore.WaitAsync(); + await Task.Delay(TimeSpan.FromMilliseconds(30)); await writer.YieldAsync("B3"); - semaphore.Release(); }); var result = await a.Merge(b).ToArrayAsync(); result.Should().Equal("A1", "B1", "B2", "A2", "B3"); } - + [Fact] public async Task ThreeSource() { - var semaphore = new SemaphoreSlim(0, 1); - var a = UniTaskAsyncEnumerable.Create(async (writer, _) => { await UniTask.SwitchToThreadPool(); - - await semaphore.WaitAsync(); + + await Task.Delay(TimeSpan.FromMilliseconds(10)); await writer.YieldAsync("A1"); - semaphore.Release(); - - await semaphore.WaitAsync(); + + await Task.Delay(TimeSpan.FromMilliseconds(30)); await writer.YieldAsync("A2"); - semaphore.Release(); }); - + var b = UniTaskAsyncEnumerable.Create(async (writer, _) => { await UniTask.SwitchToThreadPool(); - - await semaphore.WaitAsync(); + + await Task.Delay(TimeSpan.FromMilliseconds(20)); await writer.YieldAsync("B1"); await writer.YieldAsync("B2"); - semaphore.Release(); - - await semaphore.WaitAsync(); + + await Task.Delay(TimeSpan.FromMilliseconds(40)); await writer.YieldAsync("B3"); - semaphore.Release(); }); - + var c = UniTaskAsyncEnumerable.Create(async (writer, _) => { await UniTask.SwitchToThreadPool(); - + await writer.YieldAsync("C1"); - semaphore.Release(); }); var result = await a.Merge(b, c).ToArrayAsync(); @@ -107,15 +91,15 @@ namespace NetCoreTests.Linq var enumerator = a.Merge(b).GetAsyncEnumerator(); (await enumerator.MoveNextAsync()).Should().Be(true); enumerator.Current.Should().Be("A1"); - + await Assert.ThrowsAsync(async () => await enumerator.MoveNextAsync()); } - + [Fact] public async Task Cancel() { var cts = new CancellationTokenSource(); - + var a = UniTaskAsyncEnumerable.Create(async (writer, _) => { await writer.YieldAsync("A1"); @@ -129,7 +113,7 @@ namespace NetCoreTests.Linq var enumerator = a.Merge(b).GetAsyncEnumerator(cts.Token); (await enumerator.MoveNextAsync()).Should().Be(true); enumerator.Current.Should().Be("A1"); - + cts.Cancel(); await Assert.ThrowsAsync(async () => await enumerator.MoveNextAsync()); } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs index 5bc7649..f8a5fb0 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs @@ -15,7 +15,7 @@ namespace Cysharp.Threading.Tasks.Linq return new Merge(new [] { first, second }); } - + public static IUniTaskAsyncEnumerable Merge(this IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second, IUniTaskAsyncEnumerable third) { Error.ThrowArgumentNullException(first, nameof(first)); @@ -24,7 +24,7 @@ namespace Cysharp.Threading.Tasks.Linq return new Merge(new[] { first, second, third }); } - + public static IUniTaskAsyncEnumerable Merge(this IEnumerable> sources) { return new Merge(sources.ToArray()); @@ -35,11 +35,11 @@ namespace Cysharp.Threading.Tasks.Linq return new Merge(sources); } } - + internal sealed class Merge : IUniTaskAsyncEnumerable { readonly IUniTaskAsyncEnumerable[] sources; - + public Merge(IUniTaskAsyncEnumerable[] sources) { if (sources.Length <= 0) @@ -49,7 +49,7 @@ namespace Cysharp.Threading.Tasks.Linq this.sources = sources; } - public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) => new _Merge(sources, cancellationToken); enum MergeSourceState @@ -82,27 +82,30 @@ namespace Cysharp.Threading.Tasks.Linq enumerators[i] = sources[i].GetAsyncEnumerator(cancellationToken); states[i] = MergeSourceState.Pending; } - } + } public UniTask MoveNextAsync() { cancellationToken.ThrowIfCancellationRequested(); completionSource.Reset(); - if (TryDequeue(out var queuedValue, out var queuedException)) + lock (states) { - if (queuedException != null) + if (TryDequeue(out var queuedValue, out var queuedException)) { - completionSource.TrySetException(queuedException); + if (queuedException != null) + { + completionSource.TrySetException(queuedException); + } + else + { + Current = queuedValue; + completionSource.TrySetResult(!IsCompletedAll()); + } + return new UniTask(this, completionSource.Version); } - else - { - Current = queuedValue; - completionSource.TrySetResult(!IsCompletedAll()); - } - return new UniTask(this, completionSource.Version); } - + for (var i = 0; i < length; i++) { lock (states) @@ -113,7 +116,7 @@ namespace Cysharp.Threading.Tasks.Linq } states[i] = MergeSourceState.Running; } - + var awaiter = enumerators[i].MoveNextAsync().GetAwaiter(); if (awaiter.IsCompleted) { @@ -159,7 +162,8 @@ namespace Cysharp.Threading.Tasks.Linq { if (!completionSource.TrySetException(ex)) { - lock (resultQueue) + // + lock (states) { resultQueue.Enqueue((default, ex)); } @@ -167,27 +171,27 @@ namespace Cysharp.Threading.Tasks.Linq return; } - var completed = IsCompletedAll(); - if (hasNext || completed) + var completedAll = IsCompletedAll(); + if (hasNext || completedAll) { - if (completionSource.GetStatus(completionSource.Version).IsCompleted()) + lock (states) { - lock (resultQueue) + if (completionSource.GetStatus(completionSource.Version).IsCompleted()) { resultQueue.Enqueue((enumerators[index].Current, null)); } - } - else - { - Current = enumerators[index].Current; - completionSource.TrySetResult(!completed); + else + { + Current = enumerators[index].Current; + completionSource.TrySetResult(!completedAll); + } } } } bool TryDequeue(out T value, out Exception ex) { - lock (resultQueue) + lock (states) { if (resultQueue.Count > 0) {