From 1839f55055360bf903140631dd4f025beeca0e8e Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Wed, 11 Jan 2023 16:02:02 +0100 Subject: [PATCH] Polishing. Introduce HintFunction to encapsulate how hints are applied and to remove code duplications. See #4238 Original pull request: #4243 --- .../data/mongodb/core/HintFunction.java | 102 ++++++++++++++++++ .../data/mongodb/core/MongoTemplate.java | 59 ++-------- .../data/mongodb/core/QueryOperations.java | 12 +-- .../mongodb/core/ReactiveMongoTemplate.java | 32 ++---- .../core/aggregation/AggregationOptions.java | 23 +--- .../aggregation/AggregationOptionsTests.java | 2 + 6 files changed, 133 insertions(+), 97 deletions(-) create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/HintFunction.java diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/HintFunction.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/HintFunction.java new file mode 100644 index 000000000..c9b5514b2 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/HintFunction.java @@ -0,0 +1,102 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core; + +import java.util.function.Function; + +import org.bson.conversions.Bson; +import org.springframework.data.mongodb.CodecRegistryProvider; +import org.springframework.data.mongodb.util.BsonUtils; +import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; + +/** + * Function object to apply a query hint. Can be an index name or a BSON document. + * + * @author Mark Paluch + * @since 4.1 + */ +class HintFunction { + + private static final HintFunction EMPTY = new HintFunction(null); + + private final @Nullable Object hint; + + private HintFunction(@Nullable Object hint) { + this.hint = hint; + } + + /** + * Return an empty hint function. + * + * @return + */ + static HintFunction empty() { + return EMPTY; + } + + /** + * Create a {@link HintFunction} from a {@link Bson document} or {@link String index name}. + * + * @param hint + * @return + */ + static HintFunction from(@Nullable Object hint) { + return new HintFunction(hint); + } + + /** + * Return whether a hint is present. + * + * @return + */ + public boolean isPresent() { + return (hint instanceof String hintString && StringUtils.hasText(hintString)) || hint instanceof Bson; + } + + /** + * Apply the hint to consumers depending on the hint format. + * + * @param registryProvider + * @param stringConsumer + * @param bsonConsumer + * @return + * @param + */ + public R apply(@Nullable CodecRegistryProvider registryProvider, Function stringConsumer, + Function bsonConsumer) { + + if (!isPresent()) { + throw new IllegalStateException("No hint present"); + } + + if (hint instanceof Bson bson) { + return bsonConsumer.apply(bson); + } + + if (hint instanceof String hintString) { + + if (BsonUtils.isJsonDocument(hintString)) { + return bsonConsumer.apply(BsonUtils.parse(hintString, registryProvider)); + } + return stringConsumer.apply(hintString); + } + + throw new IllegalStateException( + "Unable to read hint of type %s".formatted(hint != null ? hint.getClass() : "null")); + } + +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index 9930e6243..58960074c 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -69,16 +69,7 @@ import org.springframework.data.mongodb.core.aggregation.AggregationOptions; import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; import org.springframework.data.mongodb.core.aggregation.AggregationResults; import org.springframework.data.mongodb.core.aggregation.TypedAggregation; -import org.springframework.data.mongodb.core.convert.DbRefResolver; -import org.springframework.data.mongodb.core.convert.DefaultDbRefResolver; -import org.springframework.data.mongodb.core.convert.JsonSchemaMapper; -import org.springframework.data.mongodb.core.convert.MappingMongoConverter; -import org.springframework.data.mongodb.core.convert.MongoConverter; -import org.springframework.data.mongodb.core.convert.MongoCustomConversions; -import org.springframework.data.mongodb.core.convert.MongoJsonSchemaMapper; -import org.springframework.data.mongodb.core.convert.MongoWriter; -import org.springframework.data.mongodb.core.convert.QueryMapper; -import org.springframework.data.mongodb.core.convert.UpdateMapper; +import org.springframework.data.mongodb.core.convert.*; import org.springframework.data.mongodb.core.index.IndexOperations; import org.springframework.data.mongodb.core.index.IndexOperationsProvider; import org.springframework.data.mongodb.core.index.MongoMappingEventPublisher; @@ -99,7 +90,6 @@ import org.springframework.data.mongodb.core.query.UpdateDefinition; import org.springframework.data.mongodb.core.query.UpdateDefinition.ArrayFilter; import org.springframework.data.mongodb.core.timeseries.Granularity; import org.springframework.data.mongodb.core.validation.Validator; -import org.springframework.data.mongodb.util.BsonUtils; import org.springframework.data.projection.EntityProjection; import org.springframework.data.util.CloseableIterator; import org.springframework.data.util.Optionals; @@ -116,16 +106,7 @@ import com.mongodb.ClientSessionOptions; import com.mongodb.MongoException; import com.mongodb.ReadPreference; import com.mongodb.WriteConcern; -import com.mongodb.client.AggregateIterable; -import com.mongodb.client.ClientSession; -import com.mongodb.client.DistinctIterable; -import com.mongodb.client.FindIterable; -import com.mongodb.client.MapReduceIterable; -import com.mongodb.client.MongoClient; -import com.mongodb.client.MongoCollection; -import com.mongodb.client.MongoCursor; -import com.mongodb.client.MongoDatabase; -import com.mongodb.client.MongoIterable; +import com.mongodb.client.*; import com.mongodb.client.model.*; import com.mongodb.client.result.DeleteResult; import com.mongodb.client.result.UpdateResult; @@ -2067,15 +2048,9 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, } options.getComment().ifPresent(aggregateIterable::comment); - if (options.getHintObject().isPresent()) { - Object hintObject = options.getHintObject().get(); - if (hintObject instanceof String hintString) { - aggregateIterable = aggregateIterable.hintString(hintString); - } else if (hintObject instanceof Document hintDocument) { - aggregateIterable = aggregateIterable.hint(hintDocument); - } else { - throw new IllegalStateException("Unable to read hint of type %s".formatted(hintObject.getClass())); - } + HintFunction hintFunction = options.getHintObject().map(HintFunction::from).orElseGet(HintFunction::empty); + if (hintFunction.isPresent()) { + aggregateIterable = hintFunction.apply(mongoDbFactory, aggregateIterable::hintString, aggregateIterable::hint); } if (options.hasExecutionTimeLimit()) { @@ -2135,15 +2110,9 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, } options.getComment().ifPresent(cursor::comment); + HintFunction hintFunction = options.getHintObject().map(HintFunction::from).orElseGet(HintFunction::empty); if (options.getHintObject().isPresent()) { - Object hintObject = options.getHintObject().get(); - if (hintObject instanceof String hintString) { - cursor = cursor.hintString(hintString); - } else if (hintObject instanceof Document hintDocument) { - cursor = cursor.hint(hintDocument); - } else { - throw new IllegalStateException("Unable to read hint of type %s".formatted(hintObject.getClass())); - } + cursor = hintFunction.apply(mongoDbFactory, cursor::hintString, cursor::hint); } Class domainType = aggregation instanceof TypedAggregation ? ((TypedAggregation) aggregation).getInputType() @@ -3172,8 +3141,9 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, .ifPresent(cursorToUse::collation); Meta meta = query.getMeta(); + HintFunction hintFunction = HintFunction.from(query.getHint()); if (query.getSkip() <= 0 && query.getLimit() <= 0 && ObjectUtils.isEmpty(query.getSortObject()) - && !StringUtils.hasText(query.getHint()) && !meta.hasValues() && !query.getCollation().isPresent()) { + && !hintFunction.isPresent() && !meta.hasValues() && !query.getCollation().isPresent()) { return cursorToUse; } @@ -3189,15 +3159,8 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, cursorToUse = cursorToUse.sort(sort); } - if (StringUtils.hasText(query.getHint())) { - - String hint = query.getHint(); - - if (BsonUtils.isJsonDocument(hint)) { - cursorToUse = cursorToUse.hint(BsonUtils.parse(hint, mongoDbFactory)); - } else { - cursorToUse = cursorToUse.hintString(hint); - } + if (hintFunction.isPresent()) { + cursorToUse = hintFunction.apply(mongoDbFactory, cursorToUse::hintString, cursorToUse::hint); } if (meta.hasValues()) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/QueryOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/QueryOperations.java index 114a21788..2e86b8008 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/QueryOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/QueryOperations.java @@ -60,7 +60,6 @@ import org.springframework.data.projection.EntityProjection; import org.springframework.data.util.Lazy; import org.springframework.lang.Nullable; import org.springframework.util.ClassUtils; -import org.springframework.util.StringUtils; import com.mongodb.client.model.CountOptions; import com.mongodb.client.model.DeleteOptions; @@ -567,14 +566,11 @@ class QueryOperations { if (query.getSkip() > 0) { options.skip((int) query.getSkip()); } - if (StringUtils.hasText(query.getHint())) { - String hint = query.getHint(); - if (BsonUtils.isJsonDocument(hint)) { - options.hint(BsonUtils.parse(hint, codecRegistryProvider)); - } else { - options.hintString(hint); - } + HintFunction hintFunction = HintFunction.from(query.getHint()); + + if (hintFunction.isPresent()) { + options = hintFunction.apply(codecRegistryProvider, options::hintString, options::hint); } if (callback != null) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java index e69f51754..5e33adb75 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java @@ -110,7 +110,6 @@ import org.springframework.data.mongodb.core.query.NearQuery; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.UpdateDefinition; import org.springframework.data.mongodb.core.query.UpdateDefinition.ArrayFilter; -import org.springframework.data.mongodb.util.BsonUtils; import org.springframework.data.projection.EntityProjection; import org.springframework.data.util.Optionals; import org.springframework.lang.Nullable; @@ -938,15 +937,10 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati } options.getComment().ifPresent(cursor::comment); - if (options.getHintObject().isPresent()) { - Object hintObject = options.getHintObject().get(); - if (hintObject instanceof String hintString) { - cursor = cursor.hintString(hintString); - } else if (hintObject instanceof Document hintDocument) { - cursor = cursor.hint(hintDocument); - } else { - throw new IllegalStateException("Unable to read hint of type %s".formatted(hintObject.getClass())); - } + + HintFunction hintFunction = options.getHintObject().map(HintFunction::from).orElseGet(HintFunction::empty); + if (hintFunction.isPresent()) { + cursor = hintFunction.apply(mongoDatabaseFactory, cursor::hintString, cursor::hint); } Optionals.firstNonEmpty(options::getCollation, () -> operations.forType(inputType).getCollation()) // @@ -1535,7 +1529,8 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati Publisher publisher; if (!mapped.hasId()) { - publisher = collectionToUse.insertOne(queryOperations.createInsertContext(mapped).prepareId(entityClass).getDocument()); + publisher = collectionToUse + .insertOne(queryOperations.createInsertContext(mapped).prepareId(entityClass).getDocument()); } else { MongoPersistentEntity entity = mappingContext.getPersistentEntity(entityClass); @@ -3044,9 +3039,10 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati .map(findPublisher::collation) // .orElse(findPublisher); + HintFunction hintFunction = HintFunction.from(query.getHint()); Meta meta = query.getMeta(); if (query.getSkip() <= 0 && query.getLimit() <= 0 && ObjectUtils.isEmpty(query.getSortObject()) - && !StringUtils.hasText(query.getHint()) && !meta.hasValues()) { + && !hintFunction.isPresent() && !meta.hasValues()) { return findPublisherToUse; } @@ -3065,15 +3061,9 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati findPublisherToUse = findPublisherToUse.sort(sort); } - if (StringUtils.hasText(query.getHint())) { - - String hint = query.getHint(); - - if (BsonUtils.isJsonDocument(hint)) { - findPublisherToUse = findPublisherToUse.hint(BsonUtils.parse(hint, mongoDatabaseFactory)); - } else { - findPublisherToUse = findPublisherToUse.hintString(hint); - } + if (hintFunction.isPresent()) { + findPublisherToUse = hintFunction.apply(mongoDatabaseFactory, findPublisherToUse::hintString, + findPublisherToUse::hint); } if (meta.hasValues()) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOptions.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOptions.java index 6ea2743f9..690b9470d 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOptions.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOptions.java @@ -237,10 +237,11 @@ public class AggregationOptions { } /** - * Get the hint used to to fulfill the aggregation. + * Get the hint used to fulfill the aggregation. * * @return never {@literal null}. * @since 3.1 + * @deprecated since 4.1, use {@link #getHintObject()} instead. */ public Optional getHint() { return hint.map(it -> { @@ -257,25 +258,7 @@ public class AggregationOptions { } /** - * Get the hint (indexName) used to to fulfill the aggregation. - * - * @return never {@literal null}. - * @since 4.1 - */ - public Optional getHintAsString() { - return hint.map(it -> { - if (it instanceof String hintString) { - return hintString; - } - if (it instanceof Document doc) { - return BsonUtils.toJson(doc); - } - throw new IllegalStateException("Unable to read hint of type %s".formatted(it.getClass())); - }); - } - - /** - * Get the hint used to to fulfill the aggregation. + * Get the hint used to fulfill the aggregation. * * @return never {@literal null}. * @since 4.1 diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationOptionsTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationOptionsTests.java index 472e8d33d..a35af34bb 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationOptionsTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationOptionsTests.java @@ -53,6 +53,7 @@ class AggregationOptionsTests { assertThat(aggregationOptions.isExplain()).isTrue(); assertThat(aggregationOptions.getCursor()).contains(new Document("batchSize", 1)); assertThat(aggregationOptions.getHint()).contains(dummyHint); + assertThat(aggregationOptions.getHintObject()).contains(dummyHint); } @Test // DATAMONGO-1637, DATAMONGO-2153, DATAMONGO-1836 @@ -73,6 +74,7 @@ class AggregationOptionsTests { assertThat(aggregationOptions.getCursorBatchSize()).isEqualTo(1); assertThat(aggregationOptions.getComment()).contains("hola"); assertThat(aggregationOptions.getHint()).contains(dummyHint); + assertThat(aggregationOptions.getHintObject()).contains(dummyHint); } @Test // DATAMONGO-960, DATAMONGO-2153, DATAMONGO-1836