Fix problem with part of await foreach not executing on break

master
hadashiA 2023-08-31 09:50:41 +09:00
parent d210e3d76a
commit b4486802f2
2 changed files with 35 additions and 1 deletions

View File

@ -159,6 +159,30 @@ namespace NetCoreTests.Linq
list.Should().Equal(100, 200, 300, 400); list.Should().Equal(100, 200, 300, 400);
} }
[Fact]
public async Task AwaitForeachBreak()
{
var finallyCalled = false;
var enumerable = UniTaskAsyncEnumerable.Create<int>(async (writer, _) =>
{
try
{
await writer.YieldAsync(1);
}
finally
{
finallyCalled = true;
}
});
await foreach (var x in enumerable)
{
x.Should().Be(1);
break;
}
finallyCalled.Should().BeTrue();
}
async IAsyncEnumerable<int> Range(int from, int count) async IAsyncEnumerable<int> Range(int from, int count)
{ {
for (int i = 0; i < count; i++) for (int i = 0; i < count; i++)

View File

@ -52,6 +52,7 @@ namespace Cysharp.Threading.Tasks.Linq
public UniTask DisposeAsync() public UniTask DisposeAsync()
{ {
TaskTracker.RemoveTracking(this); TaskTracker.RemoveTracking(this);
writer.Dispose();
return default; return default;
} }
@ -127,7 +128,7 @@ namespace Cysharp.Threading.Tasks.Linq
} }
} }
sealed class AsyncWriter : IUniTaskSource, IAsyncWriter<T> sealed class AsyncWriter : IUniTaskSource, IAsyncWriter<T>, IDisposable
{ {
readonly _Create enumerator; readonly _Create enumerator;
@ -138,6 +139,15 @@ namespace Cysharp.Threading.Tasks.Linq
this.enumerator = enumerator; this.enumerator = enumerator;
} }
public void Dispose()
{
var status = core.GetStatus(core.Version);
if (status == UniTaskStatus.Pending)
{
core.TrySetCanceled();
}
}
public void GetResult(short token) public void GetResult(short token)
{ {
core.GetResult(token); core.GetResult(token);