Distinct, Except, Intersect, Union

master
neuecc 2020-05-11 15:53:27 +09:00
parent 8ef7a66081
commit b20b37e7a5
11 changed files with 485 additions and 3912 deletions

View File

@ -100,10 +100,18 @@ namespace Cysharp.Threading.Tasks.Linq
} }
completionSource.Reset(); completionSource.Reset();
SourceMoveNext(); if (!OnFirstIteration())
{
SourceMoveNext();
}
return new UniTask<bool>(this, completionSource.Version); return new UniTask<bool>(this, completionSource.Version);
} }
protected virtual bool OnFirstIteration()
{
return false;
}
protected void SourceMoveNext() protected void SourceMoveNext()
{ {
CONTINUE: CONTINUE:

View File

@ -1,775 +1,76 @@
namespace Cysharp.Threading.Tasks.Linq using Cysharp.Threading.Tasks.Internal;
using System;
using System.Collections.Generic;
using System.Threading;
namespace Cysharp.Threading.Tasks.Linq
{ {
internal sealed class Distinct public static partial class UniTaskAsyncEnumerable
{ {
public static IUniTaskAsyncEnumerable<TSource> Distinct<TSource>(this IUniTaskAsyncEnumerable<TSource> source)
{
Error.ThrowArgumentNullException(source, nameof(source));
return Distinct(source, EqualityComparer<TSource>.Default);
}
public static IUniTaskAsyncEnumerable<TSource> Distinct<TSource>(this IUniTaskAsyncEnumerable<TSource> source, IEqualityComparer<TSource> comparer)
{
Error.ThrowArgumentNullException(source, nameof(source));
Error.ThrowArgumentNullException(comparer, nameof(comparer));
return new Distinct<TSource>(source, comparer);
}
} }
internal sealed class Distinct<TSource> : IUniTaskAsyncEnumerable<TSource>
} {
readonly IUniTaskAsyncEnumerable<TSource> source;
readonly IEqualityComparer<TSource> comparer;
public Distinct(IUniTaskAsyncEnumerable<TSource> source, IEqualityComparer<TSource> comparer)
{
this.source = source;
this.comparer = comparer;
}
public IUniTaskAsyncEnumerator<TSource> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
return new Enumerator(source, comparer, cancellationToken);
}
class Enumerator : AsyncEnumeratorBase<TSource, TSource>
{
readonly HashSet<TSource> set;
public Enumerator(IUniTaskAsyncEnumerable<TSource> source, IEqualityComparer<TSource> comparer, CancellationToken cancellationToken)
: base(source, cancellationToken)
{
this.set = new HashSet<TSource>(comparer);
}
protected override bool TryMoveNextCore(bool sourceHasCurrent, out bool result)
{
if (sourceHasCurrent)
{
var v = SourceCurrent;
if (set.Add(v))
{
Current = v;
result = true;
return true;
}
else
{
result = default;
return false;
}
}
result = false;
return true;
}
}
}
}

View File

@ -1,775 +1,116 @@
namespace Cysharp.Threading.Tasks.Linq using Cysharp.Threading.Tasks.Internal;
using System;
using System.Collections.Generic;
using System.Threading;
namespace Cysharp.Threading.Tasks.Linq
{ {
internal sealed class Except public static partial class UniTaskAsyncEnumerable
{ {
public static IUniTaskAsyncEnumerable<TSource> Except<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second)
{
Error.ThrowArgumentNullException(first, nameof(first));
Error.ThrowArgumentNullException(second, nameof(second));
return new Except<TSource>(first, second, EqualityComparer<TSource>.Default);
}
public static IUniTaskAsyncEnumerable<TSource> Except<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer)
{
Error.ThrowArgumentNullException(first, nameof(first));
Error.ThrowArgumentNullException(second, nameof(second));
Error.ThrowArgumentNullException(comparer, nameof(comparer));
return new Except<TSource>(first, second, comparer);
}
} }
internal sealed class Except<TSource> : IUniTaskAsyncEnumerable<TSource>
} {
readonly IUniTaskAsyncEnumerable<TSource> first;
readonly IUniTaskAsyncEnumerable<TSource> second;
readonly IEqualityComparer<TSource> comparer;
public Except(IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer)
{
this.first = first;
this.second = second;
this.comparer = comparer;
}
public IUniTaskAsyncEnumerator<TSource> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
return new Enumerator(first, second, comparer, cancellationToken);
}
class Enumerator : AsyncEnumeratorBase<TSource, TSource>
{
static Action<object> HashSetAsyncCoreDelegate = HashSetAsyncCore;
readonly IEqualityComparer<TSource> comparer;
readonly IUniTaskAsyncEnumerable<TSource> second;
HashSet<TSource> set;
UniTask<HashSet<TSource>>.Awaiter awaiter;
public Enumerator(IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer, CancellationToken cancellationToken)
: base(first, cancellationToken)
{
this.second = second;
this.comparer = comparer;
}
protected override bool OnFirstIteration()
{
if (set != null) return false;
awaiter = second.ToHashSetAsync(cancellationToken).GetAwaiter();
if (awaiter.IsCompleted)
{
set = awaiter.GetResult();
SourceMoveNext();
}
else
{
awaiter.SourceOnCompleted(HashSetAsyncCoreDelegate, this);
}
return true;
}
static void HashSetAsyncCore(object state)
{
var self = (Enumerator)state;
if (self.TryGetResult(self.awaiter, out var result))
{
self.set = result;
self.SourceMoveNext();
}
}
protected override bool TryMoveNextCore(bool sourceHasCurrent, out bool result)
{
if (sourceHasCurrent)
{
var v = SourceCurrent;
if (set.Add(v))
{
Current = v;
result = true;
return true;
}
else
{
result = default;
return false;
}
}
result = false;
return true;
}
}
}
}

View File

@ -1,775 +1,117 @@
namespace Cysharp.Threading.Tasks.Linq using Cysharp.Threading.Tasks.Internal;
using System;
using System.Collections.Generic;
using System.Threading;
namespace Cysharp.Threading.Tasks.Linq
{ {
internal sealed class Intersect public static partial class UniTaskAsyncEnumerable
{ {
public static IUniTaskAsyncEnumerable<TSource> Intersect<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second)
{
Error.ThrowArgumentNullException(first, nameof(first));
Error.ThrowArgumentNullException(second, nameof(second));
return new Intersect<TSource>(first, second, EqualityComparer<TSource>.Default);
}
public static IUniTaskAsyncEnumerable<TSource> Intersect<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer)
{
Error.ThrowArgumentNullException(first, nameof(first));
Error.ThrowArgumentNullException(second, nameof(second));
Error.ThrowArgumentNullException(comparer, nameof(comparer));
return new Intersect<TSource>(first, second, comparer);
}
} }
internal sealed class Intersect<TSource> : IUniTaskAsyncEnumerable<TSource>
} {
readonly IUniTaskAsyncEnumerable<TSource> first;
readonly IUniTaskAsyncEnumerable<TSource> second;
readonly IEqualityComparer<TSource> comparer;
public Intersect(IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer)
{
this.first = first;
this.second = second;
this.comparer = comparer;
}
public IUniTaskAsyncEnumerator<TSource> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
return new Enumerator(first, second, comparer, cancellationToken);
}
class Enumerator : AsyncEnumeratorBase<TSource, TSource>
{
static Action<object> HashSetAsyncCoreDelegate = HashSetAsyncCore;
readonly IEqualityComparer<TSource> comparer;
readonly IUniTaskAsyncEnumerable<TSource> second;
HashSet<TSource> set;
UniTask<HashSet<TSource>>.Awaiter awaiter;
public Enumerator(IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer, CancellationToken cancellationToken)
: base(first, cancellationToken)
{
this.second = second;
this.comparer = comparer;
}
protected override bool OnFirstIteration()
{
if (set != null) return false;
awaiter = second.ToHashSetAsync(cancellationToken).GetAwaiter();
if (awaiter.IsCompleted)
{
set = awaiter.GetResult();
SourceMoveNext();
}
else
{
awaiter.SourceOnCompleted(HashSetAsyncCoreDelegate, this);
}
return true;
}
static void HashSetAsyncCore(object state)
{
var self = (Enumerator)state;
if (self.TryGetResult(self.awaiter, out var result))
{
self.set = result;
self.SourceMoveNext();
}
}
protected override bool TryMoveNextCore(bool sourceHasCurrent, out bool result)
{
if (sourceHasCurrent)
{
var v = SourceCurrent;
if (set.Remove(v))
{
Current = v;
result = true;
return true;
}
else
{
result = default;
return false;
}
}
result = false;
return true;
}
}
}
}

View File

@ -1,775 +1 @@
namespace Cysharp.Threading.Tasks.Linq 
{
internal sealed class Join
{
}
}

View File

@ -10,15 +10,23 @@ namespace Cysharp.Threading.Tasks.Linq
{ {
Error.ThrowArgumentNullException(source, nameof(source)); Error.ThrowArgumentNullException(source, nameof(source));
return Cysharp.Threading.Tasks.Linq.ToHashSet.InvokeAsync(source, cancellationToken); return Cysharp.Threading.Tasks.Linq.ToHashSet.InvokeAsync(source, EqualityComparer<TSource>.Default, cancellationToken);
}
public static UniTask<HashSet<TSource>> ToHashSetAsync<TSource>(this IUniTaskAsyncEnumerable<TSource> source, IEqualityComparer<TSource> comparer, CancellationToken cancellationToken = default)
{
Error.ThrowArgumentNullException(source, nameof(source));
Error.ThrowArgumentNullException(comparer, nameof(comparer));
return Cysharp.Threading.Tasks.Linq.ToHashSet.InvokeAsync(source, comparer, cancellationToken);
} }
} }
internal static class ToHashSet internal static class ToHashSet
{ {
internal static async UniTask<HashSet<TSource>> InvokeAsync<TSource>(IUniTaskAsyncEnumerable<TSource> source, CancellationToken cancellationToken) internal static async UniTask<HashSet<TSource>> InvokeAsync<TSource>(IUniTaskAsyncEnumerable<TSource> source, IEqualityComparer<TSource> comparer, CancellationToken cancellationToken)
{ {
var set = new HashSet<TSource>(); var set = new HashSet<TSource>(comparer);
var e = source.GetAsyncEnumerator(cancellationToken); var e = source.GetAsyncEnumerator(cancellationToken);
try try

View File

@ -1,775 +1,26 @@
namespace Cysharp.Threading.Tasks.Linq using Cysharp.Threading.Tasks.Internal;
using System.Collections.Generic;
namespace Cysharp.Threading.Tasks.Linq
{ {
internal sealed class Union public static partial class UniTaskAsyncEnumerable
{ {
public static IUniTaskAsyncEnumerable<TSource> Union<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second)
{
Error.ThrowArgumentNullException(first, nameof(first));
Error.ThrowArgumentNullException(second, nameof(second));
return Union<TSource>(first, second, EqualityComparer<TSource>.Default);
}
public static IUniTaskAsyncEnumerable<TSource> Union<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer)
{
Error.ThrowArgumentNullException(first, nameof(first));
Error.ThrowArgumentNullException(second, nameof(second));
Error.ThrowArgumentNullException(comparer, nameof(comparer));
// improv without combinate?
return first.Concat(second).Distinct(comparer);
}
} }
}
}

View File

@ -26,26 +26,6 @@ namespace ___Dummy
public static IUniTaskAsyncEnumerable<TSource> Distinct<TSource>(this IUniTaskAsyncEnumerable<TSource> source)
{
throw new NotImplementedException();
}
public static IUniTaskAsyncEnumerable<TSource> Distinct<TSource>(this IUniTaskAsyncEnumerable<TSource> source, IEqualityComparer<TSource> comparer)
{
throw new NotImplementedException();
}
public static IUniTaskAsyncEnumerable<TSource> Except<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second)
{
throw new NotImplementedException();
}
public static IUniTaskAsyncEnumerable<TSource> Except<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer)
{
throw new NotImplementedException();
}
@ -200,16 +180,6 @@ namespace ___Dummy
throw new NotImplementedException(); throw new NotImplementedException();
} }
public static IUniTaskAsyncEnumerable<TSource> Intersect<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second)
{
throw new NotImplementedException();
}
public static IUniTaskAsyncEnumerable<TSource> Intersect<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer)
{
throw new NotImplementedException();
}
public static IUniTaskAsyncEnumerable<TResult> Join<TOuter, TInner, TKey, TResult>(this IUniTaskAsyncEnumerable<TOuter> outer, IUniTaskAsyncEnumerable<TInner> inner, Func<TOuter, TKey> outerKeySelector, Func<TInner, TKey> innerKeySelector, Func<TOuter, TInner, TResult> resultSelector) public static IUniTaskAsyncEnumerable<TResult> Join<TOuter, TInner, TKey, TResult>(this IUniTaskAsyncEnumerable<TOuter> outer, IUniTaskAsyncEnumerable<TInner> inner, Func<TOuter, TKey> outerKeySelector, Func<TInner, TKey> innerKeySelector, Func<TOuter, TInner, TResult> resultSelector)
{ {
throw new NotImplementedException(); throw new NotImplementedException();
@ -309,10 +279,6 @@ namespace ___Dummy
public static IUniTaskAsyncEnumerable<TSource> TakeLast<TSource>(this IUniTaskAsyncEnumerable<TSource> source, Int32 count)
{
throw new NotImplementedException();
}
public static IOrderedAsyncEnumerable<TSource> ThenBy<TSource, TKey>(this IOrderedAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector) public static IOrderedAsyncEnumerable<TSource> ThenBy<TSource, TKey>(this IOrderedAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector)
{ {
@ -380,19 +346,6 @@ namespace ___Dummy
public static IUniTaskAsyncEnumerable<TSource> Union<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second)
{
throw new NotImplementedException();
}
public static IUniTaskAsyncEnumerable<TSource> Union<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer)
{
throw new NotImplementedException();
}
} }

View File

@ -54,7 +54,7 @@ namespace NetCoreSandbox
//Enumerable.Range(1,10).ToHashSet(
} }

View File

@ -0,0 +1,138 @@
using Cysharp.Threading.Tasks;
using Cysharp.Threading.Tasks.Linq;
using FluentAssertions;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reactive.Concurrency;
using System.Reactive.Linq;
using System.Threading.Tasks;
using Xunit;
namespace NetCoreTests.Linq
{
public class Sets
{
public static IEnumerable<object[]> array1 = new object[][]
{
new object[] { new int[] { } }, // empty
new object[] { new int[] { 1, 2, 3 } }, // no dup
new object[] { new int[] { 1, 2, 3, 3, 4, 5, 2 } }, // dup
};
public static IEnumerable<object[]> array2 = new object[][]
{
new object[] { new int[] { } }, // empty
new object[] { new int[] { 1, 2 } },
new object[] { new int[] { 1, 2, 4, 5, 9 } }, // dup
};
[Theory]
[MemberData(nameof(array1))]
public async Task Distinct(int[] array)
{
var xs = await array.ToUniTaskAsyncEnumerable().Distinct().ToArrayAsync();
var ys = array.Distinct().ToArray();
xs.Should().BeEquivalentTo(ys);
}
[Fact]
public async Task DistinctThrow()
{
foreach (var item in UniTaskTestException.Throws())
{
var xs = item.Distinct().ToArrayAsync();
await Assert.ThrowsAsync<UniTaskTestException>(async () => await xs);
}
}
[Fact]
public async Task Except()
{
foreach (var a1 in array1.First().Cast<int[]>())
{
foreach (var a2 in array2.First().Cast<int[]>())
{
var xs = await a1.ToUniTaskAsyncEnumerable().Except(a2.ToUniTaskAsyncEnumerable()).ToArrayAsync();
var ys = a1.Except(a2).ToArray();
xs.Should().BeEquivalentTo(ys);
}
}
}
[Fact]
public async Task ExceptThrow()
{
foreach (var item in UniTaskTestException.Throws())
{
var xs = item.Except(UniTaskAsyncEnumerable.Return(10)).ToArrayAsync();
await Assert.ThrowsAsync<UniTaskTestException>(async () => await xs);
}
foreach (var item in UniTaskTestException.Throws())
{
var xs = UniTaskAsyncEnumerable.Return(10).Except(item).ToArrayAsync();
await Assert.ThrowsAsync<UniTaskTestException>(async () => await xs);
}
}
[Fact]
public async Task Intersect()
{
foreach (var a1 in array1.First().Cast<int[]>())
{
foreach (var a2 in array2.First().Cast<int[]>())
{
var xs = await a1.ToUniTaskAsyncEnumerable().Intersect(a2.ToUniTaskAsyncEnumerable()).ToArrayAsync();
var ys = a1.Intersect(a2).ToArray();
xs.Should().BeEquivalentTo(ys);
}
}
}
[Fact]
public async Task IntersectThrow()
{
foreach (var item in UniTaskTestException.Throws())
{
var xs = item.Intersect(UniTaskAsyncEnumerable.Return(10)).ToArrayAsync();
await Assert.ThrowsAsync<UniTaskTestException>(async () => await xs);
}
foreach (var item in UniTaskTestException.Throws())
{
var xs = UniTaskAsyncEnumerable.Return(10).Intersect(item).ToArrayAsync();
await Assert.ThrowsAsync<UniTaskTestException>(async () => await xs);
}
}
[Fact]
public async Task Union()
{
foreach (var a1 in array1.First().Cast<int[]>())
{
foreach (var a2 in array2.First().Cast<int[]>())
{
var xs = await a1.ToUniTaskAsyncEnumerable().Union(a2.ToUniTaskAsyncEnumerable()).ToArrayAsync();
var ys = a1.Union(a2).ToArray();
xs.Should().BeEquivalentTo(ys);
}
}
}
[Fact]
public async Task UnionThrow()
{
foreach (var item in UniTaskTestException.Throws())
{
var xs = item.Union(UniTaskAsyncEnumerable.Return(10)).ToArrayAsync();
await Assert.ThrowsAsync<UniTaskTestException>(async () => await xs);
}
foreach (var item in UniTaskTestException.Throws())
{
var xs = UniTaskAsyncEnumerable.Return(10).Union(item).ToArrayAsync();
await Assert.ThrowsAsync<UniTaskTestException>(async () => await xs);
}
}
}
}

View File

@ -15,6 +15,7 @@ namespace Cysharp.Threading.Tasks
{ {
internal static readonly Action<object> InvokeActionDelegate = InvokeAction; internal static readonly Action<object> InvokeActionDelegate = InvokeAction;
[DebuggerHidden]
static void InvokeAction(object state) static void InvokeAction(object state)
{ {
((Action)state).Invoke(); ((Action)state).Invoke();
@ -318,6 +319,8 @@ namespace Cysharp.Threading.Tasks
/// <summary> /// <summary>
/// If register manually continuation, you can use it instead of for compiler OnCompleted methods. /// If register manually continuation, you can use it instead of for compiler OnCompleted methods.
/// </summary> /// </summary>
[DebuggerHidden]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void SourceOnCompleted(Action<object> continuation, object state) public void SourceOnCompleted(Action<object> continuation, object state)
{ {
if (task.source == null) if (task.source == null)
@ -640,6 +643,8 @@ namespace Cysharp.Threading.Tasks
/// <summary> /// <summary>
/// If register manually continuation, you can use it instead of for compiler OnCompleted methods. /// If register manually continuation, you can use it instead of for compiler OnCompleted methods.
/// </summary> /// </summary>
[DebuggerHidden]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void SourceOnCompleted(Action<object> continuation, object state) public void SourceOnCompleted(Action<object> continuation, object state)
{ {
var s = task.source; var s = task.source;