diff --git a/src/UniTask.NetCoreTests/Linq/PulbishTest.cs b/src/UniTask.NetCoreTests/Linq/PulbishTest.cs new file mode 100644 index 0000000..c96bf47 --- /dev/null +++ b/src/UniTask.NetCoreTests/Linq/PulbishTest.cs @@ -0,0 +1,78 @@ +using Cysharp.Threading.Tasks; +using Cysharp.Threading.Tasks.Linq; +using FluentAssertions; +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace NetCoreTests.Linq +{ + public class PublishTest + { + [Fact] + public async Task Normal() + { + var rp = new AsyncReactiveProperty(1); + + var multicast = rp.Publish(); + + var a = multicast.ToArrayAsync(); + var b = multicast.Take(2).ToArrayAsync(); + + var disp = multicast.Connect(); + + rp.Value = 2; + + (await b).Should().BeEquivalentTo(1, 2); + + var c = multicast.ToArrayAsync(); + + rp.Value = 3; + rp.Value = 4; + rp.Value = 5; + + rp.Dispose(); + + (await a).Should().BeEquivalentTo(1, 2, 3, 4, 5); + (await c).Should().BeEquivalentTo(3, 4, 5); + + disp.Dispose(); + } + + [Fact] + public async Task Cancel() + { + var rp = new AsyncReactiveProperty(1); + + var multicast = rp.Publish(); + + var a = multicast.ToArrayAsync(); + var b = multicast.Take(2).ToArrayAsync(); + + var disp = multicast.Connect(); + + rp.Value = 2; + + (await b).Should().BeEquivalentTo(1, 2); + + var c = multicast.ToArrayAsync(); + + rp.Value = 3; + + disp.Dispose(); + + rp.Value = 4; + rp.Value = 5; + + rp.Dispose(); + + await Assert.ThrowsAsync(async () => await a); + await Assert.ThrowsAsync(async () => await c); + } + + + } +} diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/AsyncReactiveProperty.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/AsyncReactiveProperty.cs index e196222..f065760 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/AsyncReactiveProperty.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/AsyncReactiveProperty.cs @@ -144,6 +144,11 @@ namespace Cysharp.Threading.Tasks completionSource.TrySetResult(false); } + public void OnError(Exception ex) + { + completionSource.TrySetException(ex); + } + static void CancellationCallback(object state) { var self = (Enumerator)state; diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/IUniTaskAsyncEnumerable.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/IUniTaskAsyncEnumerable.cs index 6807908..847d430 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/IUniTaskAsyncEnumerable.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/IUniTaskAsyncEnumerable.cs @@ -28,6 +28,12 @@ namespace Cysharp.Threading.Tasks IUniTaskOrderedAsyncEnumerable CreateOrderedEnumerable(Func> keySelector, IComparer comparer, bool descending); } + public interface IConnectableUniTaskAsyncEnumerable : IUniTaskAsyncEnumerable + { + IDisposable Connect(); + } + + // don't use AsyncGrouping. //public interface IUniTaskAsyncGrouping : IUniTaskAsyncEnumerable //{ // TKey Key { get; } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Publish.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Publish.cs new file mode 100644 index 0000000..8b6d950 --- /dev/null +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Publish.cs @@ -0,0 +1,171 @@ +using Cysharp.Threading.Tasks.Internal; +using System; +using System.Threading; + +namespace Cysharp.Threading.Tasks.Linq +{ + public static partial class UniTaskAsyncEnumerable + { + public static IConnectableUniTaskAsyncEnumerable Publish(this IUniTaskAsyncEnumerable source) + { + Error.ThrowArgumentNullException(source, nameof(source)); + + return new Publish(source); + } + } + + internal sealed class Publish : IConnectableUniTaskAsyncEnumerable + { + readonly IUniTaskAsyncEnumerable source; + readonly CancellationTokenSource cancellationTokenSource; + + TriggerEvent trigger; + IUniTaskAsyncEnumerator enumerator; + IDisposable connectedDisposable; + bool isCompleted; + + public Publish(IUniTaskAsyncEnumerable source) + { + this.source = source; + this.cancellationTokenSource = new CancellationTokenSource(); + } + + public IDisposable Connect() + { + if (connectedDisposable != null) return connectedDisposable; + + if (enumerator == null) + { + enumerator = source.GetAsyncEnumerator(cancellationTokenSource.Token); + } + + ConsumeEnumerator().Forget(); + + connectedDisposable = new ConnectDisposable(cancellationTokenSource); + return connectedDisposable; + } + + async UniTaskVoid ConsumeEnumerator() + { + try + { + try + { + while (await enumerator.MoveNextAsync()) + { + trigger.SetResult(enumerator.Current); + } + trigger.SetCompleted(); + } + catch (Exception ex) + { + trigger.SetError(ex); + } + } + finally + { + isCompleted = true; + await enumerator.DisposeAsync(); + } + } + + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new _Publish(this, cancellationToken); + } + + sealed class ConnectDisposable : IDisposable + { + readonly CancellationTokenSource cancellationTokenSource; + + public ConnectDisposable(CancellationTokenSource cancellationTokenSource) + { + this.cancellationTokenSource = cancellationTokenSource; + } + + public void Dispose() + { + this.cancellationTokenSource.Cancel(); + } + } + + sealed class _Publish : MoveNextSource, IUniTaskAsyncEnumerator, ITriggerHandler + { + static readonly Action CancelDelegate = OnCanceled; + + readonly Publish parent; + CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + bool isDisposed; + + public _Publish(Publish parent, CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) return; + + this.parent = parent; + this.cancellationToken = cancellationToken; + + if (cancellationToken.CanBeCanceled) + { + this.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(CancelDelegate, this); + } + + parent.trigger.Add(this); + TaskTracker.TrackActiveTask(this, 3); + } + + public TSource Current { get; private set; } + + public UniTask MoveNextAsync() + { + cancellationToken.ThrowIfCancellationRequested(); + + if (parent.isCompleted) return CompletedTasks.False; + + completionSource.Reset(); + return new UniTask(this, completionSource.Version); + } + + static void OnCanceled(object state) + { + var self = (_Publish)state; + self.completionSource.TrySetCanceled(self.cancellationToken); + self.DisposeAsync().Forget(); + } + + public UniTask DisposeAsync() + { + if (!isDisposed) + { + isDisposed = true; + TaskTracker.RemoveTracking(this); + cancellationTokenRegistration.Dispose(); + parent.trigger.Remove(this); + } + + return default; + } + + public void OnNext(TSource value) + { + Current = value; + completionSource.TrySetResult(true); + } + + public void OnCanceled(CancellationToken cancellationToken) + { + completionSource.TrySetCanceled(cancellationToken); + } + + public void OnCompleted() + { + completionSource.TrySetResult(false); + } + + public void OnError(Exception ex) + { + completionSource.TrySetException(ex); + } + } + } +} \ No newline at end of file diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Publish.cs.meta b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Publish.cs.meta new file mode 100644 index 0000000..f3a81ba --- /dev/null +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Publish.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 93c684d1e88c09d4e89b79437d97b810 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/TriggerEvent.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/TriggerEvent.cs index 1e95590..426f0cc 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/TriggerEvent.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/TriggerEvent.cs @@ -7,8 +7,9 @@ namespace Cysharp.Threading.Tasks public interface ITriggerHandler { void OnNext(T value); - void OnCanceled(CancellationToken cancellationToken); + void OnError(Exception ex); void OnCompleted(); + void OnCanceled(CancellationToken cancellationToken); } // be careful to use, itself is struct. @@ -207,6 +208,67 @@ namespace Cysharp.Threading.Tasks } } + public void SetError(Exception exception) + { + isRunning = true; + + if (singleHandler != null) + { + try + { + singleHandler.OnError(exception); + } + catch (Exception ex) + { +#if UNITY_2018_3_OR_NEWER + UnityEngine.Debug.LogException(ex); +#else + Console.WriteLine(ex); +#endif + } + } + + if (handlers != null) + { + for (int i = 0; i < handlers.Length; i++) + { + if (handlers[i] != null) + { + try + { + handlers[i].OnError(exception); + } + catch (Exception ex) + { + handlers[i] = null; +#if UNITY_2018_3_OR_NEWER + UnityEngine.Debug.LogException(ex); +#else + Console.WriteLine(ex); +#endif + } + } + } + } + + isRunning = false; + + if (waitHandler != null) + { + var h = waitHandler; + waitHandler = null; + Add(h); + } + + if (waitQueue != null) + { + while (waitQueue.Count != 0) + { + Add(waitQueue.Dequeue()); + } + } + } + public void Add(ITriggerHandler handler) { if (isRunning) diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Triggers/AsyncTriggerBase.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Triggers/AsyncTriggerBase.cs index f75fac8..6160cde 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Triggers/AsyncTriggerBase.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Triggers/AsyncTriggerBase.cs @@ -89,6 +89,11 @@ namespace Cysharp.Threading.Tasks.Triggers completionSource.TrySetResult(false); } + public void OnError(Exception ex) + { + completionSource.TrySetException(ex); + } + static void CancellationCallback(object state) { var self = (AsyncTriggerEnumerator)state; @@ -273,6 +278,11 @@ namespace Cysharp.Threading.Tasks.Triggers core.TrySetCanceled(CancellationToken.None); } + void ITriggerHandler.OnError(Exception ex) + { + core.TrySetException(ex); + } + void IUniTaskSource.GetResult(short token) { ((IUniTaskSource)this).GetResult(token); diff --git a/src/UniTask/Assets/Scenes/SandboxMain.cs b/src/UniTask/Assets/Scenes/SandboxMain.cs index 512a600..c6f48cc 100644 --- a/src/UniTask/Assets/Scenes/SandboxMain.cs +++ b/src/UniTask/Assets/Scenes/SandboxMain.cs @@ -49,23 +49,15 @@ public static partial class UnityUIComponentExtensions public class AsyncMessageBroker : IDisposable { Channel channel; - List> asyncEvents; + + IConnectableUniTaskAsyncEnumerable multicastSource; + IDisposable connection; public AsyncMessageBroker() { channel = Channel.CreateSingleConsumerUnbounded(); - asyncEvents = new List>(); - } - - async UniTaskVoid PublishAll() - { - await channel.Reader.ReadAllAsync().ForEachAwaitAsync(async x => - { - foreach (var item in asyncEvents) - { - await item.Invoke(x); - } - }); + multicastSource = channel.Reader.ReadAllAsync().Publish(); + connection = multicastSource.Connect(); } public void Publish(T value) @@ -73,33 +65,15 @@ public class AsyncMessageBroker : IDisposable channel.Writer.TryWrite(value); } - public Subscription Subscribe(Func func) + public IUniTaskAsyncEnumerable Subscribe() { - asyncEvents.Add(func); - return new Subscription(this, func); + return multicastSource; } public void Dispose() { channel.Writer.TryComplete(); - asyncEvents.Clear(); - } - - public readonly struct Subscription : IDisposable - { - readonly AsyncMessageBroker broker; - readonly Func func; - - public Subscription(AsyncMessageBroker broker, Func func) - { - this.broker = broker; - this.func = func; - } - - public void Dispose() - { - broker.asyncEvents.Remove(func); - } + connection.Dispose(); } } @@ -205,9 +179,22 @@ public class SandboxMain : MonoBehaviour //await channel.Reader.ReadAllAsync(this.GetCancellationTokenOnDestroy()).ForEachAsync(_ => { }); - var rp = new AsyncReactiveProperty(10); + var pubsub = new AsyncMessageBroker(); + + pubsub.Subscribe().ForEachAsync(x => Debug.Log("A:" + x)).Forget(); + pubsub.Subscribe().ForEachAsync(x => Debug.Log("B:" + x)).Forget(); + + + int i = 0; + okButton.OnClickAsAsyncEnumerable().ForEachAsync(_ => + { + + Debug.Log("foo"); + pubsub.Publish(i++); + + + }).Forget(); - rp.Append(10).Select(x => x * 100).Take(30).Prepend(99).SkipLast(9).Where(x => x % 2 == 0).ForEachAsync(_ => { }).Forget(); }