From 74dd2f7c9637d02067993505d7a50eb6900a063c Mon Sep 17 00:00:00 2001 From: Greg Schohn Date: Tue, 17 Sep 2024 11:53:41 -0400 Subject: [PATCH] Proxy header translations: PR feedback and a critical bugfix. The bugfix was that the state machine for striking header lines would continue even AFTER the header was parsed. That meant that if a header name ever appeared at the beginning of a line within the contents of a payload, we'd strike that whole line. A test was added to exhibit the behavior, requiring some refactoring to the SimpleNettyHttpServer to get the payload that hits that server. That's used to confirm that we're leaving the body exactly intact in the new test. All other previous tests continue to pass. Signed-off-by: Greg Schohn --- .../proxyserver/CaptureProxy.java | 42 +++++++++--- .../proxyserver/netty/HeaderAdderHandler.java | 9 ++- .../netty/HeaderRemoverHandler.java | 44 ++++++++++-- .../proxyserver/TestHeaderRewrites.java | 68 +++++++++++++++++-- .../netty/HeaderRemoverHandlerTest.java | 8 +++ .../testutils/SimpleNettyHttpServer.java | 36 ++++++++-- 6 files changed, 178 insertions(+), 29 deletions(-) diff --git a/TrafficCapture/trafficCaptureProxyServer/src/main/java/org/opensearch/migrations/trafficcapture/proxyserver/CaptureProxy.java b/TrafficCapture/trafficCaptureProxyServer/src/main/java/org/opensearch/migrations/trafficcapture/proxyserver/CaptureProxy.java index c73b060bf..0d891910e 100644 --- a/TrafficCapture/trafficCaptureProxyServer/src/main/java/org/opensearch/migrations/trafficcapture/proxyserver/CaptureProxy.java +++ b/TrafficCapture/trafficCaptureProxyServer/src/main/java/org/opensearch/migrations/trafficcapture/proxyserver/CaptureProxy.java @@ -9,11 +9,12 @@ import java.nio.file.Paths; import java.time.Duration; import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Properties; -import java.util.TreeMap; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.function.Supplier; @@ -52,6 +53,7 @@ import com.beust.jcommander.JCommander; import com.beust.jcommander.Parameter; import com.beust.jcommander.ParameterException; +import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.socket.SocketChannel; import io.netty.handler.ssl.SslContext; @@ -348,7 +350,7 @@ protected static Map convertPairListToMap(List list) { if (list == null) { return Map.of(); } - var map = new TreeMap(); + var map = new LinkedHashMap(); for (int i = 0; i < list.size(); i += 2) { map.put(list.get(i), list.get(i + 1)); } @@ -425,10 +427,22 @@ static ProxyChannelInitializer buildProxyChannelInitializer(RootCaptureContext r BacksideConnectionPool backsideConnectionPool, Supplier sslEngineSupplier, @NonNull RequestCapturePredicate headerCapturePredicate, - List headerOverrides, + List headerOverridesArgs, IConnectionCaptureFactory connectionFactory) { - var headers = convertPairListToMap(headerOverrides); + var headers = new ArrayList<>(convertPairListToMap(headerOverridesArgs).entrySet()); + Collections.reverse(headers); + final var removeStrings = new ArrayList(headers.size()); + final var addBufs = new ArrayList(headers.size()); + + for (var kvp : headers) { + addBufs.add( + Unpooled.unreleasableBuffer( + Unpooled.wrappedBuffer( + (kvp.getKey() + ": " + kvp.getValue()).getBytes(StandardCharsets.UTF_8)))); + removeStrings.add(kvp.getKey() + ":"); + } + return new ProxyChannelInitializer( rootContext, backsideConnectionPool, @@ -439,14 +453,20 @@ static ProxyChannelInitializer buildProxyChannelInitializer(RootCaptureContext r @Override protected void initChannel(@NonNull SocketChannel ch) throws IOException { super.initChannel(ch); - for (var kvp : headers.entrySet()) { - var lineBytes = (kvp.getKey() + ": " + kvp.getValue()).getBytes(StandardCharsets.UTF_8); - ch.pipeline().addAfter(ProxyChannelInitializer.CAPTURE_HANDLER_NAME, "AddHeader-" + kvp.getKey(), - new HeaderAdderHandler(Unpooled.unreleasableBuffer(Unpooled.wrappedBuffer(lineBytes)))); + final var pipeline = ch.pipeline(); + { + int i = 0; + for (var kvp : headers) { + pipeline.addAfter(ProxyChannelInitializer.CAPTURE_HANDLER_NAME, "AddHeader-" + kvp.getKey(), + new HeaderAdderHandler(addBufs.get(i++))); + } } - for (var k : headers.keySet()) { - ch.pipeline().addAfter(ProxyChannelInitializer.CAPTURE_HANDLER_NAME, "RemoveHeader-" + k, - new HeaderRemoverHandler(k + ":")); + { + int i = 0; + for (var kvp : headers) { + pipeline.addAfter(ProxyChannelInitializer.CAPTURE_HANDLER_NAME, "RemoveHeader-" + kvp.getKey(), + new HeaderRemoverHandler(removeStrings.get(i++))); + } } } }; diff --git a/TrafficCapture/trafficCaptureProxyServer/src/main/java/org/opensearch/migrations/trafficcapture/proxyserver/netty/HeaderAdderHandler.java b/TrafficCapture/trafficCaptureProxyServer/src/main/java/org/opensearch/migrations/trafficcapture/proxyserver/netty/HeaderAdderHandler.java index 147684898..e4bf3d528 100644 --- a/TrafficCapture/trafficCaptureProxyServer/src/main/java/org/opensearch/migrations/trafficcapture/proxyserver/netty/HeaderAdderHandler.java +++ b/TrafficCapture/trafficCaptureProxyServer/src/main/java/org/opensearch/migrations/trafficcapture/proxyserver/netty/HeaderAdderHandler.java @@ -6,6 +6,7 @@ import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.util.ReferenceCountUtil; public class HeaderAdderHandler extends ChannelInboundHandlerAdapter { private static final ByteBuf CRLF_BYTE_BUF = @@ -37,7 +38,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception var composite = Unpooled.compositeBuffer(4); buf.resetReaderIndex(); composite.addComponent(true, buf.retainedSlice(0, upToIndex)); - composite.addComponent(true, headerLineToAdd.duplicate()); + composite.addComponent(true, headerLineToAdd.retainedDuplicate()); composite.addComponent(true, (useCarriageReturn ? CRLF_BYTE_BUF : LF_BYTE_BUF).duplicate()); composite.addComponent(true, buf.retainedSlice(upToIndex, buf.readableBytes()-upToIndex)); buf.release(); @@ -49,4 +50,10 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception buf.resetReaderIndex(); super.channelRead(ctx, msg); } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + ReferenceCountUtil.release(headerLineToAdd); + super.channelUnregistered(ctx); + } } diff --git a/TrafficCapture/trafficCaptureProxyServer/src/main/java/org/opensearch/migrations/trafficcapture/proxyserver/netty/HeaderRemoverHandler.java b/TrafficCapture/trafficCaptureProxyServer/src/main/java/org/opensearch/migrations/trafficcapture/proxyserver/netty/HeaderRemoverHandler.java index 56accf96b..1edf7afbe 100644 --- a/TrafficCapture/trafficCaptureProxyServer/src/main/java/org/opensearch/migrations/trafficcapture/proxyserver/netty/HeaderRemoverHandler.java +++ b/TrafficCapture/trafficCaptureProxyServer/src/main/java/org/opensearch/migrations/trafficcapture/proxyserver/netty/HeaderRemoverHandler.java @@ -16,6 +16,11 @@ public class HeaderRemoverHandler extends ChannelInboundHandlerAdapter { // when dropUntilNewline == true, we're dropping, otherwise, we're copying (when previousRemaining==null) // The starting state is previousRemaining == null and dropUntilNewline = false boolean dropUntilNewline; + MessagePosition requestPosition = MessagePosition.IN_HEADER; + + private enum MessagePosition { + IN_HEADER, ONE_NEW_LINE, AFTER_HEADERS, + } public HeaderRemoverHandler(String headerToRemove) { if (!headerToRemove.endsWith(":")) { @@ -49,10 +54,7 @@ boolean matchNextBytes(ChannelHandlerContext ctx, ByteBuf buf) { } buf.markReaderIndex(); if (Character.toLowerCase(headerToRemove.charAt(i)) != Character.toLowerCase(buf.readByte())) { // no match - previousRemaining.forEach(bb -> lambdaSafeSuperChannelRead(ctx, bb.retain())); - previousRemaining.removeComponents(0, previousRemaining.numComponents()); - previousRemaining.release(); - previousRemaining = null; + flushAndClearPreviousRemaining(ctx); buf.resetReaderIndex(); dropUntilNewline = false; return false; @@ -60,6 +62,13 @@ boolean matchNextBytes(ChannelHandlerContext ctx, ByteBuf buf) { } } + void flushAndClearPreviousRemaining(ChannelHandlerContext ctx) { + previousRemaining.forEach(bb -> lambdaSafeSuperChannelRead(ctx, bb.retain())); + previousRemaining.removeComponents(0, previousRemaining.numComponents()); + previousRemaining.release(); + previousRemaining = null; + } + boolean advanceByteBufUntilNewline(ByteBuf bb) { while (bb.isReadable()) { // sonar lint doesn't like if the while statement has an empty body if (bb.readByte() == '\n') { return true; } @@ -81,16 +90,36 @@ CompositeByteBuf addSliceToComposite(ChannelHandlerContext ctx, CompositeByteBuf @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - if (!(msg instanceof ByteBuf)) { + if (!(msg instanceof ByteBuf) || requestPosition == MessagePosition.AFTER_HEADERS) { super.channelRead(ctx, msg); return; } var sourceBuf = (ByteBuf) msg; - var currentSourceSegmentStart = (previousRemaining != null || dropUntilNewline) ? -1 : sourceBuf.readerIndex(); + var currentSourceSegmentStart = + (previousRemaining != null || dropUntilNewline || requestPosition == MessagePosition.ONE_NEW_LINE) + ? -1 : sourceBuf.readerIndex(); CompositeByteBuf cleanedIncomingBuf = null; + sourceBuf.markReaderIndex(); while (sourceBuf.isReadable()) { + if (requestPosition == MessagePosition.ONE_NEW_LINE) { + final var nextByte = sourceBuf.readByte(); + if (nextByte == '\n' || nextByte == '\r') { + requestPosition = MessagePosition.AFTER_HEADERS; + if (currentSourceSegmentStart == -1) { + currentSourceSegmentStart = sourceBuf.readerIndex() - 1; + } + sourceBuf.readerIndex(sourceBuf.writerIndex()); + break; + } else { + previousRemaining = ctx.alloc().compositeBuffer(16); + requestPosition = MessagePosition.IN_HEADER; + sourceBuf.resetReaderIndex(); + continue; + } + } + if (previousRemaining != null) { final var sourceReaderIdx = sourceBuf.readerIndex(); if (matchNextBytes(ctx, sourceBuf)) { @@ -106,7 +135,8 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception } } else { if (advanceByteBufUntilNewline(sourceBuf)) { - previousRemaining = ctx.alloc().compositeBuffer(16); + sourceBuf.markReaderIndex(); + requestPosition = MessagePosition.ONE_NEW_LINE; } else { break; } diff --git a/TrafficCapture/trafficCaptureProxyServer/src/test/java/org/opensearch/migrations/trafficcapture/proxyserver/TestHeaderRewrites.java b/TrafficCapture/trafficCaptureProxyServer/src/test/java/org/opensearch/migrations/trafficcapture/proxyserver/TestHeaderRewrites.java index b45039166..a110a3362 100644 --- a/TrafficCapture/trafficCaptureProxyServer/src/test/java/org/opensearch/migrations/trafficcapture/proxyserver/TestHeaderRewrites.java +++ b/TrafficCapture/trafficCaptureProxyServer/src/test/java/org/opensearch/migrations/trafficcapture/proxyserver/TestHeaderRewrites.java @@ -1,5 +1,6 @@ package org.opensearch.migrations.trafficcapture.proxyserver; +import java.io.ByteArrayInputStream; import java.net.URI; import java.nio.charset.StandardCharsets; import java.time.Duration; @@ -22,8 +23,15 @@ @Slf4j public class TestHeaderRewrites { + + public static final String ONLY_FOR_HEADERS_VALUE = "this is only for headers"; +public static final String BODY_WITH_HEADERS_CONTENTS = "\n" + + "body: should stay\n" + + "body: untouched\n" + + "body:\n"; + @Test - public void testRewrites() throws Exception { + public void testHeaderRewrites() throws Exception { final var payloadBytes = "Success".getBytes(StandardCharsets.UTF_8); final var headers = Map.of( "Content-Type", @@ -44,7 +52,7 @@ public void testRewrites() throws Exception { Duration.ofMinutes(10), fl -> { capturedRequestList.add(fl); - log.error("headers: " + fl.getHeaders().stream().map(kvp->kvp.getKey()+": "+kvp.getValue()) + log.trace("headers: " + fl.getHeaders().stream().map(kvp->kvp.getKey()+": "+kvp.getValue()) .collect(Collectors.joining())); return new SimpleHttpResponse(headers, payloadBytes, "OK", 200); }); @@ -55,16 +63,68 @@ public void testRewrites() throws Exception { proxy.start(); final var proxyEndpoint = CaptureProxyContainer.getUriFromContainer(proxy); - var allHeaders = new LinkedHashMap(); allHeaders.put("Host", "localhost"); allHeaders.put("User-Agent", "UnitTest"); var response = client.makeGetRequest(new URI(proxyEndpoint), allHeaders.entrySet().stream()); - log.error("response=" + response); var capturedRequest = capturedRequestList.get(capturedRequestList.size()-1).getHeaders().stream() .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); Assertions.assertEquals("localhost", capturedRequest.get("host")); Assertions.assertEquals("insignificant value", capturedRequest.get("X-new-header")); } } + + @Test + public void testBodyDoesntRewrite() throws Exception { + final var payloadBytes = "Success".getBytes(StandardCharsets.UTF_8); + final var headers = Map.of( + "Content-Type", + "text/plain", + "Content-Length", + "" + payloadBytes.length + ); + var rewriteArgs = List.of( + "--setHeader", + "host", + "localhost", + "--setHeader", + "body", + ONLY_FOR_HEADERS_VALUE + ); + var capturedRequestList = new ArrayList(); + var capturedBodies = new ArrayList(); + try (var destinationServer = SimpleNettyHttpServer.makeNettyServer(false, + Duration.ofMinutes(10), + fullRequest -> { + var request = new SimpleNettyHttpServer.RequestToAdapter(fullRequest); + capturedRequestList.add(request); + log.atTrace().setMessage(() -> "headers: " + + request.getHeaders().stream().map(kvp->kvp.getKey()+": "+kvp.getValue()) + .collect(Collectors.joining())).log(); + capturedBodies.add(fullRequest.content().toString(StandardCharsets.UTF_8)); + return new SimpleHttpResponse(headers, payloadBytes, "OK", 200); + }); + var proxy = new CaptureProxyContainer(() -> destinationServer.localhostEndpoint().toString(), null, + rewriteArgs.stream()); + var client = new SimpleHttpClientForTesting(); + var bodyStream = new ByteArrayInputStream(BODY_WITH_HEADERS_CONTENTS.getBytes(StandardCharsets.UTF_8))) + { + proxy.start(); + final var proxyEndpoint = CaptureProxyContainer.getUriFromContainer(proxy); + + var allHeaders = new LinkedHashMap(); + allHeaders.put("Host", "localhost"); + allHeaders.put("User-Agent", "UnitTest"); + var response = client.makePutRequest(new URI(proxyEndpoint), allHeaders.entrySet().stream(), + new SimpleHttpClientForTesting.PayloadAndContentType(bodyStream, "text/plain")); + log.error("response=" + response); + var capturedRequest = capturedRequestList.get(capturedRequestList.size()-1).getHeaders().stream() + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + Assertions.assertEquals("localhost", capturedRequest.get("host")); + Assertions.assertEquals(ONLY_FOR_HEADERS_VALUE, capturedRequest.get("body")); + + var lastBody = capturedBodies.get(capturedBodies.size()-1); + Assertions.assertEquals(BODY_WITH_HEADERS_CONTENTS, lastBody); + } + } } diff --git a/TrafficCapture/trafficCaptureProxyServer/src/test/java/org/opensearch/migrations/trafficcapture/proxyserver/netty/HeaderRemoverHandlerTest.java b/TrafficCapture/trafficCaptureProxyServer/src/test/java/org/opensearch/migrations/trafficcapture/proxyserver/netty/HeaderRemoverHandlerTest.java index d1011e7d7..133057c06 100644 --- a/TrafficCapture/trafficCaptureProxyServer/src/test/java/org/opensearch/migrations/trafficcapture/proxyserver/netty/HeaderRemoverHandlerTest.java +++ b/TrafficCapture/trafficCaptureProxyServer/src/test/java/org/opensearch/migrations/trafficcapture/proxyserver/netty/HeaderRemoverHandlerTest.java @@ -49,6 +49,12 @@ public void runTestWithSize(Function messageMaker, IntStream siz outputBuf.release(); } + @Test + public void newlinesArePreserved() { + runTestsWithSize((b,s) -> "GET / HTTP/1.1\r\n" + (b ? "host: localhost\r\n" : "") + "\r\n", + () -> IntStream.of(Integer.MAX_VALUE)); + } + @Test public void throwsOnHostFormatError() { Assertions.assertThrows(IllegalArgumentException.class, () -> new HeaderRemoverHandler("host")); @@ -87,6 +93,7 @@ public void randomFragmentedCheckInterlaced() { final var bound = getBound(HeaderRemoverHandlerTest::makeInterlacedMessage); for (int i=0; i "random run={}").addArgument(i).log(); runTestsWithSize(HeaderRemoverHandlerTest::makeInterlacedMessage, () -> IntStream.generate(() -> r.nextInt(bound))); } @@ -107,6 +114,7 @@ public void randomFragmentedCheckConsecutive() { final var bound = getBound(HeaderRemoverHandlerTest::makeConsecutiveMessage); for (int i=0; i "random run={}").addArgument(i).log(); runTestsWithSize(HeaderRemoverHandlerTest::makeConsecutiveMessage, () -> IntStream.generate(() -> r.nextInt(bound))); } diff --git a/testHelperFixtures/src/testFixtures/java/org/opensearch/migrations/testutils/SimpleNettyHttpServer.java b/testHelperFixtures/src/testFixtures/java/org/opensearch/migrations/testutils/SimpleNettyHttpServer.java index 9713e4676..41c2fc260 100644 --- a/testHelperFixtures/src/testFixtures/java/org/opensearch/migrations/testutils/SimpleNettyHttpServer.java +++ b/testHelperFixtures/src/testFixtures/java/org/opensearch/migrations/testutils/SimpleNettyHttpServer.java @@ -30,6 +30,7 @@ import io.netty.handler.codec.http.HttpResponseEncoder; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.logging.LoggingHandler; import io.netty.handler.ssl.SslHandler; import io.netty.handler.timeout.ReadTimeoutHandler; import io.netty.util.concurrent.DefaultThreadFactory; @@ -59,13 +60,28 @@ public static SimpleNettyHttpServer makeServer( boolean useTls, Function makeContext ) throws PortFinder.ExceededMaxPortAssigmentAttemptException { - return makeServer(useTls, null, makeContext); + return makeNettyServer(useTls, null, r -> makeContext.apply(new RequestToAdapter(r))); + } + + public static SimpleNettyHttpServer makeNettyServer( + boolean useTls, + Function makeContext + ) throws PortFinder.ExceededMaxPortAssigmentAttemptException { + return makeNettyServer(useTls, null, makeContext); } public static SimpleNettyHttpServer makeServer( boolean useTls, Duration readTimeout, Function makeContext + ) throws PortFinder.ExceededMaxPortAssigmentAttemptException { + return makeNettyServer(useTls, readTimeout, r -> makeContext.apply(new RequestToAdapter(r))); + } + + public static SimpleNettyHttpServer makeNettyServer( + boolean useTls, + Duration readTimeout, + Function makeContext ) throws PortFinder.ExceededMaxPortAssigmentAttemptException { var testServerRef = new AtomicReference(); PortFinder.retryWithNewPortUntilNoThrow(port -> { @@ -112,8 +128,13 @@ HttpHeaders convertHeaders(Map headers) { } private SimpleChannelInboundHandler makeHandlerFromResponseContext( - Function responseBuilder - ) { + Function responseBuilder) { + return makeHandlerFromNettyResponseContext(r -> responseBuilder.apply(new RequestToAdapter(r))); + } + + private SimpleChannelInboundHandler makeHandlerFromNettyResponseContext( + Function responseBuilder) + { return new SimpleChannelInboundHandler<>() { @Override protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest req) { @@ -122,7 +143,7 @@ protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest req) { ctx.close(); return; } - var specifiedResponse = responseBuilder.apply(new RequestToAdapter(req)); + var specifiedResponse = responseBuilder.apply(req); var fullResponse = new DefaultFullHttpResponse( HttpVersion.HTTP_1_1, HttpResponseStatus.valueOf(specifiedResponse.statusCode, specifiedResponse.statusText), @@ -150,7 +171,7 @@ protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest req) { boolean useTLS, int port, Duration timeout, - Function responseBuilder + Function responseBuilder ) throws Exception { this.useTls = useTLS; this.port = port; @@ -172,10 +193,13 @@ protected void initChannel(SocketChannel ch) { if (timeout != null) { pipeline.addLast(new ReadTimeoutHandler(timeout.toMillis(), TimeUnit.MILLISECONDS)); } + pipeline.addLast(new LoggingHandler("A")); pipeline.addLast(new HttpRequestDecoder()); + pipeline.addLast(new LoggingHandler("B")); pipeline.addLast(new HttpObjectAggregator(16 * 1024)); + pipeline.addLast(new LoggingHandler("C")); pipeline.addLast(new HttpResponseEncoder()); - pipeline.addLast(makeHandlerFromResponseContext(responseBuilder)); + pipeline.addLast(makeHandlerFromNettyResponseContext(responseBuilder)); } }); serverChannel = b.bind(port).sync().channel();