using LeanCloud.Storage.Internal; using System; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Reflection; namespace LeanCloud { /// /// Provides extension methods for to support /// Linq-style queries. /// public static class AVQueryExtensions { private static readonly MethodInfo getMethod; private static readonly MethodInfo stringContains; private static readonly MethodInfo stringStartsWith; private static readonly MethodInfo stringEndsWith; private static readonly MethodInfo containsMethod; private static readonly MethodInfo notContainsMethod; private static readonly MethodInfo containsKeyMethod; private static readonly MethodInfo notContainsKeyMethod; private static readonly Dictionary functionMappings; static AVQueryExtensions() { getMethod = GetMethod(obj => obj.Get(null)).GetGenericMethodDefinition(); stringContains = GetMethod(str => str.Contains(null)); stringStartsWith = GetMethod(str => str.StartsWith(null)); stringEndsWith = GetMethod(str => str.EndsWith(null)); functionMappings = new Dictionary { { stringContains, GetMethod>(q => q.WhereContains(null, null)) }, { stringStartsWith, GetMethod>(q => q.WhereStartsWith(null, null)) }, { stringEndsWith, GetMethod>(q => q.WhereEndsWith(null,null)) }, }; containsMethod = GetMethod( o => AVQueryExtensions.ContainsStub(null, null)).GetGenericMethodDefinition(); notContainsMethod = GetMethod( o => AVQueryExtensions.NotContainsStub(null, null)) .GetGenericMethodDefinition(); containsKeyMethod = GetMethod(o => AVQueryExtensions.ContainsKeyStub(null, null)); notContainsKeyMethod = GetMethod( o => AVQueryExtensions.NotContainsKeyStub(null, null)); } /// /// Gets a MethodInfo for a top-level method call. /// private static MethodInfo GetMethod(Expression> expression) { return (expression.Body as MethodCallExpression).Method; } /// /// When a query is normalized, this is a placeholder to indicate we should /// add a WhereContainedIn() clause. /// private static bool ContainsStub(object collection, T value) { throw new NotImplementedException( "Exists only for expression translation as a placeholder."); } /// /// When a query is normalized, this is a placeholder to indicate we should /// add a WhereNotContainedIn() clause. /// private static bool NotContainsStub(object collection, T value) { throw new NotImplementedException( "Exists only for expression translation as a placeholder."); } /// /// When a query is normalized, this is a placeholder to indicate that we should /// add a WhereExists() clause. /// private static bool ContainsKeyStub(AVObject obj, string key) { throw new NotImplementedException( "Exists only for expression translation as a placeholder."); } /// /// When a query is normalized, this is a placeholder to indicate that we should /// add a WhereDoesNotExist() clause. /// private static bool NotContainsKeyStub(AVObject obj, string key) { throw new NotImplementedException( "Exists only for expression translation as a placeholder."); } /// /// Evaluates an expression and throws if the expression has components that can't be /// evaluated (e.g. uses the parameter that's only represented by an object on the server). /// private static object GetValue(Expression exp) { try { return Expression.Lambda( typeof(Func<>).MakeGenericType(exp.Type), exp).Compile().DynamicInvoke(); } catch (Exception e) { throw new InvalidOperationException("Unable to evaluate expression: " + exp, e); } } /// /// Checks whether the MethodCallExpression is a call to AVObject.Get(), /// which is the call we normalize all indexing into the AVObject to. /// private static bool IsAVObjectGet(MethodCallExpression node) { if (node == null || node.Object == null) { return false; } if (!typeof(AVObject).GetTypeInfo().IsAssignableFrom(node.Object.Type.GetTypeInfo())) { return false; } return node.Method.IsGenericMethod && node.Method.GetGenericMethodDefinition() == getMethod; } /// /// Visits an Expression, converting AVObject.Get/AVObject[]/AVObject.Property, /// and nested indices into a single call to AVObject.Get() with a "field path" like /// "foo.bar.baz" /// private class ObjectNormalizer : ExpressionVisitor { protected override Expression VisitIndex(IndexExpression node) { var visitedObject = Visit(node.Object); var indexer = visitedObject as MethodCallExpression; if (IsAVObjectGet(indexer)) { var indexValue = GetValue(node.Arguments[0]) as string; if (indexValue == null) { throw new InvalidOperationException("Index must be a string"); } var newPath = GetValue(indexer.Arguments[0]) + "." + indexValue; return Expression.Call(indexer.Object, getMethod.MakeGenericMethod(node.Type), Expression.Constant(newPath, typeof(string))); } return base.VisitIndex(node); } /// /// Check for a AVFieldName attribute and use that as the path component, turning /// properties like foo.ObjectId into foo.Get("objectId") /// protected override Expression VisitMember(MemberExpression node) { var fieldName = node.Member.GetCustomAttribute(); if (fieldName != null && typeof(AVObject).GetTypeInfo().IsAssignableFrom(node.Expression.Type.GetTypeInfo())) { var newPath = fieldName.FieldName; return Expression.Call(node.Expression, getMethod.MakeGenericMethod(node.Type), Expression.Constant(newPath, typeof(string))); } return base.VisitMember(node); } /// /// If a AVObject.Get() call has been cast, just change the generic parameter. /// protected override Expression VisitUnary(UnaryExpression node) { var methodCall = Visit(node.Operand) as MethodCallExpression; if ((node.NodeType == ExpressionType.Convert || node.NodeType == ExpressionType.ConvertChecked) && IsAVObjectGet(methodCall)) { return Expression.Call(methodCall.Object, getMethod.MakeGenericMethod(node.Type), methodCall.Arguments); } return base.VisitUnary(node); } protected override Expression VisitMethodCall(MethodCallExpression node) { if (node.Method.Name == "get_Item" && node.Object is ParameterExpression) { var indexPath = GetValue(node.Arguments[0]) as string; return Expression.Call(node.Object, getMethod.MakeGenericMethod(typeof(object)), Expression.Constant(indexPath, typeof(string))); } if (node.Method.Name == "get_Item" || IsAVObjectGet(node)) { var visitedObject = Visit(node.Object); var indexer = visitedObject as MethodCallExpression; if (IsAVObjectGet(indexer)) { var indexValue = GetValue(node.Arguments[0]) as string; if (indexValue == null) { throw new InvalidOperationException("Index must be a string"); } var newPath = GetValue(indexer.Arguments[0]) + "." + indexValue; return Expression.Call(indexer.Object, getMethod.MakeGenericMethod(node.Type), Expression.Constant(newPath, typeof(string))); } } return base.VisitMethodCall(node); } } /// /// Normalizes Where expressions. /// private class WhereNormalizer : ExpressionVisitor { /// /// Normalizes binary operators. <, >, <=, >= !=, and == /// This puts the AVObject.Get() on the left side of the operation /// (reversing it if necessary), and normalizes the AVObject.Get() /// protected override Expression VisitBinary(BinaryExpression node) { var leftTransformed = new ObjectNormalizer().Visit(node.Left) as MethodCallExpression; var rightTransformed = new ObjectNormalizer().Visit(node.Right) as MethodCallExpression; MethodCallExpression objectExpression; Expression filterExpression; bool inverted; if (leftTransformed != null) { objectExpression = leftTransformed; filterExpression = node.Right; inverted = false; } else { objectExpression = rightTransformed; filterExpression = node.Left; inverted = true; } try { switch (node.NodeType) { case ExpressionType.GreaterThan: if (inverted) { return Expression.LessThan(objectExpression, filterExpression); } else { return Expression.GreaterThan(objectExpression, filterExpression); } case ExpressionType.GreaterThanOrEqual: if (inverted) { return Expression.LessThanOrEqual(objectExpression, filterExpression); } else { return Expression.GreaterThanOrEqual(objectExpression, filterExpression); } case ExpressionType.LessThan: if (inverted) { return Expression.GreaterThan(objectExpression, filterExpression); } else { return Expression.LessThan(objectExpression, filterExpression); } case ExpressionType.LessThanOrEqual: if (inverted) { return Expression.GreaterThanOrEqual(objectExpression, filterExpression); } else { return Expression.LessThanOrEqual(objectExpression, filterExpression); } case ExpressionType.Equal: return Expression.Equal(objectExpression, filterExpression); case ExpressionType.NotEqual: return Expression.NotEqual(objectExpression, filterExpression); } } catch (ArgumentException) { throw new InvalidOperationException("Operation not supported: " + node); } return base.VisitBinary(node); } /// /// If a ! operator is used, this removes the ! and instead calls the equivalent /// function (so e.g. == becomes !=, < becomes >=, Contains becomes NotContains) /// protected override Expression VisitUnary(UnaryExpression node) { // Normalizes inversion if (node.NodeType == ExpressionType.Not) { var visitedOperand = Visit(node.Operand); var binaryOperand = visitedOperand as BinaryExpression; if (binaryOperand != null) { switch (binaryOperand.NodeType) { case ExpressionType.GreaterThan: return Expression.LessThanOrEqual(binaryOperand.Left, binaryOperand.Right); case ExpressionType.GreaterThanOrEqual: return Expression.LessThan(binaryOperand.Left, binaryOperand.Right); case ExpressionType.LessThan: return Expression.GreaterThanOrEqual(binaryOperand.Left, binaryOperand.Right); case ExpressionType.LessThanOrEqual: return Expression.GreaterThan(binaryOperand.Left, binaryOperand.Right); case ExpressionType.Equal: return Expression.NotEqual(binaryOperand.Left, binaryOperand.Right); case ExpressionType.NotEqual: return Expression.Equal(binaryOperand.Left, binaryOperand.Right); } } var methodCallOperand = visitedOperand as MethodCallExpression; if (methodCallOperand != null) { if (methodCallOperand.Method.IsGenericMethod) { if (methodCallOperand.Method.GetGenericMethodDefinition() == containsMethod) { var genericNotContains = notContainsMethod.MakeGenericMethod( methodCallOperand.Method.GetGenericArguments()); return Expression.Call(genericNotContains, methodCallOperand.Arguments.ToArray()); } if (methodCallOperand.Method.GetGenericMethodDefinition() == notContainsMethod) { var genericContains = containsMethod.MakeGenericMethod( methodCallOperand.Method.GetGenericArguments()); return Expression.Call(genericContains, methodCallOperand.Arguments.ToArray()); } } if (methodCallOperand.Method == containsKeyMethod) { return Expression.Call(notContainsKeyMethod, methodCallOperand.Arguments.ToArray()); } if (methodCallOperand.Method == notContainsKeyMethod) { return Expression.Call(containsKeyMethod, methodCallOperand.Arguments.ToArray()); } } } return base.VisitUnary(node); } /// /// Normalizes .Equals into == and Contains() into the appropriate stub. /// protected override Expression VisitMethodCall(MethodCallExpression node) { // Convert .Equals() into == if (node.Method.Name == "Equals" && node.Method.ReturnType == typeof(bool) && node.Method.GetParameters().Length == 1) { var obj = new ObjectNormalizer().Visit(node.Object) as MethodCallExpression; var parameter = new ObjectNormalizer().Visit(node.Arguments[0]) as MethodCallExpression; if ((IsAVObjectGet(obj) && (obj.Object is ParameterExpression)) || (IsAVObjectGet(parameter) && (parameter.Object is ParameterExpression))) { return Expression.Equal(node.Object, node.Arguments[0]); } } // Convert the .Contains() into a ContainsStub if (node.Method != stringContains && node.Method.Name == "Contains" && node.Method.ReturnType == typeof(bool) && node.Method.GetParameters().Length <= 2) { var collection = node.Method.GetParameters().Length == 1 ? node.Object : node.Arguments[0]; var parameterIndex = node.Method.GetParameters().Length - 1; var parameter = new ObjectNormalizer().Visit(node.Arguments[parameterIndex]) as MethodCallExpression; if (IsAVObjectGet(parameter) && (parameter.Object is ParameterExpression)) { var genericContains = containsMethod.MakeGenericMethod(parameter.Type); return Expression.Call(genericContains, collection, parameter); } var target = new ObjectNormalizer().Visit(collection) as MethodCallExpression; var element = node.Arguments[parameterIndex]; if (IsAVObjectGet(target) && (target.Object is ParameterExpression)) { var genericContains = containsMethod.MakeGenericMethod(element.Type); return Expression.Call(genericContains, target, element); } } // Convert obj["foo.bar"].ContainsKey("baz") into obj.ContainsKey("foo.bar.baz") if (node.Method.Name == "ContainsKey" && node.Method.ReturnType == typeof(bool) && node.Method.GetParameters().Length == 1) { var getter = new ObjectNormalizer().Visit(node.Object) as MethodCallExpression; Expression target = null; string path = null; if (IsAVObjectGet(getter) && getter.Object is ParameterExpression) { target = getter.Object; path = GetValue(getter.Arguments[0]) + "." + GetValue(node.Arguments[0]); return Expression.Call(containsKeyMethod, target, Expression.Constant(path)); } else if (node.Object is ParameterExpression) { target = node.Object; path = GetValue(node.Arguments[0]) as string; } if (target != null && path != null) { return Expression.Call(containsKeyMethod, target, Expression.Constant(path)); } } return base.VisitMethodCall(node); } } /// /// Converts a normalized method call expression into the appropriate AVQuery clause. /// private static AVQuery WhereMethodCall( this AVQuery source, Expression> expression, MethodCallExpression node) where T : AVObject { if (IsAVObjectGet(node) && (node.Type == typeof(bool) || node.Type == typeof(bool?))) { // This is a raw boolean field access like 'where obj.Get("foo")' return source.WhereEqualTo(GetValue(node.Arguments[0]) as string, true); } MethodInfo translatedMethod; if (functionMappings.TryGetValue(node.Method, out translatedMethod)) { var objTransformed = new ObjectNormalizer().Visit(node.Object) as MethodCallExpression; if (!(IsAVObjectGet(objTransformed) && objTransformed.Object == expression.Parameters[0])) { throw new InvalidOperationException( "The left-hand side of a supported function call must be a AVObject field access."); } var fieldPath = GetValue(objTransformed.Arguments[0]); var containedIn = GetValue(node.Arguments[0]); var queryType = translatedMethod.DeclaringType.GetGenericTypeDefinition() .MakeGenericType(typeof(T)); translatedMethod = ReflectionHelpers.GetMethod(queryType, translatedMethod.Name, translatedMethod.GetParameters().Select(p => p.ParameterType).ToArray()); return translatedMethod.Invoke(source, new[] { fieldPath, containedIn }) as AVQuery; } if (node.Arguments[0] == expression.Parameters[0]) { // obj.ContainsKey("foo") --> query.WhereExists("foo") if (node.Method == containsKeyMethod) { return source.WhereExists(GetValue(node.Arguments[1]) as string); } // !obj.ContainsKey("foo") --> query.WhereDoesNotExist("foo") if (node.Method == notContainsKeyMethod) { return source.WhereDoesNotExist(GetValue(node.Arguments[1]) as string); } } if (node.Method.IsGenericMethod) { if (node.Method.GetGenericMethodDefinition() == containsMethod) { // obj.Get>("path").Contains(someValue) if (IsAVObjectGet(node.Arguments[0] as MethodCallExpression)) { return source.WhereEqualTo( GetValue(((MethodCallExpression)node.Arguments[0]).Arguments[0]) as string, GetValue(node.Arguments[1])); } // someList.Contains(obj.Get("path")) if (IsAVObjectGet(node.Arguments[1] as MethodCallExpression)) { var collection = GetValue(node.Arguments[0]) as System.Collections.IEnumerable; return source.WhereContainedIn( GetValue(((MethodCallExpression)node.Arguments[1]).Arguments[0]) as string, collection.Cast()); } } if (node.Method.GetGenericMethodDefinition() == notContainsMethod) { // !obj.Get>("path").Contains(someValue) if (IsAVObjectGet(node.Arguments[0] as MethodCallExpression)) { return source.WhereNotEqualTo( GetValue(((MethodCallExpression)node.Arguments[0]).Arguments[0]) as string, GetValue(node.Arguments[1])); } // !someList.Contains(obj.Get("path")) if (IsAVObjectGet(node.Arguments[1] as MethodCallExpression)) { var collection = GetValue(node.Arguments[0]) as System.Collections.IEnumerable; return source.WhereNotContainedIn( GetValue(((MethodCallExpression)node.Arguments[1]).Arguments[0]) as string, collection.Cast()); } } } throw new InvalidOperationException(node.Method + " is not a supported method call in a where expression."); } /// /// Converts a normalized binary expression into the appropriate AVQuery clause. /// private static AVQuery WhereBinaryExpression( this AVQuery source, Expression> expression, BinaryExpression node) where T : AVObject { var leftTransformed = new ObjectNormalizer().Visit(node.Left) as MethodCallExpression; if (!(IsAVObjectGet(leftTransformed) && leftTransformed.Object == expression.Parameters[0])) { throw new InvalidOperationException( "Where expressions must have one side be a field operation on a AVObject."); } var fieldPath = GetValue(leftTransformed.Arguments[0]) as string; var filterValue = GetValue(node.Right); if (filterValue != null && !AVEncoder.IsValidType(filterValue)) { throw new InvalidOperationException( "Where clauses must use types compatible with AVObjects."); } switch (node.NodeType) { case ExpressionType.GreaterThan: return source.WhereGreaterThan(fieldPath, filterValue); case ExpressionType.GreaterThanOrEqual: return source.WhereGreaterThanOrEqualTo(fieldPath, filterValue); case ExpressionType.LessThan: return source.WhereLessThan(fieldPath, filterValue); case ExpressionType.LessThanOrEqual: return source.WhereLessThanOrEqualTo(fieldPath, filterValue); case ExpressionType.Equal: return source.WhereEqualTo(fieldPath, filterValue); case ExpressionType.NotEqual: return source.WhereNotEqualTo(fieldPath, filterValue); default: throw new InvalidOperationException( "Where expressions do not support this operator."); } } /// /// Filters a query based upon the predicate provided. /// /// The type of AVObject being queried for. /// The base to which /// the predicate will be added. /// A function to test each AVObject for a condition. /// The predicate must be able to be represented by one of the standard Where /// functions on AVQuery /// A new AVQuery whose results will match the given predicate as /// well as the Source's filters. public static AVQuery Where( this AVQuery source, Expression> predicate) where TSource : AVObject { // Handle top-level logic operators && and || var binaryExpression = predicate.Body as BinaryExpression; if (binaryExpression != null) { if (binaryExpression.NodeType == ExpressionType.AndAlso) { return source .Where(Expression.Lambda>( binaryExpression.Left, predicate.Parameters)) .Where(Expression.Lambda>( binaryExpression.Right, predicate.Parameters)); } if (binaryExpression.NodeType == ExpressionType.OrElse) { var left = source.Where(Expression.Lambda>( binaryExpression.Left, predicate.Parameters)); var right = source.Where(Expression.Lambda>( binaryExpression.Right, predicate.Parameters)); return left.Or(right); } } var normalized = new WhereNormalizer().Visit(predicate.Body); var methodCallExpr = normalized as MethodCallExpression; if (methodCallExpr != null) { return source.WhereMethodCall(predicate, methodCallExpr); } var binaryExpr = normalized as BinaryExpression; if (binaryExpr != null) { return source.WhereBinaryExpression(predicate, binaryExpr); } var unaryExpr = normalized as UnaryExpression; if (unaryExpr != null && unaryExpr.NodeType == ExpressionType.Not) { var node = unaryExpr.Operand as MethodCallExpression; if (IsAVObjectGet(node) && (node.Type == typeof(bool) || node.Type == typeof(bool?))) { // This is a raw boolean field access like 'where !obj.Get("foo")' return source.WhereNotEqualTo(GetValue(node.Arguments[0]) as string, true); } } throw new InvalidOperationException( "Encountered an unsupported expression for AVQueries."); } /// /// Normalizes an OrderBy's keySelector expression and then extracts the path /// from the AVObject.Get() call. /// private static string GetOrderByPath( Expression> keySelector) { string result = null; var normalized = new ObjectNormalizer().Visit(keySelector.Body); var callExpr = normalized as MethodCallExpression; if (IsAVObjectGet(callExpr) && callExpr.Object == keySelector.Parameters[0]) { // We're operating on the parameter result = GetValue(callExpr.Arguments[0]) as string; } if (result == null) { throw new InvalidOperationException( "OrderBy expression must be a field access on a AVObject."); } return result; } /// /// Orders a query based upon the key selector provided. /// /// The type of AVObject being queried for. /// The type of key returned by keySelector. /// The query to order. /// A function to extract a key from the AVObject. /// A new AVQuery based on Source whose results will be ordered by /// the key specified in the keySelector. public static AVQuery OrderBy( this AVQuery source, Expression> keySelector) where TSource : AVObject { return source.OrderBy(GetOrderByPath(keySelector)); } /// /// Orders a query based upon the key selector provided. /// /// The type of AVObject being queried for. /// The type of key returned by keySelector. /// The query to order. /// A function to extract a key from the AVObject. /// A new AVQuery based on Source whose results will be ordered by /// the key specified in the keySelector. public static AVQuery OrderByDescending( this AVQuery source, Expression> keySelector) where TSource : AVObject { return source.OrderByDescending(GetOrderByPath(keySelector)); } /// /// Performs a subsequent ordering of a query based upon the key selector provided. /// /// The type of AVObject being queried for. /// The type of key returned by keySelector. /// The query to order. /// A function to extract a key from the AVObject. /// A new AVQuery based on Source whose results will be ordered by /// the key specified in the keySelector. public static AVQuery ThenBy( this AVQuery source, Expression> keySelector) where TSource : AVObject { return source.ThenBy(GetOrderByPath(keySelector)); } /// /// Performs a subsequent ordering of a query based upon the key selector provided. /// /// The type of AVObject being queried for. /// The type of key returned by keySelector. /// The query to order. /// A function to extract a key from the AVObject. /// A new AVQuery based on Source whose results will be ordered by /// the key specified in the keySelector. public static AVQuery ThenByDescending( this AVQuery source, Expression> keySelector) where TSource : AVObject { return source.ThenByDescending(GetOrderByPath(keySelector)); } /// /// Correlates the elements of two queries based on matching keys. /// /// The type of AVObjects of the first query. /// The type of AVObjects of the second query. /// The type of the keys returned by the key selector /// functions. /// The type of the result. This must match either /// TOuter or TInner /// The first query to join. /// The query to join to the first query. /// A function to extract a join key from the results of /// the first query. /// A function to extract a join key from the results of /// the second query. /// A function to select either the outer or inner query /// result to determine which query is the base query. /// A new AVQuery with a WhereMatchesQuery or WhereMatchesKeyInQuery /// clause based upon the query indicated in the . public static AVQuery Join( this AVQuery outer, AVQuery inner, Expression> outerKeySelector, Expression> innerKeySelector, Expression> resultSelector) where TOuter : AVObject where TInner : AVObject where TResult : AVObject { // resultSelector must select either the inner object or the outer object. If it's the inner // object, reverse the query. if (resultSelector.Body == resultSelector.Parameters[1]) { // The inner object was selected. return inner.Join( outer, innerKeySelector, outerKeySelector, (i, o) => i) as AVQuery; } if (resultSelector.Body != resultSelector.Parameters[0]) { throw new InvalidOperationException("Joins must select either the outer or inner object."); } // Normalize both selectors Expression outerNormalized = new ObjectNormalizer().Visit(outerKeySelector.Body); Expression innerNormalized = new ObjectNormalizer().Visit(innerKeySelector.Body); MethodCallExpression outerAsGet = outerNormalized as MethodCallExpression; MethodCallExpression innerAsGet = innerNormalized as MethodCallExpression; if (IsAVObjectGet(outerAsGet) && outerAsGet.Object == outerKeySelector.Parameters[0]) { var outerKey = GetValue(outerAsGet.Arguments[0]) as string; if (IsAVObjectGet(innerAsGet) && innerAsGet.Object == innerKeySelector.Parameters[0]) { // Both are key accesses, so treat this as a WhereMatchesKeyInQuery var innerKey = GetValue(innerAsGet.Arguments[0]) as string; return outer.WhereMatchesKeyInQuery(outerKey, innerKey, inner) as AVQuery; } if (innerKeySelector.Body == innerKeySelector.Parameters[0]) { // The inner selector is on the result of the query itself, so treat this as a // WhereMatchesQuery return outer.WhereMatchesQuery(outerKey, inner) as AVQuery; } throw new InvalidOperationException( "The key for the joined object must be a AVObject or a field access " + "on the AVObject."); } // TODO (hallucinogen): If we ever support "and" queries fully and/or support a "where this object // matches some key in some other query" (as opposed to requiring a key on this query), we // can add support for even more types of joins. throw new InvalidOperationException( "The key for the selected object must be a field access on the AVObject."); } } }