diff --git a/src/UniTask.NetCoreSandbox/Program.cs b/src/UniTask.NetCoreSandbox/Program.cs index ed6e222..3f5a671 100644 --- a/src/UniTask.NetCoreSandbox/Program.cs +++ b/src/UniTask.NetCoreSandbox/Program.cs @@ -246,11 +246,25 @@ namespace NetCoreSandbox Console.WriteLine("FooBarAsync End"); } + static async UniTask WhereSelect() + { + await foreach (var item in UniTaskAsyncEnumerable.Range(1, 10) + .SelectAwait(async x => + { + await UniTask.Yield(); + return x; + }) + .Where(x => x % 2 == 0)) + { + Console.WriteLine(item); + } + } + static async Task Main(string[] args) { #if !DEBUG - + @@ -264,6 +278,7 @@ namespace NetCoreSandbox // await new AllocationCheck().ViaUniTaskVoid(); // AsyncTest().Forget(); + await WhereSelect(); SynchronizationContext.SetSynchronizationContext(new MySyncContext()); diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/AsyncEnumeratorBase.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/AsyncEnumeratorBase.cs index 9b3c19e..e7f9968 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/AsyncEnumeratorBase.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/AsyncEnumeratorBase.cs @@ -3,7 +3,9 @@ using System.Threading; namespace Cysharp.Threading.Tasks.Linq { - public abstract class AsyncEnumeratorBase : MoveNextSource, IUniTaskAsyncEnumerator + // note: refactor all inherit class and should remove this. + // see Select and Where. + internal abstract class AsyncEnumeratorBase : MoveNextSource, IUniTaskAsyncEnumerator { static readonly Action moveNextCallbackDelegate = MoveNextCallBack; @@ -129,7 +131,7 @@ namespace Cysharp.Threading.Tasks.Linq } } - public abstract class AsyncEnumeratorAwaitSelectorBase : MoveNextSource, IUniTaskAsyncEnumerator + internal abstract class AsyncEnumeratorAwaitSelectorBase : MoveNextSource, IUniTaskAsyncEnumerator { static readonly Action moveNextCallbackDelegate = MoveNextCallBack; static readonly Action setCurrentCallbackDelegate = SetCurrentCallBack; @@ -351,5 +353,4 @@ namespace Cysharp.Threading.Tasks.Linq return default; } } - } \ No newline at end of file diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Where.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Where.cs index 2f98860..8e9ec62 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Where.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Where.cs @@ -71,36 +71,101 @@ namespace Cysharp.Threading.Tasks.Linq return new _Where(source, predicate, cancellationToken); } - class _Where : AsyncEnumeratorBase + sealed class _Where : MoveNextSource, IUniTaskAsyncEnumerator { + readonly IUniTaskAsyncEnumerable source; readonly Func predicate; + readonly CancellationToken cancellationToken; + + int state = -1; + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; + Action moveNextAction; public _Where(IUniTaskAsyncEnumerable source, Func predicate, CancellationToken cancellationToken) - - : base(source, cancellationToken) { + this.source = source; this.predicate = predicate; + this.cancellationToken = cancellationToken; + this.moveNextAction = MoveNext; } - protected override bool TryMoveNextCore(bool sourceHasCurrent, out bool result) + public TSource Current { get; private set; } + + public UniTask MoveNextAsync() { - if (sourceHasCurrent) + if (state == -2) return default; + + completionSource.Reset(); + MoveNext(); + return new UniTask(this, completionSource.Version); + } + + void MoveNext() + { + REPEAT: + try { - if (predicate(SourceCurrent)) + switch (state) { - Current = SourceCurrent; - result = true; - return true; - } - else - { - result = default; - return false; + case -1: // init + enumerator = source.GetAsyncEnumerator(cancellationToken); + goto case 0; + case 0: + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + goto case 1; + } + else + { + state = 1; + awaiter.UnsafeOnCompleted(moveNextAction); + return; + } + case 1: + if (awaiter.GetResult()) + { + Current = enumerator.Current; + if (predicate(Current)) + { + goto CONTINUE; + } + else + { + state = 0; + goto REPEAT; + } + } + else + { + goto DONE; + } + default: + goto DONE; } } + catch (Exception ex) + { + state = -2; + completionSource.TrySetException(ex); + return; + } - result = false; - return true; + DONE: + state = -2; + completionSource.TrySetResult(false); + return; + + CONTINUE: + state = 0; + completionSource.TrySetResult(true); + return; + } + + public UniTask DisposeAsync() + { + return enumerator.DisposeAsync(); } } } @@ -118,40 +183,105 @@ namespace Cysharp.Threading.Tasks.Linq public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { - return new _WhereInt(source, predicate, cancellationToken); + return new _Where(source, predicate, cancellationToken); } - class _WhereInt : AsyncEnumeratorBase + sealed class _Where : MoveNextSource, IUniTaskAsyncEnumerator { + readonly IUniTaskAsyncEnumerable source; readonly Func predicate; + readonly CancellationToken cancellationToken; + + int state = -1; + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; + Action moveNextAction; int index; - public _WhereInt(IUniTaskAsyncEnumerable source, Func predicate, CancellationToken cancellationToken) - - : base(source, cancellationToken) + public _Where(IUniTaskAsyncEnumerable source, Func predicate, CancellationToken cancellationToken) { + this.source = source; this.predicate = predicate; + this.cancellationToken = cancellationToken; + this.moveNextAction = MoveNext; } - protected override bool TryMoveNextCore(bool sourceHasCurrent, out bool result) + public TSource Current { get; private set; } + + public UniTask MoveNextAsync() { - if (sourceHasCurrent) + if (state == -2) return default; + + completionSource.Reset(); + MoveNext(); + return new UniTask(this, completionSource.Version); + } + + void MoveNext() + { + REPEAT: + try { - if (predicate(SourceCurrent, checked(index++))) + switch (state) { - Current = SourceCurrent; - result = true; - return true; - } - else - { - result = default; - return false; + case -1: // init + enumerator = source.GetAsyncEnumerator(cancellationToken); + goto case 0; + case 0: + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + goto case 1; + } + else + { + state = 1; + awaiter.UnsafeOnCompleted(moveNextAction); + return; + } + case 1: + if (awaiter.GetResult()) + { + Current = enumerator.Current; + if (predicate(Current, checked(index++))) + { + goto CONTINUE; + } + else + { + state = 0; + goto REPEAT; + } + } + else + { + goto DONE; + } + default: + goto DONE; } } + catch (Exception ex) + { + state = -2; + completionSource.TrySetException(ex); + return; + } - result = false; - return true; + DONE: + state = -2; + completionSource.TrySetResult(false); + return; + + CONTINUE: + state = 0; + completionSource.TrySetResult(true); + return; + } + + public UniTask DisposeAsync() + { + return enumerator.DisposeAsync(); } } } @@ -172,34 +302,115 @@ namespace Cysharp.Threading.Tasks.Linq return new _WhereAwait(source, predicate, cancellationToken); } - class _WhereAwait : AsyncEnumeratorAwaitSelectorBase + sealed class _WhereAwait : MoveNextSource, IUniTaskAsyncEnumerator { + readonly IUniTaskAsyncEnumerable source; readonly Func> predicate; + readonly CancellationToken cancellationToken; + + int state = -1; + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; + UniTask.Awaiter awaiter2; + Action moveNextAction; public _WhereAwait(IUniTaskAsyncEnumerable source, Func> predicate, CancellationToken cancellationToken) - - : base(source, cancellationToken) { + this.source = source; this.predicate = predicate; + this.cancellationToken = cancellationToken; + this.moveNextAction = MoveNext; } - protected override UniTask TransformAsync(TSource sourceCurrent) + public TSource Current { get; private set; } + + public UniTask MoveNextAsync() { - return predicate(sourceCurrent); + if (state == -2) return default; + + completionSource.Reset(); + MoveNext(); + return new UniTask(this, completionSource.Version); } - protected override bool TrySetCurrentCore(bool awaitResult, out bool terminateIteration) + void MoveNext() { - terminateIteration = false; - if (awaitResult) + REPEAT: + try { - Current = SourceCurrent; - return true; + switch (state) + { + case -1: // init + enumerator = source.GetAsyncEnumerator(cancellationToken); + goto case 0; + case 0: + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + goto case 1; + } + else + { + state = 1; + awaiter.UnsafeOnCompleted(moveNextAction); + return; + } + case 1: + if (awaiter.GetResult()) + { + Current = enumerator.Current; + + awaiter2 = predicate(Current).GetAwaiter(); + if (awaiter2.IsCompleted) + { + goto case 2; + } + else + { + state = 2; + awaiter2.UnsafeOnCompleted(moveNextAction); + return; + } + } + else + { + goto DONE; + } + case 2: + if (awaiter2.GetResult()) + { + goto CONTINUE; + } + else + { + state = 0; + goto REPEAT; + } + default: + goto DONE; + } } - else + catch (Exception ex) { - return false; + state = -2; + completionSource.TrySetException(ex); + return; } + + DONE: + state = -2; + completionSource.TrySetResult(false); + return; + + CONTINUE: + state = 0; + completionSource.TrySetResult(true); + return; + } + + public UniTask DisposeAsync() + { + return enumerator.DisposeAsync(); } } } @@ -217,44 +428,123 @@ namespace Cysharp.Threading.Tasks.Linq public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { - return new _WhereIntAwait(source, predicate, cancellationToken); + return new _WhereAwait(source, predicate, cancellationToken); } - class _WhereIntAwait : AsyncEnumeratorAwaitSelectorBase + sealed class _WhereAwait : MoveNextSource, IUniTaskAsyncEnumerator { + readonly IUniTaskAsyncEnumerable source; readonly Func> predicate; + readonly CancellationToken cancellationToken; + + int state = -1; + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; + UniTask.Awaiter awaiter2; + Action moveNextAction; int index; - public _WhereIntAwait(IUniTaskAsyncEnumerable source, Func> predicate, CancellationToken cancellationToken) - - : base(source, cancellationToken) + public _WhereAwait(IUniTaskAsyncEnumerable source, Func> predicate, CancellationToken cancellationToken) { + this.source = source; this.predicate = predicate; + this.cancellationToken = cancellationToken; + this.moveNextAction = MoveNext; } - protected override UniTask TransformAsync(TSource sourceCurrent) + public TSource Current { get; private set; } + + public UniTask MoveNextAsync() { - return predicate(sourceCurrent, checked(index++)); + if (state == -2) return default; + + completionSource.Reset(); + MoveNext(); + return new UniTask(this, completionSource.Version); } - protected override bool TrySetCurrentCore(bool awaitResult, out bool terminateIteration) + void MoveNext() { - terminateIteration = false; - if (awaitResult) + REPEAT: + try { - Current = SourceCurrent; - return true; + switch (state) + { + case -1: // init + enumerator = source.GetAsyncEnumerator(cancellationToken); + goto case 0; + case 0: + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + goto case 1; + } + else + { + state = 1; + awaiter.UnsafeOnCompleted(moveNextAction); + return; + } + case 1: + if (awaiter.GetResult()) + { + Current = enumerator.Current; + + awaiter2 = predicate(Current, checked(index++)).GetAwaiter(); + if (awaiter2.IsCompleted) + { + goto case 2; + } + else + { + state = 2; + awaiter2.UnsafeOnCompleted(moveNextAction); + return; + } + } + else + { + goto DONE; + } + case 2: + if (awaiter2.GetResult()) + { + goto CONTINUE; + } + else + { + state = 0; + goto REPEAT; + } + default: + goto DONE; + } } - else + catch (Exception ex) { - return false; + state = -2; + completionSource.TrySetException(ex); + return; } + + DONE: + state = -2; + completionSource.TrySetResult(false); + return; + + CONTINUE: + state = 0; + completionSource.TrySetResult(true); + return; + } + + public UniTask DisposeAsync() + { + return enumerator.DisposeAsync(); } } } - - internal sealed class WhereAwaitWithCancellation : IUniTaskAsyncEnumerable { readonly IUniTaskAsyncEnumerable source; @@ -271,34 +561,115 @@ namespace Cysharp.Threading.Tasks.Linq return new _WhereAwaitWithCancellation(source, predicate, cancellationToken); } - class _WhereAwaitWithCancellation : AsyncEnumeratorAwaitSelectorBase + sealed class _WhereAwaitWithCancellation : MoveNextSource, IUniTaskAsyncEnumerator { + readonly IUniTaskAsyncEnumerable source; readonly Func> predicate; + readonly CancellationToken cancellationToken; + + int state = -1; + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; + UniTask.Awaiter awaiter2; + Action moveNextAction; public _WhereAwaitWithCancellation(IUniTaskAsyncEnumerable source, Func> predicate, CancellationToken cancellationToken) - - : base(source, cancellationToken) { + this.source = source; this.predicate = predicate; + this.cancellationToken = cancellationToken; + this.moveNextAction = MoveNext; } - protected override UniTask TransformAsync(TSource sourceCurrent) + public TSource Current { get; private set; } + + public UniTask MoveNextAsync() { - return predicate(sourceCurrent, cancellationToken); + if (state == -2) return default; + + completionSource.Reset(); + MoveNext(); + return new UniTask(this, completionSource.Version); } - protected override bool TrySetCurrentCore(bool awaitResult, out bool terminateIteration) + void MoveNext() { - terminateIteration = false; - if (awaitResult) + REPEAT: + try { - Current = SourceCurrent; - return true; + switch (state) + { + case -1: // init + enumerator = source.GetAsyncEnumerator(cancellationToken); + goto case 0; + case 0: + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + goto case 1; + } + else + { + state = 1; + awaiter.UnsafeOnCompleted(moveNextAction); + return; + } + case 1: + if (awaiter.GetResult()) + { + Current = enumerator.Current; + + awaiter2 = predicate(Current, cancellationToken).GetAwaiter(); + if (awaiter2.IsCompleted) + { + goto case 2; + } + else + { + state = 2; + awaiter2.UnsafeOnCompleted(moveNextAction); + return; + } + } + else + { + goto DONE; + } + case 2: + if (awaiter2.GetResult()) + { + goto CONTINUE; + } + else + { + state = 0; + goto REPEAT; + } + default: + goto DONE; + } } - else + catch (Exception ex) { - return false; + state = -2; + completionSource.TrySetException(ex); + return; } + + DONE: + state = -2; + completionSource.TrySetResult(false); + return; + + CONTINUE: + state = 0; + completionSource.TrySetResult(true); + return; + } + + public UniTask DisposeAsync() + { + return enumerator.DisposeAsync(); } } } @@ -316,40 +687,120 @@ namespace Cysharp.Threading.Tasks.Linq public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { - return new _WhereIntAwaitWithCancellation(source, predicate, cancellationToken); + return new _WhereAwaitWithCancellation(source, predicate, cancellationToken); } - class _WhereIntAwaitWithCancellation : AsyncEnumeratorAwaitSelectorBase + sealed class _WhereAwaitWithCancellation : MoveNextSource, IUniTaskAsyncEnumerator { + readonly IUniTaskAsyncEnumerable source; readonly Func> predicate; + readonly CancellationToken cancellationToken; + + int state = -1; + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; + UniTask.Awaiter awaiter2; + Action moveNextAction; int index; - public _WhereIntAwaitWithCancellation(IUniTaskAsyncEnumerable source, Func> predicate, CancellationToken cancellationToken) - - : base(source, cancellationToken) + public _WhereAwaitWithCancellation(IUniTaskAsyncEnumerable source, Func> predicate, CancellationToken cancellationToken) { + this.source = source; this.predicate = predicate; + this.cancellationToken = cancellationToken; + this.moveNextAction = MoveNext; } - protected override UniTask TransformAsync(TSource sourceCurrent) + public TSource Current { get; private set; } + + public UniTask MoveNextAsync() { - return predicate(sourceCurrent, checked(index++), cancellationToken); + if (state == -2) return default; + + completionSource.Reset(); + MoveNext(); + return new UniTask(this, completionSource.Version); } - protected override bool TrySetCurrentCore(bool awaitResult, out bool terminateIteration) + void MoveNext() { - terminateIteration = false; - if (awaitResult) + REPEAT: + try { - Current = SourceCurrent; - return true; + switch (state) + { + case -1: // init + enumerator = source.GetAsyncEnumerator(cancellationToken); + goto case 0; + case 0: + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + goto case 1; + } + else + { + state = 1; + awaiter.UnsafeOnCompleted(moveNextAction); + return; + } + case 1: + if (awaiter.GetResult()) + { + Current = enumerator.Current; + + awaiter2 = predicate(Current, checked(index++), cancellationToken).GetAwaiter(); + if (awaiter2.IsCompleted) + { + goto case 2; + } + else + { + state = 2; + awaiter2.UnsafeOnCompleted(moveNextAction); + return; + } + } + else + { + goto DONE; + } + case 2: + if (awaiter2.GetResult()) + { + goto CONTINUE; + } + else + { + state = 0; + goto REPEAT; + } + default: + goto DONE; + } } - else + catch (Exception ex) { - return false; + state = -2; + completionSource.TrySetException(ex); + return; } + + DONE: + state = -2; + completionSource.TrySetResult(false); + return; + + CONTINUE: + state = 0; + completionSource.TrySetResult(true); + return; + } + + public UniTask DisposeAsync() + { + return enumerator.DisposeAsync(); } } } - } \ No newline at end of file