diff --git a/src/UniTask.NetCore/IAsyncEnumerable.cs b/src/UniTask.NetCore/IAsyncEnumerable.cs index 9c67692..dfc88e1 100644 --- a/src/UniTask.NetCore/IAsyncEnumerable.cs +++ b/src/UniTask.NetCore/IAsyncEnumerable.cs @@ -1,4 +1,5 @@ -using System.Threading; +using System.Runtime.InteropServices; +using System.Threading; namespace Cysharp.Threading.Tasks { @@ -17,4 +18,55 @@ namespace Cysharp.Threading.Tasks { UniTask DisposeAsync(); } + + public static class UniTaskAsyncEnumerableExtensions + { + public static UniTaskCancelableAsyncEnumerable WithCancellation(this IUniTaskAsyncEnumerable source, CancellationToken cancellationToken) + { + return new UniTaskCancelableAsyncEnumerable(source, cancellationToken); + } + } + + [StructLayout(LayoutKind.Auto)] + public readonly struct UniTaskCancelableAsyncEnumerable + { + private readonly IUniTaskAsyncEnumerable enumerable; + private readonly CancellationToken cancellationToken; + + internal UniTaskCancelableAsyncEnumerable(IUniTaskAsyncEnumerable enumerable, CancellationToken cancellationToken) + { + this.enumerable = enumerable; + this.cancellationToken = cancellationToken; + } + + public Enumerator GetAsyncEnumerator() + { + cancellationToken.ThrowIfCancellationRequested(); + return new Enumerator(enumerable.GetAsyncEnumerator(cancellationToken)); + } + + [StructLayout(LayoutKind.Auto)] + public readonly struct Enumerator + { + private readonly IUniTaskAsyncEnumerator enumerator; + + internal Enumerator(IUniTaskAsyncEnumerator enumerator) + { + this.enumerator = enumerator; + } + + public T Current => enumerator.Current; + + public UniTask MoveNextAsync() + { + return enumerator.MoveNextAsync(); + } + + + public UniTask DisposeAsync() + { + return enumerator.DisposeAsync(); + } + } + } } \ No newline at end of file diff --git a/src/UniTask.NetCore/Linq/Range.cs b/src/UniTask.NetCore/Linq/Range.cs index c48e2da..d07e7d6 100644 --- a/src/UniTask.NetCore/Linq/Range.cs +++ b/src/UniTask.NetCore/Linq/Range.cs @@ -55,7 +55,7 @@ namespace Cysharp.Threading.Tasks.Linq public UniTask MoveNextAsync() { - if (cancellationToken.IsCancellationRequested) return CompletedTasks.False; + cancellationToken.ThrowIfCancellationRequested(); current++; diff --git a/src/UniTask.NetCore/Linq/Repeat.cs b/src/UniTask.NetCore/Linq/Repeat.cs index a4f6415..42cdc10 100644 --- a/src/UniTask.NetCore/Linq/Repeat.cs +++ b/src/UniTask.NetCore/Linq/Repeat.cs @@ -50,7 +50,7 @@ namespace Cysharp.Threading.Tasks.Linq public UniTask MoveNextAsync() { - if (cancellationToken.IsCancellationRequested) return CompletedTasks.False; + cancellationToken.ThrowIfCancellationRequested(); if (remaining-- != 0) { diff --git a/src/UniTask.NetCoreSandbox/Program.cs b/src/UniTask.NetCoreSandbox/Program.cs index 3d797d3..2aa34f9 100644 --- a/src/UniTask.NetCoreSandbox/Program.cs +++ b/src/UniTask.NetCoreSandbox/Program.cs @@ -32,17 +32,24 @@ namespace NetCoreSandbox static async Task Main(string[] args) { + var cts = new CancellationTokenSource(); + await foreach (var item in UniTaskAsyncEnumerable.Range(1, 3).WithCancellation(cts.Token)) + { + Console.WriteLine(item); + cts.Cancel(); + } - await UniTaskAsyncEnumerable.Range(1, 3).ForEachAsync(x => + /* + .ForEachAsync(x => { if (x == 2) throw new Exception(); Console.WriteLine(x); }); - + */