Skip to content

Commit

Permalink
Avoid multiple mapping iterations.
Browse files Browse the repository at this point in the history
A 2nd pass is no longer needed as the context already does all the work.

Closes: #4043
Original pull request: #4240
  • Loading branch information
christophstrobl authored and mp911de committed Jan 11, 2023
1 parent 1839f55 commit 8bcab93
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty;
import org.springframework.data.util.Lazy;
import org.springframework.lang.Nullable;
import org.springframework.util.ObjectUtils;

/**
* Utility methods to map {@link org.springframework.data.mongodb.core.aggregation.Aggregation} pipeline definitions and
Expand Down Expand Up @@ -96,12 +95,7 @@ AggregationOperationContext createAggregationContext(Aggregation aggregation, @N
* @return
*/
List<Document> createPipeline(Aggregation aggregation, AggregationOperationContext context) {

if (ObjectUtils.nullSafeEquals(context, Aggregation.DEFAULT_CONTEXT)) {
return aggregation.toPipeline(context);
}

return mapAggregationPipeline(aggregation.toPipeline(context));
return aggregation.toPipeline(context);
}

/**
Expand All @@ -112,16 +106,7 @@ List<Document> createPipeline(Aggregation aggregation, AggregationOperationConte
* @return
*/
Document createCommand(String collection, Aggregation aggregation, AggregationOperationContext context) {

Document command = aggregation.toDocument(collection, context);

if (!ObjectUtils.nullSafeEquals(context, Aggregation.DEFAULT_CONTEXT)) {
return command;
}

command.put("pipeline", mapAggregationPipeline(command.get("pipeline", List.class)));

return command;
return aggregation.toDocument(collection, context);
}

private List<Document> mapAggregationPipeline(List<Document> pipeline) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@

import org.assertj.core.data.Offset;
import org.bson.Document;
import org.bson.types.ObjectId;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;

import org.springframework.core.io.ClassPathResource;
import org.springframework.data.annotation.Id;
import org.springframework.data.domain.Sort;
Expand All @@ -65,6 +65,7 @@
import org.springframework.data.mongodb.core.geo.GeoJsonPoint;
import org.springframework.data.mongodb.core.index.GeoSpatialIndexType;
import org.springframework.data.mongodb.core.index.GeospatialIndex;
import org.springframework.data.mongodb.core.mapping.MongoId;
import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.core.query.NearQuery;
import org.springframework.data.mongodb.core.query.Query;
Expand Down Expand Up @@ -1933,6 +1934,24 @@ void mapsEnumsInMatchClauseUsingInCriteriaCorrectly() {
assertThat(results.getMappedResults()).hasSize(1);
}

@Test // GH-4043
void considersMongoIdWithinTypedCollections() {

UserRef userRef = new UserRef();
userRef.id = "4ee921aca44fd11b3254e001";
userRef.name = "u-1";

Widget widget = new Widget();
widget.id = "w-1";
widget.users = List.of(userRef);

mongoTemplate.save(widget);

Criteria criteria = Criteria.where("users").elemMatch(Criteria.where("id").is("4ee921aca44fd11b3254e001"));
AggregationResults<Widget> aggregate = mongoTemplate.aggregate(newAggregation(match(criteria)), Widget.class, Widget.class);
assertThat(aggregate.getMappedResults()).contains(widget);
}

private void createUsersWithReferencedPersons() {

mongoTemplate.dropCollection(User.class);
Expand Down Expand Up @@ -2250,4 +2269,18 @@ static class WithEnum {
@Id String id;
MyEnum enumValue;
}

@lombok.Data
static class Widget {
@Id
String id;
List<UserRef> users;
}

@lombok.Data
static class UserRef {
@MongoId
String id;
String name;
}
}

0 comments on commit 8bcab93

Please sign in to comment.