diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java index 5a40abd24..28bcc3e18 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java @@ -496,7 +496,7 @@ public class ArrayOperators { } NestedDelegatingExpressionAggregationOperationContext nea = new NestedDelegatingExpressionAggregationOperationContext( - context); + context, Collections.singleton(as)); return ((AggregationExpression) condition).toDocument(nea); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/NestedDelegatingExpressionAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/NestedDelegatingExpressionAggregationOperationContext.java index 59c449930..4657d1033 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/NestedDelegatingExpressionAggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/NestedDelegatingExpressionAggregationOperationContext.java @@ -15,6 +15,11 @@ */ package org.springframework.data.mongodb.core.aggregation; +import java.util.Collection; +import java.util.HashSet; +import java.util.Set; +import java.util.stream.Collectors; + import org.bson.Document; import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference; import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExpressionFieldReference; @@ -31,16 +36,18 @@ import org.springframework.util.Assert; class NestedDelegatingExpressionAggregationOperationContext implements AggregationOperationContext { private final AggregationOperationContext delegate; + private final Set inners; /** * Creates new {@link NestedDelegatingExpressionAggregationOperationContext}. * * @param referenceContext must not be {@literal null}. */ - public NestedDelegatingExpressionAggregationOperationContext(AggregationOperationContext referenceContext) { + NestedDelegatingExpressionAggregationOperationContext(AggregationOperationContext referenceContext, Collection inners) { Assert.notNull(referenceContext, "Reference context must not be null!"); this.delegate = referenceContext; + this.inners = inners.stream().map(Field::getName).collect(Collectors.toSet()); } /* @@ -58,7 +65,22 @@ class NestedDelegatingExpressionAggregationOperationContext implements Aggregati */ @Override public FieldReference getReference(Field field) { - return new ExpressionFieldReference(delegate.getReference(field)); + + FieldReference reference = delegate.getReference(field); + return !isInnerVariableReference(field) ? reference : new ExpressionFieldReference(delegate.getReference(field)) ; + } + + private boolean isInnerVariableReference(Field field) { + + if(inners.isEmpty()) { + return false; + } + + if(inners.contains(field.getName())) { + return true; + } + + return inners.stream().anyMatch(it -> field.getTarget().contains(".") && field.getTarget().startsWith(it)); } /* diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java index 99642aacb..4daef44e6 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java @@ -18,7 +18,9 @@ package org.springframework.data.mongodb.core.aggregation; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.List; +import java.util.stream.Collectors; import org.bson.Document; import org.springframework.data.mongodb.core.aggregation.VariableOperators.Let.ExpressionVariable; @@ -185,7 +187,8 @@ public class VariableOperators { map.putAll(context.getMappedObject(input)); map.put("as", itemVariableName); map.put("in", - functionToApply.toDocument(new NestedDelegatingExpressionAggregationOperationContext(operationContext))); + functionToApply.toDocument(new NestedDelegatingExpressionAggregationOperationContext(operationContext, + Collections.singleton(Fields.field(itemVariableName))))); return new Document("$map", map); } @@ -322,12 +325,14 @@ public class VariableOperators { private Document getMappedVariable(ExpressionVariable var, AggregationOperationContext context) { - return new Document(var.variableName, var.expression instanceof AggregationExpression - ? ((AggregationExpression) var.expression).toDocument(context) : var.expression); + return new Document(var.variableName, + var.expression instanceof AggregationExpression ? ((AggregationExpression) var.expression).toDocument(context) + : var.expression); } private Object getMappedIn(AggregationOperationContext context) { - return expression.toDocument(new NestedDelegatingExpressionAggregationOperationContext(context)); + return expression.toDocument(new NestedDelegatingExpressionAggregationOperationContext(context, + this.vars.stream().map(var -> Fields.field(var.variableName)).collect(Collectors.toList()))); } /** diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/FilterExpressionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/FilterExpressionUnitTests.java index be2fe8a9f..688929c7a 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/FilterExpressionUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/FilterExpressionUnitTests.java @@ -123,6 +123,29 @@ public class FilterExpressionUnitTests { assertThat($filter, is(expected)); } + @Test // DATAMONGO-2320 + public void shouldConstructFilterExpressionCorrectlyWhenConditionContainsFieldReference() { + + Aggregation agg = Aggregation.newAggregation(Aggregation.project().and((ctx) -> new Document()).as("field-1") + .and(filter("items").as("item").by(ComparisonOperators.valueOf("item.price").greaterThan("field-1"))) + .as("items")); + + Document dbo = agg.toDocument("sales", Aggregation.DEFAULT_CONTEXT); + + List pipeline = DocumentTestUtils.getAsDBList(dbo, "pipeline"); + Document $project = DocumentTestUtils.getAsDocument((Document) pipeline.get(0), "$project"); + Document items = DocumentTestUtils.getAsDocument($project, "items"); + Document $filter = DocumentTestUtils.getAsDocument(items, "$filter"); + + Document expected = Document.parse("{" + // + "input: \"$items\"," + // + "as: \"item\"," + // + "cond: { $gt: [ \"$$item.price\", \"$field-1\" ] }" + // + "}"); + + assertThat($filter).isEqualTo(new Document(expected)); + } + static class Sales { List items;