Polishing.

Introduce HintFunction to encapsulate how hints are applied and to remove code duplications.

See #4238
Original pull request: #4243
This commit is contained in:
Mark Paluch
2023-01-11 16:02:02 +01:00
parent 4220df5bf8
commit 1839f55055
6 changed files with 133 additions and 97 deletions

View File

@@ -0,0 +1,102 @@
/*
* Copyright 2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.core;
import java.util.function.Function;
import org.bson.conversions.Bson;
import org.springframework.data.mongodb.CodecRegistryProvider;
import org.springframework.data.mongodb.util.BsonUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;
/**
* Function object to apply a query hint. Can be an index name or a BSON document.
*
* @author Mark Paluch
* @since 4.1
*/
class HintFunction {
private static final HintFunction EMPTY = new HintFunction(null);
private final @Nullable Object hint;
private HintFunction(@Nullable Object hint) {
this.hint = hint;
}
/**
* Return an empty hint function.
*
* @return
*/
static HintFunction empty() {
return EMPTY;
}
/**
* Create a {@link HintFunction} from a {@link Bson document} or {@link String index name}.
*
* @param hint
* @return
*/
static HintFunction from(@Nullable Object hint) {
return new HintFunction(hint);
}
/**
* Return whether a hint is present.
*
* @return
*/
public boolean isPresent() {
return (hint instanceof String hintString && StringUtils.hasText(hintString)) || hint instanceof Bson;
}
/**
* Apply the hint to consumers depending on the hint format.
*
* @param registryProvider
* @param stringConsumer
* @param bsonConsumer
* @return
* @param <R>
*/
public <R> R apply(@Nullable CodecRegistryProvider registryProvider, Function<String, R> stringConsumer,
Function<Bson, R> bsonConsumer) {
if (!isPresent()) {
throw new IllegalStateException("No hint present");
}
if (hint instanceof Bson bson) {
return bsonConsumer.apply(bson);
}
if (hint instanceof String hintString) {
if (BsonUtils.isJsonDocument(hintString)) {
return bsonConsumer.apply(BsonUtils.parse(hintString, registryProvider));
}
return stringConsumer.apply(hintString);
}
throw new IllegalStateException(
"Unable to read hint of type %s".formatted(hint != null ? hint.getClass() : "null"));
}
}

View File

@@ -69,16 +69,7 @@ import org.springframework.data.mongodb.core.aggregation.AggregationOptions;
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
import org.springframework.data.mongodb.core.convert.DbRefResolver;
import org.springframework.data.mongodb.core.convert.DefaultDbRefResolver;
import org.springframework.data.mongodb.core.convert.JsonSchemaMapper;
import org.springframework.data.mongodb.core.convert.MappingMongoConverter;
import org.springframework.data.mongodb.core.convert.MongoConverter;
import org.springframework.data.mongodb.core.convert.MongoCustomConversions;
import org.springframework.data.mongodb.core.convert.MongoJsonSchemaMapper;
import org.springframework.data.mongodb.core.convert.MongoWriter;
import org.springframework.data.mongodb.core.convert.QueryMapper;
import org.springframework.data.mongodb.core.convert.UpdateMapper;
import org.springframework.data.mongodb.core.convert.*;
import org.springframework.data.mongodb.core.index.IndexOperations;
import org.springframework.data.mongodb.core.index.IndexOperationsProvider;
import org.springframework.data.mongodb.core.index.MongoMappingEventPublisher;
@@ -99,7 +90,6 @@ import org.springframework.data.mongodb.core.query.UpdateDefinition;
import org.springframework.data.mongodb.core.query.UpdateDefinition.ArrayFilter;
import org.springframework.data.mongodb.core.timeseries.Granularity;
import org.springframework.data.mongodb.core.validation.Validator;
import org.springframework.data.mongodb.util.BsonUtils;
import org.springframework.data.projection.EntityProjection;
import org.springframework.data.util.CloseableIterator;
import org.springframework.data.util.Optionals;
@@ -116,16 +106,7 @@ import com.mongodb.ClientSessionOptions;
import com.mongodb.MongoException;
import com.mongodb.ReadPreference;
import com.mongodb.WriteConcern;
import com.mongodb.client.AggregateIterable;
import com.mongodb.client.ClientSession;
import com.mongodb.client.DistinctIterable;
import com.mongodb.client.FindIterable;
import com.mongodb.client.MapReduceIterable;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.MongoCursor;
import com.mongodb.client.MongoDatabase;
import com.mongodb.client.MongoIterable;
import com.mongodb.client.*;
import com.mongodb.client.model.*;
import com.mongodb.client.result.DeleteResult;
import com.mongodb.client.result.UpdateResult;
@@ -2067,15 +2048,9 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware,
}
options.getComment().ifPresent(aggregateIterable::comment);
if (options.getHintObject().isPresent()) {
Object hintObject = options.getHintObject().get();
if (hintObject instanceof String hintString) {
aggregateIterable = aggregateIterable.hintString(hintString);
} else if (hintObject instanceof Document hintDocument) {
aggregateIterable = aggregateIterable.hint(hintDocument);
} else {
throw new IllegalStateException("Unable to read hint of type %s".formatted(hintObject.getClass()));
}
HintFunction hintFunction = options.getHintObject().map(HintFunction::from).orElseGet(HintFunction::empty);
if (hintFunction.isPresent()) {
aggregateIterable = hintFunction.apply(mongoDbFactory, aggregateIterable::hintString, aggregateIterable::hint);
}
if (options.hasExecutionTimeLimit()) {
@@ -2135,15 +2110,9 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware,
}
options.getComment().ifPresent(cursor::comment);
HintFunction hintFunction = options.getHintObject().map(HintFunction::from).orElseGet(HintFunction::empty);
if (options.getHintObject().isPresent()) {
Object hintObject = options.getHintObject().get();
if (hintObject instanceof String hintString) {
cursor = cursor.hintString(hintString);
} else if (hintObject instanceof Document hintDocument) {
cursor = cursor.hint(hintDocument);
} else {
throw new IllegalStateException("Unable to read hint of type %s".formatted(hintObject.getClass()));
}
cursor = hintFunction.apply(mongoDbFactory, cursor::hintString, cursor::hint);
}
Class<?> domainType = aggregation instanceof TypedAggregation ? ((TypedAggregation) aggregation).getInputType()
@@ -3172,8 +3141,9 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware,
.ifPresent(cursorToUse::collation);
Meta meta = query.getMeta();
HintFunction hintFunction = HintFunction.from(query.getHint());
if (query.getSkip() <= 0 && query.getLimit() <= 0 && ObjectUtils.isEmpty(query.getSortObject())
&& !StringUtils.hasText(query.getHint()) && !meta.hasValues() && !query.getCollation().isPresent()) {
&& !hintFunction.isPresent() && !meta.hasValues() && !query.getCollation().isPresent()) {
return cursorToUse;
}
@@ -3189,15 +3159,8 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware,
cursorToUse = cursorToUse.sort(sort);
}
if (StringUtils.hasText(query.getHint())) {
String hint = query.getHint();
if (BsonUtils.isJsonDocument(hint)) {
cursorToUse = cursorToUse.hint(BsonUtils.parse(hint, mongoDbFactory));
} else {
cursorToUse = cursorToUse.hintString(hint);
}
if (hintFunction.isPresent()) {
cursorToUse = hintFunction.apply(mongoDbFactory, cursorToUse::hintString, cursorToUse::hint);
}
if (meta.hasValues()) {

View File

@@ -60,7 +60,6 @@ import org.springframework.data.projection.EntityProjection;
import org.springframework.data.util.Lazy;
import org.springframework.lang.Nullable;
import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils;
import com.mongodb.client.model.CountOptions;
import com.mongodb.client.model.DeleteOptions;
@@ -567,14 +566,11 @@ class QueryOperations {
if (query.getSkip() > 0) {
options.skip((int) query.getSkip());
}
if (StringUtils.hasText(query.getHint())) {
String hint = query.getHint();
if (BsonUtils.isJsonDocument(hint)) {
options.hint(BsonUtils.parse(hint, codecRegistryProvider));
} else {
options.hintString(hint);
}
HintFunction hintFunction = HintFunction.from(query.getHint());
if (hintFunction.isPresent()) {
options = hintFunction.apply(codecRegistryProvider, options::hintString, options::hint);
}
if (callback != null) {

View File

@@ -110,7 +110,6 @@ import org.springframework.data.mongodb.core.query.NearQuery;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.UpdateDefinition;
import org.springframework.data.mongodb.core.query.UpdateDefinition.ArrayFilter;
import org.springframework.data.mongodb.util.BsonUtils;
import org.springframework.data.projection.EntityProjection;
import org.springframework.data.util.Optionals;
import org.springframework.lang.Nullable;
@@ -938,15 +937,10 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati
}
options.getComment().ifPresent(cursor::comment);
if (options.getHintObject().isPresent()) {
Object hintObject = options.getHintObject().get();
if (hintObject instanceof String hintString) {
cursor = cursor.hintString(hintString);
} else if (hintObject instanceof Document hintDocument) {
cursor = cursor.hint(hintDocument);
} else {
throw new IllegalStateException("Unable to read hint of type %s".formatted(hintObject.getClass()));
}
HintFunction hintFunction = options.getHintObject().map(HintFunction::from).orElseGet(HintFunction::empty);
if (hintFunction.isPresent()) {
cursor = hintFunction.apply(mongoDatabaseFactory, cursor::hintString, cursor::hint);
}
Optionals.firstNonEmpty(options::getCollation, () -> operations.forType(inputType).getCollation()) //
@@ -1535,7 +1529,8 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati
Publisher<?> publisher;
if (!mapped.hasId()) {
publisher = collectionToUse.insertOne(queryOperations.createInsertContext(mapped).prepareId(entityClass).getDocument());
publisher = collectionToUse
.insertOne(queryOperations.createInsertContext(mapped).prepareId(entityClass).getDocument());
} else {
MongoPersistentEntity<?> entity = mappingContext.getPersistentEntity(entityClass);
@@ -3044,9 +3039,10 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati
.map(findPublisher::collation) //
.orElse(findPublisher);
HintFunction hintFunction = HintFunction.from(query.getHint());
Meta meta = query.getMeta();
if (query.getSkip() <= 0 && query.getLimit() <= 0 && ObjectUtils.isEmpty(query.getSortObject())
&& !StringUtils.hasText(query.getHint()) && !meta.hasValues()) {
&& !hintFunction.isPresent() && !meta.hasValues()) {
return findPublisherToUse;
}
@@ -3065,15 +3061,9 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati
findPublisherToUse = findPublisherToUse.sort(sort);
}
if (StringUtils.hasText(query.getHint())) {
String hint = query.getHint();
if (BsonUtils.isJsonDocument(hint)) {
findPublisherToUse = findPublisherToUse.hint(BsonUtils.parse(hint, mongoDatabaseFactory));
} else {
findPublisherToUse = findPublisherToUse.hintString(hint);
}
if (hintFunction.isPresent()) {
findPublisherToUse = hintFunction.apply(mongoDatabaseFactory, findPublisherToUse::hintString,
findPublisherToUse::hint);
}
if (meta.hasValues()) {

View File

@@ -237,10 +237,11 @@ public class AggregationOptions {
}
/**
* Get the hint used to to fulfill the aggregation.
* Get the hint used to fulfill the aggregation.
*
* @return never {@literal null}.
* @since 3.1
* @deprecated since 4.1, use {@link #getHintObject()} instead.
*/
public Optional<Document> getHint() {
return hint.map(it -> {
@@ -257,25 +258,7 @@ public class AggregationOptions {
}
/**
* Get the hint (indexName) used to to fulfill the aggregation.
*
* @return never {@literal null}.
* @since 4.1
*/
public Optional<String> getHintAsString() {
return hint.map(it -> {
if (it instanceof String hintString) {
return hintString;
}
if (it instanceof Document doc) {
return BsonUtils.toJson(doc);
}
throw new IllegalStateException("Unable to read hint of type %s".formatted(it.getClass()));
});
}
/**
* Get the hint used to to fulfill the aggregation.
* Get the hint used to fulfill the aggregation.
*
* @return never {@literal null}.
* @since 4.1

View File

@@ -53,6 +53,7 @@ class AggregationOptionsTests {
assertThat(aggregationOptions.isExplain()).isTrue();
assertThat(aggregationOptions.getCursor()).contains(new Document("batchSize", 1));
assertThat(aggregationOptions.getHint()).contains(dummyHint);
assertThat(aggregationOptions.getHintObject()).contains(dummyHint);
}
@Test // DATAMONGO-1637, DATAMONGO-2153, DATAMONGO-1836
@@ -73,6 +74,7 @@ class AggregationOptionsTests {
assertThat(aggregationOptions.getCursorBatchSize()).isEqualTo(1);
assertThat(aggregationOptions.getComment()).contains("hola");
assertThat(aggregationOptions.getHint()).contains(dummyHint);
assertThat(aggregationOptions.getHintObject()).contains(dummyHint);
}
@Test // DATAMONGO-960, DATAMONGO-2153, DATAMONGO-1836