Fix IUniTaskAsyncEnumerable.Take

master
neuecc 2020-05-18 11:29:35 +09:00
parent 6f4d1183cc
commit 49ba57f20a
1 changed files with 67 additions and 14 deletions

View File

@ -30,40 +30,93 @@ namespace Cysharp.Threading.Tasks.Linq
return new _Take(source, count, cancellationToken);
}
sealed class _Take : AsyncEnumeratorBase<TSource, TSource>
sealed class _Take : MoveNextSource, IUniTaskAsyncEnumerator<TSource>
{
readonly int count;
static readonly Action<object> MoveNextCoreDelegate = MoveNextCore;
readonly IUniTaskAsyncEnumerable<TSource> source;
readonly int count;
CancellationToken cancellationToken;
IUniTaskAsyncEnumerator<TSource> enumerator;
UniTask<bool>.Awaiter awaiter;
int index;
public _Take(IUniTaskAsyncEnumerable<TSource> source, int count, CancellationToken cancellationToken)
: base(source, cancellationToken)
{
this.source = source;
this.count = count;
this.cancellationToken = cancellationToken;
}
protected override bool TryMoveNextCore(bool sourceHasCurrent, out bool result)
public TSource Current { get; private set; }
public UniTask<bool> MoveNextAsync()
{
if (sourceHasCurrent)
cancellationToken.ThrowIfCancellationRequested();
if (enumerator == null)
{
if (checked(index++) < count)
enumerator = source.GetAsyncEnumerator(cancellationToken);
}
if (checked(index) >= count)
{
return CompletedTasks.False;
}
completionSource.Reset();
SourceMoveNext();
return new UniTask<bool>(this, completionSource.Version);
}
void SourceMoveNext()
{
try
{
awaiter = enumerator.MoveNextAsync().GetAwaiter();
if (awaiter.IsCompleted)
{
Current = SourceCurrent;
result = true;
return true;
MoveNextCore(this);
}
else
{
result = false;
return true;
awaiter.SourceOnCompleted(MoveNextCoreDelegate, this);
}
}
else
catch (Exception ex)
{
result = false;
return true;
completionSource.TrySetException(ex);
}
}
static void MoveNextCore(object state)
{
var self = (_Take)state;
if (self.TryGetResult(self.awaiter, out var result))
{
if (result)
{
self.index++;
self.Current = self.enumerator.Current;
self.completionSource.TrySetResult(true);
}
else
{
self.completionSource.TrySetResult(false);
}
}
}
public UniTask DisposeAsync()
{
if (enumerator != null)
{
return enumerator.DisposeAsync();
}
return default;
}
}
}
}