From 66d98b355ee063c869443edb1e3d2c444f3c5fd7 Mon Sep 17 00:00:00 2001 From: Oliver Gierke Date: Mon, 15 Oct 2012 11:00:12 -0400 Subject: [PATCH] DATAMONGO-550 - Fixed potential NullPointerExceptions in MongoTemplate. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Guarded access to results of mappingContext.getPersistentEntity(…) to prevent NullPointerExceptions in case the template is used with non-entity types like a plain DBObject. Added some custom logic for plain BasicDBObjects to get the _id field populated appropriately. --- .../data/mongodb/core/MongoTemplate.java | 18 +++++--- .../data/mongodb/core/MongoTemplateTests.java | 41 +++++++++++++++++++ 2 files changed, 53 insertions(+), 6 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index 010aafb52..560004f1b 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -505,13 +505,12 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { } public T findById(Object id, Class entityClass) { - MongoPersistentEntity persistentEntity = mappingContext.getPersistentEntity(entityClass); - return findById(id, entityClass, persistentEntity.getCollection()); + return findById(id, entityClass, determineCollectionName(entityClass)); } public T findById(Object id, Class entityClass, String collectionName) { MongoPersistentEntity persistentEntity = mappingContext.getPersistentEntity(entityClass); - MongoPersistentProperty idProperty = persistentEntity.getIdProperty(); + MongoPersistentProperty idProperty = persistentEntity == null ? null : persistentEntity.getIdProperty(); String idKey = idProperty == null ? ID : idProperty.getName(); return doFindOne(collectionName, new BasicDBObject(idKey, id), null, entityClass); } @@ -795,7 +794,7 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { assertUpdateableIdIfNotSet(objectToSave); - BasicDBObject dbDoc = new BasicDBObject(); + DBObject dbDoc = new BasicDBObject(); maybeEmitEvent(new BeforeConvertEvent(objectToSave)); writer.write(objectToSave, dbDoc); @@ -982,7 +981,7 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { Assert.notNull(object); MongoPersistentEntity entity = mappingContext.getPersistentEntity(object.getClass()); - MongoPersistentProperty idProp = entity.getIdProperty(); + MongoPersistentProperty idProp = entity == null ? null : entity.getIdProperty(); if (idProp == null) { throw new MappingException("No id property found for object of type " + entity.getType().getName()); @@ -1438,6 +1437,12 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { return; } + if (savedObject instanceof BasicDBObject) { + DBObject dbObject = (DBObject) savedObject; + dbObject.put(ID, id); + return; + } + MongoPersistentProperty idProp = getIdPropertyFor(savedObject.getClass()); if (idProp == null) { @@ -1555,7 +1560,8 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { } private MongoPersistentProperty getIdPropertyFor(Class type) { - return mappingContext.getPersistentEntity(type).getIdProperty(); + MongoPersistentEntity persistentEntity = mappingContext.getPersistentEntity(type); + return persistentEntity == null ? null : persistentEntity.getIdProperty(); } private String determineEntityCollectionName(T obj) { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateTests.java index 6440674f0..edca3d7ae 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateTests.java @@ -42,6 +42,7 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.core.convert.converter.Converter; import org.springframework.dao.DataAccessException; import org.springframework.dao.DataIntegrityViolationException; +import org.springframework.dao.InvalidDataAccessApiUsageException; import org.springframework.dao.OptimisticLockingFailureException; import org.springframework.data.annotation.Id; import org.springframework.data.annotation.PersistenceConstructor; @@ -63,6 +64,7 @@ import org.springframework.data.mongodb.core.query.Update; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import com.mongodb.BasicDBObject; import com.mongodb.DBCollection; import com.mongodb.DBCursor; import com.mongodb.DBObject; @@ -143,6 +145,7 @@ public class MongoTemplateTests { template.dropCollection(TestClass.class); template.dropCollection(Sample.class); template.dropCollection(MyPerson.class); + template.dropCollection("collection"); } @Test @@ -1352,6 +1355,44 @@ public class MongoTemplateTests { template.save(null); } + /** + * @see DATAMONGO-550 + */ + @Test + public void savesPlainDbObjectCorrectly() { + + DBObject dbObject = new BasicDBObject("foo", "bar"); + template.save(dbObject, "collection"); + + assertThat(dbObject.containsField("_id"), is(true)); + } + + /** + * @see DATAMONGO-550 + */ + @Test(expected = InvalidDataAccessApiUsageException.class) + public void rejectsPlainObjectWithOutExplicitCollection() { + + DBObject dbObject = new BasicDBObject("foo", "bar"); + template.save(dbObject, "collection"); + + template.findById(dbObject.get("_id"), DBObject.class); + } + + /** + * @see DATAMONGO-550 + */ + @Test + public void readsPlainDbObjectById() { + + DBObject dbObject = new BasicDBObject("foo", "bar"); + template.save(dbObject, "collection"); + + DBObject result = template.findById(dbObject.get("_id"), DBObject.class, "collection"); + assertThat(result.get("foo"), is(dbObject.get("foo"))); + assertThat(result.get("_id"), is(dbObject.get("_id"))); + } + static class MyId { String first;