Fix AsyncLazy can not await multiple times when task is not completed

master
neuecc 2020-06-18 03:02:01 +09:00
parent 769b5c6bab
commit 0640f278cc
4 changed files with 331 additions and 36 deletions

View File

@ -0,0 +1,167 @@
using Cysharp.Threading.Tasks;
using FluentAssertions;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Channels;
using Cysharp.Threading.Tasks.Linq;
using System.Threading.Tasks;
using Xunit;
namespace NetCoreTests
{
public class AsyncLazyTest
{
[Fact]
public async Task LazyLazy()
{
{
var l = UniTask.Lazy(() => After());
var a = AwaitAwait(l.Task);
var b = AwaitAwait(l.Task);
var c = AwaitAwait(l.Task);
await a;
await b;
await c;
}
{
var l = UniTask.Lazy(() => AfterException());
var a = AwaitAwait(l.Task);
var b = AwaitAwait(l.Task);
var c = AwaitAwait(l.Task);
await Assert.ThrowsAsync<TaskTestException>(async () => await a);
await Assert.ThrowsAsync<TaskTestException>(async () => await b);
await Assert.ThrowsAsync<TaskTestException>(async () => await c);
}
}
[Fact]
public async Task LazyImmediate()
{
{
var l = UniTask.Lazy(() => UniTask.FromResult(1).AsUniTask());
var a = AwaitAwait(l.Task);
var b = AwaitAwait(l.Task);
var c = AwaitAwait(l.Task);
await a;
await b;
await c;
}
{
var l = UniTask.Lazy(() => UniTask.FromException(new TaskTestException()));
var a = AwaitAwait(l.Task);
var b = AwaitAwait(l.Task);
var c = AwaitAwait(l.Task);
await Assert.ThrowsAsync<TaskTestException>(async () => await a);
await Assert.ThrowsAsync<TaskTestException>(async () => await b);
await Assert.ThrowsAsync<TaskTestException>(async () => await c);
}
}
static async UniTask AwaitAwait(UniTask t)
{
await t;
}
async UniTask After()
{
await UniTask.Yield();
Thread.Sleep(TimeSpan.FromSeconds(1));
await UniTask.Yield();
await UniTask.Yield();
}
async UniTask AfterException()
{
await UniTask.Yield();
Thread.Sleep(TimeSpan.FromSeconds(1));
await UniTask.Yield();
throw new TaskTestException();
}
}
public class AsyncLazyTest2
{
[Fact]
public async Task LazyLazy()
{
{
var l = UniTask.Lazy(() => After());
var a = AwaitAwait(l.Task);
var b = AwaitAwait(l.Task);
var c = AwaitAwait(l.Task);
var a2 = await a;
var b2 = await b;
var c2 = await c;
(a2, b2, c2).Should().Be((10, 10, 10));
}
{
var l = UniTask.Lazy(() => AfterException());
var a = AwaitAwait(l.Task);
var b = AwaitAwait(l.Task);
var c = AwaitAwait(l.Task);
await Assert.ThrowsAsync<TaskTestException>(async () => await a);
await Assert.ThrowsAsync<TaskTestException>(async () => await b);
await Assert.ThrowsAsync<TaskTestException>(async () => await c);
}
}
[Fact]
public async Task LazyImmediate()
{
{
var l = UniTask.Lazy(() => UniTask.FromResult(1));
var a = AwaitAwait(l.Task);
var b = AwaitAwait(l.Task);
var c = AwaitAwait(l.Task);
var a2 = await a;
var b2 = await b;
var c2 = await c;
(a2, b2, c2).Should().Be((1, 1, 1));
}
{
var l = UniTask.Lazy(() => UniTask.FromException<int>(new TaskTestException()));
var a = AwaitAwait(l.Task);
var b = AwaitAwait(l.Task);
var c = AwaitAwait(l.Task);
await Assert.ThrowsAsync<TaskTestException>(async () => await a);
await Assert.ThrowsAsync<TaskTestException>(async () => await b);
await Assert.ThrowsAsync<TaskTestException>(async () => await c);
}
}
static async UniTask<int> AwaitAwait(UniTask<int> t)
{
return await t;
}
async UniTask<int> After()
{
await UniTask.Yield();
Thread.Sleep(TimeSpan.FromSeconds(1));
await UniTask.Yield();
await UniTask.Yield();
return 10;
}
async UniTask<int> AfterException()
{
await UniTask.Yield();
Thread.Sleep(TimeSpan.FromSeconds(1));
await UniTask.Yield();
throw new TaskTestException();
}
}
}

View File

@ -7,113 +7,239 @@ namespace Cysharp.Threading.Tasks
{ {
public class AsyncLazy public class AsyncLazy
{ {
Func<UniTask> valueFactory; static Action<object> continuation = SetCompletionSource;
UniTask target;
Func<UniTask> taskFactory;
UniTaskCompletionSource completionSource;
UniTask.Awaiter awaiter;
object syncLock; object syncLock;
bool initialized; bool initialized;
public AsyncLazy(Func<UniTask> valueFactory) public AsyncLazy(Func<UniTask> taskFactory)
{ {
this.valueFactory = valueFactory; this.taskFactory = taskFactory;
this.target = default; this.completionSource = new UniTaskCompletionSource();
this.syncLock = new object(); this.syncLock = new object();
this.initialized = false; this.initialized = false;
} }
internal AsyncLazy(UniTask value) internal AsyncLazy(UniTask task)
{ {
this.valueFactory = null; this.taskFactory = null;
this.target = value; this.completionSource = new UniTaskCompletionSource();
this.syncLock = null; this.syncLock = null;
this.initialized = true; this.initialized = true;
var awaiter = task.GetAwaiter();
if (awaiter.IsCompleted)
{
SetCompletionSource(awaiter);
}
else
{
this.awaiter = awaiter;
awaiter.SourceOnCompleted(continuation, this);
}
} }
public UniTask Task => EnsureInitialized(); public UniTask Task
{
get
{
EnsureInitialized();
return completionSource.Task;
}
}
public UniTask.Awaiter GetAwaiter() => EnsureInitialized().GetAwaiter();
UniTask EnsureInitialized() public UniTask.Awaiter GetAwaiter() => Task.GetAwaiter();
void EnsureInitialized()
{ {
if (Volatile.Read(ref initialized)) if (Volatile.Read(ref initialized))
{ {
return target; return;
} }
return EnsureInitializedCore(); EnsureInitializedCore();
} }
UniTask EnsureInitializedCore() void EnsureInitializedCore()
{ {
lock (syncLock) lock (syncLock)
{ {
if (!Volatile.Read(ref initialized)) if (!Volatile.Read(ref initialized))
{ {
var f = Interlocked.Exchange(ref valueFactory, null); var f = Interlocked.Exchange(ref taskFactory, null);
if (f != null) if (f != null)
{ {
target = f().Preserve(); // with preserve(allow multiple await). var task = f();
var awaiter = task.GetAwaiter();
if (awaiter.IsCompleted)
{
SetCompletionSource(awaiter);
}
else
{
this.awaiter = awaiter;
awaiter.SourceOnCompleted(continuation, this);
}
Volatile.Write(ref initialized, true); Volatile.Write(ref initialized, true);
} }
} }
} }
}
return target; void SetCompletionSource(in UniTask.Awaiter awaiter)
{
try
{
awaiter.GetResult();
completionSource.TrySetResult();
}
catch (Exception ex)
{
completionSource.TrySetException(ex);
}
}
static void SetCompletionSource(object state)
{
var self = (AsyncLazy)state;
try
{
self.awaiter.GetResult();
self.completionSource.TrySetResult();
}
catch (Exception ex)
{
self.completionSource.TrySetException(ex);
}
finally
{
self.awaiter = default;
}
} }
} }
public class AsyncLazy<T> public class AsyncLazy<T>
{ {
Func<UniTask<T>> valueFactory; static Action<object> continuation = SetCompletionSource;
UniTask<T> target;
Func<UniTask<T>> taskFactory;
UniTaskCompletionSource<T> completionSource;
UniTask<T>.Awaiter awaiter;
object syncLock; object syncLock;
bool initialized; bool initialized;
public AsyncLazy(Func<UniTask<T>> valueFactory) public AsyncLazy(Func<UniTask<T>> taskFactory)
{ {
this.valueFactory = valueFactory; this.taskFactory = taskFactory;
this.target = default; this.completionSource = new UniTaskCompletionSource<T>();
this.syncLock = new object(); this.syncLock = new object();
this.initialized = false; this.initialized = false;
} }
internal AsyncLazy(UniTask<T> value) internal AsyncLazy(UniTask<T> task)
{ {
this.valueFactory = null; this.taskFactory = null;
this.target = value; this.completionSource = new UniTaskCompletionSource<T>();
this.syncLock = null; this.syncLock = null;
this.initialized = true; this.initialized = true;
var awaiter = task.GetAwaiter();
if (awaiter.IsCompleted)
{
SetCompletionSource(awaiter);
}
else
{
this.awaiter = awaiter;
awaiter.SourceOnCompleted(continuation, this);
}
} }
public UniTask<T> Task => EnsureInitialized(); public UniTask<T> Task
{
get
{
EnsureInitialized();
return completionSource.Task;
}
}
public UniTask<T>.Awaiter GetAwaiter() => EnsureInitialized().GetAwaiter();
UniTask<T> EnsureInitialized() public UniTask<T>.Awaiter GetAwaiter() => Task.GetAwaiter();
void EnsureInitialized()
{ {
if (Volatile.Read(ref initialized)) if (Volatile.Read(ref initialized))
{ {
return target; return;
} }
return EnsureInitializedCore(); EnsureInitializedCore();
} }
UniTask<T> EnsureInitializedCore() void EnsureInitializedCore()
{ {
lock (syncLock) lock (syncLock)
{ {
if (!Volatile.Read(ref initialized)) if (!Volatile.Read(ref initialized))
{ {
var f = Interlocked.Exchange(ref valueFactory, null); var f = Interlocked.Exchange(ref taskFactory, null);
if (f != null) if (f != null)
{ {
target = f().Preserve(); // with preserve(allow multiple await). var task = f();
var awaiter = task.GetAwaiter();
if (awaiter.IsCompleted)
{
SetCompletionSource(awaiter);
}
else
{
this.awaiter = awaiter;
awaiter.SourceOnCompleted(continuation, this);
}
Volatile.Write(ref initialized, true); Volatile.Write(ref initialized, true);
} }
} }
} }
}
return target; void SetCompletionSource(in UniTask<T>.Awaiter awaiter)
{
try
{
var result = awaiter.GetResult();
completionSource.TrySetResult(result);
}
catch (Exception ex)
{
completionSource.TrySetException(ex);
}
}
static void SetCompletionSource(object state)
{
var self = (AsyncLazy<T>)state;
try
{
var result = self.awaiter.GetResult();
self.completionSource.TrySetResult(result);
}
catch (Exception ex)
{
self.completionSource.TrySetException(ex);
}
finally
{
self.awaiter = default;
}
} }
} }
} }

View File

@ -696,6 +696,7 @@ namespace Cysharp.Threading.Tasks
} }
} }
[DebuggerHidden]
bool TrySignalCompletion(UniTaskStatus status) bool TrySignalCompletion(UniTaskStatus status)
{ {
if (Interlocked.CompareExchange(ref intStatus, (int)status, (int)UniTaskStatus.Pending) == (int)UniTaskStatus.Pending) if (Interlocked.CompareExchange(ref intStatus, (int)status, (int)UniTaskStatus.Pending) == (int)UniTaskStatus.Pending)
@ -886,6 +887,7 @@ namespace Cysharp.Threading.Tasks
} }
} }
[DebuggerHidden]
bool TrySignalCompletion(UniTaskStatus status) bool TrySignalCompletion(UniTaskStatus status)
{ {
if (Interlocked.CompareExchange(ref intStatus, (int)status, (int)UniTaskStatus.Pending) == (int)UniTaskStatus.Pending) if (Interlocked.CompareExchange(ref intStatus, (int)status, (int)UniTaskStatus.Pending) == (int)UniTaskStatus.Pending)

View File

@ -181,12 +181,12 @@ namespace Cysharp.Threading.Tasks
public static AsyncLazy ToAsyncLazy(this UniTask task) public static AsyncLazy ToAsyncLazy(this UniTask task)
{ {
return new AsyncLazy(task.Preserve()); // require Preserve return new AsyncLazy(task);
} }
public static AsyncLazy<T> ToAsyncLazy<T>(this UniTask<T> task) public static AsyncLazy<T> ToAsyncLazy<T>(this UniTask<T> task)
{ {
return new AsyncLazy<T>(task.Preserve()); // require Preserve return new AsyncLazy<T>(task);
} }
#if UNITY_2018_3_OR_NEWER #if UNITY_2018_3_OR_NEWER