diff --git a/src/UniTask/Assets/Plugins/UniTask/Triggers/AsyncTriggerBase.cs b/src/UniTask/Assets/Plugins/UniTask/Triggers/AsyncTriggerBase.cs index 8568d03..926bebc 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Triggers/AsyncTriggerBase.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Triggers/AsyncTriggerBase.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Threading; using Cysharp.Threading.Tasks.Internal; +using Cysharp.Threading.Tasks.Linq; using UnityEngine; namespace Cysharp.Threading.Tasks.Triggers @@ -87,7 +88,7 @@ namespace Cysharp.Threading.Tasks.Triggers } } - public sealed partial class AsyncTriggerHandler : IUniTaskSource, IResolvePromise, ICancelPromise, IDisposable + public sealed partial class AsyncTriggerHandler : IUniTaskSource, IResolveCancelPromise, IDisposable { static Action cancellationCallback = CancellationCallback; @@ -207,17 +208,101 @@ namespace Cysharp.Threading.Tasks.Triggers } } - public sealed class TriggerEvent : IResolvePromise, ICancelPromise + public sealed class TriggerAsyncEnumerable : IUniTaskAsyncEnumerable + { + readonly TriggerEvent triggerEvent; + + public TriggerAsyncEnumerable(TriggerEvent triggerEvent) + { + this.triggerEvent = triggerEvent; + } + + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new Enumerator(triggerEvent, cancellationToken); + } + + sealed class Enumerator : MoveNextSource, IUniTaskAsyncEnumerator, IResolveCancelPromise + { + static Action cancellationCallback = CancellationCallback; + + readonly TriggerEvent triggerEvent; + CancellationToken cancellationToken; + CancellationTokenRegistration registration; + bool called; + bool isDisposed; + + public Enumerator(TriggerEvent triggerEvent, CancellationToken cancellationToken) + { + this.triggerEvent = triggerEvent; + this.cancellationToken = cancellationToken; + } + + public bool TrySetCanceled(CancellationToken cancellationToken = default) + { + return completionSource.TrySetCanceled(cancellationToken); + } + + public bool TrySetResult(T value) + { + Current = value; + return completionSource.TrySetResult(true); + } + + static void CancellationCallback(object state) + { + var self = (Enumerator)state; + self.DisposeAsync().Forget(); // sync + + self.completionSource.TrySetCanceled(self.cancellationToken); + } + + public T Current { get; private set; } + + public UniTask MoveNextAsync() + { + cancellationToken.ThrowIfCancellationRequested(); + + if (!called) + { + TaskTracker.TrackActiveTask(this, 3); + triggerEvent.Add(this); + if (cancellationToken.CanBeCanceled) + { + registration = cancellationToken.RegisterWithoutCaptureExecutionContext(cancellationCallback, this); + } + } + + completionSource.Reset(); + return new UniTask(this, completionSource.Version); + } + + public UniTask DisposeAsync() + { + if (!isDisposed) + { + isDisposed = true; + TaskTracker.RemoveTracking(this); + registration.Dispose(); + triggerEvent.Remove(this); + } + + return default; + } + } + } + + public sealed class TriggerEvent : IResolveCancelPromise { // optimize: many cases, handler is single. - AsyncTriggerHandler singleHandler; + IResolveCancelPromise singleHandler; - AsyncTriggerHandler[] handlers; + IResolveCancelPromise[] handlers; // when running(in TrySetResult), does not add immediately. bool isRunning; - AsyncTriggerHandler waitHandler; - MinimumQueue> waitQueue; + IResolveCancelPromise waitHandler; + MinimumQueue> waitQueue; public bool TrySetResult(T value) { @@ -227,7 +312,7 @@ namespace Cysharp.Threading.Tasks.Triggers { try { - ((IResolvePromise)singleHandler).TrySetResult(value); + singleHandler.TrySetResult(value); } catch (Exception ex) { @@ -243,7 +328,7 @@ namespace Cysharp.Threading.Tasks.Triggers { try { - ((IResolvePromise)handlers[i]).TrySetResult(value); + handlers[i].TrySetResult(value); } catch (Exception ex) { @@ -329,7 +414,7 @@ namespace Cysharp.Threading.Tasks.Triggers return true; } - public void Add(AsyncTriggerHandler handler) + public void Add(IResolveCancelPromise handler) { if (isRunning) { @@ -341,7 +426,7 @@ namespace Cysharp.Threading.Tasks.Triggers if (waitQueue == null) { - waitQueue = new MinimumQueue>(4); + waitQueue = new MinimumQueue>(4); } waitQueue.Enqueue(handler); return; @@ -355,7 +440,7 @@ namespace Cysharp.Threading.Tasks.Triggers { if (handlers == null) { - handlers = new AsyncTriggerHandler[4]; + handlers = new IResolveCancelPromise[4]; } // check empty @@ -377,15 +462,15 @@ namespace Cysharp.Threading.Tasks.Triggers } } - static void EnsureCapacity(ref AsyncTriggerHandler[] array) + static void EnsureCapacity(ref IResolveCancelPromise[] array) { var newSize = array.Length * 2; - var newArray = new AsyncTriggerHandler[newSize]; + var newArray = new IResolveCancelPromise[newSize]; Array.Copy(array, 0, newArray, 0, array.Length); array = newArray; } - public void Remove(AsyncTriggerHandler handler) + public void Remove(IResolveCancelPromise handler) { if (singleHandler == handler) { diff --git a/src/UniTask/Assets/Plugins/UniTask/UniTaskCompletionSource.cs b/src/UniTask/Assets/Plugins/UniTask/UniTaskCompletionSource.cs index cba8dfa..cefc851 100644 --- a/src/UniTask/Assets/Plugins/UniTask/UniTaskCompletionSource.cs +++ b/src/UniTask/Assets/Plugins/UniTask/UniTaskCompletionSource.cs @@ -38,6 +38,14 @@ namespace Cysharp.Threading.Tasks { } + public interface IResolveCancelPromise : IResolvePromise, ICancelPromise + { + } + + public interface IResolveCancelPromise : IResolvePromise, ICancelPromise + { + } + [StructLayout(LayoutKind.Auto)] public struct UniTaskCompletionSourceCore {