Skip to content

Commit 4e6e47b

Browse files
committed
Earlier detection of token authentication
Use a callback to detect token authentication (via inteceptor) thus avoiding a potential race between that detection after the message is sent on the inbound channel (via Executor) and the processing of the CONNECTED frame returned from the broker on the outbound channel. Closes gh-23160
1 parent 5af9a8e commit 4e6e47b

File tree

4 files changed

+75
-10
lines changed

4 files changed

+75
-10
lines changed

spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2017 the original author or authors.
2+
* Copyright 2002-2019 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@
1919
import java.security.Principal;
2020
import java.util.List;
2121
import java.util.Map;
22+
import java.util.function.Consumer;
2223

2324
import org.springframework.lang.Nullable;
2425
import org.springframework.messaging.Message;
@@ -84,6 +85,10 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
8485
public static final String IGNORE_ERROR = "simpIgnoreError";
8586

8687

88+
@Nullable
89+
private Consumer<Principal> userCallback;
90+
91+
8792
/**
8893
* A constructor for creating new message headers.
8994
* This constructor is protected. See factory methods in this and sub-classes.
@@ -171,6 +176,9 @@ public Map<String, Object> getSessionAttributes() {
171176

172177
public void setUser(@Nullable Principal principal) {
173178
setHeader(USER_HEADER, principal);
179+
if (this.userCallback != null) {
180+
this.userCallback.accept(principal);
181+
}
174182
}
175183

176184
/**
@@ -181,6 +189,18 @@ public Principal getUser() {
181189
return (Principal) getHeader(USER_HEADER);
182190
}
183191

192+
/**
193+
* Provide a callback to be invoked if and when {@link #setUser(Principal)}
194+
* is called. This is used internally on the inbound channel to detect
195+
* token-based authentications through an interceptor.
196+
* @param callback the callback to invoke
197+
* @since 5.1.9
198+
*/
199+
public void setUserChangeCallback(Consumer<Principal> callback) {
200+
Assert.notNull(callback, "'callback' is required");
201+
this.userCallback = this.userCallback != null ? this.userCallback.andThen(callback) : callback;
202+
}
203+
184204
@Override
185205
public String getShortLogMessage(Object payload) {
186206
if (getMessageType() == null) {

spring-messaging/src/test/java/org/springframework/messaging/simp/SimpMessageHeaderAccessorTests.java

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2014 the original author or authors.
2+
* Copyright 2002-2019 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,11 +16,14 @@
1616

1717
package org.springframework.messaging.simp;
1818

19+
import java.security.Principal;
1920
import java.util.Collections;
21+
import java.util.function.Consumer;
2022

2123
import org.junit.Test;
2224

2325
import static org.junit.Assert.*;
26+
import static org.mockito.Mockito.mock;
2427

2528
/**
2629
* Unit tests for SimpMessageHeaderAccessor.
@@ -63,4 +66,35 @@ public void getDetailedLogMessageWithValuesSet() {
6366
"{nativeKey=[nativeValue]} payload=p", accessor.getDetailedLogMessage("p"));
6467
}
6568

69+
@Test
70+
public void userChangeCallback() {
71+
UserCallback userCallback = new UserCallback();
72+
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create();
73+
accessor.setUserChangeCallback(userCallback);
74+
75+
Principal user1 = mock(Principal.class);
76+
accessor.setUser(user1);
77+
assertEquals(user1, userCallback.getUser());
78+
79+
Principal user2 = mock(Principal.class);
80+
accessor.setUser(user2);
81+
assertEquals(user2, userCallback.getUser());
82+
}
83+
84+
85+
private static class UserCallback implements Consumer<Principal> {
86+
87+
private Principal user;
88+
89+
90+
public Principal getUser() {
91+
return this.user;
92+
}
93+
94+
@Override
95+
public void accept(Principal principal) {
96+
this.user = principal;
97+
}
98+
}
99+
66100
}

spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,19 @@ else if (webSocketMessage instanceof BinaryMessage) {
258258
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
259259
Assert.state(headerAccessor != null, "No StompHeaderAccessor");
260260

261+
StompCommand command = headerAccessor.getCommand();
262+
boolean isConnect = StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command);
263+
261264
headerAccessor.setSessionId(session.getId());
262265
headerAccessor.setSessionAttributes(session.getAttributes());
263266
headerAccessor.setUser(getUser(session));
267+
if (isConnect) {
268+
headerAccessor.setUserChangeCallback(user -> {
269+
if (user != null && user != session.getPrincipal()) {
270+
this.stompAuthentications.put(session.getId(), user);
271+
}
272+
});
273+
}
264274
headerAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, headerAccessor.getHeartbeat());
265275
if (!detectImmutableMessageInterceptor(outputChannel)) {
266276
headerAccessor.setImmutable();
@@ -270,8 +280,6 @@ else if (webSocketMessage instanceof BinaryMessage) {
270280
logger.trace("From client: " + headerAccessor.getShortLogMessage(message.getPayload()));
271281
}
272282

273-
StompCommand command = headerAccessor.getCommand();
274-
boolean isConnect = StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command);
275283
if (isConnect) {
276284
this.stats.incrementConnectCount();
277285
}
@@ -284,12 +292,6 @@ else if (StompCommand.DISCONNECT.equals(command)) {
284292
boolean sent = outputChannel.send(message);
285293

286294
if (sent) {
287-
if (isConnect) {
288-
Principal user = headerAccessor.getUser();
289-
if (user != null && user != session.getPrincipal()) {
290-
this.stompAuthentications.put(session.getId(), user);
291-
}
292-
}
293295
if (this.eventPublisher != null) {
294296
Principal user = getUser(session);
295297
if (isConnect) {

spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,15 @@ public void handleMessageFromClientWithTokenAuthentication() {
378378
Principal user = SimpMessageHeaderAccessor.getUser(message.getHeaders());
379379
assertNotNull(user);
380380
assertEquals("[email protected]", user.getName());
381+
382+
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECTED);
383+
message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
384+
handler.handleMessageToClient(this.session, message);
385+
386+
assertEquals(1, this.session.getSentMessages().size());
387+
WebSocketMessage<?> textMessage = this.session.getSentMessages().get(0);
388+
assertEquals("CONNECTED\n" + "user-name:[email protected]\n" + "\n" + "\u0000",
389+
textMessage.getPayload());
381390
}
382391

383392
@Test

0 commit comments

Comments
 (0)