Fix IUniTaskAsyncEnumerable.Take

master
neuecc 2020-05-18 11:29:35 +09:00
parent 6f4d1183cc
commit 49ba57f20a
1 changed files with 67 additions and 14 deletions

View File

@ -30,40 +30,93 @@ namespace Cysharp.Threading.Tasks.Linq
return new _Take(source, count, cancellationToken); return new _Take(source, count, cancellationToken);
} }
sealed class _Take : AsyncEnumeratorBase<TSource, TSource> sealed class _Take : MoveNextSource, IUniTaskAsyncEnumerator<TSource>
{ {
readonly int count; static readonly Action<object> MoveNextCoreDelegate = MoveNextCore;
readonly IUniTaskAsyncEnumerable<TSource> source;
readonly int count;
CancellationToken cancellationToken;
IUniTaskAsyncEnumerator<TSource> enumerator;
UniTask<bool>.Awaiter awaiter;
int index; int index;
public _Take(IUniTaskAsyncEnumerable<TSource> source, int count, CancellationToken cancellationToken) public _Take(IUniTaskAsyncEnumerable<TSource> source, int count, CancellationToken cancellationToken)
: base(source, cancellationToken)
{ {
this.source = source;
this.count = count; this.count = count;
this.cancellationToken = cancellationToken;
} }
protected override bool TryMoveNextCore(bool sourceHasCurrent, out bool result) public TSource Current { get; private set; }
public UniTask<bool> MoveNextAsync()
{ {
if (sourceHasCurrent) cancellationToken.ThrowIfCancellationRequested();
if (enumerator == null)
{ {
if (checked(index++) < count) enumerator = source.GetAsyncEnumerator(cancellationToken);
}
if (checked(index) >= count)
{ {
Current = SourceCurrent; return CompletedTasks.False;
result = true; }
return true;
completionSource.Reset();
SourceMoveNext();
return new UniTask<bool>(this, completionSource.Version);
}
void SourceMoveNext()
{
try
{
awaiter = enumerator.MoveNextAsync().GetAwaiter();
if (awaiter.IsCompleted)
{
MoveNextCore(this);
} }
else else
{ {
result = false; awaiter.SourceOnCompleted(MoveNextCoreDelegate, this);
return true;
} }
} }
catch (Exception ex)
{
completionSource.TrySetException(ex);
}
}
static void MoveNextCore(object state)
{
var self = (_Take)state;
if (self.TryGetResult(self.awaiter, out var result))
{
if (result)
{
self.index++;
self.Current = self.enumerator.Current;
self.completionSource.TrySetResult(true);
}
else else
{ {
result = false; self.completionSource.TrySetResult(false);
return true;
} }
} }
} }
public UniTask DisposeAsync()
{
if (enumerator != null)
{
return enumerator.DisposeAsync();
}
return default;
}
}
} }
} }