diff --git a/spring-session-core/src/main/java/org/springframework/session/web/server/session/SpringSessionWebSessionManager.java b/spring-session-core/src/main/java/org/springframework/session/web/server/session/SpringSessionWebSessionManager.java index 2cee57fa..2a55905b 100644 --- a/spring-session-core/src/main/java/org/springframework/session/web/server/session/SpringSessionWebSessionManager.java +++ b/spring-session-core/src/main/java/org/springframework/session/web/server/session/SpringSessionWebSessionManager.java @@ -109,21 +109,21 @@ public class SpringSessionWebSessionManager implements WebSessionManager { public Mono getSession(ServerWebExchange exchange) { // @formatter:off return Mono.defer(() -> - retrieveSession(exchange) + retrieveSession(exchange)) .flatMap(session -> removeSessionIfExpired(exchange, session)) .flatMap(session -> { Instant lastAccessTime = Instant.now(getClock()); return this.sessionStore.setLastAccessedTime(session, lastAccessTime); }) .switchIfEmpty(createSession(exchange)) - .doOnNext(session -> exchange.getResponse().beforeCommit(session::save))); + .doOnNext(session -> exchange.getResponse().beforeCommit(session::save)); // @formatter:on } private Mono retrieveSession(ServerWebExchange exchange) { // @formatter:off return Flux.fromIterable(getSessionIdResolver().resolveSessionIds(exchange)) - .concatMap(this.sessionStore::retrieveSession) + .concatMap(sessionId -> this.sessionStore.retrieveSession(sessionId, session -> saveSession(exchange, session))) .cast(WebSession.class) .next(); // @formatter:on @@ -167,7 +167,7 @@ public class SpringSessionWebSessionManager implements WebSessionManager { } private Mono createSession(ServerWebExchange exchange) { - return this.sessionStore.createSession(); + return this.sessionStore.createSession(session -> saveSession(exchange, session)); } } diff --git a/spring-session-core/src/main/java/org/springframework/session/web/server/session/SpringSessionWebSessionStore.java b/spring-session-core/src/main/java/org/springframework/session/web/server/session/SpringSessionWebSessionStore.java index c73f494c..73aad93b 100644 --- a/spring-session-core/src/main/java/org/springframework/session/web/server/session/SpringSessionWebSessionStore.java +++ b/spring-session-core/src/main/java/org/springframework/session/web/server/session/SpringSessionWebSessionStore.java @@ -27,7 +27,7 @@ import java.util.Iterator; import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Supplier; +import java.util.function.Function; import reactor.core.publisher.Mono; @@ -56,8 +56,8 @@ class SpringSessionWebSessionStore implements WebSessionStore this.sessions = sessions; } - public Mono createSession() { - return this.sessions.createSession().map(this::createSession); + public Mono createSession(Function> saveOperation) { + return this.sessions.createSession().map(session -> this.createSession(session, saveOperation)); } public Mono setLastAccessedTime(WebSession session, @@ -77,7 +77,11 @@ class SpringSessionWebSessionStore implements WebSessionStore @Override public Mono retrieveSession(String sessionId) { - return this.sessions.findById(sessionId).map(this::existingSession); + return Mono.error(new UnsupportedOperationException("This method is not supported. Use retrieveSession(String,Function>)")); + } + + public Mono retrieveSession(String sessionId, Function> saveOperation) { + return this.sessions.findById(sessionId).map(session -> this.existingSession(session, saveOperation)); } @Override @@ -85,12 +89,12 @@ class SpringSessionWebSessionStore implements WebSessionStore return storeSession(webSession); } - private SpringSessionWebSession createSession(S session) { - return new SpringSessionWebSession(session, State.NEW); + private SpringSessionWebSession createSession(S session, Function> saveOperation) { + return new SpringSessionWebSession(session, State.NEW, saveOperation); } - private SpringSessionWebSession existingSession(S session) { - return new SpringSessionWebSession(session, State.STARTED); + private SpringSessionWebSession existingSession(S session, Function> saveOperation) { + return new SpringSessionWebSession(session, State.STARTED, saveOperation); } @Override @@ -250,13 +254,14 @@ class SpringSessionWebSessionStore implements WebSessionStore private AtomicReference state = new AtomicReference<>(); - private volatile transient Supplier> saveOperation = Mono::empty; + private final Function> saveOperation; - SpringSessionWebSession(S session, State state) { + SpringSessionWebSession(S session, State state, Function> saveOperation) { Assert.notNull(session, "session cannot be null"); this.session = session; this.attributes = new SpringSessionMap(session); this.state.set(state); + this.saveOperation = saveOperation; } @Override @@ -291,7 +296,7 @@ class SpringSessionWebSessionStore implements WebSessionStore @Override public Mono save() { - return this.saveOperation.get(); + return this.saveOperation.apply(this); } @Override diff --git a/spring-session-core/src/test/java/org/springframework/session/web/server/session/SpringSessionWebSessionManagerTests.java b/spring-session-core/src/test/java/org/springframework/session/web/server/session/SpringSessionWebSessionManagerTests.java index b8af6e50..5a4ae082 100644 --- a/spring-session-core/src/test/java/org/springframework/session/web/server/session/SpringSessionWebSessionManagerTests.java +++ b/spring-session-core/src/test/java/org/springframework/session/web/server/session/SpringSessionWebSessionManagerTests.java @@ -29,11 +29,15 @@ import reactor.core.publisher.Mono; import reactor.test.StepVerifier; import org.springframework.http.HttpCookie; +import org.springframework.http.ResponseCookie; +import org.springframework.http.codec.ServerCodecConfigurer; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.session.MapReactorSessionRepository; import org.springframework.session.ReactorSessionRepository; import org.springframework.session.Session; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebSession; +import org.springframework.web.server.i18n.LocaleContextResolver; import org.springframework.web.server.session.WebSessionIdResolver; import static org.assertj.core.api.Assertions.assertThat; @@ -56,6 +60,12 @@ public class SpringSessionWebSessionManagerTests { @Mock private WebSessionIdResolver resolver; + @Mock + private ServerCodecConfigurer serverCodecConfigurer; + + @Mock + private LocaleContextResolver localeContextResolver; + @Mock private S createSession; @@ -143,4 +153,23 @@ public class SpringSessionWebSessionManagerTests { verify(this.sessions).findById(findByIdSessionId); } + @Test + public void commitWrites() { + MapReactorSessionRepository repository = new MapReactorSessionRepository(); + this.manager = new SpringSessionWebSessionManager(repository); + Mono getSession = this.manager.getSession(this.exchange) + .doOnSuccess(session -> session.getAttributes().put("foo", "bar")) + .flatMap(webSession -> this.exchange.getResponse().setComplete()); + StepVerifier.create(getSession) + .expectComplete() + .verify(); + + ResponseCookie sessionCookie = this.exchange.getResponse().getCookies() + .getFirst("SESSION"); + assertThat(sessionCookie).isNotNull(); + + Session session = repository.findById(sessionCookie.getValue()).block(); + assertThat(session).isNotNull(); + assertThat(session.getAttribute("foo")).isEqualTo("bar"); + } } diff --git a/spring-session-core/src/test/java/org/springframework/session/web/server/session/SpringSessionWebSessionStoreTests.java b/spring-session-core/src/test/java/org/springframework/session/web/server/session/SpringSessionWebSessionStoreTests.java index 81212515..858227e1 100644 --- a/spring-session-core/src/test/java/org/springframework/session/web/server/session/SpringSessionWebSessionStoreTests.java +++ b/spring-session-core/src/test/java/org/springframework/session/web/server/session/SpringSessionWebSessionStoreTests.java @@ -20,6 +20,7 @@ import java.util.AbstractMap; import java.util.Collections; import java.util.Map; import java.util.Set; +import java.util.function.Function; import org.junit.Before; import org.junit.Test; @@ -55,6 +56,8 @@ public class SpringSessionWebSessionStoreTests { @Mock private S findByIdSession; + private Function> saveOperation; + private SpringSessionWebSessionStore webSessionStore; @Before @@ -73,7 +76,8 @@ public class SpringSessionWebSessionStoreTests { @Test public void createSessionWhenNoAttributesThenNotStarted() { - WebSession createdWebSession = this.webSessionStore.createSession().block(); + WebSession createdWebSession = this.webSessionStore.createSession(this.saveOperation) + .block(); assertThat(createdWebSession.isStarted()).isFalse(); } @@ -82,14 +86,16 @@ public class SpringSessionWebSessionStoreTests { public void createSessionWhenAddAttributeThenStarted() { given(this.createSession.getAttributeNames()) .willReturn(Collections.singleton("a")); - WebSession createdWebSession = this.webSessionStore.createSession().block(); + WebSession createdWebSession = this.webSessionStore.createSession(this.saveOperation) + .block(); assertThat(createdWebSession.isStarted()).isTrue(); } @Test public void createSessionWhenGetAttributesAndSizeThenDelegatesToCreateSession() { - WebSession createdWebSession = this.webSessionStore.createSession().block(); + WebSession createdWebSession = this.webSessionStore.createSession(this.saveOperation) + .block(); Map attributes = createdWebSession.getAttributes(); @@ -103,7 +109,8 @@ public class SpringSessionWebSessionStoreTests { @Test public void createSessionWhenGetAttributesAndIsEmptyThenDelegatesToCreateSession() { - WebSession createdWebSession = this.webSessionStore.createSession().block(); + WebSession createdWebSession = this.webSessionStore.createSession(this.saveOperation) + .block(); Map attributes = createdWebSession.getAttributes(); @@ -117,7 +124,8 @@ public class SpringSessionWebSessionStoreTests { @Test public void createSessionWhenGetAttributesAndContainsKeyAndNotStringThenFalse() { - WebSession createdWebSession = this.webSessionStore.createSession().block(); + WebSession createdWebSession = this.webSessionStore.createSession(this.saveOperation) + .block(); Map attributes = createdWebSession.getAttributes(); @@ -126,7 +134,8 @@ public class SpringSessionWebSessionStoreTests { @Test public void createSessionWhenGetAttributesAndContainsKeyAndNotFoundThenFalse() { - WebSession createdWebSession = this.webSessionStore.createSession().block(); + WebSession createdWebSession = this.webSessionStore.createSession(this.saveOperation) + .block(); Map attributes = createdWebSession.getAttributes(); @@ -137,7 +146,8 @@ public class SpringSessionWebSessionStoreTests { public void createSessionWhenGetAttributesAndContainsKeyAndFoundThenTrue() { given(this.createSession.getAttributeNames()) .willReturn(Collections.singleton("a")); - WebSession createdWebSession = this.webSessionStore.createSession().block(); + WebSession createdWebSession = this.webSessionStore.createSession(this.saveOperation) + .block(); Map attributes = createdWebSession.getAttributes(); @@ -146,7 +156,8 @@ public class SpringSessionWebSessionStoreTests { @Test public void createSessionWhenGetAttributesAndPutThenDelegatesToCreateSession() { - WebSession createdWebSession = this.webSessionStore.createSession().block(); + WebSession createdWebSession = this.webSessionStore.createSession(this.saveOperation) + .block(); Map attributes = createdWebSession.getAttributes(); attributes.put("a", "b"); @@ -156,7 +167,8 @@ public class SpringSessionWebSessionStoreTests { @Test public void createSessionWhenGetAttributesAndPutNullThenDelegatesToCreateSession() { - WebSession createdWebSession = this.webSessionStore.createSession().block(); + WebSession createdWebSession = this.webSessionStore.createSession(this.saveOperation) + .block(); Map attributes = createdWebSession.getAttributes(); attributes.put("a", null); @@ -166,7 +178,8 @@ public class SpringSessionWebSessionStoreTests { @Test public void createSessionWhenGetAttributesAndRemoveThenDelegatesToCreateSession() { - WebSession createdWebSession = this.webSessionStore.createSession().block(); + WebSession createdWebSession = this.webSessionStore.createSession(this.saveOperation) + .block(); Map attributes = createdWebSession.getAttributes(); attributes.remove("a"); @@ -176,7 +189,8 @@ public class SpringSessionWebSessionStoreTests { @Test public void createSessionWhenGetAttributesAndPutAllThenDelegatesToCreateSession() { - WebSession createdWebSession = this.webSessionStore.createSession().block(); + WebSession createdWebSession = this.webSessionStore.createSession(this.saveOperation) + .block(); Map attributes = createdWebSession.getAttributes(); attributes.putAll(Collections.singletonMap("a", "b")); @@ -188,7 +202,8 @@ public class SpringSessionWebSessionStoreTests { public void createSessionWhenGetAttributesAndClearThenDelegatesToCreateSession() { given(this.createSession.getAttributeNames()) .willReturn(Collections.singleton("a")); - WebSession createdWebSession = this.webSessionStore.createSession().block(); + WebSession createdWebSession = this.webSessionStore.createSession(this.saveOperation) + .block(); Map attributes = createdWebSession.getAttributes(); attributes.clear(); @@ -200,7 +215,8 @@ public class SpringSessionWebSessionStoreTests { public void createSessionWhenGetAttributesAndKeySetThenDelegatesToCreateSession() { given(this.createSession.getAttributeNames()) .willReturn(Collections.singleton("a")); - WebSession createdWebSession = this.webSessionStore.createSession().block(); + WebSession createdWebSession = this.webSessionStore.createSession(this.saveOperation) + .block(); Map attributes = createdWebSession.getAttributes(); @@ -212,7 +228,8 @@ public class SpringSessionWebSessionStoreTests { given(this.createSession.getAttributeNames()) .willReturn(Collections.singleton("a")); given(this.createSession.getAttribute("a")).willReturn("b"); - WebSession createdWebSession = this.webSessionStore.createSession().block(); + WebSession createdWebSession = this.webSessionStore.createSession(this.saveOperation) + .block(); Map attributes = createdWebSession.getAttributes(); @@ -226,7 +243,8 @@ public class SpringSessionWebSessionStoreTests { .willReturn(Collections.singleton(attrName)); String attrValue = "attrValue"; given(this.createSession.getAttribute(attrName)).willReturn(attrValue); - WebSession createdWebSession = this.webSessionStore.createSession().block(); + WebSession createdWebSession = this.webSessionStore.createSession(this.saveOperation) + .block(); Map attributes = createdWebSession.getAttributes(); Set> entries = attributes.entrySet(); @@ -238,7 +256,8 @@ public class SpringSessionWebSessionStoreTests { @Test public void storeSessionWhenInvokedThenSessionSaved() { given(this.sessionRepository.save(this.createSession)).willReturn(Mono.empty()); - WebSession createdSession = this.webSessionStore.createSession().block(); + WebSession createdSession = this.webSessionStore.createSession(this.saveOperation) + .block(); this.webSessionStore.storeSession(createdSession).block(); @@ -248,7 +267,8 @@ public class SpringSessionWebSessionStoreTests { @Test public void retrieveSessionThenStarted() { String id = "id"; - WebSession retrievedWebSession = this.webSessionStore.retrieveSession(id).block(); + WebSession retrievedWebSession = this.webSessionStore + .retrieveSession(id, this.saveOperation).block(); assertThat(retrievedWebSession.isStarted()).isTrue(); }