Skip to content

Commit

Permalink
harden serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
ceki committed Feb 7, 2017
1 parent 909afb9 commit f46044b
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
package ch.qos.logback.classic.net;

import java.io.IOException;
import java.lang.reflect.Constructor;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.ArrayList;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package ch.qos.logback.classic.net.server;

import java.util.ArrayList;
import java.util.List;

import org.slf4j.helpers.BasicMarker;

import ch.qos.logback.classic.Logger;
import ch.qos.logback.classic.spi.LoggerContextVO;
import ch.qos.logback.classic.spi.LoggingEventVO;
import ch.qos.logback.classic.spi.ThrowableProxyVO;

public class LogbackClassicSerializationHelper {



static public List<String> getWhilelist() {
List<String> whitelist = new ArrayList<String>();
whitelist.add(LoggingEventVO.class.getName());
whitelist.add(LoggerContextVO.class.getName());
whitelist.add(ThrowableProxyVO.class.getName());
whitelist.add(StackTraceElement.class.getName());
whitelist.add(BasicMarker.class.getName());
whitelist.add(BasicMarker.class.getName());
whitelist.add(Logger.class.getName());
return whitelist;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
package ch.qos.logback.classic;

import java.io.*;
import java.util.List;

import ch.qos.logback.classic.net.server.LogbackClassicSerializationHelper;
import ch.qos.logback.core.net.HardenedObjectInputStream;
import ch.qos.logback.core.util.CoreTestConstants;
import org.junit.After;
import org.junit.Before;
Expand All @@ -36,7 +39,8 @@ public class LoggerSerializationTest {
ByteArrayOutputStream bos;
ObjectOutputStream oos;
ObjectInputStream inputStream;

List<String> whitelist ;

@Before
public void setUp() throws Exception {
lc = new LoggerContext();
Expand All @@ -45,6 +49,8 @@ public void setUp() throws Exception {
// create the byte output stream
bos = new ByteArrayOutputStream();
oos = new ObjectOutputStream(bos);
whitelist = LogbackClassicSerializationHelper.getWhilelist();
whitelist.add(Foo.class.getName());
}

@After
Expand Down Expand Up @@ -110,7 +116,7 @@ public void deepTreeSerialization() throws IOException {
private Foo writeAndRead(Foo foo) throws IOException, ClassNotFoundException {
writeObject(oos, foo);
ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
inputStream = new ObjectInputStream(bis);
inputStream = new HardenedObjectInputStream(bis, whitelist);
Foo fooBack = readFooObject(inputStream);
inputStream.close();
return fooBack;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package ch.qos.logback.core.net;

import java.io.IOException;
import java.io.InputStream;
import java.io.InvalidClassException;
import java.io.ObjectInputStream;
import java.io.ObjectStreamClass;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/**
*
* @author Ceki G&uuml;lc&uuml;
* @since 1.2.0
*/
public class HardenedObjectInputStream extends ObjectInputStream {

List<String> whitelistedClassNames;
String[] javaPackages = new String[] {"java.lang", "java.util"};

public HardenedObjectInputStream(InputStream in, List<String> whilelist) throws IOException {
super(in);
this.whitelistedClassNames = Collections.synchronizedList(new ArrayList<String>(whilelist));
}

@Override
protected Class<?> resolveClass(ObjectStreamClass anObjectStreamClass) throws IOException, ClassNotFoundException {
String incomingClassName = anObjectStreamClass.getName();
if(!isWhitelisted(incomingClassName)) {
throw new InvalidClassException("Unauthorized deserialization attempt", anObjectStreamClass.getName());
}

return super.resolveClass(anObjectStreamClass);
}

private boolean isWhitelisted(String incomingClassName) {
for(int i = 0; i < javaPackages.length; i++) {
if(incomingClassName.startsWith(javaPackages[i]))
return true;
}
for(String className: whitelistedClassNames) {
if(incomingClassName.equals(className))
return true;
}
return false;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package ch.qos.logback.core.net;

import static org.junit.Assert.*;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.List;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;

public class HardenedObjectInputStreamTest {

ByteArrayOutputStream bos;
ObjectOutputStream oos;
HardenedObjectInputStream inputStream;
List<String> whitelist = new ArrayList<String>();

@Before
public void setUp() throws Exception {
whitelist.add(Innocent.class.getName());
bos = new ByteArrayOutputStream();
oos = new ObjectOutputStream(bos);
}

@After
public void tearDown() throws Exception {
}

@Test
public void smoke() throws ClassNotFoundException, IOException {
Innocent innocent = new Innocent();
innocent.setAnInt(1);
innocent.setAnInteger(2);
innocent.setaString("smoke");
Innocent back = writeAndRead(innocent);
assertEquals(innocent, back);
}



private Innocent writeAndRead(Innocent innocent) throws IOException, ClassNotFoundException {
writeObject(oos, innocent);
ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
inputStream = new HardenedObjectInputStream(bis, whitelist);
Innocent fooBack = (Innocent) inputStream.readObject();
inputStream.close();
return fooBack;
}

private void writeObject(ObjectOutputStream oos, Object o) throws IOException {
oos.writeObject(o);
oos.flush();
oos.close();
}

}
69 changes: 69 additions & 0 deletions logback-core/src/test/java/ch/qos/logback/core/net/Innocent.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package ch.qos.logback.core.net;

public class Innocent implements java.io.Serializable {

private static final long serialVersionUID = -1227008349289885025L;

int anInt;
Integer anInteger;
String aString;

public int getAnInt() {
return anInt;
}

public void setAnInt(int anInt) {
this.anInt = anInt;
}

public Integer getAnInteger() {
return anInteger;
}

public void setAnInteger(Integer anInteger) {
this.anInteger = anInteger;
}

public String getaString() {
return aString;
}

public void setaString(String aString) {
this.aString = aString;
}

@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((aString == null) ? 0 : aString.hashCode());
result = prime * result + anInt;
result = prime * result + ((anInteger == null) ? 0 : anInteger.hashCode());
return result;
}

@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Innocent other = (Innocent) obj;
if (aString == null) {
if (other.aString != null)
return false;
} else if (!aString.equals(other.aString))
return false;
if (anInt != other.anInt)
return false;
if (anInteger == null) {
if (other.anInteger != null)
return false;
} else if (!anInteger.equals(other.anInteger))
return false;
return true;
}

}

1 comment on commit f46044b

@carnil
Copy link

@carnil carnil commented on f46044b Mar 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.