diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AbstractAggregationExpression.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AbstractAggregationExpression.java index 82c03758b..07fa9023c 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AbstractAggregationExpression.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AbstractAggregationExpression.java @@ -29,8 +29,11 @@ import org.springframework.util.Assert; import org.springframework.util.ObjectUtils; /** + * Support class for {@link AggregationExpression} implementations. + * * @author Christoph Strobl * @author Matt Morrissette + * @author Mark Paluch * @since 1.10 */ abstract class AbstractAggregationExpression implements AggregationExpression { @@ -49,7 +52,6 @@ abstract class AbstractAggregationExpression implements AggregationExpression { return toDocument(this.value, context); } - @SuppressWarnings("unchecked") public Document toDocument(Object value, AggregationOperationContext context) { return new Document(getMongoMethod(), unpack(value, context)); } @@ -101,17 +103,19 @@ abstract class AbstractAggregationExpression implements AggregationExpression { return value; } + @SuppressWarnings("unchecked") protected List append(Object value, Expand expandList) { if (this.value instanceof List) { - List clone = new ArrayList((List) this.value); + List clone = new ArrayList<>((List) this.value); if (value instanceof Collection && Expand.EXPAND_VALUES.equals(expandList)) { clone.addAll((Collection) value); } else { clone.add(value); } + return clone; } @@ -129,22 +133,23 @@ abstract class AbstractAggregationExpression implements AggregationExpression { return append(value, Expand.EXPAND_VALUES); } - @SuppressWarnings("unchecked") - protected java.util.Map append(String key, Object value) { + @SuppressWarnings({ "unchecked", "rawtypes" }) + protected Map append(String key, Object value) { Assert.isInstanceOf(Map.class, this.value, "Value must be a type of Map!"); - java.util.Map clone = new LinkedHashMap<>((java.util.Map) this.value); + Map clone = new LinkedHashMap<>((java.util.Map) this.value); clone.put(key, value); return clone; } - protected java.util.Map remove(String key) { + @SuppressWarnings({ "unchecked", "rawtypes" }) + protected Map remove(String key) { Assert.isInstanceOf(Map.class, this.value, "Value must be a type of Map!"); - java.util.Map clone = new LinkedHashMap<>((java.util.Map) this.value); + Map clone = new LinkedHashMap<>((java.util.Map) this.value); clone.remove(key); return clone; } @@ -158,14 +163,15 @@ abstract class AbstractAggregationExpression implements AggregationExpression { * @return * @since 3.1 */ - protected java.util.Map appendAt(int index, String key, Object value) { + @SuppressWarnings({ "unchecked" }) + protected Map appendAt(int index, String key, Object value) { Assert.isInstanceOf(Map.class, this.value, "Value must be a type of Map!"); - java.util.LinkedHashMap clone = new java.util.LinkedHashMap<>(); + Map clone = new LinkedHashMap<>(); int i = 0; - for (Map.Entry entry : ((java.util.Map) this.value).entrySet()) { + for (Map.Entry entry : ((Map) this.value).entrySet()) { if (i == index) { clone.put(key, value); @@ -182,14 +188,17 @@ abstract class AbstractAggregationExpression implements AggregationExpression { } + @SuppressWarnings({ "rawtypes" }) protected List values() { if (value instanceof List) { return new ArrayList((List) value); } + if (value instanceof java.util.Map) { return new ArrayList(((java.util.Map) value).values()); } + return new ArrayList<>(Collections.singletonList(value)); } @@ -219,7 +228,7 @@ abstract class AbstractAggregationExpression implements AggregationExpression { Assert.isInstanceOf(Map.class, this.value, "Value must be a type of Map!"); - return (T) ((java.util.Map) this.value).get(key); + return (T) ((Map) this.value).get(key); } /** @@ -229,11 +238,11 @@ abstract class AbstractAggregationExpression implements AggregationExpression { * @return */ @SuppressWarnings("unchecked") - protected java.util.Map argumentMap() { + protected Map argumentMap() { Assert.isInstanceOf(Map.class, this.value, "Value must be a type of Map!"); - return Collections.unmodifiableMap((java.util.Map) value); + return Collections.unmodifiableMap((java.util.Map) value); } /** @@ -250,7 +259,7 @@ abstract class AbstractAggregationExpression implements AggregationExpression { return false; } - return ((java.util.Map) this.value).containsKey(key); + return ((Map) this.value).containsKey(key); } protected abstract String getMongoMethod(); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ScriptOperators.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ScriptOperators.java index 6d451aca0..3fde9da6a 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ScriptOperators.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ScriptOperators.java @@ -36,6 +36,7 @@ import org.springframework.util.CollectionUtils; * enabled. * * @author Christoph Strobl + * @author Mark Paluch * @since 3.1 */ public class ScriptOperators { @@ -83,7 +84,6 @@ public class ScriptOperators { * * @see MongoDB Documentation: * $function - * @since 3.1 */ public static class Function extends AbstractAggregationExpression { @@ -99,6 +99,8 @@ public class ScriptOperators { */ public static Function function(String body) { + Assert.notNull(body, "Function body must not be null!"); + Map function = new LinkedHashMap<>(2); function.put(Fields.BODY.toString(), body); function.put(Fields.ARGS.toString(), Collections.emptyList()); @@ -126,6 +128,7 @@ public class ScriptOperators { public Function args(List args) { Assert.notNull(args, "Args must not be null! Use an empty list instead."); + return new Function(appendAt(1, Fields.ARGS.toString(), args)); } @@ -137,7 +140,8 @@ public class ScriptOperators { */ public Function lang(String lang) { - Assert.hasText(lang, "Lang must not be null nor emtpy! The default would be 'js'."); + Assert.hasText(lang, "Lang must not be null nor empty! The default would be 'js'."); + return new Function(appendAt(2, Fields.LANG.toString(), lang)); } @@ -198,7 +202,6 @@ public class ScriptOperators { * * @see MongoDB Documentation: * $accumulator - * @since 3.1 */ public static class Accumulator extends AbstractAggregationExpression { @@ -293,10 +296,10 @@ public class ScriptOperators { /** * Define the optional {@code initArgs} for the {@link AccumulatorInitBuilder#init(String)} function. * - * @param args can be {@literal null}. + * @param args must not be {@literal null}. * @return this. */ - AccumulatorAccumulateBuilder initArgs(@Nullable List args); + AccumulatorAccumulateBuilder initArgs(List args); } public interface AccumulatorAccumulateBuilder { @@ -355,10 +358,10 @@ public class ScriptOperators { * Define additional {@code accumulateArgs} for the {@link AccumulatorAccumulateBuilder#accumulate(String)} * function. * - * @param args can be {@literal null}. + * @param args must not be {@literal null}. * @return this. */ - AccumulatorMergeBuilder accumulateArgs(@Nullable List args); + AccumulatorMergeBuilder accumulateArgs(List args); } public interface AccumulatorMergeBuilder { @@ -398,9 +401,16 @@ public class ScriptOperators { * @return new instance of {@link Accumulator}. */ Accumulator finalize(String function); + + /** + * Build the {@link Accumulator} object without specifying a {@link #finalize(String) finalize function}. + * + * @return new instance of {@link Accumulator}. + */ + Accumulator build(); } - public static class AccumulatorBuilder + static class AccumulatorBuilder implements AccumulatorInitBuilder, AccumulatorInitArgsBuilder, AccumulatorAccumulateBuilder, AccumulatorAccumulateArgsBuilder, AccumulatorMergeBuilder, AccumulatorFinalizeBuilder { @@ -426,6 +436,7 @@ public class ScriptOperators { * @param function must not be {@literal null}. * @return this. */ + @Override public AccumulatorBuilder init(String function) { this.initFunction = function; @@ -435,12 +446,15 @@ public class ScriptOperators { /** * Define the optional {@code initArgs} for the {@link #init(String)} function. * - * @param args can be {@literal null}. + * @param function must not be {@literal null}. * @return this. */ - public AccumulatorBuilder initArgs(@Nullable List args) { + @Override + public AccumulatorBuilder initArgs(List args) { - this.initArgs = args != null ? new ArrayList<>(args) : Collections.emptyList(); + Assert.notNull(args, "Args must not be null"); + + this.initArgs = new ArrayList<>(args); return this; } @@ -458,8 +472,11 @@ public class ScriptOperators { * @param function must not be {@literal null}. * @return this. */ + @Override public AccumulatorBuilder accumulate(String function) { + Assert.notNull(function, "Accumulate function must not be null"); + this.accumulateFunction = function; return this; } @@ -467,12 +484,15 @@ public class ScriptOperators { /** * Define additional {@code accumulateArgs} for the {@link #accumulate(String)} function. * - * @param args can be {@literal null}. + * @param args must not be {@literal null}. * @return this. */ - public AccumulatorBuilder accumulateArgs(@Nullable List args) { + @Override + public AccumulatorBuilder accumulateArgs(List args) { - this.accumulateArgs = args != null ? new ArrayList<>(args) : Collections.emptyList(); + Assert.notNull(args, "Args must not be null"); + + this.accumulateArgs = new ArrayList<>(args); return this; } @@ -491,8 +511,11 @@ public class ScriptOperators { * @param function must not be {@literal null}. * @return this. */ + @Override public AccumulatorBuilder merge(String function) { + Assert.notNull(function, "Merge function must not be null"); + this.mergeFunction = function; return this; } @@ -505,6 +528,8 @@ public class ScriptOperators { */ public AccumulatorBuilder lang(String lang) { + Assert.hasText(lang, "Lang must not be null nor empty! The default would be 'js'."); + this.lang = lang; return this; } @@ -523,10 +548,26 @@ public class ScriptOperators { * @param function must not be {@literal null}. * @return new instance of {@link Accumulator}. */ + @Override public Accumulator finalize(String function) { + Assert.notNull(function, "Finalize function must not be null"); + this.finalizeFunction = function; + Map args = createArgumentMap(); + args.put(Fields.FINALIZE.toString(), finalizeFunction); + + return new Accumulator(args); + } + + @Override + public Accumulator build() { + return new Accumulator(createArgumentMap()); + } + + private Map createArgumentMap() { + Map args = new LinkedHashMap<>(); args.put(Fields.INIT.toString(), initFunction); if (!CollectionUtils.isEmpty(initArgs)) { @@ -537,12 +578,10 @@ public class ScriptOperators { args.put(Fields.ACCUMULATE_ARGS.toString(), accumulateArgs); } args.put(Fields.MERGE.toString(), mergeFunction); - args.put(Fields.FINALIZE.toString(), finalizeFunction); args.put(Fields.LANG.toString(), lang); - return new Accumulator(args); + return args; } - } } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ScriptOperatorsUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ScriptOperatorsUnitTests.java index fb237b631..3e9d9e804 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ScriptOperatorsUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ScriptOperatorsUnitTests.java @@ -24,6 +24,8 @@ import org.bson.Document; import org.junit.jupiter.api.Test; /** + * Unit tests for {@link ScriptOperators}. + * * @author Christoph Strobl */ class ScriptOperatorsUnitTests { @@ -32,20 +34,6 @@ class ScriptOperatorsUnitTests { private static final Document EMPTY_ARGS_FUNCTION_DOCUMENT = new Document("body", FUNCTION_BODY) .append("args", Collections.emptyList()).append("lang", "js"); - @Test // DATAMONGO-2623 - void functionWithoutArgsShouldBeRenderedCorrectly() { - - assertThat(function(FUNCTION_BODY).toDocument(Aggregation.DEFAULT_CONTEXT)) - .isEqualTo($function(EMPTY_ARGS_FUNCTION_DOCUMENT)); - } - - @Test // DATAMONGO-2623 - void functionWithArgsShouldBeRenderedCorrectly() { - - assertThat(function(FUNCTION_BODY).args("$name").toDocument(Aggregation.DEFAULT_CONTEXT)).isEqualTo( - $function(new Document(EMPTY_ARGS_FUNCTION_DOCUMENT).append("args", Collections.singletonList("$name")))); - } - private static final String INIT_FUNCTION = "function() { return { count: 0, sum: 0 } }"; private static final String ACC_FUNCTION = "function(state, numCopies) { return { count: state.count + 1, sum: state.sum + numCopies } }"; private static final String MERGE_FUNCTION = "function(state1, state2) { return { count: state1.count + state2.count, sum: state1.sum + state2.sum } }"; @@ -64,6 +52,20 @@ class ScriptOperatorsUnitTests { " }" + // " }"); + @Test // DATAMONGO-2623 + void functionWithoutArgsShouldBeRenderedCorrectly() { + + assertThat(function(FUNCTION_BODY).toDocument(Aggregation.DEFAULT_CONTEXT)) + .isEqualTo($function(EMPTY_ARGS_FUNCTION_DOCUMENT)); + } + + @Test // DATAMONGO-2623 + void functionWithArgsShouldBeRenderedCorrectly() { + + assertThat(function(FUNCTION_BODY).args("$name").toDocument(Aggregation.DEFAULT_CONTEXT)).isEqualTo( + $function(new Document(EMPTY_ARGS_FUNCTION_DOCUMENT).append("args", Collections.singletonList("$name")))); + } + @Test // DATAMONGO-2623 void accumulatorWithStringInput() {