Include & Exclude paths in project aggregation stage.

This commit is contained in:
Christoph Strobl
2023-07-06 14:19:11 +02:00
parent f42e63d613
commit cd44432f81
2 changed files with 110 additions and 5 deletions

View File

@@ -167,6 +167,46 @@ public class ProjectionOperation implements FieldsExposingAggregationOperation {
return new ProjectionOperation(this.projections, projections);
}
/**
* @param path
* @return
* @since 4.2
*/
public ProjectionOperation andIncludePath(String path) {
return andIncludePaths(new String[] { path });
}
/**
* @param paths
* @return
* @since 4.2
*/
public ProjectionOperation andIncludePaths(String... paths) {
List<FieldProjection> projections = FieldProjection.from(Fields.fields(paths), 1);
return new ProjectionOperation(this.projections, projections);
}
/**
* @param path
* @return
* @since 4.2
*/
public ProjectionOperation andExcludePath(String path) {
return andExcludePaths(new String[] { path });
}
/**
* @param paths
* @return
* @since 4.2
*/
public ProjectionOperation andExcludePaths(String... paths) {
List<FieldProjection> projections = FieldProjection.from(Fields.fields(paths), 0);
return new ProjectionOperation(this.projections, projections);
}
/**
* Includes the given fields into the projection.
*
@@ -495,8 +535,8 @@ public class ProjectionOperation implements FieldsExposingAggregationOperation {
if (value instanceof AggregationExpression) {
return this.operation.and(new ExpressionProjection(Fields.field(alias, alias), (AggregationExpression) value));
}
return this.operation.and(new FieldProjection(Fields.field(alias, getRequiredName()), null));
Field field = Fields.field(alias, getRequiredName());
return this.operation.and(new FieldProjection(field, null));
}
@Override
@@ -1329,6 +1369,11 @@ public class ProjectionOperation implements FieldsExposingAggregationOperation {
private final Field field;
private final @Nullable Object value;
private final ProjectOn projectOn;
enum ProjectOn {
PATH, NAME
}
/**
* Creates a new {@link FieldProjection} for the field of the given name, assigning the given value.
@@ -1341,11 +1386,16 @@ public class ProjectionOperation implements FieldsExposingAggregationOperation {
}
private FieldProjection(Field field, @Nullable Object value) {
this(field, value, value instanceof Integer ? ProjectOn.PATH : ProjectOn.NAME);
}
private FieldProjection(Field field, @Nullable Object value, ProjectOn project) {
super(new ExposedField(field.getName(), true));
this.field = field;
this.value = value;
this.projectOn = project;
}
/**
@@ -1372,7 +1422,8 @@ public class ProjectionOperation implements FieldsExposingAggregationOperation {
List<FieldProjection> projections = new ArrayList<FieldProjection>();
for (Field field : fields) {
projections.add(new FieldProjection(field, value));
projections
.add(new FieldProjection(field, value, value instanceof Integer ? ProjectOn.PATH : ProjectOn.NAME));
}
return projections;
@@ -1382,12 +1433,13 @@ public class ProjectionOperation implements FieldsExposingAggregationOperation {
* @return {@literal true} if this field is excluded.
*/
public boolean isExcluded() {
return Boolean.FALSE.equals(value);
return Boolean.FALSE.equals(value) || value instanceof Number number && number.intValue() == 0;
}
@Override
public Document toDocument(AggregationOperationContext context) {
return new Document(field.getName(), renderFieldValue(context));
return new Document(ProjectOn.NAME.equals(projectOn) ? field.getName() : context.getReference(field.getTarget()).getRaw(),
renderFieldValue(context));
}
private Object renderFieldValue(AggregationOperationContext context) {

View File

@@ -36,6 +36,7 @@ import org.springframework.data.mongodb.core.aggregation.ProjectionOperationUnit
import org.springframework.data.mongodb.core.convert.MappingMongoConverter;
import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver;
import org.springframework.data.mongodb.core.convert.QueryMapper;
import org.springframework.data.mongodb.core.mapping.Field;
import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
import org.springframework.data.mongodb.core.query.Criteria;
@@ -653,6 +654,58 @@ public class AggregationUnitTests {
assertThat(documents.get(2)).isEqualTo("{ $sort : { 'serial_number' : -1, 'label_name' : -1 } }");
}
@Test // GH-4428
void projectIncludePath() {
MongoMappingContext mappingContext = new MongoMappingContext();
RelaxedTypeBasedAggregationOperationContext context = new RelaxedTypeBasedAggregationOperationContext(
Root.class, mappingContext,
new QueryMapper(new MappingMongoConverter(NoOpDbRefResolver.INSTANCE, mappingContext)));
assertThat(project("flat").andIncludePath("list.element").toDocument(context)).isEqualTo(
Document.parse("""
{
"$project": {
"flat": 1,
"list.elE_m_enT": 1
}
}""")
);
}
@Test // GH-4428
void projectExcludePath() {
MongoMappingContext mappingContext = new MongoMappingContext();
RelaxedTypeBasedAggregationOperationContext context = new RelaxedTypeBasedAggregationOperationContext(
Root.class, mappingContext,
new QueryMapper(new MappingMongoConverter(NoOpDbRefResolver.INSTANCE, mappingContext)));
assertThat(project("flat").andExcludePath("list.element").toDocument(context)).isEqualTo(
Document.parse("""
{
"$project": {
"flat": 1,
"list.elE_m_enT": 0
}
}""")
);
}
static class Root {
String flat;
List<Nested> list;
}
static class Nested {
@Field("elE_m_enT")
int element;
String description;
}
private Document extractPipelineElement(Document agg, int index, String operation) {
List<Document> pipeline = (List<Document>) agg.get("pipeline");