Skip to content

Commit

Permalink
Use trusted packages in StreamMessage
Browse files Browse the repository at this point in the history
StreamMessage now uses the same "white list" mechanism as
ObjectMessage to avoid some arbitrary code execution on deserialization.

Even though StreamMessage is supposed to handle only primitive types,
it is still to possible to send a message that contains an arbitrary
serializable instance. The consuming application application may
then execute code from this class on deserialization.

The fix consists in using the list of trusted packages that can be
set at the connection factory level.

Fixes #135
  • Loading branch information
acogoluegnes committed Nov 2, 2020
1 parent 4497022 commit f647e5d
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 6 deletions.
18 changes: 16 additions & 2 deletions src/main/java/com/rabbitmq/jms/client/RMQMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,8 @@ static RMQMessage fromMessage(byte[] b, List<String> trustedPackages) throws RMQ
private static RMQMessage instantiateRmqMessage(String messageClass, List<String> trustedPackages) throws RMQJMSException {
if(isRmqObjectMessageClass(messageClass)) {
return instantiateRmqObjectMessageWithTrustedPackages(trustedPackages);
} else if (isRmqStreamMessageClass(messageClass)) {
return instantiateRmqStreamMessageWithTrustedPackages(trustedPackages);
} else {
try {
// instantiate the message object with the thread context classloader
Expand All @@ -1168,12 +1170,24 @@ private static boolean isRmqObjectMessageClass(String clazz) {
return RMQObjectMessage.class.getName().equals(clazz);
}

private static boolean isRmqStreamMessageClass(String clazz) {
return RMQStreamMessage.class.getName().equals(clazz);
}

private static RMQObjectMessage instantiateRmqObjectMessageWithTrustedPackages(List<String> trustedPackages) throws RMQJMSException {
return (RMQObjectMessage) instantiateRmqMessageWithTrustedPackages(RMQObjectMessage.class.getName(), trustedPackages);
}

private static RMQStreamMessage instantiateRmqStreamMessageWithTrustedPackages(List<String> trustedPackages) throws RMQJMSException {
return (RMQStreamMessage) instantiateRmqMessageWithTrustedPackages(RMQStreamMessage.class.getName(), trustedPackages);
}

private static RMQMessage instantiateRmqMessageWithTrustedPackages(String messageClazz, List<String> trustedPackages) throws RMQJMSException {
try {
// instantiate the message object with the thread context classloader
Class<?> messageClass = Class.forName(RMQObjectMessage.class.getName(), true, Thread.currentThread().getContextClassLoader());
Class<?> messageClass = Class.forName(messageClazz, true, Thread.currentThread().getContextClassLoader());
Constructor<?> constructor = messageClass.getConstructor(List.class);
return (RMQObjectMessage) constructor.newInstance(trustedPackages);
return (RMQMessage) constructor.newInstance(trustedPackages);
} catch (NoSuchMethodException e) {
throw new RMQJMSException(e);
} catch (InvocationTargetException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// Copyright (c) 2013-2020 VMware, Inc. or its affiliates. All rights reserved.
package com.rabbitmq.jms.client.message;

import com.rabbitmq.jms.util.WhiteListObjectInputStream;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.EOFException;
Expand All @@ -15,6 +16,7 @@
import java.io.ObjectOutputStream;
import java.io.UTFDataFormatException;

import java.util.List;
import javax.jms.JMSException;
import javax.jms.MessageEOFException;
import javax.jms.MessageFormatException;
Expand Down Expand Up @@ -47,12 +49,19 @@ public class RMQStreamMessage extends RMQMessage implements StreamMessage {
private volatile transient byte[] buf;
private volatile transient byte[] readbuf = null;

private final List<String> trustedPackages;

public RMQStreamMessage(List<String> trustedPackages) {
this(false, trustedPackages);
}

public RMQStreamMessage() {
this(false);
this(false, WhiteListObjectInputStream.DEFAULT_TRUSTED_PACKAGES);
}

private RMQStreamMessage(boolean reading) {
private RMQStreamMessage(boolean reading, List<String> trustedPackages) {
this.reading = reading;
this.trustedPackages = trustedPackages;
if (!reading) {
this.bout = new ByteArrayOutputStream(RMQMessage.DEFAULT_MESSAGE_BODY_SIZE);
try {
Expand Down Expand Up @@ -513,7 +522,7 @@ protected void readBody(ObjectInput inputStream, ByteArrayInputStream bin) throw
inputStream.read(buf);
this.reading = true;
this.bin = new ByteArrayInputStream(buf);
this.in = new ObjectInputStream(this.bin);
this.in = new WhiteListObjectInputStream(this.bin, this.trustedPackages);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

public class ObjectMessageSerializationIT extends AbstractITQueue {

private static final String QUEUE_NAME = "test.queue." + SimpleQueueMessageDefaultsIT.class.getCanonicalName();
private static final String QUEUE_NAME = "test.queue." + ObjectMessageSerializationIT.class.getCanonicalName();
private static final long TEST_RECEIVE_TIMEOUT = 1000; // one second
private static final java.util.List<String> TRUSTED_PACKAGES = Arrays.asList("java.lang", "com.rabbitmq.jms");

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// Copyright (c) 2013-2020 VMware, Inc. or its affiliates. All rights reserved.
package com.rabbitmq.integration.tests;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.fail;

import com.rabbitmq.jms.admin.RMQConnectionFactory;
import com.rabbitmq.jms.client.message.RMQStreamMessage;
import com.rabbitmq.jms.client.message.TestMessages;
import com.rabbitmq.jms.util.RMQJMSException;
import java.awt.Color;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import javax.jms.Queue;
import javax.jms.QueueReceiver;
import javax.jms.QueueSender;
import javax.jms.QueueSession;
import javax.jms.Session;
import javax.jms.StreamMessage;
import org.junit.jupiter.api.Test;

public class StreamMessageSerializationIT extends AbstractITQueue {

private static final String QUEUE_NAME = "test.queue." + StreamMessageSerializationIT.class.getCanonicalName();
private static final long TEST_RECEIVE_TIMEOUT = 1000; // one second
private static final java.util.List<String> TRUSTED_PACKAGES = Arrays.asList("java.lang", "com.rabbitmq.jms");

@Override
protected void customise(RMQConnectionFactory connectionFactory) {
super.customise(connectionFactory);
connectionFactory.setTrustedPackages(TRUSTED_PACKAGES);
}

protected void testReceiveStreamMessageWithValue(Object value) throws Exception {
try {
queueConn.start();
QueueSession queueSession = queueConn.createQueueSession(false, Session.DUPS_OK_ACKNOWLEDGE);
Queue queue = queueSession.createQueue(QUEUE_NAME);

drainQueue(queueSession, queue);

QueueSender queueSender = queueSession.createSender(queue);
StreamMessage message = (StreamMessage) MessageTestType.STREAM.gen(queueSession, null);

// we simulate an attack from the sender by calling writeObject with a non-primitive value
// (StreamMessage supports only primitive types)
// the value is then sent to the destination and the consumer will have to
// deserialize it and can potentially execute malicious code
Method writeObjectMethod = RMQStreamMessage.class
.getDeclaredMethod("writeObject", Object.class, boolean.class);
writeObjectMethod.setAccessible(true);
writeObjectMethod.invoke(message, value, true);

queueSender.send(message);
} finally {
reconnect(Arrays.asList("java.lang", "com.rabbitmq.jms"));
}

queueConn.start();
QueueSession queueSession = queueConn.createQueueSession(false, Session.DUPS_OK_ACKNOWLEDGE);
Queue queue = queueSession.createQueue(QUEUE_NAME);
QueueReceiver queueReceiver = queueSession.createReceiver(queue);
RMQStreamMessage m = (RMQStreamMessage) queueReceiver.receive(TEST_RECEIVE_TIMEOUT);
MessageTestType.STREAM.check(m, null);
assertEquals(m.readObject(), value);
}

@Test
public void testReceiveStreamMessageWithPrimitiveValue() throws Exception {
testReceiveStreamMessageWithValue(1024L);
testReceiveStreamMessageWithValue("a string");
}

@Test
public void testReceiveStreamMessageWithTrustedValue() throws Exception {
testReceiveStreamMessageWithValue(new TestMessages.TestSerializable(8, "An object"));
}

@Test
public void testReceiveStreamMessageWithUntrustedValue1() throws Exception {
// StreamMessage cannot be used with a Map, unless the sender uses a trick
// this is to simulate an attack from the sender
// Note: java.util is not on the trusted package list
assertThrows(RMQJMSException.class, () -> {
Map<String, String> m = new HashMap<String, String>();
m.put("key", "value");
testReceiveStreamMessageWithValue(m);
});
}
@Test
public void testReceiveStreamMessageWithUntrustedValue2() throws Exception {
// StreamMessage cannot be used with a Map, unless the sender uses a trick
// this is to simulate an attack from the sender
// java.awt is not on the trusted package list
assertThrows(RMQJMSException.class, () -> {
testReceiveStreamMessageWithValue(Color.WHITE);
});
}

protected void reconnect(java.util.List<String> trustedPackages) throws Exception {
if (queueConn != null) {
this.queueConn.close();
((RMQConnectionFactory) connFactory).setTrustedPackages(trustedPackages);
this.queueConn = connFactory.createQueueConnection();
} else {
fail("Cannot reconnect");
}
}
}

0 comments on commit f647e5d

Please sign in to comment.