diff --git a/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/CallbackNotifier.java b/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/CallbackNotifier.java new file mode 100644 index 00000000..3bb6ba7f --- /dev/null +++ b/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/CallbackNotifier.java @@ -0,0 +1,8 @@ +package co.freeside.betamax.proxy.netty; + +import io.netty.channel.ChannelHandlerContext; + +public interface CallbackNotifier { + void onSuccess(ChannelHandlerContext outboundCtx); + void onFailure(ChannelHandlerContext outboundCtx, Throwable cause); +} diff --git a/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/DirectClientHandler.java b/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/DirectClientHandler.java new file mode 100644 index 00000000..2e578d89 --- /dev/null +++ b/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/DirectClientHandler.java @@ -0,0 +1,28 @@ +package co.freeside.betamax.proxy.netty; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; + +public final class DirectClientHandler extends ChannelInboundHandlerAdapter { + private static final String name = "DIRECT_CLIENT_HANDLER"; + + public static String getName() { + return name; + } + private final CallbackNotifier cb; + + public DirectClientHandler(CallbackNotifier cb) { + this.cb = cb; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) { + ctx.pipeline().remove(this); + cb.onSuccess(ctx); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable throwable) { + cb.onFailure(ctx, throwable); + } +} diff --git a/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/DirectClientInitializer.java b/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/DirectClientInitializer.java new file mode 100644 index 00000000..5ec3817f --- /dev/null +++ b/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/DirectClientInitializer.java @@ -0,0 +1,20 @@ +package co.freeside.betamax.proxy.netty; + +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.socket.SocketChannel; + +public final class DirectClientInitializer extends ChannelInitializer { + + private final CallbackNotifier callbackNotifier; + + public DirectClientInitializer(CallbackNotifier callbackNotifier) { + this.callbackNotifier = callbackNotifier; + } + + @Override + public void initChannel(SocketChannel socketChannel) { + ChannelPipeline channelPipeline = socketChannel.pipeline(); + channelPipeline.addLast(DirectClientHandler.getName(), new DirectClientHandler(callbackNotifier)); + } +} diff --git a/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/HttpChannelInitializer.java b/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/HttpChannelInitializer.java index 5f5ba275..94c701dc 100644 --- a/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/HttpChannelInitializer.java +++ b/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/HttpChannelInitializer.java @@ -16,8 +16,13 @@ public class HttpChannelInitializer extends ChannelInitializer { public static final int MAX_CONTENT_LENGTH = 65536; + static final String HANDLER_HTTP_DECODER = "decoder"; + static final String HANDLER_HTTP_AGGREGATOR = "aggregator"; + static final String HANDLER_HTTP_ENCODER = "encoder"; + static final String HANDLER_CHUNKED_WRITER = "chunkedWriter"; + private final ChannelHandler handler; - private final EventLoopGroup workerGroup; + protected final EventLoopGroup workerGroup; public HttpChannelInitializer(int workerThreads, ChannelHandler handler) { this.handler = handler; @@ -32,10 +37,10 @@ public HttpChannelInitializer(int workerThreads, ChannelHandler handler) { @Override public void initChannel(SocketChannel channel) throws Exception { ChannelPipeline pipeline = channel.pipeline(); - pipeline.addLast("decoder", new HttpRequestDecoder()); - pipeline.addLast("aggregator", new HttpObjectAggregator(MAX_CONTENT_LENGTH)); - pipeline.addLast("encoder", new HttpResponseEncoder()); - pipeline.addLast("chunkedWriter", new ChunkedWriteHandler()); + pipeline.addLast(HANDLER_HTTP_DECODER, new HttpRequestDecoder()); + pipeline.addLast(HANDLER_HTTP_AGGREGATOR, new HttpObjectAggregator(MAX_CONTENT_LENGTH)); + pipeline.addLast(HANDLER_HTTP_ENCODER, new HttpResponseEncoder()); + pipeline.addLast(HANDLER_CHUNKED_WRITER, new ChunkedWriteHandler()); if (workerGroup == null) { pipeline.addLast("betamaxHandler", handler); } else { diff --git a/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/NettyHelpers.java b/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/NettyHelpers.java new file mode 100644 index 00000000..303d5e90 --- /dev/null +++ b/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/NettyHelpers.java @@ -0,0 +1,14 @@ +package co.freeside.betamax.proxy.netty; + +import io.netty.buffer.*; +import io.netty.channel.*; + +class NettyHelpers { + public static void closeOnFlush(Channel ch) { + if (ch.isActive()) { + ch.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE); + } + } + + private NettyHelpers() {} +} diff --git a/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/ProxyConnectHandler.java b/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/ProxyConnectHandler.java index 0aa6e48c..44fa4fde 100644 --- a/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/ProxyConnectHandler.java +++ b/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/ProxyConnectHandler.java @@ -1,23 +1,62 @@ package co.freeside.betamax.proxy.netty; +import java.net.*; +import io.netty.bootstrap.*; import io.netty.channel.*; +import io.netty.channel.socket.nio.*; import io.netty.handler.codec.http.*; import static io.netty.handler.codec.http.HttpMethod.*; -import static io.netty.handler.codec.http.HttpVersion.*; public class ProxyConnectHandler extends SimpleChannelInboundHandler { + private final SocketAddress proxyAddress; + private final Bootstrap bootstrap = new Bootstrap(); + + public ProxyConnectHandler(SocketAddress proxyAddress) { + this.proxyAddress = proxyAddress; + } + @Override public boolean acceptInboundMessage(Object message) throws Exception { System.err.printf("Evaluating %s%n", message); - return super.acceptInboundMessage(message) && CONNECT.equals(((HttpRequest) message).getMethod()); + return super.acceptInboundMessage(message) && isConnectRequest((HttpRequest) message); } @Override - protected void channelRead0(ChannelHandlerContext context, HttpRequest request) { - System.err.println("I actually got a CONNECT. Now what?"); - FullHttpResponse response = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.METHOD_NOT_ALLOWED); - context.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); + protected void channelRead0(final ChannelHandlerContext context, final HttpRequest request) { + CallbackNotifier callback = new CallbackNotifier() { + @Override + public void onSuccess(final ChannelHandlerContext outboundContext) { + HttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + context.channel().writeAndFlush(response) + .addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture channelFuture) { + outboundContext.channel().pipeline().addLast(new RelayHandler(context.channel())); + context.channel().pipeline().addLast(new RelayHandler(outboundContext.channel())); + } + }); + } + + @Override + public void onFailure(ChannelHandlerContext outboundCtx, Throwable cause) { + HttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_GATEWAY); + context.channel().writeAndFlush(response); + NettyHelpers.closeOnFlush(context.channel()); + } + }; + + final Channel inboundChannel = context.channel(); + bootstrap.group(inboundChannel.eventLoop()) + .channel(NioSocketChannel.class) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 10000) + .option(ChannelOption.SO_KEEPALIVE, true) + .handler(new DirectClientInitializer(callback)); + + bootstrap.connect(proxyAddress); } + private boolean isConnectRequest(HttpRequest request) { + return CONNECT.equals(request.getMethod()); + } } diff --git a/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/RelayHandler.java b/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/RelayHandler.java new file mode 100644 index 00000000..931c22f5 --- /dev/null +++ b/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/RelayHandler.java @@ -0,0 +1,45 @@ +package co.freeside.betamax.proxy.netty; + +import io.netty.buffer.*; +import io.netty.channel.*; +import io.netty.util.*; + +public final class RelayHandler extends ChannelInboundHandlerAdapter { + private static final String name = "RELAY_HANDLER"; + + public static String getName() { + return name; + } + + private final Channel relayChannel; + + public RelayHandler(Channel relayChannel) { + this.relayChannel = relayChannel; + } + + @Override + public void channelActive(ChannelHandlerContext context) { + context.writeAndFlush(Unpooled.EMPTY_BUFFER); + } + + @Override + public void channelRead(ChannelHandlerContext context, Object message) { + if (relayChannel.isActive()) { + relayChannel.writeAndFlush(message); + } else { + ReferenceCountUtil.release(message); + } + } + + @Override + public void channelInactive(ChannelHandlerContext context) { + NettyHelpers.closeOnFlush(relayChannel); + } + + @Override + public void exceptionCaught(ChannelHandlerContext context, Throwable cause) { + cause.printStackTrace(); + context.close(); + } + +} diff --git a/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/TunnelingHttpChannelInitializer.java b/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/TunnelingHttpChannelInitializer.java index 2b5dbb9a..c15d0496 100644 --- a/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/TunnelingHttpChannelInitializer.java +++ b/betamax-netty/src/main/java/co/freeside/betamax/proxy/netty/TunnelingHttpChannelInitializer.java @@ -1,18 +1,24 @@ package co.freeside.betamax.proxy.netty; +import java.net.*; import io.netty.channel.*; import io.netty.channel.socket.*; public class TunnelingHttpChannelInitializer extends HttpChannelInitializer { - public TunnelingHttpChannelInitializer(int workerThreads, ChannelHandler handler) { + static final String HANDLER_HTTP_CONNECT = "connector"; + + private final SocketAddress proxyAddress; + + public TunnelingHttpChannelInitializer(int workerThreads, ChannelHandler handler, SocketAddress proxyAddress) { super(workerThreads, handler); + this.proxyAddress = proxyAddress; } @Override public void initChannel(SocketChannel channel) throws Exception { super.initChannel(channel); - channel.pipeline().addAfter("decoder", "connector", new ProxyConnectHandler()); + channel.pipeline().addAfter("betamaxHandler", HANDLER_HTTP_CONNECT, new ProxyConnectHandler(proxyAddress)); } } diff --git a/betamax-proxy/src/main/groovy/co/freeside/betamax/proxy/netty/NettyRequestAdapter.java b/betamax-proxy/src/main/groovy/co/freeside/betamax/proxy/netty/NettyRequestAdapter.java index d69b3a39..9e3e856f 100644 --- a/betamax-proxy/src/main/groovy/co/freeside/betamax/proxy/netty/NettyRequestAdapter.java +++ b/betamax-proxy/src/main/groovy/co/freeside/betamax/proxy/netty/NettyRequestAdapter.java @@ -40,7 +40,11 @@ public String getMethod() { @Override public URI getUri() { try { - return new URI(delegate.getUri()); + String uri = delegate.getUri(); + if (!uri.startsWith("http")) { + uri = "https://" + uri; + } + return new URI(uri); } catch (URISyntaxException e) { throw new IllegalStateException("Invalid URI in underlying request", e); } diff --git a/betamax-proxy/src/main/groovy/co/freeside/betamax/proxy/netty/ProxyServer.groovy b/betamax-proxy/src/main/groovy/co/freeside/betamax/proxy/netty/ProxyServer.groovy index 3a7fe38b..7a4be647 100644 --- a/betamax-proxy/src/main/groovy/co/freeside/betamax/proxy/netty/ProxyServer.groovy +++ b/betamax-proxy/src/main/groovy/co/freeside/betamax/proxy/netty/ProxyServer.groovy @@ -41,9 +41,10 @@ class ProxyServer implements HttpInterceptor { proxyHandler = new BetamaxChannelHandler() proxyHandler << new DefaultHandlerChain(recorder, newHttpClient()) - def standardInitializer = new TunnelingHttpChannelInitializer(0, proxyHandler); // TODO: correct worker threads? After all nothing in Betamax is actually async so we should probably not tie up the main thread + final int sslPort = recorder.proxyPort + 1 + def standardInitializer = new TunnelingHttpChannelInitializer(0, proxyHandler, new InetSocketAddress("localhost", sslPort)); // TODO: correct worker threads? After all nothing in Betamax is actually async so we should probably not tie up the main thread def secureInitializer = new HttpsChannelInitializer(0, proxyHandler); // TODO: correct worker threads? After all nothing in Betamax is actually async so we should probably not tie up the main thread - proxyServer = new NettyBetamaxServer(recorder.proxyPort, recorder.proxyPort + 1, standardInitializer, secureInitializer) + proxyServer = new NettyBetamaxServer(recorder.proxyPort, sslPort, standardInitializer, secureInitializer) } @Override