Skip to content

Commit

Permalink
Merge branch '5.1.x'
Browse files Browse the repository at this point in the history
  • Loading branch information
rstoyanchev committed Jul 3, 2019
2 parents 03a3423 + 4e6e47b commit 3d913b8
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 48 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,7 @@
import java.security.Principal;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;

import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
Expand Down Expand Up @@ -84,6 +85,10 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
public static final String IGNORE_ERROR = "simpIgnoreError";


@Nullable
private Consumer<Principal> userCallback;


/**
* A constructor for creating new message headers.
* This constructor is protected. See factory methods in this and sub-classes.
Expand Down Expand Up @@ -171,6 +176,9 @@ public Map<String, Object> getSessionAttributes() {

public void setUser(@Nullable Principal principal) {
setHeader(USER_HEADER, principal);
if (this.userCallback != null) {
this.userCallback.accept(principal);
}
}

/**
Expand All @@ -181,6 +189,18 @@ public Principal getUser() {
return (Principal) getHeader(USER_HEADER);
}

/**
* Provide a callback to be invoked if and when {@link #setUser(Principal)}
* is called. This is used internally on the inbound channel to detect
* token-based authentications through an interceptor.
* @param callback the callback to invoke
* @since 5.1.9
*/
public void setUserChangeCallback(Consumer<Principal> callback) {
Assert.notNull(callback, "'callback' is required");
this.userCallback = this.userCallback != null ? this.userCallback.andThen(callback) : callback;
}

@Override
public String getShortLogMessage(Object payload) {
if (getMessageType() == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@

package org.springframework.messaging.simp;

import java.security.Principal;
import java.util.Collections;
import java.util.function.Consumer;

import org.junit.Test;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;

/**
* Unit tests for SimpMessageHeaderAccessor.
Expand All @@ -32,7 +35,8 @@ public class SimpMessageHeaderAccessorTests {

@Test
public void getShortLogMessage() {
assertThat(SimpMessageHeaderAccessor.create().getShortLogMessage("p")).isEqualTo("MESSAGE session=null payload=p");
assertThat(SimpMessageHeaderAccessor.create().getShortLogMessage("p"))
.isEqualTo("MESSAGE session=null payload=p");
}

@Test
Expand All @@ -44,8 +48,9 @@ public void getLogMessageWithValuesSet() {
accessor.setUser(new TestPrincipal("user"));
accessor.setSessionAttributes(Collections.<String, Object>singletonMap("key", "value"));

assertThat(accessor.getShortLogMessage("p")).isEqualTo(("MESSAGE destination=/destination subscriptionId=subscription " +
"session=session user=user attributes[1] payload=p"));
assertThat(accessor.getShortLogMessage("p"))
.isEqualTo(("MESSAGE destination=/destination subscriptionId=subscription " +
"session=session user=user attributes[1] payload=p"));
}

@Test
Expand All @@ -58,9 +63,41 @@ public void getDetailedLogMessageWithValuesSet() {
accessor.setSessionAttributes(Collections.<String, Object>singletonMap("key", "value"));
accessor.setNativeHeader("nativeKey", "nativeValue");

assertThat(accessor.getDetailedLogMessage("p")).isEqualTo(("MESSAGE destination=/destination subscriptionId=subscription " +
"session=session user=user attributes={key=value} nativeHeaders=" +
"{nativeKey=[nativeValue]} payload=p"));
assertThat(accessor.getDetailedLogMessage("p"))
.isEqualTo(("MESSAGE destination=/destination subscriptionId=subscription " +
"session=session user=user attributes={key=value} nativeHeaders=" +
"{nativeKey=[nativeValue]} payload=p"));
}

@Test
public void userChangeCallback() {
UserCallback userCallback = new UserCallback();
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create();
accessor.setUserChangeCallback(userCallback);

Principal user1 = mock(Principal.class);
accessor.setUser(user1);
assertThat(userCallback.getUser()).isEqualTo(user1);

Principal user2 = mock(Principal.class);
accessor.setUser(user2);
assertThat(userCallback.getUser()).isEqualTo(user2);
}


private static class UserCallback implements Consumer<Principal> {

private Principal user;


public Principal getUser() {
return this.user;
}

@Override
public void accept(Principal principal) {
this.user = principal;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,19 @@ else if (webSocketMessage instanceof BinaryMessage) {
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
Assert.state(headerAccessor != null, "No StompHeaderAccessor");

StompCommand command = headerAccessor.getCommand();
boolean isConnect = StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command);

headerAccessor.setSessionId(session.getId());
headerAccessor.setSessionAttributes(session.getAttributes());
headerAccessor.setUser(getUser(session));
if (isConnect) {
headerAccessor.setUserChangeCallback(user -> {
if (user != null && user != session.getPrincipal()) {
this.stompAuthentications.put(session.getId(), user);
}
});
}
headerAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, headerAccessor.getHeartbeat());
if (!detectImmutableMessageInterceptor(outputChannel)) {
headerAccessor.setImmutable();
Expand All @@ -279,8 +289,6 @@ else if (webSocketMessage instanceof BinaryMessage) {
logger.trace("From client: " + headerAccessor.getShortLogMessage(message.getPayload()));
}

StompCommand command = headerAccessor.getCommand();
boolean isConnect = StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command);
if (isConnect) {
this.stats.incrementConnectCount();
}
Expand All @@ -293,12 +301,6 @@ else if (StompCommand.DISCONNECT.equals(command)) {
boolean sent = outputChannel.send(message);

if (sent) {
if (isConnect) {
Principal user = headerAccessor.getUser();
if (user != null && user != session.getPrincipal()) {
this.stompAuthentications.put(session.getId(), user);
}
}
if (this.eventPublisher != null) {
Principal user = getUser(session);
if (isConnect) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -166,7 +166,6 @@ public void handleRequest(HttpServletRequest servletRequest, HttpServletResponse
}
this.handshakeHandler.doHandshake(request, response, this.wsHandler, attributes);
chain.applyAfterHandshake(request, response, null);
response.close();
}
catch (HandshakeFailureException ex) {
failure = ex;
Expand All @@ -177,8 +176,10 @@ public void handleRequest(HttpServletRequest servletRequest, HttpServletResponse
finally {
if (failure != null) {
chain.applyAfterHandshake(request, response, failure);
response.close();
throw failure;
}
response.close();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,15 @@ public void handleMessageFromClientWithTokenAuthentication() {
Principal user = SimpMessageHeaderAccessor.getUser(message.getHeaders());
assertThat(user).isNotNull();
assertThat(user.getName()).isEqualTo("__pete__@gmail.com");

StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECTED);
message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
handler.handleMessageToClient(this.session, message);

assertThat(this.session.getSentMessages()).hasSize(1);
WebSocketMessage<?> textMessage = this.session.getSentMessages().get(0);
assertThat(textMessage.getPayload())
.isEqualTo("CONNECTED\n" + "user-name:__pete__@gmail.com\n" + "\n" + "\u0000");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import org.junit.Test;

import org.springframework.http.HttpHeaders;
import org.springframework.web.socket.AbstractHttpRequestTests;
import org.springframework.web.socket.SubProtocolCapable;
import org.springframework.web.socket.WebSocketExtension;
Expand Down Expand Up @@ -51,14 +52,9 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests {
public void supportedSubProtocols() {
this.handshakeHandler.setSupportedProtocols("stomp", "mqtt");
given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[] {"13"});
this.servletRequest.setMethod("GET");

WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders());
headers.setUpgrade("WebSocket");
headers.setConnection("Upgrade");
headers.setSecWebSocketVersion("13");
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
headers.setSecWebSocketProtocol("STOMP");
this.servletRequest.setMethod("GET");
initHeaders(this.request.getHeaders()).setSecWebSocketProtocol("STOMP");

WebSocketHandler handler = new TextWebSocketHandler();
Map<String, Object> attributes = Collections.emptyMap();
Expand All @@ -77,16 +73,10 @@ public void supportedExtensions() {
given(this.upgradeStrategy.getSupportedExtensions(this.request)).willReturn(Collections.singletonList(extension1));

this.servletRequest.setMethod("GET");

WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders());
headers.setUpgrade("WebSocket");
headers.setConnection("Upgrade");
headers.setSecWebSocketVersion("13");
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
headers.setSecWebSocketExtensions(Arrays.asList(extension1, extension2));
initHeaders(this.request.getHeaders()).setSecWebSocketExtensions(Arrays.asList(extension1, extension2));

WebSocketHandler handler = new TextWebSocketHandler();
Map<String, Object> attributes = Collections.<String, Object>emptyMap();
Map<String, Object> attributes = Collections.emptyMap();
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);

verify(this.upgradeStrategy).upgrade(this.request, this.response, null,
Expand All @@ -98,16 +88,10 @@ public void subProtocolCapableHandler() {
given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[] {"13"});

this.servletRequest.setMethod("GET");

WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders());
headers.setUpgrade("WebSocket");
headers.setConnection("Upgrade");
headers.setSecWebSocketVersion("13");
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
headers.setSecWebSocketProtocol("v11.stomp");
initHeaders(this.request.getHeaders()).setSecWebSocketProtocol("v11.stomp");

WebSocketHandler handler = new SubProtocolCapableHandler("v12.stomp", "v11.stomp");
Map<String, Object> attributes = Collections.<String, Object>emptyMap();
Map<String, Object> attributes = Collections.emptyMap();
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);

verify(this.upgradeStrategy).upgrade(this.request, this.response, "v11.stomp",
Expand All @@ -119,22 +103,25 @@ public void subProtocolCapableHandlerNoMatch() {
given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[] {"13"});

this.servletRequest.setMethod("GET");

WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders());
headers.setUpgrade("WebSocket");
headers.setConnection("Upgrade");
headers.setSecWebSocketVersion("13");
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
headers.setSecWebSocketProtocol("v10.stomp");
initHeaders(this.request.getHeaders()).setSecWebSocketProtocol("v10.stomp");

WebSocketHandler handler = new SubProtocolCapableHandler("v12.stomp", "v11.stomp");
Map<String, Object> attributes = Collections.<String, Object>emptyMap();
Map<String, Object> attributes = Collections.emptyMap();
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);

verify(this.upgradeStrategy).upgrade(this.request, this.response, null,
Collections.emptyList(), null, handler, attributes);
}

private WebSocketHttpHeaders initHeaders(HttpHeaders httpHeaders) {
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(httpHeaders);
headers.setUpgrade("WebSocket");
headers.setConnection("Upgrade");
headers.setSecWebSocketVersion("13");
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
return headers;
}


private static class SubProtocolCapableHandler extends TextWebSocketHandler implements SubProtocolCapable {

Expand Down
Loading

0 comments on commit 3d913b8

Please sign in to comment.