diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index 48331e10b..67b1cdba8 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -1929,86 +1929,89 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, return doAggregate(aggregation, collectionName, outputType, rootContext); } + @SuppressWarnings("ConstantConditions") protected AggregationResults doAggregate(Aggregation aggregation, String collectionName, Class outputType, AggregationOperationContext context) { - Document command = aggregation.toDocument(collectionName, context); - - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Executing aggregation: {}", serializeToJsonSafely(command)); - } - DocumentCallback callback = new UnwrapAndReadDocumentCallback<>(mongoConverter, outputType, collectionName); - if (aggregation.getOptions().isExplain()) { + AggregationOptions options = aggregation.getOptions(); + if (options.isExplain()) { + + Document command = aggregation.toDocument(collectionName, context); + + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Executing aggregation: {}", serializeToJsonSafely(command)); + } Document commandResult = executeCommand(command); return new AggregationResults<>(commandResult.get("results", new ArrayList(0)).stream() .map(callback::doWith).collect(Collectors.toList()), commandResult); } + List pipeline = aggregation.toPipeline(context); + + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Executing aggregation: {} in collection {}", serializeToJsonSafely(pipeline), collectionName); + } + return execute(collectionName, collection -> { List rawResult = new ArrayList<>(); - MongoIterable iterable = collection // - .aggregate(command.get("pipeline", new ArrayList(0)), Document.class) // - .collation(aggregation.getOptions().getCollation().map(Collation::toMongoCollation).orElse(null)).map(val -> { + AggregateIterable aggregateIterable = collection.aggregate(pipeline, Document.class) // + .collation(options.getCollation().map(Collation::toMongoCollation).orElse(null)) // + .allowDiskUse(options.isAllowDiskUse()); - rawResult.add(val); - return callback.doWith(val); - }); + if (options.getCursorBatchSize() != null) { + aggregateIterable = aggregateIterable.batchSize(options.getCursorBatchSize()); + } + + MongoIterable iterable = aggregateIterable.map(val -> { + + rawResult.add(val); + return callback.doWith(val); + }); return new AggregationResults<>(iterable.into(new ArrayList<>()), new Document("results", rawResult).append("ok", 1.0D)); - }); - } + @SuppressWarnings("ConstantConditions") protected CloseableIterator aggregateStream(Aggregation aggregation, String collectionName, Class outputType, @Nullable AggregationOperationContext context) { Assert.hasText(collectionName, "Collection name must not be null or empty!"); Assert.notNull(aggregation, "Aggregation pipeline must not be null!"); Assert.notNull(outputType, "Output type must not be null!"); + Assert.isTrue(!aggregation.getOptions().isExplain(), "Can't use explain option with streaming!"); AggregationOperationContext rootContext = context == null ? Aggregation.DEFAULT_CONTEXT : context; - - Document command = aggregation.toDocument(collectionName, rootContext); - - assertNotExplain(command); + AggregationOptions options = aggregation.getOptions(); + List pipeline = aggregation.toPipeline(rootContext); if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Streaming aggregation: {}", serializeToJsonSafely(command)); + LOGGER.debug("Streaming aggregation: {} in collection {}", serializeToJsonSafely(pipeline), collectionName); } - ReadDocumentCallback readCallback = new ReadDocumentCallback(mongoConverter, outputType, collectionName); + ReadDocumentCallback readCallback = new ReadDocumentCallback<>(mongoConverter, outputType, collectionName); - return execute(collectionName, new CollectionCallback>() { + return execute(collectionName, (CollectionCallback>) collection -> { - @Override - public CloseableIterator doInCollection(MongoCollection collection) - throws MongoException, DataAccessException { + AggregateIterable cursor = collection.aggregate(pipeline) // + .allowDiskUse(options.isAllowDiskUse()) // + .useCursor(true); - List pipeline = (List) command.get("pipeline"); - - AggregationOptions options = AggregationOptions.fromDocument(command); - - AggregateIterable cursor = collection.aggregate(pipeline).allowDiskUse(options.isAllowDiskUse()) - .useCursor(true); - - Integer cursorBatchSize = options.getCursorBatchSize(); - if (cursorBatchSize != null) { - cursor = cursor.batchSize(cursorBatchSize); - } - - if (options.getCollation().isPresent()) { - cursor = cursor.collation(options.getCollation().map(Collation::toMongoCollation).get()); - } - - return new CloseableIterableCursorAdapter(cursor.iterator(), exceptionTranslator, readCallback); + if (options.getCursorBatchSize() != null) { + cursor = cursor.batchSize(options.getCursorBatchSize()); } + + if (options.getCollation().isPresent()) { + cursor = cursor.collation(options.getCollation().map(Collation::toMongoCollation).get()); + } + + return new CloseableIterableCursorAdapter<>(cursor.iterator(), exceptionTranslator, readCallback); }); } @@ -2057,20 +2060,6 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, return new ExecutableInsertOperationSupport(this).insert(domainType); } - /** - * Assert that the {@link Document} does not enable Aggregation explain mode. - * - * @param command the command {@link Document}. - */ - private void assertNotExplain(Document command) { - - Boolean explain = command.get("explain", Boolean.class); - - if (explain != null && explain) { - throw new IllegalArgumentException("Can't use explain option with streaming!"); - } - } - protected String replaceWithResourceIfNecessary(String function) { String func = function; diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java index 6b9e97bb6..5a67d8e9c 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java @@ -733,19 +733,19 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati Assert.notNull(outputType, "Output type must not be null!"); AggregationOperationContext rootContext = context == null ? Aggregation.DEFAULT_CONTEXT : context; - Document command = aggregation.toDocument(collectionName, rootContext); - AggregationOptions options = AggregationOptions.fromDocument(command); + AggregationOptions options = aggregation.getOptions(); + List pipeline = aggregation.toPipeline(rootContext); Assert.isTrue(!options.isExplain(), "Cannot use explain option with streaming!"); Assert.isNull(options.getCursorBatchSize(), "Cannot use batchSize cursor option with streaming!"); if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Streaming aggregation: {}", serializeToJsonSafely(command)); + LOGGER.debug("Streaming aggregation: {} in collection {}", serializeToJsonSafely(pipeline), collectionName); } ReadDocumentCallback readCallback = new ReadDocumentCallback<>(mongoConverter, outputType, collectionName); return execute(collectionName, - collection -> aggregateAndMap(collection, (List) command.get("pipeline"), options, readCallback)); + collection -> aggregateAndMap(collection, pipeline, options, readCallback)); } private Flux aggregateAndMap(MongoCollection collection, List pipeline, diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java index 2ec58c81c..c6f13241f 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java @@ -181,7 +181,7 @@ public class Aggregation { * Get the {@link AggregationOptions}. * * @return never {@literal null}. - * @since 2.0.2 + * @since 2.1 */ public AggregationOptions getOptions() { return options; @@ -585,21 +585,31 @@ public class Aggregation { } /** - * Converts this {@link Aggregation} specification to a {@link Document}. + * Renders this {@link Aggregation} specification to an aggregation pipeline returning a {@link List} of + * {@link Document}. * - * @param inputCollectionName the name of the input collection - * @return the {@code Document} representing this aggregation + * @return the aggregation pipeline representing this aggregation. + * @since 2.1 + */ + public List toPipeline(AggregationOperationContext rootContext) { + return AggregationOperationRenderer.toDocument(operations, rootContext); + } + + /** + * Converts this {@link Aggregation} specification to a {@link Document}. + *

+ * MongoDB requires as of 3.6 cursor-based aggregation. Use {@link #toPipeline(AggregationOperationContext)} to render + * an aggregation pipeline. + * + * @param inputCollectionName the name of the input collection. + * @return the {@code Document} representing this aggregation. */ public Document toDocument(String inputCollectionName, AggregationOperationContext rootContext) { - List operationDocuments = AggregationOperationRenderer.toDocument(operations, rootContext); - Document command = new Document("aggregate", inputCollectionName); - command.put("pipeline", operationDocuments); + command.put("pipeline", toPipeline(rootContext)); - command = options.applyAndReturnPotentiallyChangedCommand(command); - - return command; + return options.applyAndReturnPotentiallyChangedCommand(command); } /* diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOptions.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOptions.java index 7a50de10d..05e7a7bf5 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOptions.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOptions.java @@ -147,6 +147,7 @@ public class AggregationOptions { * @return the batch size or {@literal null}. * @since 2.0 */ + @Nullable public Integer getCursorBatchSize() { if (cursor.filter(val -> val.containsKey(BATCH_SIZE)).isPresent()) { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java index 84a2ecaac..073618a72 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java @@ -132,7 +132,7 @@ public class MongoTemplateUnitTests extends MongoOperationsUnitTests { when(collection.find(Mockito.any(org.bson.Document.class))).thenReturn(findIterable); when(collection.mapReduce(Mockito.any(), Mockito.any())).thenReturn(mapReduceIterable); when(collection.count(any(Bson.class), any(CountOptions.class))).thenReturn(1L); - when(collection.aggregate(any(), any())).thenReturn(aggregateIterable); + when(collection.aggregate(any(List.class), any())).thenReturn(aggregateIterable); when(collection.withReadPreference(any())).thenReturn(collection); when(findIterable.projection(Mockito.any())).thenReturn(findIterable); when(findIterable.sort(Mockito.any(org.bson.Document.class))).thenReturn(findIterable); @@ -144,6 +144,8 @@ public class MongoTemplateUnitTests extends MongoOperationsUnitTests { when(mapReduceIterable.iterator()).thenReturn(cursor); when(mapReduceIterable.filter(any())).thenReturn(mapReduceIterable); when(aggregateIterable.collation(any())).thenReturn(aggregateIterable); + when(aggregateIterable.allowDiskUse(any())).thenReturn(aggregateIterable); + when(aggregateIterable.batchSize(anyInt())).thenReturn(aggregateIterable); when(aggregateIterable.map(any())).thenReturn(aggregateIterable); when(aggregateIterable.into(any())).thenReturn(Collections.emptyList()); @@ -764,6 +766,16 @@ public class MongoTemplateUnitTests extends MongoOperationsUnitTests { verify(aggregateIterable).collation(eq(com.mongodb.client.model.Collation.builder().locale("fr").build())); } + @Test // DATAMONGO-1824 + public void aggregateShouldUseBatchSizeWhenPresent() { + + Aggregation aggregation = newAggregation(project("id")) + .withOptions(newAggregationOptions().collation(Collation.of("fr")).cursorBatchSize(100).build()); + template.aggregate(aggregation, AutogenerateableId.class, Document.class); + + verify(aggregateIterable).batchSize(100); + } + @Test // DATAMONGO-1518 public void mapReduceShouldUseCollationWhenPresent() { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java index 48e60bc2e..42ab2a9dc 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java @@ -1337,7 +1337,6 @@ public class AggregationTests { Document rawResult = result.getRawResults(); assertThat(rawResult, is(notNullValue())); - assertThat(rawResult.containsKey("stages"), is(true)); } @@ -1357,7 +1356,7 @@ public class AggregationTests { AggregationResults result = mongoTemplate.aggregate(agg, Person.class, Document.class); assertThat(result.getMappedResults(), hasSize(3)); - Document o = (Document) result.getMappedResults().get(2); + Document o = result.getMappedResults().get(2); assertThat(o.get("_id"), is((Object) 25)); assertThat((List) o.get("users"), hasSize(2)); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoVersion.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoVersion.java index e47938bbc..2aaa0466b 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoVersion.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoVersion.java @@ -26,7 +26,7 @@ import java.lang.annotation.Target; * be used along with {@link MongoVersionRule}. * * @author Christoph Strobl - * @since 2.0.2 + * @since 2.1 */ @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.METHOD) diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoVersionRule.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoVersionRule.java index 78c342e24..0de7e8534 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoVersionRule.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoVersionRule.java @@ -1,5 +1,5 @@ /* - * Copyright 2014 the original author or authors. + * Copyright 2014-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.data.mongodb.test.util; import java.util.concurrent.atomic.AtomicReference; +import org.bson.Document; import org.junit.AssumptionViolatedException; import org.junit.ClassRule; import org.junit.Rule; @@ -26,17 +27,16 @@ import org.junit.runners.model.Statement; import org.springframework.data.util.Version; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; -import com.mongodb.BasicDBObject; -import com.mongodb.CommandResult; -import com.mongodb.DB; import com.mongodb.MongoClient; +import com.mongodb.client.MongoDatabase; /** * {@link TestRule} verifying server tests are executed against match a given version. This one can be used as * {@link ClassRule} eg. in context depending tests run with {@link SpringJUnit4ClassRunner} when the context would fail * to start in case of invalid version, or as simple {@link Rule} on specific tests. - * + * * @author Christoph Strobl + * @author Mark Paluch * @since 1.6 */ public class MongoVersionRule implements TestRule { @@ -170,11 +170,11 @@ public class MongoVersionRule implements TestRule { MongoClient client; client = new MongoClient(host, port); - DB db = client.getDB("test"); - CommandResult result = db.command(new BasicDBObject().append("buildInfo", 1)); + MongoDatabase database = client.getDatabase("test"); + Document result = database.runCommand(new Document("buildInfo", 1)); client.close(); - return Version.parse(result.get("version").toString()); + return Version.parse(result.get("version", String.class)); } catch (Exception e) { return ANY; } diff --git a/src/main/asciidoc/new-features.adoc b/src/main/asciidoc/new-features.adoc index 2f5e65a90..57e6addb3 100644 --- a/src/main/asciidoc/new-features.adoc +++ b/src/main/asciidoc/new-features.adoc @@ -1,6 +1,10 @@ [[new-features]] = New & Noteworthy +[[new-features.2-1-0]] +== What's new in Spring Data MongoDB 2.1 +* Cursor-based aggregation execution. + [[new-features.2-0-0]] == What's new in Spring Data MongoDB 2.0 * Upgrade to Java 8. diff --git a/src/main/asciidoc/reference/mongodb.adoc b/src/main/asciidoc/reference/mongodb.adoc index 857f08736..de79c5d95 100644 --- a/src/main/asciidoc/reference/mongodb.adoc +++ b/src/main/asciidoc/reference/mongodb.adoc @@ -1410,7 +1410,7 @@ List results = template.find(query, Person.class); ---- Collation collation = Collation.of("de"); -AggregationOptions options = new AggregationOptions.Builder().collation(collation).build(); +AggregationOptions options = AggregationOptions.builder().collation(collation).build(); Aggregation aggregation = newAggregation( project("tags"),