Fix AsyncLazy can not await multiple times when task is not completed
parent
769b5c6bab
commit
0640f278cc
|
@ -0,0 +1,167 @@
|
||||||
|
using Cysharp.Threading.Tasks;
|
||||||
|
using FluentAssertions;
|
||||||
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using System.Linq;
|
||||||
|
using System.Text;
|
||||||
|
using System.Threading;
|
||||||
|
using System.Threading.Channels;
|
||||||
|
using Cysharp.Threading.Tasks.Linq;
|
||||||
|
using System.Threading.Tasks;
|
||||||
|
using Xunit;
|
||||||
|
|
||||||
|
namespace NetCoreTests
|
||||||
|
{
|
||||||
|
public class AsyncLazyTest
|
||||||
|
{
|
||||||
|
[Fact]
|
||||||
|
public async Task LazyLazy()
|
||||||
|
{
|
||||||
|
{
|
||||||
|
var l = UniTask.Lazy(() => After());
|
||||||
|
var a = AwaitAwait(l.Task);
|
||||||
|
var b = AwaitAwait(l.Task);
|
||||||
|
var c = AwaitAwait(l.Task);
|
||||||
|
|
||||||
|
await a;
|
||||||
|
await b;
|
||||||
|
await c;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
var l = UniTask.Lazy(() => AfterException());
|
||||||
|
var a = AwaitAwait(l.Task);
|
||||||
|
var b = AwaitAwait(l.Task);
|
||||||
|
var c = AwaitAwait(l.Task);
|
||||||
|
|
||||||
|
await Assert.ThrowsAsync<TaskTestException>(async () => await a);
|
||||||
|
await Assert.ThrowsAsync<TaskTestException>(async () => await b);
|
||||||
|
await Assert.ThrowsAsync<TaskTestException>(async () => await c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task LazyImmediate()
|
||||||
|
{
|
||||||
|
{
|
||||||
|
var l = UniTask.Lazy(() => UniTask.FromResult(1).AsUniTask());
|
||||||
|
var a = AwaitAwait(l.Task);
|
||||||
|
var b = AwaitAwait(l.Task);
|
||||||
|
var c = AwaitAwait(l.Task);
|
||||||
|
|
||||||
|
await a;
|
||||||
|
await b;
|
||||||
|
await c;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
var l = UniTask.Lazy(() => UniTask.FromException(new TaskTestException()));
|
||||||
|
var a = AwaitAwait(l.Task);
|
||||||
|
var b = AwaitAwait(l.Task);
|
||||||
|
var c = AwaitAwait(l.Task);
|
||||||
|
|
||||||
|
await Assert.ThrowsAsync<TaskTestException>(async () => await a);
|
||||||
|
await Assert.ThrowsAsync<TaskTestException>(async () => await b);
|
||||||
|
await Assert.ThrowsAsync<TaskTestException>(async () => await c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static async UniTask AwaitAwait(UniTask t)
|
||||||
|
{
|
||||||
|
await t;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async UniTask After()
|
||||||
|
{
|
||||||
|
await UniTask.Yield();
|
||||||
|
Thread.Sleep(TimeSpan.FromSeconds(1));
|
||||||
|
await UniTask.Yield();
|
||||||
|
await UniTask.Yield();
|
||||||
|
}
|
||||||
|
|
||||||
|
async UniTask AfterException()
|
||||||
|
{
|
||||||
|
await UniTask.Yield();
|
||||||
|
Thread.Sleep(TimeSpan.FromSeconds(1));
|
||||||
|
await UniTask.Yield();
|
||||||
|
throw new TaskTestException();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class AsyncLazyTest2
|
||||||
|
{
|
||||||
|
[Fact]
|
||||||
|
public async Task LazyLazy()
|
||||||
|
{
|
||||||
|
{
|
||||||
|
var l = UniTask.Lazy(() => After());
|
||||||
|
var a = AwaitAwait(l.Task);
|
||||||
|
var b = AwaitAwait(l.Task);
|
||||||
|
var c = AwaitAwait(l.Task);
|
||||||
|
|
||||||
|
var a2 = await a;
|
||||||
|
var b2 = await b;
|
||||||
|
var c2 = await c;
|
||||||
|
(a2, b2, c2).Should().Be((10, 10, 10));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
var l = UniTask.Lazy(() => AfterException());
|
||||||
|
var a = AwaitAwait(l.Task);
|
||||||
|
var b = AwaitAwait(l.Task);
|
||||||
|
var c = AwaitAwait(l.Task);
|
||||||
|
|
||||||
|
await Assert.ThrowsAsync<TaskTestException>(async () => await a);
|
||||||
|
await Assert.ThrowsAsync<TaskTestException>(async () => await b);
|
||||||
|
await Assert.ThrowsAsync<TaskTestException>(async () => await c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task LazyImmediate()
|
||||||
|
{
|
||||||
|
{
|
||||||
|
var l = UniTask.Lazy(() => UniTask.FromResult(1));
|
||||||
|
var a = AwaitAwait(l.Task);
|
||||||
|
var b = AwaitAwait(l.Task);
|
||||||
|
var c = AwaitAwait(l.Task);
|
||||||
|
|
||||||
|
var a2 = await a;
|
||||||
|
var b2 = await b;
|
||||||
|
var c2 = await c;
|
||||||
|
(a2, b2, c2).Should().Be((1, 1, 1));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
var l = UniTask.Lazy(() => UniTask.FromException<int>(new TaskTestException()));
|
||||||
|
var a = AwaitAwait(l.Task);
|
||||||
|
var b = AwaitAwait(l.Task);
|
||||||
|
var c = AwaitAwait(l.Task);
|
||||||
|
|
||||||
|
await Assert.ThrowsAsync<TaskTestException>(async () => await a);
|
||||||
|
await Assert.ThrowsAsync<TaskTestException>(async () => await b);
|
||||||
|
await Assert.ThrowsAsync<TaskTestException>(async () => await c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static async UniTask<int> AwaitAwait(UniTask<int> t)
|
||||||
|
{
|
||||||
|
return await t;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async UniTask<int> After()
|
||||||
|
{
|
||||||
|
await UniTask.Yield();
|
||||||
|
Thread.Sleep(TimeSpan.FromSeconds(1));
|
||||||
|
await UniTask.Yield();
|
||||||
|
await UniTask.Yield();
|
||||||
|
return 10;
|
||||||
|
}
|
||||||
|
|
||||||
|
async UniTask<int> AfterException()
|
||||||
|
{
|
||||||
|
await UniTask.Yield();
|
||||||
|
Thread.Sleep(TimeSpan.FromSeconds(1));
|
||||||
|
await UniTask.Yield();
|
||||||
|
throw new TaskTestException();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -7,113 +7,239 @@ namespace Cysharp.Threading.Tasks
|
||||||
{
|
{
|
||||||
public class AsyncLazy
|
public class AsyncLazy
|
||||||
{
|
{
|
||||||
Func<UniTask> valueFactory;
|
static Action<object> continuation = SetCompletionSource;
|
||||||
UniTask target;
|
|
||||||
|
Func<UniTask> taskFactory;
|
||||||
|
UniTaskCompletionSource completionSource;
|
||||||
|
UniTask.Awaiter awaiter;
|
||||||
|
|
||||||
object syncLock;
|
object syncLock;
|
||||||
bool initialized;
|
bool initialized;
|
||||||
|
|
||||||
public AsyncLazy(Func<UniTask> valueFactory)
|
public AsyncLazy(Func<UniTask> taskFactory)
|
||||||
{
|
{
|
||||||
this.valueFactory = valueFactory;
|
this.taskFactory = taskFactory;
|
||||||
this.target = default;
|
this.completionSource = new UniTaskCompletionSource();
|
||||||
this.syncLock = new object();
|
this.syncLock = new object();
|
||||||
this.initialized = false;
|
this.initialized = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
internal AsyncLazy(UniTask value)
|
internal AsyncLazy(UniTask task)
|
||||||
{
|
{
|
||||||
this.valueFactory = null;
|
this.taskFactory = null;
|
||||||
this.target = value;
|
this.completionSource = new UniTaskCompletionSource();
|
||||||
this.syncLock = null;
|
this.syncLock = null;
|
||||||
this.initialized = true;
|
this.initialized = true;
|
||||||
|
|
||||||
|
var awaiter = task.GetAwaiter();
|
||||||
|
if (awaiter.IsCompleted)
|
||||||
|
{
|
||||||
|
SetCompletionSource(awaiter);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
this.awaiter = awaiter;
|
||||||
|
awaiter.SourceOnCompleted(continuation, this);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public UniTask Task => EnsureInitialized();
|
public UniTask Task
|
||||||
|
{
|
||||||
|
get
|
||||||
|
{
|
||||||
|
EnsureInitialized();
|
||||||
|
return completionSource.Task;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public UniTask.Awaiter GetAwaiter() => EnsureInitialized().GetAwaiter();
|
|
||||||
|
|
||||||
UniTask EnsureInitialized()
|
public UniTask.Awaiter GetAwaiter() => Task.GetAwaiter();
|
||||||
|
|
||||||
|
void EnsureInitialized()
|
||||||
{
|
{
|
||||||
if (Volatile.Read(ref initialized))
|
if (Volatile.Read(ref initialized))
|
||||||
{
|
{
|
||||||
return target;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
return EnsureInitializedCore();
|
EnsureInitializedCore();
|
||||||
}
|
}
|
||||||
|
|
||||||
UniTask EnsureInitializedCore()
|
void EnsureInitializedCore()
|
||||||
{
|
{
|
||||||
lock (syncLock)
|
lock (syncLock)
|
||||||
{
|
{
|
||||||
if (!Volatile.Read(ref initialized))
|
if (!Volatile.Read(ref initialized))
|
||||||
{
|
{
|
||||||
var f = Interlocked.Exchange(ref valueFactory, null);
|
var f = Interlocked.Exchange(ref taskFactory, null);
|
||||||
if (f != null)
|
if (f != null)
|
||||||
{
|
{
|
||||||
target = f().Preserve(); // with preserve(allow multiple await).
|
var task = f();
|
||||||
|
var awaiter = task.GetAwaiter();
|
||||||
|
if (awaiter.IsCompleted)
|
||||||
|
{
|
||||||
|
SetCompletionSource(awaiter);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
this.awaiter = awaiter;
|
||||||
|
awaiter.SourceOnCompleted(continuation, this);
|
||||||
|
}
|
||||||
|
|
||||||
Volatile.Write(ref initialized, true);
|
Volatile.Write(ref initialized, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return target;
|
void SetCompletionSource(in UniTask.Awaiter awaiter)
|
||||||
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
awaiter.GetResult();
|
||||||
|
completionSource.TrySetResult();
|
||||||
|
}
|
||||||
|
catch (Exception ex)
|
||||||
|
{
|
||||||
|
completionSource.TrySetException(ex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void SetCompletionSource(object state)
|
||||||
|
{
|
||||||
|
var self = (AsyncLazy)state;
|
||||||
|
try
|
||||||
|
{
|
||||||
|
self.awaiter.GetResult();
|
||||||
|
self.completionSource.TrySetResult();
|
||||||
|
}
|
||||||
|
catch (Exception ex)
|
||||||
|
{
|
||||||
|
self.completionSource.TrySetException(ex);
|
||||||
|
}
|
||||||
|
finally
|
||||||
|
{
|
||||||
|
self.awaiter = default;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public class AsyncLazy<T>
|
public class AsyncLazy<T>
|
||||||
{
|
{
|
||||||
Func<UniTask<T>> valueFactory;
|
static Action<object> continuation = SetCompletionSource;
|
||||||
UniTask<T> target;
|
|
||||||
|
Func<UniTask<T>> taskFactory;
|
||||||
|
UniTaskCompletionSource<T> completionSource;
|
||||||
|
UniTask<T>.Awaiter awaiter;
|
||||||
|
|
||||||
object syncLock;
|
object syncLock;
|
||||||
bool initialized;
|
bool initialized;
|
||||||
|
|
||||||
public AsyncLazy(Func<UniTask<T>> valueFactory)
|
public AsyncLazy(Func<UniTask<T>> taskFactory)
|
||||||
{
|
{
|
||||||
this.valueFactory = valueFactory;
|
this.taskFactory = taskFactory;
|
||||||
this.target = default;
|
this.completionSource = new UniTaskCompletionSource<T>();
|
||||||
this.syncLock = new object();
|
this.syncLock = new object();
|
||||||
this.initialized = false;
|
this.initialized = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
internal AsyncLazy(UniTask<T> value)
|
internal AsyncLazy(UniTask<T> task)
|
||||||
{
|
{
|
||||||
this.valueFactory = null;
|
this.taskFactory = null;
|
||||||
this.target = value;
|
this.completionSource = new UniTaskCompletionSource<T>();
|
||||||
this.syncLock = null;
|
this.syncLock = null;
|
||||||
this.initialized = true;
|
this.initialized = true;
|
||||||
|
|
||||||
|
var awaiter = task.GetAwaiter();
|
||||||
|
if (awaiter.IsCompleted)
|
||||||
|
{
|
||||||
|
SetCompletionSource(awaiter);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
this.awaiter = awaiter;
|
||||||
|
awaiter.SourceOnCompleted(continuation, this);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public UniTask<T> Task => EnsureInitialized();
|
public UniTask<T> Task
|
||||||
|
{
|
||||||
|
get
|
||||||
|
{
|
||||||
|
EnsureInitialized();
|
||||||
|
return completionSource.Task;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public UniTask<T>.Awaiter GetAwaiter() => EnsureInitialized().GetAwaiter();
|
|
||||||
|
|
||||||
UniTask<T> EnsureInitialized()
|
public UniTask<T>.Awaiter GetAwaiter() => Task.GetAwaiter();
|
||||||
|
|
||||||
|
void EnsureInitialized()
|
||||||
{
|
{
|
||||||
if (Volatile.Read(ref initialized))
|
if (Volatile.Read(ref initialized))
|
||||||
{
|
{
|
||||||
return target;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
return EnsureInitializedCore();
|
EnsureInitializedCore();
|
||||||
}
|
}
|
||||||
|
|
||||||
UniTask<T> EnsureInitializedCore()
|
void EnsureInitializedCore()
|
||||||
{
|
{
|
||||||
lock (syncLock)
|
lock (syncLock)
|
||||||
{
|
{
|
||||||
if (!Volatile.Read(ref initialized))
|
if (!Volatile.Read(ref initialized))
|
||||||
{
|
{
|
||||||
var f = Interlocked.Exchange(ref valueFactory, null);
|
var f = Interlocked.Exchange(ref taskFactory, null);
|
||||||
if (f != null)
|
if (f != null)
|
||||||
{
|
{
|
||||||
target = f().Preserve(); // with preserve(allow multiple await).
|
var task = f();
|
||||||
|
var awaiter = task.GetAwaiter();
|
||||||
|
if (awaiter.IsCompleted)
|
||||||
|
{
|
||||||
|
SetCompletionSource(awaiter);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
this.awaiter = awaiter;
|
||||||
|
awaiter.SourceOnCompleted(continuation, this);
|
||||||
|
}
|
||||||
|
|
||||||
Volatile.Write(ref initialized, true);
|
Volatile.Write(ref initialized, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return target;
|
void SetCompletionSource(in UniTask<T>.Awaiter awaiter)
|
||||||
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
var result = awaiter.GetResult();
|
||||||
|
completionSource.TrySetResult(result);
|
||||||
|
}
|
||||||
|
catch (Exception ex)
|
||||||
|
{
|
||||||
|
completionSource.TrySetException(ex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void SetCompletionSource(object state)
|
||||||
|
{
|
||||||
|
var self = (AsyncLazy<T>)state;
|
||||||
|
try
|
||||||
|
{
|
||||||
|
var result = self.awaiter.GetResult();
|
||||||
|
self.completionSource.TrySetResult(result);
|
||||||
|
}
|
||||||
|
catch (Exception ex)
|
||||||
|
{
|
||||||
|
self.completionSource.TrySetException(ex);
|
||||||
|
}
|
||||||
|
finally
|
||||||
|
{
|
||||||
|
self.awaiter = default;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -696,6 +696,7 @@ namespace Cysharp.Threading.Tasks
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[DebuggerHidden]
|
||||||
bool TrySignalCompletion(UniTaskStatus status)
|
bool TrySignalCompletion(UniTaskStatus status)
|
||||||
{
|
{
|
||||||
if (Interlocked.CompareExchange(ref intStatus, (int)status, (int)UniTaskStatus.Pending) == (int)UniTaskStatus.Pending)
|
if (Interlocked.CompareExchange(ref intStatus, (int)status, (int)UniTaskStatus.Pending) == (int)UniTaskStatus.Pending)
|
||||||
|
@ -886,6 +887,7 @@ namespace Cysharp.Threading.Tasks
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[DebuggerHidden]
|
||||||
bool TrySignalCompletion(UniTaskStatus status)
|
bool TrySignalCompletion(UniTaskStatus status)
|
||||||
{
|
{
|
||||||
if (Interlocked.CompareExchange(ref intStatus, (int)status, (int)UniTaskStatus.Pending) == (int)UniTaskStatus.Pending)
|
if (Interlocked.CompareExchange(ref intStatus, (int)status, (int)UniTaskStatus.Pending) == (int)UniTaskStatus.Pending)
|
||||||
|
|
|
@ -181,12 +181,12 @@ namespace Cysharp.Threading.Tasks
|
||||||
|
|
||||||
public static AsyncLazy ToAsyncLazy(this UniTask task)
|
public static AsyncLazy ToAsyncLazy(this UniTask task)
|
||||||
{
|
{
|
||||||
return new AsyncLazy(task.Preserve()); // require Preserve
|
return new AsyncLazy(task);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static AsyncLazy<T> ToAsyncLazy<T>(this UniTask<T> task)
|
public static AsyncLazy<T> ToAsyncLazy<T>(this UniTask<T> task)
|
||||||
{
|
{
|
||||||
return new AsyncLazy<T>(task.Preserve()); // require Preserve
|
return new AsyncLazy<T>(task);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if UNITY_2018_3_OR_NEWER
|
#if UNITY_2018_3_OR_NEWER
|
||||||
|
|
Loading…
Reference in New Issue