diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationExpressions.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationExpressions.java index e04193d14..2b5e87374 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationExpressions.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationExpressions.java @@ -28,7 +28,6 @@ import org.springframework.data.mongodb.core.aggregation.AggregationExpressions. import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Filter.AsBuilder; import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Let.ExpressionVariable; import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField; -import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference; import org.springframework.data.mongodb.core.query.CriteriaDefinition; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; @@ -3905,31 +3904,19 @@ public interface AggregationExpressions { */ @Override public DBObject toDbObject(final AggregationOperationContext context) { - - return toFilter(new ExposedFieldsAggregationOperationContext(ExposedFields.from(as), context) { - - @Override - public FieldReference getReference(Field field) { - - FieldReference ref = null; - try { - ref = context.getReference(field); - } catch (Exception e) { - // just ignore that one. - } - return ref != null ? ref : super.getReference(field); - } - }); + return toFilter(ExposedFields.from(as), context); } - private DBObject toFilter(AggregationOperationContext context) { + private DBObject toFilter(ExposedFields exposedFields, AggregationOperationContext context) { DBObject filterExpression = new BasicDBObject(); + InheritingExposedFieldsAggregationOperationContext operationContext = new InheritingExposedFieldsAggregationOperationContext( + exposedFields, context); filterExpression.putAll(context.getMappedObject(new BasicDBObject("input", getMappedInput(context)))); filterExpression.put("as", as.getTarget()); - filterExpression.putAll(context.getMappedObject(new BasicDBObject("cond", getMappedCondition(context)))); + filterExpression.putAll(context.getMappedObject(new BasicDBObject("cond", getMappedCondition(operationContext)))); return new BasicDBObject("$filter", filterExpression); } @@ -6019,27 +6006,14 @@ public interface AggregationExpressions { @Override public DBObject toDbObject(final AggregationOperationContext context) { - - return toMap(new ExposedFieldsAggregationOperationContext( - ExposedFields.synthetic(Fields.fields(itemVariableName)), context) { - - @Override - public FieldReference getReference(Field field) { - - FieldReference ref = null; - try { - ref = context.getReference(field); - } catch (Exception e) { - // just ignore that one. - } - return ref != null ? ref : super.getReference(field); - } - }); + return toMap(ExposedFields.synthetic(Fields.fields(itemVariableName)), context); } - private DBObject toMap(AggregationOperationContext context) { + private DBObject toMap(ExposedFields exposedFields, AggregationOperationContext context) { BasicDBObject map = new BasicDBObject(); + InheritingExposedFieldsAggregationOperationContext operationContext = new InheritingExposedFieldsAggregationOperationContext( + exposedFields, context); BasicDBObject input; if (sourceArray instanceof Field) { @@ -6050,7 +6024,8 @@ public interface AggregationExpressions { map.putAll(context.getMappedObject(input)); map.put("as", itemVariableName); - map.put("in", functionToApply.toDbObject(new NestedDelegatingExpressionAggregationOperationContext(context))); + map.put("in", + functionToApply.toDbObject(new NestedDelegatingExpressionAggregationOperationContext(operationContext))); return new BasicDBObject("$map", map); } @@ -6792,22 +6767,7 @@ public interface AggregationExpressions { @Override public DBObject toDbObject(final AggregationOperationContext context) { - - return toLet(new ExposedFieldsAggregationOperationContext( - ExposedFields.synthetic(Fields.fields(getVariableNames())), context) { - - @Override - public FieldReference getReference(Field field) { - - FieldReference ref = null; - try { - ref = context.getReference(field); - } catch (Exception e) { - // just ignore that one. - } - return ref != null ? ref : super.getReference(field); - } - }); + return toLet(ExposedFields.synthetic(Fields.fields(getVariableNames())), context); } private String[] getVariableNames() { @@ -6816,20 +6776,23 @@ public interface AggregationExpressions { for (int i = 0; i < this.vars.size(); i++) { varNames[i] = this.vars.get(i).variableName; } + return varNames; } - private DBObject toLet(AggregationOperationContext context) { + private DBObject toLet(ExposedFields exposedFields, AggregationOperationContext context) { DBObject letExpression = new BasicDBObject(); - DBObject mappedVars = new BasicDBObject(); + InheritingExposedFieldsAggregationOperationContext operationContext = new InheritingExposedFieldsAggregationOperationContext( + exposedFields, context); + for (ExpressionVariable var : this.vars) { mappedVars.putAll(getMappedVariable(var, context)); } letExpression.put("vars", mappedVars); - letExpression.put("in", getMappedIn(context)); + letExpression.put("in", getMappedIn(operationContext)); return new BasicDBObject("$let", letExpression); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java index c25b56732..2071dc0b6 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java @@ -13,17 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.springframework.data.mongodb.core.aggregation; import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference; -import org.springframework.util.Assert; /** * {@link ExposedFieldsAggregationOperationContext} that inherits fields from its parent * {@link AggregationOperationContext}. * * @author Mark Paluch + * @since 1.9 */ class InheritingExposedFieldsAggregationOperationContext extends ExposedFieldsAggregationOperationContext { @@ -40,7 +39,7 @@ class InheritingExposedFieldsAggregationOperationContext extends ExposedFieldsAg AggregationOperationContext previousContext) { super(exposedFields, previousContext); - Assert.notNull(previousContext, "PreviousContext must not be null!"); + this.previousContext = previousContext; } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java index 4faddf85d..d89d1a782 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java @@ -143,6 +143,8 @@ public class AggregationTests { mongoTemplate.dropCollection(MeterData.class); mongoTemplate.dropCollection(LineItem.class); mongoTemplate.dropCollection(InventoryItem.class); + mongoTemplate.dropCollection(Sales.class); + mongoTemplate.dropCollection(Sales2.class); } /**