Compare commits

..

4 Commits

Author SHA1 Message Date
Mark Paluch
7103ca2228 Polishing.
Reformatting, trailing whitespaces.
2023-07-10 10:15:30 +02:00
Christoph Strobl
b46aab2c28 Map collection and fields for $graphLookup aggregation against type.
This commit enables using a type parameter to define the from collection of a graphLookup aggregation stage. In doing so we can derive the target collection name from the type and use the given information to also map the from field against the domain object to so that the user is able to operate on property names instead of the target db field name.
2023-07-10 09:47:06 +02:00
Christoph Strobl
ff137eca8a Map collection and fields for $lookup aggregation against type.
This commit enables using a type parameter to define the from collection of a lookup aggregation stage. In doing so we can derive the target collection name from the type and use the given information to also map the from field against the domain object to so that the user is able to operate on property names instead of the target db field name.
2023-07-10 09:47:06 +02:00
Christoph Strobl
a93c870b45 Prepare issue branch. 2023-07-10 09:47:06 +02:00
24 changed files with 1102 additions and 1314 deletions

View File

@@ -5,7 +5,7 @@
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.2.x-4393-SNAPSHOT</version>
<version>4.2.x-4379-SNAPSHOT</version>
<packaging>pom</packaging>
<name>Spring Data MongoDB</name>
@@ -27,7 +27,7 @@
<project.type>multi</project.type>
<dist.id>spring-data-mongodb</dist.id>
<springdata.commons>3.2.0-SNAPSHOT</springdata.commons>
<mongo>4.10.2</mongo>
<mongo>4.9.1</mongo>
<mongo.reactivestreams>${mongo}</mongo.reactivestreams>
<jmh.version>1.19</jmh.version>
</properties>

View File

@@ -7,7 +7,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.2.x-4393-SNAPSHOT</version>
<version>4.2.x-4379-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -15,7 +15,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.2.x-4393-SNAPSHOT</version>
<version>4.2.x-4379-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -13,7 +13,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.2.x-4393-SNAPSHOT</version>
<version>4.2.x-4379-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
@@ -104,12 +104,6 @@
<version>${mongo}</version>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.mongodb</groupId>
<artifactId>mongodb-driver-kotlin-sync</artifactId>
<version>${mongo}</version>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.mongodb</groupId>
@@ -117,17 +111,11 @@
<version>${mongo.reactivestreams}</version>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.mongodb</groupId>
<artifactId>mongodb-driver-kotlin-coroutine</artifactId>
<version>${mongo}</version>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.mongodb</groupId>
<artifactId>mongodb-crypt</artifactId>
<version>1.8.0</version>
<version>1.6.1</version>
<optional>true</optional>
</dependency>

View File

@@ -23,6 +23,7 @@ import org.bson.Document;
import org.bson.codecs.configuration.CodecRegistry;
import org.springframework.beans.BeanUtils;
import org.springframework.data.mongodb.CodecRegistryProvider;
import org.springframework.data.mongodb.MongoCollectionUtils;
import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
@@ -79,7 +80,30 @@ public interface AggregationOperationContext extends CodecRegistryProvider {
FieldReference getReference(String name);
/**
* Returns the {@link Fields} exposed by the type. May be a {@literal class} or an {@literal interface}. The default
* Obtain the target field name for a given field/type combination.
*
* @param type The type containing the field.
* @param field The property/field name
* @return never {@literal null}.
* @since 4.2
*/
default String getMappedFieldName(Class<?> type, String field) {
return field;
}
/**
* Obtain the collection name for a given {@link Class type} combination.
*
* @param type
* @return never {@literal null}.
* @since 4.2
*/
default String getCollection(Class<?> type) {
return MongoCollectionUtils.getPreferredCollectionName(type);
}
/**
* Returns the {@link Fields} exposed by the type. Can be a {@literal class} or an {@literal interface}. The default
* implementation uses {@link BeanUtils#getPropertyDescriptors(Class) property descriptors} discover fields from a
* {@link Class}.
*
@@ -109,7 +133,7 @@ public interface AggregationOperationContext extends CodecRegistryProvider {
/**
* This toggle allows the {@link AggregationOperationContext context} to use any given field name without checking for
* its existence. Typically the {@link AggregationOperationContext} fails when referencing unknown fields, those that
* its existence. Typically, the {@link AggregationOperationContext} fails when referencing unknown fields, those that
* are not present in one of the previous stages or the input source, throughout the pipeline.
*
* @return a more relaxed {@link AggregationOperationContext}.

View File

@@ -46,7 +46,7 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
private static final Set<Class<?>> ALLOWED_START_TYPES = new HashSet<Class<?>>(
Arrays.<Class<?>> asList(AggregationExpression.class, String.class, Field.class, Document.class));
private final String from;
private final Object from;
private final List<Object> startWith;
private final Field connectFrom;
private final Field connectTo;
@@ -55,7 +55,7 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
private final @Nullable Field depthField;
private final @Nullable CriteriaDefinition restrictSearchWithMatch;
private GraphLookupOperation(String from, List<Object> startWith, Field connectFrom, Field connectTo, Field as,
private GraphLookupOperation(Object from, List<Object> startWith, Field connectFrom, Field connectTo, Field as,
@Nullable Long maxDepth, @Nullable Field depthField, @Nullable CriteriaDefinition restrictSearchWithMatch) {
this.from = from;
@@ -82,7 +82,7 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
Document graphLookup = new Document();
graphLookup.put("from", from);
graphLookup.put("from", getCollectionName(context));
List<Object> mappedStartWith = new ArrayList<>(startWith.size());
@@ -99,7 +99,7 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
graphLookup.put("startWith", mappedStartWith.size() == 1 ? mappedStartWith.iterator().next() : mappedStartWith);
graphLookup.put("connectFromField", connectFrom.getTarget());
graphLookup.put("connectFromField", getForeignFieldName(context));
graphLookup.put("connectToField", connectTo.getTarget());
graphLookup.put("as", as.getName());
@@ -118,6 +118,16 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
return new Document(getOperator(), graphLookup);
}
String getCollectionName(AggregationOperationContext context) {
return from instanceof Class<?> type ? context.getCollection(type) : from.toString();
}
String getForeignFieldName(AggregationOperationContext context) {
return from instanceof Class<?> type ? context.getMappedFieldName(type, connectFrom.getTarget())
: connectFrom.getTarget();
}
@Override
public String getOperator() {
return "$graphLookup";
@@ -128,7 +138,7 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
List<ExposedField> fields = new ArrayList<>(2);
fields.add(new ExposedField(as, true));
if(depthField != null) {
if (depthField != null) {
fields.add(new ExposedField(depthField, true));
}
return ExposedFields.from(fields.toArray(new ExposedField[0]));
@@ -146,6 +156,17 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
* @return never {@literal null}.
*/
StartWithBuilder from(String collectionName);
/**
* Use the given type to determine name of the foreign collection and map
* {@link ConnectFromBuilder#connectFrom(String)} against it to consider eventually present
* {@link org.springframework.data.mongodb.core.mapping.Field} annotations.
*
* @param type must not be {@literal null}.
* @return never {@literal null}.
* @since 4.2
*/
StartWithBuilder from(Class<?> type);
}
/**
@@ -218,7 +239,7 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
static final class GraphLookupOperationFromBuilder
implements FromBuilder, StartWithBuilder, ConnectFromBuilder, ConnectToBuilder {
private @Nullable String from;
private @Nullable Object from;
private @Nullable List<? extends Object> startWith;
private @Nullable String connectFrom;
@@ -231,6 +252,14 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
return this;
}
@Override
public StartWithBuilder from(Class<?> type) {
Assert.notNull(type, "Type must not be null");
this.from = type;
return this;
}
@Override
public ConnectFromBuilder startWith(String... fieldReferences) {
@@ -321,7 +350,7 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
*/
public static final class GraphLookupOperationBuilder {
private final String from;
private final Object from;
private final List<Object> startWith;
private final Field connectFrom;
private final Field connectTo;
@@ -329,7 +358,7 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
private @Nullable Field depthField;
private @Nullable CriteriaDefinition restrictSearchWithMatch;
protected GraphLookupOperationBuilder(String from, List<? extends Object> startWith, String connectFrom,
protected GraphLookupOperationBuilder(Object from, List<? extends Object> startWith, String connectFrom,
String connectTo) {
this.from = from;

View File

@@ -39,7 +39,7 @@ import org.springframework.util.Assert;
*/
public class LookupOperation implements FieldsExposingAggregationOperation, InheritsFieldsAggregationOperation {
private final String from;
private Object from;
@Nullable //
private final Field localField;
@@ -97,6 +97,22 @@ public class LookupOperation implements FieldsExposingAggregationOperation, Inhe
*/
public LookupOperation(String from, @Nullable Field localField, @Nullable Field foreignField, @Nullable Let let,
@Nullable AggregationPipeline pipeline, Field as) {
this((Object) from, localField, foreignField, let, pipeline, as);
}
/**
* Creates a new {@link LookupOperation} for the given combination of {@link Field}s and {@link AggregationPipeline
* pipeline}.
*
* @param from must not be {@literal null}. Can be eiter the target collection name or a {@link Class}.
* @param localField can be {@literal null} if {@literal pipeline} is present.
* @param foreignField can be {@literal null} if {@literal pipeline} is present.
* @param let can be {@literal null} if {@literal localField} and {@literal foreignField} are present.
* @param as must not be {@literal null}.
* @since 4.2
*/
private LookupOperation(Object from, @Nullable Field localField, @Nullable Field foreignField, @Nullable Let let,
@Nullable AggregationPipeline pipeline, Field as) {
Assert.notNull(from, "From must not be null");
if (pipeline == null) {
@@ -125,12 +141,14 @@ public class LookupOperation implements FieldsExposingAggregationOperation, Inhe
Document lookupObject = new Document();
lookupObject.append("from", from);
lookupObject.append("from", getCollectionName(context));
if (localField != null) {
lookupObject.append("localField", localField.getTarget());
}
if (foreignField != null) {
lookupObject.append("foreignField", foreignField.getTarget());
lookupObject.append("foreignField", getForeignFieldName(context));
}
if (let != null) {
lookupObject.append("let", let.toDocument(context).get("$let", Document.class).get("vars"));
@@ -144,6 +162,16 @@ public class LookupOperation implements FieldsExposingAggregationOperation, Inhe
return new Document(getOperator(), lookupObject);
}
String getCollectionName(AggregationOperationContext context) {
return from instanceof Class<?> type ? context.getCollection(type) : from.toString();
}
String getForeignFieldName(AggregationOperationContext context) {
return from instanceof Class<?> type ? context.getMappedFieldName(type, foreignField.getTarget())
: foreignField.getTarget();
}
@Override
public String getOperator() {
return "$lookup";
@@ -158,16 +186,28 @@ public class LookupOperation implements FieldsExposingAggregationOperation, Inhe
return new LookupOperationBuilder();
}
public static interface FromBuilder {
public interface FromBuilder {
/**
* @param name the collection in the same database to perform the join with, must not be {@literal null} or empty.
* @return never {@literal null}.
*/
LocalFieldBuilder from(String name);
/**
* Use the given type to determine name of the foreign collection and map
* {@link ForeignFieldBuilder#foreignField(String)} against it to consider eventually present
* {@link org.springframework.data.mongodb.core.mapping.Field} annotations.
*
* @param type the type of the target collection in the same database to perform the join with, must not be
* {@literal null}.
* @return never {@literal null}.
* @since 4.2
*/
LocalFieldBuilder from(Class<?> type);
}
public static interface LocalFieldBuilder extends PipelineBuilder {
public interface LocalFieldBuilder extends PipelineBuilder {
/**
* @param name the field from the documents input to the {@code $lookup} stage, must not be {@literal null} or
@@ -177,7 +217,7 @@ public class LookupOperation implements FieldsExposingAggregationOperation, Inhe
ForeignFieldBuilder localField(String name);
}
public static interface ForeignFieldBuilder {
public interface ForeignFieldBuilder {
/**
* @param name the field from the documents in the {@code from} collection, must not be {@literal null} or empty.
@@ -246,7 +286,7 @@ public class LookupOperation implements FieldsExposingAggregationOperation, Inhe
LookupOperation as(String name);
}
public static interface AsBuilder extends PipelineBuilder {
public interface AsBuilder extends PipelineBuilder {
/**
* @param name the name of the new array field to add to the input documents, must not be {@literal null} or empty.
@@ -264,7 +304,7 @@ public class LookupOperation implements FieldsExposingAggregationOperation, Inhe
public static final class LookupOperationBuilder
implements FromBuilder, LocalFieldBuilder, ForeignFieldBuilder, AsBuilder {
private @Nullable String from;
private @Nullable Object from;
private @Nullable Field localField;
private @Nullable Field foreignField;
private @Nullable ExposedField as;
@@ -288,6 +328,14 @@ public class LookupOperation implements FieldsExposingAggregationOperation, Inhe
return this;
}
@Override
public LocalFieldBuilder from(Class<?> type) {
Assert.notNull(type, "'From' must not be null");
from = type;
return this;
}
@Override
public AsBuilder foreignField(String name) {

View File

@@ -30,9 +30,8 @@ import org.springframework.lang.Nullable;
/**
* {@link AggregationOperationContext} implementation prefixing non-command keys on root level with the given prefix.
* Useful when mapping fields to domain specific types while having to prefix keys for query purpose.
* <br />
* Fields to be excluded from prefixing my be added to a {@literal denylist}.
* Useful when mapping fields to domain specific types while having to prefix keys for query purpose. <br />
* Fields to be excluded from prefixing can be added to a {@literal denylist}.
*
* @author Christoph Strobl
* @author Mark Paluch

View File

@@ -92,6 +92,20 @@ public class TypeBasedAggregationOperationContext implements AggregationOperatio
return getReferenceFor(field(name));
}
@Override
public String getCollection(Class<?> type) {
MongoPersistentEntity<?> persistentEntity = mappingContext.getPersistentEntity(type);
return persistentEntity != null ? persistentEntity.getCollection() : AggregationOperationContext.super.getCollection(type);
}
@Override
public String getMappedFieldName(Class<?> type, String field) {
PersistentPropertyPath<MongoPersistentProperty> persistentPropertyPath = mappingContext.getPersistentPropertyPath(field, type);
return persistentPropertyPath.getLeafProperty().getFieldName();
}
@Override
public Fields getFields(Class<?> type) {

View File

@@ -17,7 +17,6 @@ package org.springframework.data.mongodb.core.convert.encryption;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -64,7 +63,7 @@ public class MongoEncryptionConverter implements EncryptingConverter<Object, Obj
public Object read(Object value, MongoConversionContext context) {
Object decrypted = EncryptingConverter.super.read(value, context);
return decrypted instanceof BsonValue bsonValue ? BsonUtils.toJavaType(bsonValue) : decrypted;
return decrypted instanceof BsonValue ? BsonUtils.toJavaType((BsonValue) decrypted) : decrypted;
}
@Override
@@ -88,56 +87,36 @@ public class MongoEncryptionConverter implements EncryptingConverter<Object, Obj
}
MongoPersistentProperty persistentProperty = getProperty(context);
if (getProperty(context).isCollectionLike() && decryptedValue instanceof Iterable<?> iterable) {
int size = iterable instanceof Collection<?> c ? c.size() : 10;
if (!persistentProperty.isEntity()) {
Collection<Object> collection = CollectionFactory.createCollection(persistentProperty.getType(), size);
iterable.forEach(it -> {
if (it instanceof BsonValue bsonValue) {
collection.add(BsonUtils.toJavaType(bsonValue));
} else {
collection.add(context.read(it, persistentProperty.getActualType()));
}
});
iterable.forEach(it -> collection.add(BsonUtils.toJavaType((BsonValue) it)));
return collection;
} else {
Collection<Object> collection = CollectionFactory.createCollection(persistentProperty.getType(), size);
iterable.forEach(it -> {
if (it instanceof BsonValue bsonValue) {
collection.add(context.read(BsonUtils.toJavaType(bsonValue), persistentProperty.getActualType()));
} else {
collection.add(context.read(it, persistentProperty.getActualType()));
}
collection.add(context.read(BsonUtils.toJavaType((BsonValue) it), persistentProperty.getActualType()));
});
return collection;
}
}
if (!persistentProperty.isEntity() && persistentProperty.isMap()) {
if (persistentProperty.getType() != Document.class) {
if (decryptedValue instanceof BsonValue bsonValue) {
return new LinkedHashMap<>((Document) BsonUtils.toJavaType(bsonValue));
}
if (decryptedValue instanceof Document document) {
return new LinkedHashMap<>(document);
}
if (decryptedValue instanceof Map map) {
return map;
}
if (!persistentProperty.isEntity() && decryptedValue instanceof BsonValue bsonValue) {
if (persistentProperty.isMap() && persistentProperty.getType() != Document.class) {
return new LinkedHashMap<>((Document) BsonUtils.toJavaType(bsonValue));
}
return BsonUtils.toJavaType(bsonValue);
}
if (persistentProperty.isEntity() && decryptedValue instanceof BsonDocument bsonDocument) {
return context.read(BsonUtils.toJavaType(bsonDocument), persistentProperty.getTypeInformation().getType());
}
if (persistentProperty.isEntity() && decryptedValue instanceof Document document) {
return context.read(document, persistentProperty.getTypeInformation().getType());
}
return decryptedValue;
}

View File

@@ -15,33 +15,26 @@
*/
package org.springframework.data.mongodb.util;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.StringJoiner;
import java.util.function.Function;
import java.util.stream.StreamSupport;
import org.bson.*;
import org.bson.codecs.Codec;
import org.bson.codecs.DocumentCodec;
import org.bson.codecs.EncoderContext;
import org.bson.codecs.configuration.CodecConfigurationException;
import org.bson.codecs.configuration.CodecRegistry;
import org.bson.conversions.Bson;
import org.bson.json.JsonParseException;
import org.bson.types.Binary;
import org.bson.types.Decimal128;
import org.bson.types.ObjectId;
import org.springframework.core.convert.converter.Converter;
import org.springframework.data.mongodb.CodecRegistryProvider;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;
@@ -110,7 +103,7 @@ public class BsonUtils {
return dbo.toMap();
}
return new Document(bson.toBsonDocument(Document.class, codecRegistry));
return new Document((Map) bson.toBsonDocument(Document.class, codecRegistry));
}
/**
@@ -287,22 +280,36 @@ public class BsonUtils {
*/
public static Object toJavaType(BsonValue value) {
return switch (value.getBsonType()) {
case INT32 -> value.asInt32().getValue();
case INT64 -> value.asInt64().getValue();
case STRING -> value.asString().getValue();
case DECIMAL128 -> value.asDecimal128().doubleValue();
case DOUBLE -> value.asDouble().getValue();
case BOOLEAN -> value.asBoolean().getValue();
case OBJECT_ID -> value.asObjectId().getValue();
case DB_POINTER -> new DBRef(value.asDBPointer().getNamespace(), value.asDBPointer().getId());
case BINARY -> value.asBinary().getData();
case DATE_TIME -> new Date(value.asDateTime().getValue());
case SYMBOL -> value.asSymbol().getSymbol();
case ARRAY -> value.asArray().toArray();
case DOCUMENT -> Document.parse(value.asDocument().toJson());
default -> value;
};
switch (value.getBsonType()) {
case INT32:
return value.asInt32().getValue();
case INT64:
return value.asInt64().getValue();
case STRING:
return value.asString().getValue();
case DECIMAL128:
return value.asDecimal128().doubleValue();
case DOUBLE:
return value.asDouble().getValue();
case BOOLEAN:
return value.asBoolean().getValue();
case OBJECT_ID:
return value.asObjectId().getValue();
case DB_POINTER:
return new DBRef(value.asDBPointer().getNamespace(), value.asDBPointer().getId());
case BINARY:
return value.asBinary().getData();
case DATE_TIME:
return new Date(value.asDateTime().getValue());
case SYMBOL:
return value.asSymbol().getSymbol();
case ARRAY:
return value.asArray().toArray();
case DOCUMENT:
return Document.parse(value.asDocument().toJson());
default:
return value;
}
}
/**
@@ -314,21 +321,6 @@ public class BsonUtils {
* @since 3.0
*/
public static BsonValue simpleToBsonValue(Object source) {
return simpleToBsonValue(source, MongoClientSettings.getDefaultCodecRegistry());
}
/**
* Convert a given simple value (eg. {@link String}, {@link Long}) to its corresponding {@link BsonValue}.
*
* @param source must not be {@literal null}.
* @param codecRegistry The {@link CodecRegistry} used as a fallback to convert types using native {@link Codec}. Must
* not be {@literal null}.
* @return the corresponding {@link BsonValue} representation.
* @throws IllegalArgumentException if {@literal source} does not correspond to a {@link BsonValue} type.
* @since 4.2
*/
@SuppressWarnings("unchecked")
public static BsonValue simpleToBsonValue(Object source, CodecRegistry codecRegistry) {
if (source instanceof BsonValue bsonValue) {
return bsonValue;
@@ -366,35 +358,17 @@ public class BsonUtils {
return new BsonDouble(floatValue);
}
if (source instanceof Binary binary) {
if(source instanceof Binary binary) {
return new BsonBinary(binary.getType(), binary.getData());
}
if (source instanceof Date date) {
new BsonDateTime(date.getTime());
}
try {
Object value = source;
if (ClassUtils.isPrimitiveArray(source.getClass())) {
value = CollectionUtils.arrayToList(source);
}
Codec codec = codecRegistry.get(value.getClass());
BsonCapturingWriter writer = new BsonCapturingWriter(value.getClass());
codec.encode(writer, value,
ObjectUtils.isArray(value) || value instanceof Collection<?> ? EncoderContext.builder().build() : null);
return writer.getCapturedValue();
} catch (CodecConfigurationException e) {
throw new IllegalArgumentException(
String.format("Unable to convert %s to BsonValue.", source != null ? source.getClass().getName() : "null"));
}
throw new IllegalArgumentException(String.format("Unable to convert %s (%s) to BsonValue.", source,
source != null ? source.getClass().getName() : "null"));
}
/**
* Merge the given {@link Document documents} into on in the given order. Keys contained within multiple documents are
* overwritten by their follow-ups.
* overwritten by their follow ups.
*
* @param documents must not be {@literal null}. Can be empty.
* @return the document containing all key value pairs.
@@ -695,7 +669,7 @@ public class BsonUtils {
if (value instanceof Collection<?> collection) {
return toString(collection);
} else if (value instanceof Map<?, ?> map) {
} else if (value instanceof Map<?,?> map) {
return toString(map);
} else if (ObjectUtils.isArray(value)) {
return toString(Arrays.asList(ObjectUtils.toObjectArray(value)));
@@ -717,9 +691,8 @@ public class BsonUtils {
private static String toString(Map<?, ?> source) {
// Avoid String.format for performance
return iterableToDelimitedString(source.entrySet(), "{ ", " }",
entry -> "\"" + entry.getKey() + "\" : " + toJson(entry.getValue()));
entry -> String.format("\"%s\" : %s", entry.getKey(), toJson(entry.getValue())));
}
private static String toString(Collection<?> source) {
@@ -735,160 +708,4 @@ public class BsonUtils {
return joiner.toString();
}
static class BsonCapturingWriter extends AbstractBsonWriter {
private final List<BsonValue> values = new ArrayList<>(0);
public BsonCapturingWriter(Class<?> type) {
super(new BsonWriterSettings());
if (ClassUtils.isAssignable(Map.class, type)) {
setContext(new Context(null, BsonContextType.DOCUMENT));
} else if (ClassUtils.isAssignable(List.class, type) || type.isArray()) {
setContext(new Context(null, BsonContextType.ARRAY));
} else {
setContext(new Context(null, BsonContextType.DOCUMENT));
}
}
@Nullable
BsonValue getCapturedValue() {
if (values.isEmpty()) {
return null;
}
if (!getContext().getContextType().equals(BsonContextType.ARRAY)) {
return values.get(0);
}
return new BsonArray(values);
}
@Override
protected void doWriteStartDocument() {
}
@Override
protected void doWriteEndDocument() {
}
@Override
public void writeStartArray() {
setState(State.VALUE);
}
@Override
public void writeEndArray() {
setState(State.NAME);
}
@Override
protected void doWriteStartArray() {
}
@Override
protected void doWriteEndArray() {
}
@Override
protected void doWriteBinaryData(BsonBinary value) {
values.add(value);
}
@Override
protected void doWriteBoolean(boolean value) {
values.add(BsonBoolean.valueOf(value));
}
@Override
protected void doWriteDateTime(long value) {
values.add(new BsonDateTime(value));
}
@Override
protected void doWriteDBPointer(BsonDbPointer value) {
values.add(value);
}
@Override
protected void doWriteDouble(double value) {
values.add(new BsonDouble(value));
}
@Override
protected void doWriteInt32(int value) {
values.add(new BsonInt32(value));
}
@Override
protected void doWriteInt64(long value) {
values.add(new BsonInt64(value));
}
@Override
protected void doWriteDecimal128(Decimal128 value) {
values.add(new BsonDecimal128(value));
}
@Override
protected void doWriteJavaScript(String value) {
values.add(new BsonJavaScript(value));
}
@Override
protected void doWriteJavaScriptWithScope(String value) {
throw new UnsupportedOperationException("Cannot capture JavaScriptWith");
}
@Override
protected void doWriteMaxKey() {}
@Override
protected void doWriteMinKey() {}
@Override
protected void doWriteNull() {
values.add(new BsonNull());
}
@Override
protected void doWriteObjectId(ObjectId value) {
values.add(new BsonObjectId(value));
}
@Override
protected void doWriteRegularExpression(BsonRegularExpression value) {
values.add(value);
}
@Override
protected void doWriteString(String value) {
values.add(new BsonString(value));
}
@Override
protected void doWriteSymbol(String value) {
values.add(new BsonSymbol(value));
}
@Override
protected void doWriteTimestamp(BsonTimestamp value) {
values.add(value);
}
@Override
protected void doWriteUndefined() {
values.add(new BsonUndefined());
}
@Override
public void flush() {
values.clear();
}
}
}

View File

@@ -1,32 +0,0 @@
/*
* Copyright 2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.core
import com.mongodb.kotlin.client.MongoClient
import org.springframework.beans.DirectFieldAccessor
/**
* Extension for [SimpleMongoClientDatabaseFactory] that accepts a [MongoClient].
*
* @author Christoph Strobl
* @since 4.2
*/
fun SimpleMongoClientDatabaseFactory(client: MongoClient, database: String): SimpleMongoClientDatabaseFactory =
SimpleMongoClientDatabaseFactory(
DirectFieldAccessor(client).getPropertyValue("wrapped") as com.mongodb.client.MongoClient,
database
)

View File

@@ -1,32 +0,0 @@
/*
* Copyright 2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.core
import com.mongodb.kotlin.client.coroutine.MongoClient
import org.springframework.beans.DirectFieldAccessor
/**
* Extension for [SimpleReactiveMongoDatabaseFactory] that accepts a [MongoClient].
*
* @author Christoph Strobl
* @since 4.2
*/
fun SimpleReactiveMongoDatabaseFactory(client: MongoClient, database: String): SimpleReactiveMongoDatabaseFactory =
SimpleReactiveMongoDatabaseFactory(
DirectFieldAccessor(client).getPropertyValue("wrapped") as com.mongodb.reactivestreams.client.MongoClient,
database
)

View File

@@ -0,0 +1,87 @@
/*
* Copyright 2023. the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.core.aggregation;
import org.springframework.data.mapping.context.MappingContext;
import org.springframework.data.mongodb.core.convert.MappingMongoConverter;
import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver;
import org.springframework.data.mongodb.core.convert.QueryMapper;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty;
import org.springframework.data.mongodb.test.util.MongoTestMappingContext;
/**
* @author Christoph Strobl
*/
public final class AggregationTestUtils {
public static AggregationContextBuilder<TypeBasedAggregationOperationContext> strict(Class<?> type) {
AggregationContextBuilder<AggregationOperationContext> builder = new AggregationContextBuilder<>();
builder.strict = true;
return builder.forType(type);
}
public static AggregationContextBuilder<TypeBasedAggregationOperationContext> relaxed(Class<?> type) {
AggregationContextBuilder<AggregationOperationContext> builder = new AggregationContextBuilder<>();
builder.strict = false;
return builder.forType(type);
}
public static class AggregationContextBuilder<T extends AggregationOperationContext> {
Class<?> targetType;
MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> mappingContext;
QueryMapper queryMapper;
boolean strict;
public AggregationContextBuilder<TypeBasedAggregationOperationContext> forType(Class<?> type) {
this.targetType = type;
return (AggregationContextBuilder<TypeBasedAggregationOperationContext>) this;
}
public AggregationContextBuilder<T> using(
MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> mappingContext) {
this.mappingContext = mappingContext;
return this;
}
public AggregationContextBuilder<T> using(QueryMapper queryMapper) {
this.queryMapper = queryMapper;
return this;
}
public T ctx() {
//
if (targetType == null) {
return (T) Aggregation.DEFAULT_CONTEXT;
}
MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> ctx = mappingContext != null
? mappingContext
: MongoTestMappingContext.newTestContext().init();
QueryMapper qm = queryMapper != null ? queryMapper
: new QueryMapper(new MappingMongoConverter(NoOpDbRefResolver.INSTANCE, ctx));
return (T) (strict ? new TypeBasedAggregationOperationContext(targetType, ctx, qm)
: new RelaxedTypeBasedAggregationOperationContext(targetType, ctx, qm));
}
}
}

View File

@@ -22,6 +22,7 @@ import java.util.Arrays;
import org.bson.Document;
import org.junit.jupiter.api.Test;
import org.springframework.data.mongodb.core.Person;
import org.springframework.data.mongodb.core.mapping.Field;
import org.springframework.data.mongodb.core.query.Criteria;
/**
@@ -34,7 +35,7 @@ public class GraphLookupOperationUnitTests {
@Test // DATAMONGO-1551
public void rejectsNullFromCollection() {
assertThatIllegalArgumentException().isThrownBy(() -> GraphLookupOperation.builder().from(null));
assertThatIllegalArgumentException().isThrownBy(() -> GraphLookupOperation.builder().from((String) null));
}
@Test // DATAMONGO-1551
@@ -158,4 +159,59 @@ public class GraphLookupOperationUnitTests {
assertThat(document).containsEntry("$graphLookup.depthField", "foo.bar");
}
@Test // GH-4379
void unmappedLookupWithFromExtractedFromType() {
GraphLookupOperation graphLookupOperation = GraphLookupOperation.builder() //
.from(Employee.class) //
.startWith(LiteralOperators.Literal.asLiteral("hello")) //
.connectFrom("manager") //
.connectTo("name") //
.as("reportingHierarchy");
assertThat(graphLookupOperation.toDocument(Aggregation.DEFAULT_CONTEXT)).isEqualTo("""
{ $graphLookup:
{
from: "employee",
startWith : { $literal : "hello" },
connectFromField: "manager",
connectToField: "name",
as: "reportingHierarchy"
}
}}
""");
}
@Test // GH-4379
void mappedLookupWithFromExtractedFromType() {
GraphLookupOperation graphLookupOperation = GraphLookupOperation.builder() //
.from(Employee.class) //
.startWith(LiteralOperators.Literal.asLiteral("hello")) //
.connectFrom("manager") //
.connectTo("name") //
.as("reportingHierarchy");
assertThat(graphLookupOperation.toDocument(AggregationTestUtils.strict(Employee.class).ctx())).isEqualTo("""
{ $graphLookup:
{
from: "employees",
startWith : { $literal : "hello" },
connectFromField: "reportsTo",
connectToField: "name",
as: "reportingHierarchy"
}
}}
""");
}
@org.springframework.data.mongodb.core.mapping.Document("employees")
static class Employee {
String id;
@Field("reportsTo")
String manager;
}
}

View File

@@ -25,6 +25,7 @@ import java.util.List;
import org.bson.Document;
import org.junit.jupiter.api.Test;
import org.springframework.data.mongodb.core.DocumentTestUtils;
import org.springframework.data.mongodb.core.mapping.Field;
import org.springframework.data.mongodb.core.query.Criteria;
/**
@@ -92,7 +93,7 @@ public class LookupOperationUnitTests {
@Test // DATAMONGO-1326
public void builderRejectsNullFromField() {
assertThatIllegalArgumentException().isThrownBy(() -> LookupOperation.newLookup().from(null));
assertThatIllegalArgumentException().isThrownBy(() -> LookupOperation.newLookup().from((String) null));
}
@Test // DATAMONGO-1326
@@ -195,10 +196,10 @@ public class LookupOperationUnitTests {
void buildsLookupWithLocalAndForeignFieldAsWellAsLetAndPipeline() {
LookupOperation lookupOperation = Aggregation.lookup().from("restaurants") //
.localField("restaurant_name")
.foreignField("name")
.localField("restaurant_name") //
.foreignField("name") //
.let(newVariable("orders_drink").forField("drink")) //
.pipeline(match(ctx -> new Document("$expr", new Document("$in", List.of("$$orders_drink", "$beverages")))))
.pipeline(match(ctx -> new Document("$expr", new Document("$in", List.of("$$orders_drink", "$beverages"))))) //
.as("matches");
assertThat(lookupOperation.toDocument(Aggregation.DEFAULT_CONTEXT)).isEqualTo("""
@@ -216,4 +217,54 @@ public class LookupOperationUnitTests {
}}
""");
}
@Test // GH-4379
void unmappedLookupWithFromExtractedFromType() {
LookupOperation lookupOperation = Aggregation.lookup().from(Restaurant.class) //
.localField("restaurant_name") //
.foreignField("name") //
.as("restaurants");
assertThat(lookupOperation.toDocument(Aggregation.DEFAULT_CONTEXT)).isEqualTo("""
{ $lookup:
{
from: "restaurant",
localField: "restaurant_name",
foreignField: "name",
as: "restaurants"
}
}}
""");
}
@Test // GH-4379
void mappedLookupWithFromExtractedFromType() {
LookupOperation lookupOperation = Aggregation.lookup().from(Restaurant.class) //
.localField("restaurant_name") //
.foreignField("name") //
.as("restaurants");
assertThat(lookupOperation.toDocument(AggregationTestUtils.strict(Restaurant.class).ctx())).isEqualTo("""
{ $lookup:
{
from: "sites",
localField: "restaurant_name",
foreignField: "rs_name",
as: "restaurants"
}
}}
""");
}
@org.springframework.data.mongodb.core.mapping.Document("sites")
static class Restaurant {
String id;
@Field("rs_name") //
String name;
}
}

View File

@@ -25,6 +25,7 @@ import java.math.BigInteger;
import java.net.URL;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.time.temporal.ChronoUnit;
import java.util.*;
@@ -105,7 +106,6 @@ import com.mongodb.DBRef;
* @author Mark Paluch
* @author Roman Puchkovskiy
* @author Heesu Jung
* @author Julia Lee
*/
@ExtendWith(MockitoExtension.class)
class MappingMongoConverterUnitTests {
@@ -2619,7 +2619,7 @@ class MappingMongoConverterUnitTests {
void projectShouldReadSimpleInterfaceProjection() {
org.bson.Document source = new org.bson.Document("birthDate",
Date.from(LocalDate.of(1999, 12, 1).atStartOfDay(systemDefault()).toInstant())).append("foo", "Walter");
Date.from(LocalDate.of(1999, 12, 1).atStartOfDay().toInstant(ZoneOffset.UTC))).append("foo", "Walter");
EntityProjectionIntrospector discoverer = EntityProjectionIntrospector.create(converter.getProjectionFactory(),
EntityProjectionIntrospector.ProjectionPredicate.typeHierarchy()
@@ -2637,7 +2637,7 @@ class MappingMongoConverterUnitTests {
void projectShouldReadSimpleDtoProjection() {
org.bson.Document source = new org.bson.Document("birthDate",
Date.from(LocalDate.of(1999, 12, 1).atStartOfDay(systemDefault()).toInstant())).append("foo", "Walter");
Date.from(LocalDate.of(1999, 12, 1).atStartOfDay().toInstant(ZoneOffset.UTC))).append("foo", "Walter");
EntityProjectionIntrospector introspector = EntityProjectionIntrospector.create(converter.getProjectionFactory(),
EntityProjectionIntrospector.ProjectionPredicate.typeHierarchy()

View File

@@ -1,756 +0,0 @@
/*
* Copyright 2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.core.encryption;
import static org.assertj.core.api.Assertions.*;
import static org.springframework.data.mongodb.core.EncryptionAlgorithms.*;
import static org.springframework.data.mongodb.core.aggregation.Aggregation.*;
import static org.springframework.data.mongodb.core.query.Criteria.*;
import java.security.SecureRandom;
import java.time.LocalDate;
import java.time.Month;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import org.assertj.core.api.Assertions;
import org.bson.BsonBinary;
import org.bson.Document;
import org.bson.types.Binary;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.dao.PermissionDeniedDataAccessException;
import org.springframework.data.convert.PropertyValueConverterFactory;
import org.springframework.data.mongodb.config.AbstractMongoClientConfiguration;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
import org.springframework.data.mongodb.core.convert.MongoCustomConversions.MongoConverterConfigurationAdapter;
import org.springframework.data.mongodb.core.convert.encryption.MongoEncryptionConverter;
import org.springframework.data.mongodb.core.mapping.ExplicitEncrypted;
import org.springframework.data.mongodb.core.query.Update;
import org.springframework.data.util.Lazy;
import com.mongodb.ClientEncryptionSettings;
import com.mongodb.ConnectionString;
import com.mongodb.MongoClientSettings;
import com.mongodb.MongoNamespace;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoClients;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.model.Filters;
import com.mongodb.client.model.IndexOptions;
import com.mongodb.client.model.Indexes;
import com.mongodb.client.model.vault.DataKeyOptions;
import com.mongodb.client.vault.ClientEncryption;
import com.mongodb.client.vault.ClientEncryptions;
/**
* @author Christoph Strobl
* @author Julia Lee
*/
public abstract class AbstractEncryptionTestBase {
@Autowired MongoTemplate template;
@Test // GH-4284
void encryptAndDecryptSimpleValue() {
Person source = new Person();
source.id = "id-1";
source.ssn = "mySecretSSN";
template.save(source);
verifyThat(source) //
.identifiedBy(Person::getId) //
.wasSavedMatching(it -> assertThat(it.get("ssn")).isInstanceOf(Binary.class)) //
.loadedIsEqualToSource();
}
@Test // GH-4432
void encryptAndDecryptJavaTime() {
Person source = new Person();
source.id = "id-1";
source.today = LocalDate.of(1979, Month.SEPTEMBER, 18);
template.save(source);
verifyThat(source) //
.identifiedBy(Person::getId) //
.wasSavedMatching(it -> assertThat(it.get("today")).isInstanceOf(Binary.class)) //
.loadedIsEqualToSource();
}
@Test // GH-4284
void encryptAndDecryptComplexValue() {
Person source = new Person();
source.id = "id-1";
source.address = new Address();
source.address.city = "NYC";
source.address.street = "4th Ave.";
template.save(source);
verifyThat(source) //
.identifiedBy(Person::getId) //
.wasSavedMatching(it -> assertThat(it.get("address")).isInstanceOf(Binary.class)) //
.loadedIsEqualToSource();
}
@Test // GH-4284
void encryptAndDecryptValueWithinComplexOne() {
Person source = new Person();
source.id = "id-1";
source.encryptedZip = new AddressWithEncryptedZip();
source.encryptedZip.city = "Boston";
source.encryptedZip.street = "central square";
source.encryptedZip.zip = "1234567890";
template.save(source);
verifyThat(source) //
.identifiedBy(Person::getId) //
.wasSavedMatching(it -> {
assertThat(it.get("encryptedZip")).isInstanceOf(Document.class);
assertThat(it.get("encryptedZip", Document.class).get("city")).isInstanceOf(String.class);
assertThat(it.get("encryptedZip", Document.class).get("street")).isInstanceOf(String.class);
assertThat(it.get("encryptedZip", Document.class).get("zip")).isInstanceOf(Binary.class);
}) //
.loadedIsEqualToSource();
}
@Test // GH-4284
void encryptAndDecryptListOfSimpleValue() {
Person source = new Person();
source.id = "id-1";
source.listOfString = Arrays.asList("spring", "data", "mongodb");
template.save(source);
verifyThat(source) //
.identifiedBy(Person::getId) //
.wasSavedMatching(it -> assertThat(it.get("listOfString")).isInstanceOf(Binary.class)) //
.loadedIsEqualToSource();
}
@Test // GH-4284
void encryptAndDecryptListOfComplexValue() {
Person source = new Person();
source.id = "id-1";
Address address = new Address();
address.city = "SFO";
address.street = "---";
source.listOfComplex = Collections.singletonList(address);
template.save(source);
verifyThat(source) //
.identifiedBy(Person::getId) //
.wasSavedMatching(it -> assertThat(it.get("listOfComplex")).isInstanceOf(Binary.class)) //
.loadedIsEqualToSource();
}
@Test // GH-4284
void encryptAndDecryptMapOfSimpleValues() {
Person source = new Person();
source.id = "id-1";
source.mapOfString = Map.of("k1", "v1", "k2", "v2");
template.save(source);
verifyThat(source) //
.identifiedBy(Person::getId) //
.wasSavedMatching(it -> assertThat(it.get("mapOfString")).isInstanceOf(Binary.class)) //
.loadedIsEqualToSource();
}
@Test // GH-4284
void encryptAndDecryptMapOfComplexValues() {
Person source = new Person();
source.id = "id-1";
Address address1 = new Address();
address1.city = "SFO";
address1.street = "---";
Address address2 = new Address();
address2.city = "NYC";
address2.street = "---";
source.mapOfComplex = Map.of("a1", address1, "a2", address2);
template.save(source);
verifyThat(source) //
.identifiedBy(Person::getId) //
.wasSavedMatching(it -> assertThat(it.get("mapOfComplex")).isInstanceOf(Binary.class)) //
.loadedIsEqualToSource();
}
@Test // GH-4284
void canQueryDeterministicallyEncrypted() {
Person source = new Person();
source.id = "id-1";
source.ssn = "mySecretSSN";
template.save(source);
Person loaded = template.query(Person.class).matching(where("ssn").is(source.ssn)).firstValue();
assertThat(loaded).isEqualTo(source);
}
@Test // GH-4284
void cannotQueryRandomlyEncrypted() {
Person source = new Person();
source.id = "id-1";
source.wallet = "secret-wallet-id";
template.save(source);
Person loaded = template.query(Person.class).matching(where("wallet").is(source.wallet)).firstValue();
assertThat(loaded).isNull();
}
@Test // GH-4284
void updateSimpleTypeEncryptedFieldWithNewValue() {
Person source = new Person();
source.id = "id-1";
template.save(source);
template.update(Person.class).matching(where("id").is(source.id)).apply(Update.update("ssn", "secret-value"))
.first();
verifyThat(source) //
.identifiedBy(Person::getId) //
.wasSavedMatching(it -> assertThat(it.get("ssn")).isInstanceOf(Binary.class)) //
.loadedMatches(it -> assertThat(it.getSsn()).isEqualTo("secret-value"));
}
@Test // GH-4284
void updateComplexTypeEncryptedFieldWithNewValue() {
Person source = new Person();
source.id = "id-1";
template.save(source);
Address address = new Address();
address.city = "SFO";
address.street = "---";
template.update(Person.class).matching(where("id").is(source.id)).apply(Update.update("address", address)).first();
verifyThat(source) //
.identifiedBy(Person::getId) //
.wasSavedMatching(it -> assertThat(it.get("address")).isInstanceOf(Binary.class)) //
.loadedMatches(it -> assertThat(it.getAddress()).isEqualTo(address));
}
@Test // GH-4284
void updateEncryptedFieldInNestedElementWithNewValue() {
Person source = new Person();
source.id = "id-1";
source.encryptedZip = new AddressWithEncryptedZip();
source.encryptedZip.city = "Boston";
source.encryptedZip.street = "central square";
template.save(source);
template.update(Person.class).matching(where("id").is(source.id)).apply(Update.update("encryptedZip.zip", "179"))
.first();
verifyThat(source) //
.identifiedBy(Person::getId) //
.wasSavedMatching(it -> {
assertThat(it.get("encryptedZip")).isInstanceOf(Document.class);
assertThat(it.get("encryptedZip", Document.class).get("city")).isInstanceOf(String.class);
assertThat(it.get("encryptedZip", Document.class).get("street")).isInstanceOf(String.class);
assertThat(it.get("encryptedZip", Document.class).get("zip")).isInstanceOf(Binary.class);
}) //
.loadedMatches(it -> assertThat(it.getEncryptedZip().getZip()).isEqualTo("179"));
}
@Test
void aggregationWithMatch() {
Person person = new Person();
person.id = "id-1";
person.name = "p1-name";
person.ssn = "mySecretSSN";
template.save(person);
AggregationResults<Person> aggregationResults = template.aggregateAndReturn(Person.class)
.by(newAggregation(Person.class, Aggregation.match(where("ssn").is(person.ssn)))).all();
assertThat(aggregationResults.getMappedResults()).containsExactly(person);
}
@Test
void altKeyDetection(@Autowired CachingMongoClientEncryption mongoClientEncryption) throws InterruptedException {
BsonBinary user1key = mongoClientEncryption.getClientEncryption().createDataKey("local",
new DataKeyOptions().keyAltNames(Collections.singletonList("user-1")));
BsonBinary user2key = mongoClientEncryption.getClientEncryption().createDataKey("local",
new DataKeyOptions().keyAltNames(Collections.singletonList("user-2")));
Person p1 = new Person();
p1.id = "id-1";
p1.name = "user-1";
p1.ssn = "ssn";
p1.viaAltKeyNameField = "value-1";
Person p2 = new Person();
p2.id = "id-2";
p2.name = "user-2";
p2.viaAltKeyNameField = "value-1";
Person p3 = new Person();
p3.id = "id-3";
p3.name = "user-1";
p3.viaAltKeyNameField = "value-1";
template.save(p1);
template.save(p2);
template.save(p3);
template.execute(Person.class, collection -> {
collection.find(new Document()).forEach(it -> System.out.println(it.toJson()));
return null;
});
// remove the key and invalidate encrypted data
mongoClientEncryption.getClientEncryption().deleteKey(user2key);
// clear the 60 second key cache within the mongo client
mongoClientEncryption.destroy();
assertThat(template.query(Person.class).matching(where("id").is(p1.id)).firstValue()).isEqualTo(p1);
assertThatExceptionOfType(PermissionDeniedDataAccessException.class)
.isThrownBy(() -> template.query(Person.class).matching(where("id").is(p2.id)).firstValue());
}
<T> SaveAndLoadAssert<T> verifyThat(T source) {
return new SaveAndLoadAssert<>(source);
}
class SaveAndLoadAssert<T> {
T source;
Function<T, ?> idProvider;
SaveAndLoadAssert(T source) {
this.source = source;
}
SaveAndLoadAssert<T> identifiedBy(Function<T, ?> idProvider) {
this.idProvider = idProvider;
return this;
}
SaveAndLoadAssert<T> wasSavedAs(Document expected) {
return wasSavedMatching(it -> Assertions.assertThat(it).isEqualTo(expected));
}
SaveAndLoadAssert<T> wasSavedMatching(Consumer<Document> saved) {
AbstractEncryptionTestBase.this.assertSaved(source, idProvider, saved);
return this;
}
SaveAndLoadAssert<T> loadedMatches(Consumer<T> expected) {
AbstractEncryptionTestBase.this.assertLoaded(source, idProvider, expected);
return this;
}
SaveAndLoadAssert<T> loadedIsEqualToSource() {
return loadedIsEqualTo(source);
}
SaveAndLoadAssert<T> loadedIsEqualTo(T expected) {
return loadedMatches(it -> Assertions.assertThat(it).isEqualTo(expected));
}
}
<T> void assertSaved(T source, Function<T, ?> idProvider, Consumer<Document> dbValue) {
Document savedDocument = template.execute(Person.class, collection -> {
MongoNamespace namespace = collection.getNamespace();
try (MongoClient rawClient = MongoClients.create()) {
return rawClient.getDatabase(namespace.getDatabaseName()).getCollection(namespace.getCollectionName())
.find(new Document("_id", idProvider.apply(source))).first();
}
});
dbValue.accept(savedDocument);
}
<T> void assertLoaded(T source, Function<T, ?> idProvider, Consumer<T> loadedValue) {
T loaded = template.query((Class<T>) source.getClass()).matching(where("id").is(idProvider.apply(source)))
.firstValue();
loadedValue.accept(loaded);
}
protected static class EncryptionConfig extends AbstractMongoClientConfiguration {
@Autowired ApplicationContext applicationContext;
@Override
protected String getDatabaseName() {
return "fle-test";
}
@Bean
public MongoClient mongoClient() {
return super.mongoClient();
}
@Override
protected void configureConverters(MongoConverterConfigurationAdapter converterConfigurationAdapter) {
converterConfigurationAdapter
.registerPropertyValueConverterFactory(PropertyValueConverterFactory.beanFactoryAware(applicationContext))
.useNativeDriverJavaTimeCodecs();
}
@Bean
MongoEncryptionConverter encryptingConverter(MongoClientEncryption mongoClientEncryption) {
Lazy<BsonBinary> dataKey = Lazy.of(() -> mongoClientEncryption.getClientEncryption().createDataKey("local",
new DataKeyOptions().keyAltNames(Collections.singletonList("mySuperSecretKey"))));
return new MongoEncryptionConverter(mongoClientEncryption,
EncryptionKeyResolver.annotated((ctx) -> EncryptionKey.keyId(dataKey.get())));
}
@Bean
CachingMongoClientEncryption clientEncryption(ClientEncryptionSettings encryptionSettings) {
return new CachingMongoClientEncryption(() -> ClientEncryptions.create(encryptionSettings));
}
@Bean
ClientEncryptionSettings encryptionSettings(MongoClient mongoClient) {
MongoNamespace keyVaultNamespace = new MongoNamespace("encryption.testKeyVault");
MongoCollection<Document> keyVaultCollection = mongoClient.getDatabase(keyVaultNamespace.getDatabaseName())
.getCollection(keyVaultNamespace.getCollectionName());
keyVaultCollection.drop();
// Ensure that two data keys cannot share the same keyAltName.
keyVaultCollection.createIndex(Indexes.ascending("keyAltNames"),
new IndexOptions().unique(true).partialFilterExpression(Filters.exists("keyAltNames")));
MongoCollection<Document> collection = mongoClient.getDatabase(getDatabaseName()).getCollection("test");
collection.drop(); // Clear old data
byte[] localMasterKey = new byte[96];
new SecureRandom().nextBytes(localMasterKey);
Map<String, Map<String, Object>> kmsProviders = Map.of("local", Map.of("key", localMasterKey));
// Create the ClientEncryption instance
return ClientEncryptionSettings.builder() //
.keyVaultMongoClientSettings(
MongoClientSettings.builder().applyConnectionString(new ConnectionString("mongodb://localhost")).build()) //
.keyVaultNamespace(keyVaultNamespace.getFullName()) //
.kmsProviders(kmsProviders) //
.build();
}
}
static class CachingMongoClientEncryption extends MongoClientEncryption implements DisposableBean {
static final AtomicReference<ClientEncryption> cache = new AtomicReference<>();
CachingMongoClientEncryption(Supplier<ClientEncryption> source) {
super(() -> {
if (cache.get() != null) {
return cache.get();
}
ClientEncryption clientEncryption = source.get();
cache.set(clientEncryption);
return clientEncryption;
});
}
@Override
public void destroy() {
ClientEncryption clientEncryption = cache.get();
if (clientEncryption != null) {
clientEncryption.close();
cache.set(null);
}
}
}
@org.springframework.data.mongodb.core.mapping.Document("test")
static class Person {
String id;
String name;
@ExplicitEncrypted(algorithm = AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic) //
String ssn;
@ExplicitEncrypted(algorithm = AEAD_AES_256_CBC_HMAC_SHA_512_Random, keyAltName = "mySuperSecretKey") //
String wallet;
@ExplicitEncrypted(algorithm = AEAD_AES_256_CBC_HMAC_SHA_512_Random) // full document must be random
Address address;
AddressWithEncryptedZip encryptedZip;
@ExplicitEncrypted(algorithm = AEAD_AES_256_CBC_HMAC_SHA_512_Random) // lists must be random
List<String> listOfString;
@ExplicitEncrypted(algorithm = AEAD_AES_256_CBC_HMAC_SHA_512_Random) // lists must be random
List<Address> listOfComplex;
@ExplicitEncrypted(algorithm = AEAD_AES_256_CBC_HMAC_SHA_512_Random, keyAltName = "/name") //
String viaAltKeyNameField;
@ExplicitEncrypted(algorithm = AEAD_AES_256_CBC_HMAC_SHA_512_Random) //
Map<String, String> mapOfString;
@ExplicitEncrypted(algorithm = AEAD_AES_256_CBC_HMAC_SHA_512_Random) //
Map<String, Address> mapOfComplex;
@ExplicitEncrypted(algorithm = AEAD_AES_256_CBC_HMAC_SHA_512_Random) //
LocalDate today;
public String getId() {
return this.id;
}
public String getName() {
return this.name;
}
public String getSsn() {
return this.ssn;
}
public String getWallet() {
return this.wallet;
}
public Address getAddress() {
return this.address;
}
public AddressWithEncryptedZip getEncryptedZip() {
return this.encryptedZip;
}
public List<String> getListOfString() {
return this.listOfString;
}
public List<Address> getListOfComplex() {
return this.listOfComplex;
}
public String getViaAltKeyNameField() {
return this.viaAltKeyNameField;
}
public Map<String, String> getMapOfString() {
return this.mapOfString;
}
public Map<String, Address> getMapOfComplex() {
return this.mapOfComplex;
}
public LocalDate getToday() {
return today;
}
public void setId(String id) {
this.id = id;
}
public void setName(String name) {
this.name = name;
}
public void setSsn(String ssn) {
this.ssn = ssn;
}
public void setWallet(String wallet) {
this.wallet = wallet;
}
public void setAddress(Address address) {
this.address = address;
}
public void setEncryptedZip(AddressWithEncryptedZip encryptedZip) {
this.encryptedZip = encryptedZip;
}
public void setListOfString(List<String> listOfString) {
this.listOfString = listOfString;
}
public void setListOfComplex(List<Address> listOfComplex) {
this.listOfComplex = listOfComplex;
}
public void setViaAltKeyNameField(String viaAltKeyNameField) {
this.viaAltKeyNameField = viaAltKeyNameField;
}
public void setMapOfString(Map<String, String> mapOfString) {
this.mapOfString = mapOfString;
}
public void setMapOfComplex(Map<String, Address> mapOfComplex) {
this.mapOfComplex = mapOfComplex;
}
public void setToday(LocalDate today) {
this.today = today;
}
@Override
public boolean equals(Object o) {
if (o == this) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Person person = (Person) o;
return Objects.equals(id, person.id) && Objects.equals(name, person.name) && Objects.equals(ssn, person.ssn)
&& Objects.equals(wallet, person.wallet) && Objects.equals(address, person.address)
&& Objects.equals(encryptedZip, person.encryptedZip) && Objects.equals(listOfString, person.listOfString)
&& Objects.equals(listOfComplex, person.listOfComplex)
&& Objects.equals(viaAltKeyNameField, person.viaAltKeyNameField)
&& Objects.equals(mapOfString, person.mapOfString) && Objects.equals(mapOfComplex, person.mapOfComplex)
&& Objects.equals(today, person.today);
}
@Override
public int hashCode() {
return Objects.hash(id, name, ssn, wallet, address, encryptedZip, listOfString, listOfComplex, viaAltKeyNameField,
mapOfString, mapOfComplex, today);
}
public String toString() {
return "EncryptionTests.Person(id=" + this.getId() + ", name=" + this.getName() + ", ssn=" + this.getSsn()
+ ", wallet=" + this.getWallet() + ", address=" + this.getAddress() + ", encryptedZip="
+ this.getEncryptedZip() + ", listOfString=" + this.getListOfString() + ", listOfComplex="
+ this.getListOfComplex() + ", viaAltKeyNameField=" + this.getViaAltKeyNameField() + ", mapOfString="
+ this.getMapOfString() + ", mapOfComplex=" + this.getMapOfComplex() + ", today=" + this.getToday() + ")";
}
}
static class Address {
String city;
String street;
public Address() {}
public String getCity() {
return this.city;
}
public String getStreet() {
return this.street;
}
public void setCity(String city) {
this.city = city;
}
public void setStreet(String street) {
this.street = street;
}
@Override
public boolean equals(Object o) {
if (o == this) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Address address = (Address) o;
return Objects.equals(city, address.city) && Objects.equals(street, address.street);
}
@Override
public int hashCode() {
return Objects.hash(city, street);
}
public String toString() {
return "EncryptionTests.Address(city=" + this.getCity() + ", street=" + this.getStreet() + ")";
}
}
static class AddressWithEncryptedZip extends Address {
@ExplicitEncrypted(algorithm = AEAD_AES_256_CBC_HMAC_SHA_512_Random) String zip;
@Override
public String toString() {
return "AddressWithEncryptedZip{" + "zip='" + zip + '\'' + ", city='" + getCity() + '\'' + ", street='"
+ getStreet() + '\'' + '}';
}
public String getZip() {
return this.zip;
}
public void setZip(String zip) {
this.zip = zip;
}
}
}

View File

@@ -1,64 +0,0 @@
/*
* Copyright 2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.core.encryption;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit.jupiter.SpringExtension;
import com.mongodb.AutoEncryptionSettings;
import com.mongodb.ClientEncryptionSettings;
import com.mongodb.MongoClientSettings.Builder;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoClients;
/**
* Encryption tests for client having {@link AutoEncryptionSettings#isBypassAutoEncryption()}.
*
* @author Christoph Strobl
* @author Julia Lee
*/
@ExtendWith(SpringExtension.class)
@ContextConfiguration(classes = BypassAutoEncryptionTest.Config.class)
public class BypassAutoEncryptionTest extends AbstractEncryptionTestBase {
@Disabled
@Override
void altKeyDetection(@Autowired CachingMongoClientEncryption mongoClientEncryption) throws InterruptedException {
super.altKeyDetection(mongoClientEncryption);
}
@Configuration
static class Config extends EncryptionConfig {
@Override
protected void configureClientSettings(Builder builder) {
MongoClient mongoClient = MongoClients.create();
ClientEncryptionSettings clientEncryptionSettings = encryptionSettings(mongoClient);
mongoClient.close();
builder.autoEncryptionSettings(AutoEncryptionSettings.builder() //
.kmsProviders(clientEncryptionSettings.getKmsProviders()) //
.keyVaultNamespace(clientEncryptionSettings.getKeyVaultNamespace()) //
.bypassAutoEncryption(true).build());
}
}
}

View File

@@ -15,16 +15,721 @@
*/
package org.springframework.data.mongodb.core.encryption;
import static org.assertj.core.api.Assertions.*;
import static org.springframework.data.mongodb.core.EncryptionAlgorithms.*;
import static org.springframework.data.mongodb.core.aggregation.Aggregation.*;
import static org.springframework.data.mongodb.core.query.Criteria.*;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import org.assertj.core.api.Assertions;
import org.bson.BsonBinary;
import org.bson.Document;
import org.bson.types.Binary;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.dao.PermissionDeniedDataAccessException;
import org.springframework.data.convert.PropertyValueConverterFactory;
import org.springframework.data.mongodb.config.AbstractMongoClientConfiguration;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
import org.springframework.data.mongodb.core.convert.MongoCustomConversions.MongoConverterConfigurationAdapter;
import org.springframework.data.mongodb.core.convert.encryption.MongoEncryptionConverter;
import org.springframework.data.mongodb.core.encryption.EncryptionTests.Config;
import org.springframework.data.mongodb.core.mapping.ExplicitEncrypted;
import org.springframework.data.mongodb.core.query.Update;
import org.springframework.data.util.Lazy;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit.jupiter.SpringExtension;
import com.mongodb.ClientEncryptionSettings;
import com.mongodb.ConnectionString;
import com.mongodb.MongoClientSettings;
import com.mongodb.MongoNamespace;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.model.Filters;
import com.mongodb.client.model.IndexOptions;
import com.mongodb.client.model.Indexes;
import com.mongodb.client.model.vault.DataKeyOptions;
import com.mongodb.client.vault.ClientEncryption;
import com.mongodb.client.vault.ClientEncryptions;
/**
* @author Christoph Strobl
* @author Julia Lee
*/
@ExtendWith(SpringExtension.class)
@ContextConfiguration(classes = AbstractEncryptionTestBase.EncryptionConfig.class)
public class EncryptionTests extends AbstractEncryptionTestBase {
@ContextConfiguration(classes = Config.class)
public class EncryptionTests {
@Autowired MongoTemplate template;
@Test // GH-4284
void encryptAndDecryptSimpleValue() {
Person source = new Person();
source.id = "id-1";
source.ssn = "mySecretSSN";
template.save(source);
verifyThat(source) //
.identifiedBy(Person::getId) //
.wasSavedMatching(it -> assertThat(it.get("ssn")).isInstanceOf(Binary.class)) //
.loadedIsEqualToSource();
}
@Test // GH-4284
void encryptAndDecryptComplexValue() {
Person source = new Person();
source.id = "id-1";
source.address = new Address();
source.address.city = "NYC";
source.address.street = "4th Ave.";
template.save(source);
verifyThat(source) //
.identifiedBy(Person::getId) //
.wasSavedMatching(it -> assertThat(it.get("address")).isInstanceOf(Binary.class)) //
.loadedIsEqualToSource();
}
@Test // GH-4284
void encryptAndDecryptValueWithinComplexOne() {
Person source = new Person();
source.id = "id-1";
source.encryptedZip = new AddressWithEncryptedZip();
source.encryptedZip.city = "Boston";
source.encryptedZip.street = "central square";
source.encryptedZip.zip = "1234567890";
template.save(source);
verifyThat(source) //
.identifiedBy(Person::getId) //
.wasSavedMatching(it -> {
assertThat(it.get("encryptedZip")).isInstanceOf(Document.class);
assertThat(it.get("encryptedZip", Document.class).get("city")).isInstanceOf(String.class);
assertThat(it.get("encryptedZip", Document.class).get("street")).isInstanceOf(String.class);
assertThat(it.get("encryptedZip", Document.class).get("zip")).isInstanceOf(Binary.class);
}) //
.loadedIsEqualToSource();
}
@Test // GH-4284
void encryptAndDecryptListOfSimpleValue() {
Person source = new Person();
source.id = "id-1";
source.listOfString = Arrays.asList("spring", "data", "mongodb");
template.save(source);
verifyThat(source) //
.identifiedBy(Person::getId) //
.wasSavedMatching(it -> assertThat(it.get("listOfString")).isInstanceOf(Binary.class)) //
.loadedIsEqualToSource();
}
@Test // GH-4284
void encryptAndDecryptListOfComplexValue() {
Person source = new Person();
source.id = "id-1";
Address address = new Address();
address.city = "SFO";
address.street = "---";
source.listOfComplex = Collections.singletonList(address);
template.save(source);
verifyThat(source) //
.identifiedBy(Person::getId) //
.wasSavedMatching(it -> assertThat(it.get("listOfComplex")).isInstanceOf(Binary.class)) //
.loadedIsEqualToSource();
}
@Test // GH-4284
void encryptAndDecryptMapOfSimpleValues() {
Person source = new Person();
source.id = "id-1";
source.mapOfString = Map.of("k1", "v1", "k2", "v2");
template.save(source);
verifyThat(source) //
.identifiedBy(Person::getId) //
.wasSavedMatching(it -> assertThat(it.get("mapOfString")).isInstanceOf(Binary.class)) //
.loadedIsEqualToSource();
}
@Test // GH-4284
void encryptAndDecryptMapOfComplexValues() {
Person source = new Person();
source.id = "id-1";
Address address1 = new Address();
address1.city = "SFO";
address1.street = "---";
Address address2 = new Address();
address2.city = "NYC";
address2.street = "---";
source.mapOfComplex = Map.of("a1", address1, "a2", address2);
template.save(source);
verifyThat(source) //
.identifiedBy(Person::getId) //
.wasSavedMatching(it -> assertThat(it.get("mapOfComplex")).isInstanceOf(Binary.class)) //
.loadedIsEqualToSource();
}
@Test // GH-4284
void canQueryDeterministicallyEncrypted() {
Person source = new Person();
source.id = "id-1";
source.ssn = "mySecretSSN";
template.save(source);
Person loaded = template.query(Person.class).matching(where("ssn").is(source.ssn)).firstValue();
assertThat(loaded).isEqualTo(source);
}
@Test // GH-4284
void cannotQueryRandomlyEncrypted() {
Person source = new Person();
source.id = "id-1";
source.wallet = "secret-wallet-id";
template.save(source);
Person loaded = template.query(Person.class).matching(where("wallet").is(source.wallet)).firstValue();
assertThat(loaded).isNull();
}
@Test // GH-4284
void updateSimpleTypeEncryptedFieldWithNewValue() {
Person source = new Person();
source.id = "id-1";
template.save(source);
template.update(Person.class).matching(where("id").is(source.id)).apply(Update.update("ssn", "secret-value"))
.first();
verifyThat(source) //
.identifiedBy(Person::getId) //
.wasSavedMatching(it -> assertThat(it.get("ssn")).isInstanceOf(Binary.class)) //
.loadedMatches(it -> assertThat(it.getSsn()).isEqualTo("secret-value"));
}
@Test // GH-4284
void updateComplexTypeEncryptedFieldWithNewValue() {
Person source = new Person();
source.id = "id-1";
template.save(source);
Address address = new Address();
address.city = "SFO";
address.street = "---";
template.update(Person.class).matching(where("id").is(source.id)).apply(Update.update("address", address)).first();
verifyThat(source) //
.identifiedBy(Person::getId) //
.wasSavedMatching(it -> assertThat(it.get("address")).isInstanceOf(Binary.class)) //
.loadedMatches(it -> assertThat(it.getAddress()).isEqualTo(address));
}
@Test // GH-4284
void updateEncryptedFieldInNestedElementWithNewValue() {
Person source = new Person();
source.id = "id-1";
source.encryptedZip = new AddressWithEncryptedZip();
source.encryptedZip.city = "Boston";
source.encryptedZip.street = "central square";
template.save(source);
template.update(Person.class).matching(where("id").is(source.id)).apply(Update.update("encryptedZip.zip", "179"))
.first();
verifyThat(source) //
.identifiedBy(Person::getId) //
.wasSavedMatching(it -> {
assertThat(it.get("encryptedZip")).isInstanceOf(Document.class);
assertThat(it.get("encryptedZip", Document.class).get("city")).isInstanceOf(String.class);
assertThat(it.get("encryptedZip", Document.class).get("street")).isInstanceOf(String.class);
assertThat(it.get("encryptedZip", Document.class).get("zip")).isInstanceOf(Binary.class);
}) //
.loadedMatches(it -> assertThat(it.getEncryptedZip().getZip()).isEqualTo("179"));
}
@Test
void aggregationWithMatch() {
Person person = new Person();
person.id = "id-1";
person.name = "p1-name";
person.ssn = "mySecretSSN";
template.save(person);
AggregationResults<Person> aggregationResults = template.aggregateAndReturn(Person.class)
.by(newAggregation(Person.class, Aggregation.match(where("ssn").is(person.ssn)))).all();
assertThat(aggregationResults.getMappedResults()).containsExactly(person);
}
@Test
void altKeyDetection(@Autowired CachingMongoClientEncryption mongoClientEncryption) throws InterruptedException {
BsonBinary user1key = mongoClientEncryption.getClientEncryption().createDataKey("local",
new DataKeyOptions().keyAltNames(Collections.singletonList("user-1")));
BsonBinary user2key = mongoClientEncryption.getClientEncryption().createDataKey("local",
new DataKeyOptions().keyAltNames(Collections.singletonList("user-2")));
Person p1 = new Person();
p1.id = "id-1";
p1.name = "user-1";
p1.ssn = "ssn";
p1.viaAltKeyNameField = "value-1";
Person p2 = new Person();
p2.id = "id-2";
p2.name = "user-2";
p2.viaAltKeyNameField = "value-1";
Person p3 = new Person();
p3.id = "id-3";
p3.name = "user-1";
p3.viaAltKeyNameField = "value-1";
template.save(p1);
template.save(p2);
template.save(p3);
template.execute(Person.class, collection -> {
collection.find(new Document()).forEach(it -> System.out.println(it.toJson()));
return null;
});
// remove the key and invalidate encrypted data
mongoClientEncryption.getClientEncryption().deleteKey(user2key);
// clear the 60 second key cache within the mongo client
mongoClientEncryption.destroy();
assertThat(template.query(Person.class).matching(where("id").is(p1.id)).firstValue()).isEqualTo(p1);
assertThatExceptionOfType(PermissionDeniedDataAccessException.class)
.isThrownBy(() -> template.query(Person.class).matching(where("id").is(p2.id)).firstValue());
}
<T> SaveAndLoadAssert<T> verifyThat(T source) {
return new SaveAndLoadAssert<>(source);
}
class SaveAndLoadAssert<T> {
T source;
Function<T, ?> idProvider;
SaveAndLoadAssert(T source) {
this.source = source;
}
SaveAndLoadAssert<T> identifiedBy(Function<T, ?> idProvider) {
this.idProvider = idProvider;
return this;
}
SaveAndLoadAssert<T> wasSavedAs(Document expected) {
return wasSavedMatching(it -> Assertions.assertThat(it).isEqualTo(expected));
}
SaveAndLoadAssert<T> wasSavedMatching(Consumer<Document> saved) {
EncryptionTests.this.assertSaved(source, idProvider, saved);
return this;
}
SaveAndLoadAssert<T> loadedMatches(Consumer<T> expected) {
EncryptionTests.this.assertLoaded(source, idProvider, expected);
return this;
}
SaveAndLoadAssert<T> loadedIsEqualToSource() {
return loadedIsEqualTo(source);
}
SaveAndLoadAssert<T> loadedIsEqualTo(T expected) {
return loadedMatches(it -> Assertions.assertThat(it).isEqualTo(expected));
}
}
<T> void assertSaved(T source, Function<T, ?> idProvider, Consumer<Document> dbValue) {
Document savedDocument = template.execute(Person.class, collection -> {
return collection.find(new Document("_id", idProvider.apply(source))).first();
});
dbValue.accept(savedDocument);
}
<T> void assertLoaded(T source, Function<T, ?> idProvider, Consumer<T> loadedValue) {
T loaded = template.query((Class<T>) source.getClass()).matching(where("id").is(idProvider.apply(source)))
.firstValue();
loadedValue.accept(loaded);
}
@Configuration
static class Config extends AbstractMongoClientConfiguration {
@Autowired ApplicationContext applicationContext;
@Override
protected String getDatabaseName() {
return "fle-test";
}
@Bean
public MongoClient mongoClient() {
return super.mongoClient();
}
@Override
protected void configureConverters(MongoConverterConfigurationAdapter converterConfigurationAdapter) {
converterConfigurationAdapter
.registerPropertyValueConverterFactory(PropertyValueConverterFactory.beanFactoryAware(applicationContext));
}
@Bean
MongoEncryptionConverter encryptingConverter(MongoClientEncryption mongoClientEncryption) {
Lazy<BsonBinary> dataKey = Lazy.of(() -> mongoClientEncryption.getClientEncryption().createDataKey("local",
new DataKeyOptions().keyAltNames(Collections.singletonList("mySuperSecretKey"))));
return new MongoEncryptionConverter(mongoClientEncryption,
EncryptionKeyResolver.annotated((ctx) -> EncryptionKey.keyId(dataKey.get())));
}
@Bean
CachingMongoClientEncryption clientEncryption(ClientEncryptionSettings encryptionSettings) {
return new CachingMongoClientEncryption(() -> ClientEncryptions.create(encryptionSettings));
}
@Bean
ClientEncryptionSettings encryptionSettings(MongoClient mongoClient) {
MongoNamespace keyVaultNamespace = new MongoNamespace("encryption.testKeyVault");
MongoCollection<Document> keyVaultCollection = mongoClient.getDatabase(keyVaultNamespace.getDatabaseName())
.getCollection(keyVaultNamespace.getCollectionName());
keyVaultCollection.drop();
// Ensure that two data keys cannot share the same keyAltName.
keyVaultCollection.createIndex(Indexes.ascending("keyAltNames"),
new IndexOptions().unique(true).partialFilterExpression(Filters.exists("keyAltNames")));
MongoCollection<Document> collection = mongoClient.getDatabase(getDatabaseName()).getCollection("test");
collection.drop(); // Clear old data
final byte[] localMasterKey = new byte[96];
new SecureRandom().nextBytes(localMasterKey);
Map<String, Map<String, Object>> kmsProviders = new HashMap<>() {
{
put("local", new HashMap<>() {
{
put("key", localMasterKey);
}
});
}
};
// Create the ClientEncryption instance
ClientEncryptionSettings clientEncryptionSettings = ClientEncryptionSettings.builder()
.keyVaultMongoClientSettings(
MongoClientSettings.builder().applyConnectionString(new ConnectionString("mongodb://localhost")).build())
.keyVaultNamespace(keyVaultNamespace.getFullName()).kmsProviders(kmsProviders).build();
return clientEncryptionSettings;
}
}
static class CachingMongoClientEncryption extends MongoClientEncryption implements DisposableBean {
static final AtomicReference<ClientEncryption> cache = new AtomicReference<>();
CachingMongoClientEncryption(Supplier<ClientEncryption> source) {
super(() -> {
if (cache.get() != null) {
return cache.get();
}
ClientEncryption clientEncryption = source.get();
cache.set(clientEncryption);
return clientEncryption;
});
}
@Override
public void destroy() {
ClientEncryption clientEncryption = cache.get();
if (clientEncryption != null) {
clientEncryption.close();
cache.set(null);
}
}
}
@org.springframework.data.mongodb.core.mapping.Document("test")
static class Person {
String id;
String name;
@ExplicitEncrypted(algorithm = AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic) //
String ssn;
@ExplicitEncrypted(algorithm = AEAD_AES_256_CBC_HMAC_SHA_512_Random, keyAltName = "mySuperSecretKey") //
String wallet;
@ExplicitEncrypted(algorithm = AEAD_AES_256_CBC_HMAC_SHA_512_Random) // full document must be random
Address address;
AddressWithEncryptedZip encryptedZip;
@ExplicitEncrypted(algorithm = AEAD_AES_256_CBC_HMAC_SHA_512_Random) // lists must be random
List<String> listOfString;
@ExplicitEncrypted(algorithm = AEAD_AES_256_CBC_HMAC_SHA_512_Random) // lists must be random
List<Address> listOfComplex;
@ExplicitEncrypted(algorithm = AEAD_AES_256_CBC_HMAC_SHA_512_Random, keyAltName = "/name") //
String viaAltKeyNameField;
@ExplicitEncrypted(algorithm = AEAD_AES_256_CBC_HMAC_SHA_512_Random) //
Map<String, String> mapOfString;
@ExplicitEncrypted(algorithm = AEAD_AES_256_CBC_HMAC_SHA_512_Random) //
Map<String, Address> mapOfComplex;
public String getId() {
return this.id;
}
public String getName() {
return this.name;
}
public String getSsn() {
return this.ssn;
}
public String getWallet() {
return this.wallet;
}
public Address getAddress() {
return this.address;
}
public AddressWithEncryptedZip getEncryptedZip() {
return this.encryptedZip;
}
public List<String> getListOfString() {
return this.listOfString;
}
public List<Address> getListOfComplex() {
return this.listOfComplex;
}
public String getViaAltKeyNameField() {
return this.viaAltKeyNameField;
}
public Map<String, String> getMapOfString() {
return this.mapOfString;
}
public Map<String, Address> getMapOfComplex() {
return this.mapOfComplex;
}
public void setId(String id) {
this.id = id;
}
public void setName(String name) {
this.name = name;
}
public void setSsn(String ssn) {
this.ssn = ssn;
}
public void setWallet(String wallet) {
this.wallet = wallet;
}
public void setAddress(Address address) {
this.address = address;
}
public void setEncryptedZip(AddressWithEncryptedZip encryptedZip) {
this.encryptedZip = encryptedZip;
}
public void setListOfString(List<String> listOfString) {
this.listOfString = listOfString;
}
public void setListOfComplex(List<Address> listOfComplex) {
this.listOfComplex = listOfComplex;
}
public void setViaAltKeyNameField(String viaAltKeyNameField) {
this.viaAltKeyNameField = viaAltKeyNameField;
}
public void setMapOfString(Map<String, String> mapOfString) {
this.mapOfString = mapOfString;
}
public void setMapOfComplex(Map<String, Address> mapOfComplex) {
this.mapOfComplex = mapOfComplex;
}
@Override
public boolean equals(Object o) {
if (o == this) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Person person = (Person) o;
return Objects.equals(id, person.id) && Objects.equals(name, person.name) && Objects.equals(ssn, person.ssn)
&& Objects.equals(wallet, person.wallet) && Objects.equals(address, person.address)
&& Objects.equals(encryptedZip, person.encryptedZip) && Objects.equals(listOfString, person.listOfString)
&& Objects.equals(listOfComplex, person.listOfComplex)
&& Objects.equals(viaAltKeyNameField, person.viaAltKeyNameField)
&& Objects.equals(mapOfString, person.mapOfString) && Objects.equals(mapOfComplex, person.mapOfComplex);
}
@Override
public int hashCode() {
return Objects.hash(id, name, ssn, wallet, address, encryptedZip, listOfString, listOfComplex, viaAltKeyNameField,
mapOfString, mapOfComplex);
}
public String toString() {
return "EncryptionTests.Person(id=" + this.getId() + ", name=" + this.getName() + ", ssn=" + this.getSsn()
+ ", wallet=" + this.getWallet() + ", address=" + this.getAddress() + ", encryptedZip="
+ this.getEncryptedZip() + ", listOfString=" + this.getListOfString() + ", listOfComplex="
+ this.getListOfComplex() + ", viaAltKeyNameField=" + this.getViaAltKeyNameField() + ", mapOfString="
+ this.getMapOfString() + ", mapOfComplex=" + this.getMapOfComplex() + ")";
}
}
static class Address {
String city;
String street;
public Address() {}
public String getCity() {
return this.city;
}
public String getStreet() {
return this.street;
}
public void setCity(String city) {
this.city = city;
}
public void setStreet(String street) {
this.street = street;
}
@Override
public boolean equals(Object o) {
if (o == this) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Address address = (Address) o;
return Objects.equals(city, address.city) && Objects.equals(street, address.street);
}
@Override
public int hashCode() {
return Objects.hash(city, street);
}
public String toString() {
return "EncryptionTests.Address(city=" + this.getCity() + ", street=" + this.getStreet() + ")";
}
}
static class AddressWithEncryptedZip extends Address {
@ExplicitEncrypted(algorithm = AEAD_AES_256_CBC_HMAC_SHA_512_Random) String zip;
@Override
public String toString() {
return "AddressWithEncryptedZip{" + "zip='" + zip + '\'' + ", city='" + getCity() + '\'' + ", street='"
+ getStreet() + '\'' + '}';
}
public String getZip() {
return this.zip;
}
public void setZip(String zip) {
this.zip = zip;
}
}
}

View File

@@ -17,19 +17,10 @@ package org.springframework.data.mongodb.util.json;
import static org.assertj.core.api.Assertions.*;
import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.time.temporal.Temporal;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.stream.Stream;
import org.bson.BsonArray;
import org.bson.BsonDouble;
import org.bson.BsonInt32;
import org.bson.BsonInt64;
@@ -38,9 +29,7 @@ import org.bson.BsonString;
import org.bson.Document;
import org.bson.types.ObjectId;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.springframework.data.mongodb.util.BsonUtils;
import com.mongodb.BasicDBList;
@@ -116,9 +105,9 @@ class BsonUtilsTest {
@Test // GH-3571
void asCollectionConvertsArrayToCollection() {
Object source = new String[] { "one", "two" };
Object source = new String[]{"one", "two"};
assertThat((Collection) BsonUtils.asCollection(source)).containsExactly("one", "two");
assertThat((Collection)BsonUtils.asCollection(source)).containsExactly("one", "two");
}
@Test // GH-3571
@@ -126,7 +115,7 @@ class BsonUtilsTest {
Object source = 100L;
assertThat((Collection) BsonUtils.asCollection(source)).containsExactly(source);
assertThat((Collection)BsonUtils.asCollection(source)).containsExactly(source);
}
@Test // GH-3702
@@ -137,41 +126,4 @@ class BsonUtilsTest {
assertThat(BsonUtils.supportsBson(new BasicDBList())).isTrue();
assertThat(BsonUtils.supportsBson(Collections.emptyMap())).isTrue();
}
@ParameterizedTest // GH-4432
@MethodSource("javaTimeInstances")
void convertsJavaTimeTypesToBsonDateTime(Temporal source) {
assertThat(BsonUtils.simpleToBsonValue(source))
.isEqualTo(new Document("value", source).toBsonDocument().get("value"));
}
@ParameterizedTest // GH-4432
@MethodSource("collectionLikeInstances")
void convertsCollectionLikeToBsonArray(Object source) {
assertThat(BsonUtils.simpleToBsonValue(source))
.isEqualTo(new Document("value", source).toBsonDocument().get("value"));
}
@Test // GH-4432
void convertsPrimitiveArrayToBsonArray() {
assertThat(BsonUtils.simpleToBsonValue(new int[] { 1, 2, 3 }))
.isEqualTo(new BsonArray(List.of(new BsonInt32(1), new BsonInt32(2), new BsonInt32(3))));
}
static Stream<Arguments> javaTimeInstances() {
return Stream.of(Arguments.of(Instant.now()), Arguments.of(LocalDate.now()), Arguments.of(LocalDateTime.now()),
Arguments.of(LocalTime.now()));
}
static Stream<Arguments> collectionLikeInstances() {
return Stream.of(Arguments.of(new String[] { "1", "2", "3" }), Arguments.of(List.of("1", "2", "3")),
Arguments.of(new Integer[] { 1, 2, 3 }), Arguments.of(List.of(1, 2, 3)),
Arguments.of(new Date[] { new Date() }), Arguments.of(List.of(new Date())),
Arguments.of(new LocalDate[] { LocalDate.now() }), Arguments.of(List.of(LocalDate.now())));
}
}

View File

@@ -1,36 +0,0 @@
/*
* Copyright 2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.core
import com.mongodb.kotlin.client.MongoClient
import org.bson.Document
import org.junit.jupiter.api.Test
import org.springframework.data.mongodb.test.util.Assertions.assertThat
/**
* @author Christoph Strobl
*/
class SimpleMongoClientDatabaseFactoryExtensionTests {
@Test // GH-4393
fun `extension allows to create SimpleMongoClientDatabaseFactory with a Kotlin Driver instance`() {
val factory = SimpleMongoClientDatabaseFactory(MongoClient.create(), "test")
assertThat(factory.mongoDatabase.runCommand(Document("ping", 1))).containsKey("ok")
}
}

View File

@@ -1,40 +0,0 @@
/*
* Copyright 2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.core
import com.mongodb.kotlin.client.coroutine.MongoClient
import org.bson.Document
import org.junit.jupiter.api.Test
import reactor.core.publisher.Mono
import reactor.test.StepVerifier
/**
* @author Christoph Strobl
*/
class SimpleReactiveMongoDatabaseFactoryExtensionTests {
@Test // GH-4393
fun `extension allows to create SimpleReactiveMongoDatabaseFactory with a Kotlin Coroutine Driver instance`() {
val factory = SimpleReactiveMongoDatabaseFactory(MongoClient.create(), "test")
factory.mongoDatabase.flatMap { Mono.from(it.runCommand(Document("ping", 1))) }
.`as` { StepVerifier.create(it) }
.expectNextCount(1)
.verifyComplete()
}
}

View File

@@ -1,4 +1,4 @@
Spring Data MongoDB 4.2 M1 (2023.1.0)
Spring Data MongoDB 4.1 GA (2023.0.0)
Copyright (c) [2010-2019] Pivotal Software, Inc.
This product is licensed to you under the Apache License, Version 2.0 (the "License").
@@ -45,6 +45,5 @@ conditions of the subcomponent's license, as noted in the LICENSE file.