diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Threading.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Threading.cs index 39cd70f..4735dad 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Threading.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Threading.cs @@ -15,33 +15,33 @@ namespace Cysharp.Threading.Tasks /// /// If running on mainthread, do nothing. Otherwise, same as UniTask.Yield(PlayerLoopTiming.Update). /// - public static SwitchToMainThreadAwaitable SwitchToMainThread() + public static SwitchToMainThreadAwaitable SwitchToMainThread(CancellationToken cancellationToken = default) { - return new SwitchToMainThreadAwaitable(PlayerLoopTiming.Update); + return new SwitchToMainThreadAwaitable(PlayerLoopTiming.Update, cancellationToken); } /// /// If running on mainthread, do nothing. Otherwise, same as UniTask.Yield(timing). /// - public static SwitchToMainThreadAwaitable SwitchToMainThread(PlayerLoopTiming timing) + public static SwitchToMainThreadAwaitable SwitchToMainThread(PlayerLoopTiming timing, CancellationToken cancellationToken = default) { - return new SwitchToMainThreadAwaitable(timing); + return new SwitchToMainThreadAwaitable(timing, cancellationToken); } /// /// Return to mainthread(same as await SwitchToMainThread) after using scope is closed. /// - public static ReturnToMainThread ReturnToMainThread() + public static ReturnToMainThread ReturnToMainThread(CancellationToken cancellationToken = default) { - return new ReturnToMainThread(PlayerLoopTiming.Update); + return new ReturnToMainThread(PlayerLoopTiming.Update, cancellationToken); } /// /// Return to mainthread(same as await SwitchToMainThread) after using scope is closed. /// - public static ReturnToMainThread ReturnToMainThread(PlayerLoopTiming timing) + public static ReturnToMainThread ReturnToMainThread(PlayerLoopTiming timing, CancellationToken cancellationToken = default) { - return new ReturnToMainThread(timing); + return new ReturnToMainThread(timing, cancellationToken); } /// @@ -67,20 +67,20 @@ namespace Cysharp.Threading.Tasks return new SwitchToTaskPoolAwaitable(); } - public static SwitchToSynchronizationContextAwaitable SwitchToSynchronizationContext(SynchronizationContext synchronizationContext) + public static SwitchToSynchronizationContextAwaitable SwitchToSynchronizationContext(SynchronizationContext synchronizationContext, CancellationToken cancellationToken = default) { Error.ThrowArgumentNullException(synchronizationContext, nameof(synchronizationContext)); - return new SwitchToSynchronizationContextAwaitable(synchronizationContext); + return new SwitchToSynchronizationContextAwaitable(synchronizationContext, cancellationToken); } - public static ReturnToSynchronizationContext ReturnToSynchronizationContext(SynchronizationContext synchronizationContext) + public static ReturnToSynchronizationContext ReturnToSynchronizationContext(SynchronizationContext synchronizationContext, CancellationToken cancellationToken = default) { - return new ReturnToSynchronizationContext(synchronizationContext, false); + return new ReturnToSynchronizationContext(synchronizationContext, false, cancellationToken); } - public static ReturnToSynchronizationContext ReturnToCurrentSynchronizationContext(bool dontPostWhenSameContext = true) + public static ReturnToSynchronizationContext ReturnToCurrentSynchronizationContext(bool dontPostWhenSameContext = true, CancellationToken cancellationToken = default) { - return new ReturnToSynchronizationContext(SynchronizationContext.Current, dontPostWhenSameContext); + return new ReturnToSynchronizationContext(SynchronizationContext.Current, dontPostWhenSameContext, cancellationToken); } } @@ -89,21 +89,25 @@ namespace Cysharp.Threading.Tasks public struct SwitchToMainThreadAwaitable { readonly PlayerLoopTiming playerLoopTiming; + readonly CancellationToken cancellationToken; - public SwitchToMainThreadAwaitable(PlayerLoopTiming playerLoopTiming) + public SwitchToMainThreadAwaitable(PlayerLoopTiming playerLoopTiming, CancellationToken cancellationToken) { this.playerLoopTiming = playerLoopTiming; + this.cancellationToken = cancellationToken; } - public Awaiter GetAwaiter() => new Awaiter(playerLoopTiming); + public Awaiter GetAwaiter() => new Awaiter(playerLoopTiming, cancellationToken); public struct Awaiter : ICriticalNotifyCompletion { readonly PlayerLoopTiming playerLoopTiming; + readonly CancellationToken cancellationToken; - public Awaiter(PlayerLoopTiming playerLoopTiming) + public Awaiter(PlayerLoopTiming playerLoopTiming, CancellationToken cancellationToken) { this.playerLoopTiming = playerLoopTiming; + this.cancellationToken = cancellationToken; } public bool IsCompleted @@ -122,7 +126,7 @@ namespace Cysharp.Threading.Tasks } } - public void GetResult() { } + public void GetResult() { cancellationToken.ThrowIfCancellationRequested(); } public void OnCompleted(Action continuation) { @@ -139,31 +143,35 @@ namespace Cysharp.Threading.Tasks public struct ReturnToMainThread { readonly PlayerLoopTiming playerLoopTiming; + readonly CancellationToken cancellationToken; - public ReturnToMainThread(PlayerLoopTiming playerLoopTiming) + public ReturnToMainThread(PlayerLoopTiming playerLoopTiming, CancellationToken cancellationToken) { this.playerLoopTiming = playerLoopTiming; + this.cancellationToken = cancellationToken; } public Awaiter DisposeAsync() { - return new Awaiter(playerLoopTiming); // run immediate. + return new Awaiter(playerLoopTiming, cancellationToken); // run immediate. } public readonly struct Awaiter : ICriticalNotifyCompletion { readonly PlayerLoopTiming timing; + readonly CancellationToken cancellationToken; - public Awaiter(PlayerLoopTiming timing) + public Awaiter(PlayerLoopTiming timing, CancellationToken cancellationToken) { this.timing = timing; + this.cancellationToken = cancellationToken; } public Awaiter GetAwaiter() => this; public bool IsCompleted => PlayerLoopHelper.MainThreadId == System.Threading.Thread.CurrentThread.ManagedThreadId; - public void GetResult() { } + public void GetResult() { cancellationToken.ThrowIfCancellationRequested(); } public void OnCompleted(Action continuation) { @@ -285,26 +293,30 @@ namespace Cysharp.Threading.Tasks public struct SwitchToSynchronizationContextAwaitable { readonly SynchronizationContext synchronizationContext; + readonly CancellationToken cancellationToken; - public SwitchToSynchronizationContextAwaitable(SynchronizationContext synchronizationContext) + public SwitchToSynchronizationContextAwaitable(SynchronizationContext synchronizationContext, CancellationToken cancellationToken) { this.synchronizationContext = synchronizationContext; + this.cancellationToken = cancellationToken; } - public Awaiter GetAwaiter() => new Awaiter(synchronizationContext); + public Awaiter GetAwaiter() => new Awaiter(synchronizationContext, cancellationToken); public struct Awaiter : ICriticalNotifyCompletion { static readonly SendOrPostCallback switchToCallback = Callback; readonly SynchronizationContext synchronizationContext; + readonly CancellationToken cancellationToken; - public Awaiter(SynchronizationContext synchronizationContext) + public Awaiter(SynchronizationContext synchronizationContext, CancellationToken cancellationToken) { this.synchronizationContext = synchronizationContext; + this.cancellationToken = cancellationToken; } public bool IsCompleted => false; - public void GetResult() { } + public void GetResult() { cancellationToken.ThrowIfCancellationRequested(); } public void OnCompleted(Action continuation) { @@ -328,16 +340,18 @@ namespace Cysharp.Threading.Tasks { readonly SynchronizationContext syncContext; readonly bool dontPostWhenSameContext; + readonly CancellationToken cancellationToken; - public ReturnToSynchronizationContext(SynchronizationContext syncContext, bool dontPostWhenSameContext) + public ReturnToSynchronizationContext(SynchronizationContext syncContext, bool dontPostWhenSameContext, CancellationToken cancellationToken) { this.syncContext = syncContext; this.dontPostWhenSameContext = dontPostWhenSameContext; + this.cancellationToken = cancellationToken; } public Awaiter DisposeAsync() { - return new Awaiter(syncContext, dontPostWhenSameContext); + return new Awaiter(syncContext, dontPostWhenSameContext, cancellationToken); } public struct Awaiter : ICriticalNotifyCompletion @@ -346,11 +360,13 @@ namespace Cysharp.Threading.Tasks readonly SynchronizationContext synchronizationContext; readonly bool dontPostWhenSameContext; + readonly CancellationToken cancellationToken; - public Awaiter(SynchronizationContext synchronizationContext, bool dontPostWhenSameContext) + public Awaiter(SynchronizationContext synchronizationContext, bool dontPostWhenSameContext, CancellationToken cancellationToken) { this.synchronizationContext = synchronizationContext; this.dontPostWhenSameContext = dontPostWhenSameContext; + this.cancellationToken = cancellationToken; } public Awaiter GetAwaiter() => this; @@ -373,7 +389,7 @@ namespace Cysharp.Threading.Tasks } } - public void GetResult() { } + public void GetResult() { cancellationToken.ThrowIfCancellationRequested(); } public void OnCompleted(Action continuation) {