diff --git a/src/UniTask.NetCoreSandbox/Program.cs b/src/UniTask.NetCoreSandbox/Program.cs index 65a63e5..d982b12 100644 --- a/src/UniTask.NetCoreSandbox/Program.cs +++ b/src/UniTask.NetCoreSandbox/Program.cs @@ -16,8 +16,16 @@ namespace NetCoreSandbox static async UniTask outer() { - var v = await DoAsync(); - return v; + //await Task.WhenAll(); + + //var foo = await Task.WhenAny(Array.Empty>()); + + + await UniTask.WhenAny(new UniTask[0]); + + return 10; + //var v = await DoAsync(); + //return v; } diff --git a/src/UniTask/Assets/Plugins/UniTask/Internal/ArrayPoolUtil.cs b/src/UniTask/Assets/Plugins/UniTask/Internal/ArrayPoolUtil.cs index 0892dc6..016901d 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Internal/ArrayPoolUtil.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Internal/ArrayPoolUtil.cs @@ -32,11 +32,21 @@ namespace Cysharp.Threading.Tasks.Internal } } - public static RentArray CopyToRentArray(IEnumerable source) + public static RentArray Materialize(IEnumerable source) { + if (source is T[] array) + { + return new RentArray(array, array.Length, null); + } + var defaultCount = 32; if (source is ICollection coll) { + if (coll.Count == 0) + { + return new RentArray(Array.Empty(), 0, null); + } + defaultCount = coll.Count; var pool = ArrayPool.Shared; var buffer = pool.Rent(defaultCount); diff --git a/src/UniTask/Assets/Plugins/UniTask/UniTask.WhenAll.cs b/src/UniTask/Assets/Plugins/UniTask/UniTask.WhenAll.cs index 5f8f340..714da5b 100644 --- a/src/UniTask/Assets/Plugins/UniTask/UniTask.WhenAll.cs +++ b/src/UniTask/Assets/Plugins/UniTask/UniTask.WhenAll.cs @@ -11,12 +11,17 @@ namespace Cysharp.Threading.Tasks { public static UniTask WhenAll(params UniTask[] tasks) { + if (tasks.Length == 0) + { + return UniTask.FromResult(Array.Empty()); + } + return new UniTask(new WhenAllPromise(tasks, tasks.Length), 0); } public static UniTask WhenAll(IEnumerable> tasks) { - using (var span = ArrayPoolUtil.CopyToRentArray(tasks)) + using (var span = ArrayPoolUtil.Materialize(tasks)) { var promise = new WhenAllPromise(span.Array, span.Length); // consumed array in constructor. return new UniTask(promise, 0); @@ -25,12 +30,17 @@ namespace Cysharp.Threading.Tasks public static UniTask WhenAll(params UniTask[] tasks) { + if (tasks.Length == 0) + { + return UniTask.CompletedTask; + } + return new UniTask(new WhenAllPromise(tasks, tasks.Length), 0); } public static UniTask WhenAll(IEnumerable tasks) { - using (var span = ArrayPoolUtil.CopyToRentArray(tasks)) + using (var span = ArrayPoolUtil.Materialize(tasks)) { var promise = new WhenAllPromise(span.Array, span.Length); // consumed array in constructor. return new UniTask(promise, 0); @@ -48,6 +58,14 @@ namespace Cysharp.Threading.Tasks TaskTracker.TrackActiveTask(this, 3); this.completeCount = 0; + + if (tasksLength == 0) + { + this.result = Array.Empty(); + core.TrySetResult(result); + return; + } + this.result = new T[tasksLength]; for (int i = 0; i < tasksLength; i++) @@ -144,6 +162,12 @@ namespace Cysharp.Threading.Tasks this.tasksLength = tasksLength; this.completeCount = 0; + if (tasksLength == 0) + { + core.TrySetResult(AsyncUnit.Default); + return; + } + for (int i = 0; i < tasksLength; i++) { UniTask.Awaiter awaiter; diff --git a/src/UniTask/Assets/Plugins/UniTask/UniTask.WhenAny.cs b/src/UniTask/Assets/Plugins/UniTask/UniTask.WhenAny.cs index 64de31c..15ecd1f 100644 --- a/src/UniTask/Assets/Plugins/UniTask/UniTask.WhenAny.cs +++ b/src/UniTask/Assets/Plugins/UniTask/UniTask.WhenAny.cs @@ -21,7 +21,7 @@ namespace Cysharp.Threading.Tasks public static UniTask<(int winArgumentIndex, T result)> WhenAny(IEnumerable> tasks) { - using (var span = ArrayPoolUtil.CopyToRentArray(tasks)) + using (var span = ArrayPoolUtil.Materialize(tasks)) { return new UniTask<(int, T)>(new WhenAnyPromise(span.Array, span.Length), 0); } @@ -36,7 +36,7 @@ namespace Cysharp.Threading.Tasks /// Return value is winArgumentIndex public static UniTask WhenAny(IEnumerable tasks) { - using (var span = ArrayPoolUtil.CopyToRentArray(tasks)) + using (var span = ArrayPoolUtil.Materialize(tasks)) { return new UniTask(new WhenAnyPromise(span.Array, span.Length), 0); } @@ -186,6 +186,11 @@ namespace Cysharp.Threading.Tasks public WhenAnyPromise(UniTask[] tasks, int tasksLength) { + if (tasksLength == 0) + { + throw new ArgumentException("The tasks argument contains no tasks."); + } + TaskTracker.TrackActiveTask(this, 3); for (int i = 0; i < tasksLength; i++) @@ -277,6 +282,11 @@ namespace Cysharp.Threading.Tasks public WhenAnyPromise(UniTask[] tasks, int tasksLength) { + if (tasksLength == 0) + { + throw new ArgumentException("The tasks argument contains no tasks."); + } + TaskTracker.TrackActiveTask(this, 3); for (int i = 0; i < tasksLength; i++)