diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Select.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Select.cs index 97167f1..3a646d2 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Select.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Select.cs @@ -71,29 +71,92 @@ namespace Cysharp.Threading.Tasks.Linq return new _Select(source, selector, cancellationToken); } - sealed class _Select : AsyncEnumeratorBase + sealed class _Select : MoveNextSource, IUniTaskAsyncEnumerator { + readonly IUniTaskAsyncEnumerable source; readonly Func selector; + readonly CancellationToken cancellationToken; + + int state = -1; + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; + Action moveNextAction; public _Select(IUniTaskAsyncEnumerable source, Func selector, CancellationToken cancellationToken) - : base(source, cancellationToken) { + this.source = source; this.selector = selector; + this.cancellationToken = cancellationToken; + this.moveNextAction = MoveNext; } - protected override bool TryMoveNextCore(bool sourceHasCurrent, out bool result) + public TResult Current { get; private set; } + + public UniTask MoveNextAsync() { - if (sourceHasCurrent) + if (state == -2) return default; + + completionSource.Reset(); + MoveNext(); + return new UniTask(this, completionSource.Version); + } + + void MoveNext() + { + try { - Current = selector(SourceCurrent); - result = true; - return true; + switch (state) + { + case -1: // init + enumerator = source.GetAsyncEnumerator(cancellationToken); + goto case 0; + case 0: + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + goto case 1; + } + else + { + state = 1; + awaiter.UnsafeOnCompleted(moveNextAction); + return; + } + case 1: + if (awaiter.GetResult()) + { + Current = selector(enumerator.Current); + goto CONTINUE; + } + else + { + goto DONE; + } + default: + goto DONE; + } } - else + catch (Exception ex) { - result = false; - return true; + state = -2; + completionSource.TrySetException(ex); + return; } + + DONE: + state = -2; + completionSource.TrySetResult(false); + return; + + CONTINUE: + state = 0; + completionSource.TrySetResult(true); + return; + } + + public UniTask DisposeAsync() + { + return enumerator.DisposeAsync(); } } } @@ -111,33 +174,96 @@ namespace Cysharp.Threading.Tasks.Linq public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { - return new _SelectInt(source, selector, cancellationToken); + return new _Select(source, selector, cancellationToken); } - sealed class _SelectInt : AsyncEnumeratorBase + sealed class _Select : MoveNextSource, IUniTaskAsyncEnumerator { + readonly IUniTaskAsyncEnumerable source; readonly Func selector; + readonly CancellationToken cancellationToken; + + int state = -1; + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; + Action moveNextAction; int index; - public _SelectInt(IUniTaskAsyncEnumerable source, Func selector, CancellationToken cancellationToken) - : base(source, cancellationToken) + public _Select(IUniTaskAsyncEnumerable source, Func selector, CancellationToken cancellationToken) { + this.source = source; this.selector = selector; + this.cancellationToken = cancellationToken; + this.moveNextAction = MoveNext; } - protected override bool TryMoveNextCore(bool sourceHasCurrent, out bool result) + public TResult Current { get; private set; } + + public UniTask MoveNextAsync() { - if (sourceHasCurrent) + if (state == -2) return default; + + completionSource.Reset(); + MoveNext(); + return new UniTask(this, completionSource.Version); + } + + void MoveNext() + { + try { - Current = selector(SourceCurrent, checked(index++)); - result = true; - return true; + switch (state) + { + case -1: // init + enumerator = source.GetAsyncEnumerator(cancellationToken); + goto case 0; + case 0: + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + goto case 1; + } + else + { + state = 1; + awaiter.UnsafeOnCompleted(moveNextAction); + return; + } + case 1: + if (awaiter.GetResult()) + { + Current = selector(enumerator.Current, checked(index++)); + goto CONTINUE; + } + else + { + goto DONE; + } + default: + goto DONE; + } } - else + catch (Exception ex) { - result = false; - return true; + state = -2; + completionSource.TrySetException(ex); + return; } + + DONE: + state = -2; + completionSource.TrySetResult(false); + return; + + CONTINUE: + state = 0; + completionSource.TrySetResult(true); + return; + } + + public UniTask DisposeAsync() + { + return enumerator.DisposeAsync(); } } } @@ -158,26 +284,105 @@ namespace Cysharp.Threading.Tasks.Linq return new _SelectAwait(source, selector, cancellationToken); } - sealed class _SelectAwait : AsyncEnumeratorAwaitSelectorBase + sealed class _SelectAwait : MoveNextSource, IUniTaskAsyncEnumerator { + readonly IUniTaskAsyncEnumerable source; readonly Func> selector; + readonly CancellationToken cancellationToken; + + int state = -1; + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; + UniTask.Awaiter awaiter2; + Action moveNextAction; public _SelectAwait(IUniTaskAsyncEnumerable source, Func> selector, CancellationToken cancellationToken) - : base(source, cancellationToken) { + this.source = source; this.selector = selector; + this.cancellationToken = cancellationToken; + this.moveNextAction = MoveNext; } - protected override UniTask TransformAsync(TSource sourceCurrent) + public TResult Current { get; private set; } + + public UniTask MoveNextAsync() { - return selector(sourceCurrent); + if (state == -2) return default; + + completionSource.Reset(); + MoveNext(); + return new UniTask(this, completionSource.Version); } - protected override bool TrySetCurrentCore(TResult awaitResult, out bool terminateIteration) + void MoveNext() { - Current = awaitResult; - terminateIteration = false; - return true; + try + { + switch (state) + { + case -1: // init + enumerator = source.GetAsyncEnumerator(cancellationToken); + goto case 0; + case 0: + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + goto case 1; + } + else + { + state = 1; + awaiter.UnsafeOnCompleted(moveNextAction); + return; + } + case 1: + if (awaiter.GetResult()) + { + awaiter2 = selector(enumerator.Current).GetAwaiter(); + if (awaiter2.IsCompleted) + { + goto case 2; + } + else + { + state = 2; + awaiter2.UnsafeOnCompleted(moveNextAction); + return; + } + } + else + { + goto DONE; + } + case 2: + Current = awaiter2.GetResult(); + goto CONTINUE; + default: + goto DONE; + } + } + catch (Exception ex) + { + state = -2; + completionSource.TrySetException(ex); + return; + } + + DONE: + state = -2; + completionSource.TrySetResult(false); + return; + + CONTINUE: + state = 0; + completionSource.TrySetResult(true); + return; + } + + public UniTask DisposeAsync() + { + return enumerator.DisposeAsync(); } } } @@ -195,30 +400,109 @@ namespace Cysharp.Threading.Tasks.Linq public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { - return new _SelectIntAwait(source, selector, cancellationToken); + return new _SelectAwait(source, selector, cancellationToken); } - sealed class _SelectIntAwait : AsyncEnumeratorAwaitSelectorBase + sealed class _SelectAwait : MoveNextSource, IUniTaskAsyncEnumerator { + readonly IUniTaskAsyncEnumerable source; readonly Func> selector; + readonly CancellationToken cancellationToken; + + int state = -1; + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; + UniTask.Awaiter awaiter2; + Action moveNextAction; int index; - public _SelectIntAwait(IUniTaskAsyncEnumerable source, Func> selector, CancellationToken cancellationToken) - : base(source, cancellationToken) + public _SelectAwait(IUniTaskAsyncEnumerable source, Func> selector, CancellationToken cancellationToken) { + this.source = source; this.selector = selector; + this.cancellationToken = cancellationToken; + this.moveNextAction = MoveNext; } - protected override UniTask TransformAsync(TSource sourceCurrent) + public TResult Current { get; private set; } + + public UniTask MoveNextAsync() { - return selector(sourceCurrent, checked(index++)); + if (state == -2) return default; + + completionSource.Reset(); + MoveNext(); + return new UniTask(this, completionSource.Version); } - protected override bool TrySetCurrentCore(TResult awaitResult, out bool terminateIteration) + void MoveNext() { - Current = awaitResult; - terminateIteration = false; - return true; + try + { + switch (state) + { + case -1: // init + enumerator = source.GetAsyncEnumerator(cancellationToken); + goto case 0; + case 0: + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + goto case 1; + } + else + { + state = 1; + awaiter.UnsafeOnCompleted(moveNextAction); + return; + } + case 1: + if (awaiter.GetResult()) + { + awaiter2 = selector(enumerator.Current, checked(index++)).GetAwaiter(); + if (awaiter2.IsCompleted) + { + goto case 2; + } + else + { + state = 2; + awaiter2.UnsafeOnCompleted(moveNextAction); + return; + } + } + else + { + goto DONE; + } + case 2: + Current = awaiter2.GetResult(); + goto CONTINUE; + default: + goto DONE; + } + } + catch (Exception ex) + { + state = -2; + completionSource.TrySetException(ex); + return; + } + + DONE: + state = -2; + completionSource.TrySetResult(false); + return; + + CONTINUE: + state = 0; + completionSource.TrySetResult(true); + return; + } + + public UniTask DisposeAsync() + { + return enumerator.DisposeAsync(); } } } @@ -239,26 +523,105 @@ namespace Cysharp.Threading.Tasks.Linq return new _SelectAwaitWithCancellation(source, selector, cancellationToken); } - sealed class _SelectAwaitWithCancellation : AsyncEnumeratorAwaitSelectorBase + sealed class _SelectAwaitWithCancellation : MoveNextSource, IUniTaskAsyncEnumerator { + readonly IUniTaskAsyncEnumerable source; readonly Func> selector; + readonly CancellationToken cancellationToken; + + int state = -1; + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; + UniTask.Awaiter awaiter2; + Action moveNextAction; public _SelectAwaitWithCancellation(IUniTaskAsyncEnumerable source, Func> selector, CancellationToken cancellationToken) - : base(source, cancellationToken) { + this.source = source; this.selector = selector; + this.cancellationToken = cancellationToken; + this.moveNextAction = MoveNext; } - protected override UniTask TransformAsync(TSource sourceCurrent) + public TResult Current { get; private set; } + + public UniTask MoveNextAsync() { - return selector(sourceCurrent, cancellationToken); + if (state == -2) return default; + + completionSource.Reset(); + MoveNext(); + return new UniTask(this, completionSource.Version); } - protected override bool TrySetCurrentCore(TResult awaitResult, out bool terminateIteration) + void MoveNext() { - Current = awaitResult; - terminateIteration = false; - return true; + try + { + switch (state) + { + case -1: // init + enumerator = source.GetAsyncEnumerator(cancellationToken); + goto case 0; + case 0: + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + goto case 1; + } + else + { + state = 1; + awaiter.UnsafeOnCompleted(moveNextAction); + return; + } + case 1: + if (awaiter.GetResult()) + { + awaiter2 = selector(enumerator.Current, cancellationToken).GetAwaiter(); + if (awaiter2.IsCompleted) + { + goto case 2; + } + else + { + state = 2; + awaiter2.UnsafeOnCompleted(moveNextAction); + return; + } + } + else + { + goto DONE; + } + case 2: + Current = awaiter2.GetResult(); + goto CONTINUE; + default: + goto DONE; + } + } + catch (Exception ex) + { + state = -2; + completionSource.TrySetException(ex); + return; + } + + DONE: + state = -2; + completionSource.TrySetResult(false); + return; + + CONTINUE: + state = 0; + completionSource.TrySetResult(true); + return; + } + + public UniTask DisposeAsync() + { + return enumerator.DisposeAsync(); } } } @@ -276,32 +639,110 @@ namespace Cysharp.Threading.Tasks.Linq public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { - return new _SelectIntAwaitWithCancellation(source, selector, cancellationToken); + return new _SelectAwaitWithCancellation(source, selector, cancellationToken); } - sealed class _SelectIntAwaitWithCancellation : AsyncEnumeratorAwaitSelectorBase + sealed class _SelectAwaitWithCancellation : MoveNextSource, IUniTaskAsyncEnumerator { + readonly IUniTaskAsyncEnumerable source; readonly Func> selector; + readonly CancellationToken cancellationToken; + + int state = -1; + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; + UniTask.Awaiter awaiter2; + Action moveNextAction; int index; - public _SelectIntAwaitWithCancellation(IUniTaskAsyncEnumerable source, Func> selector, CancellationToken cancellationToken) - : base(source, cancellationToken) + public _SelectAwaitWithCancellation(IUniTaskAsyncEnumerable source, Func> selector, CancellationToken cancellationToken) { + this.source = source; this.selector = selector; + this.cancellationToken = cancellationToken; + this.moveNextAction = MoveNext; } - protected override UniTask TransformAsync(TSource sourceCurrent) + public TResult Current { get; private set; } + + public UniTask MoveNextAsync() { - return selector(sourceCurrent, checked(index++), cancellationToken); + if (state == -2) return default; + + completionSource.Reset(); + MoveNext(); + return new UniTask(this, completionSource.Version); } - protected override bool TrySetCurrentCore(TResult awaitResult, out bool terminateIteration) + void MoveNext() { - Current = awaitResult; - terminateIteration = false; - return true; + try + { + switch (state) + { + case -1: // init + enumerator = source.GetAsyncEnumerator(cancellationToken); + goto case 0; + case 0: + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + goto case 1; + } + else + { + state = 1; + awaiter.UnsafeOnCompleted(moveNextAction); + return; + } + case 1: + if (awaiter.GetResult()) + { + awaiter2 = selector(enumerator.Current, checked(index++), cancellationToken).GetAwaiter(); + if (awaiter2.IsCompleted) + { + goto case 2; + } + else + { + state = 2; + awaiter2.UnsafeOnCompleted(moveNextAction); + return; + } + } + else + { + goto DONE; + } + case 2: + Current = awaiter2.GetResult(); + goto CONTINUE; + default: + goto DONE; + } + } + catch (Exception ex) + { + state = -2; + completionSource.TrySetException(ex); + return; + } + + DONE: + state = -2; + completionSource.TrySetResult(false); + return; + + CONTINUE: + state = 0; + completionSource.TrySetResult(true); + return; + } + + public UniTask DisposeAsync() + { + return enumerator.DisposeAsync(); } } } - } \ No newline at end of file diff --git a/src/UniTask/Assets/Scenes/SandboxMain.cs b/src/UniTask/Assets/Scenes/SandboxMain.cs index 09983dd..0c754b7 100644 --- a/src/UniTask/Assets/Scenes/SandboxMain.cs +++ b/src/UniTask/Assets/Scenes/SandboxMain.cs @@ -326,6 +326,10 @@ public class SandboxMain : MonoBehaviour + await UniTaskAsyncEnumerable.EveryUpdate().Select((x, _) => x).ForEachAsync(x => + { + Debug.Log("test"); + });