Skip to content

Commit

Permalink
Proxy header translations: PR feedback and a critical bugfix.
Browse files Browse the repository at this point in the history
The bugfix was that the state machine for striking header lines would continue even AFTER the header was parsed.  That meant that if a header name ever appeared at the beginning of a line within the contents of a payload, we'd strike that whole line.  A test was added to exhibit the behavior, requiring some refactoring to the SimpleNettyHttpServer to get the payload that hits that server.  That's used to confirm that we're leaving the body exactly intact in the new test.  All other previous tests continue to pass.

Signed-off-by: Greg Schohn <[email protected]>
  • Loading branch information
gregschohn committed Sep 17, 2024
1 parent 58a1892 commit 74dd2f7
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
import java.nio.file.Paths;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Properties;
import java.util.TreeMap;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.Supplier;
Expand Down Expand Up @@ -52,6 +53,7 @@
import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.ssl.SslContext;
Expand Down Expand Up @@ -348,7 +350,7 @@ protected static Map<String, String> convertPairListToMap(List<String> list) {
if (list == null) {
return Map.of();
}
var map = new TreeMap<String, String>();
var map = new LinkedHashMap<String, String>();
for (int i = 0; i < list.size(); i += 2) {
map.put(list.get(i), list.get(i + 1));
}
Expand Down Expand Up @@ -425,10 +427,22 @@ static ProxyChannelInitializer buildProxyChannelInitializer(RootCaptureContext r
BacksideConnectionPool backsideConnectionPool,
Supplier<SSLEngine> sslEngineSupplier,
@NonNull RequestCapturePredicate headerCapturePredicate,
List<String> headerOverrides,
List<String> headerOverridesArgs,
IConnectionCaptureFactory connectionFactory)
{
var headers = convertPairListToMap(headerOverrides);
var headers = new ArrayList<>(convertPairListToMap(headerOverridesArgs).entrySet());
Collections.reverse(headers);
final var removeStrings = new ArrayList<String>(headers.size());
final var addBufs = new ArrayList<ByteBuf>(headers.size());

for (var kvp : headers) {
addBufs.add(
Unpooled.unreleasableBuffer(
Unpooled.wrappedBuffer(
(kvp.getKey() + ": " + kvp.getValue()).getBytes(StandardCharsets.UTF_8))));
removeStrings.add(kvp.getKey() + ":");
}

return new ProxyChannelInitializer(
rootContext,
backsideConnectionPool,
Expand All @@ -439,14 +453,20 @@ static ProxyChannelInitializer buildProxyChannelInitializer(RootCaptureContext r
@Override
protected void initChannel(@NonNull SocketChannel ch) throws IOException {
super.initChannel(ch);
for (var kvp : headers.entrySet()) {
var lineBytes = (kvp.getKey() + ": " + kvp.getValue()).getBytes(StandardCharsets.UTF_8);
ch.pipeline().addAfter(ProxyChannelInitializer.CAPTURE_HANDLER_NAME, "AddHeader-" + kvp.getKey(),
new HeaderAdderHandler(Unpooled.unreleasableBuffer(Unpooled.wrappedBuffer(lineBytes))));
final var pipeline = ch.pipeline();
{
int i = 0;
for (var kvp : headers) {
pipeline.addAfter(ProxyChannelInitializer.CAPTURE_HANDLER_NAME, "AddHeader-" + kvp.getKey(),
new HeaderAdderHandler(addBufs.get(i++)));
}
}
for (var k : headers.keySet()) {
ch.pipeline().addAfter(ProxyChannelInitializer.CAPTURE_HANDLER_NAME, "RemoveHeader-" + k,
new HeaderRemoverHandler(k + ":"));
{
int i = 0;
for (var kvp : headers) {
pipeline.addAfter(ProxyChannelInitializer.CAPTURE_HANDLER_NAME, "RemoveHeader-" + kvp.getKey(),
new HeaderRemoverHandler(removeStrings.get(i++)));
}
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.ReferenceCountUtil;

public class HeaderAdderHandler extends ChannelInboundHandlerAdapter {
private static final ByteBuf CRLF_BYTE_BUF =
Expand Down Expand Up @@ -37,7 +38,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
var composite = Unpooled.compositeBuffer(4);
buf.resetReaderIndex();
composite.addComponent(true, buf.retainedSlice(0, upToIndex));
composite.addComponent(true, headerLineToAdd.duplicate());
composite.addComponent(true, headerLineToAdd.retainedDuplicate());
composite.addComponent(true, (useCarriageReturn ? CRLF_BYTE_BUF : LF_BYTE_BUF).duplicate());
composite.addComponent(true, buf.retainedSlice(upToIndex, buf.readableBytes()-upToIndex));
buf.release();
Expand All @@ -49,4 +50,10 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
buf.resetReaderIndex();
super.channelRead(ctx, msg);
}

@Override
public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
ReferenceCountUtil.release(headerLineToAdd);
super.channelUnregistered(ctx);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ public class HeaderRemoverHandler extends ChannelInboundHandlerAdapter {
// when dropUntilNewline == true, we're dropping, otherwise, we're copying (when previousRemaining==null)
// The starting state is previousRemaining == null and dropUntilNewline = false
boolean dropUntilNewline;
MessagePosition requestPosition = MessagePosition.IN_HEADER;

private enum MessagePosition {
IN_HEADER, ONE_NEW_LINE, AFTER_HEADERS,
}

public HeaderRemoverHandler(String headerToRemove) {
if (!headerToRemove.endsWith(":")) {
Expand Down Expand Up @@ -49,17 +54,21 @@ boolean matchNextBytes(ChannelHandlerContext ctx, ByteBuf buf) {
}
buf.markReaderIndex();
if (Character.toLowerCase(headerToRemove.charAt(i)) != Character.toLowerCase(buf.readByte())) { // no match
previousRemaining.forEach(bb -> lambdaSafeSuperChannelRead(ctx, bb.retain()));
previousRemaining.removeComponents(0, previousRemaining.numComponents());
previousRemaining.release();
previousRemaining = null;
flushAndClearPreviousRemaining(ctx);
buf.resetReaderIndex();
dropUntilNewline = false;
return false;
}
}
}

void flushAndClearPreviousRemaining(ChannelHandlerContext ctx) {
previousRemaining.forEach(bb -> lambdaSafeSuperChannelRead(ctx, bb.retain()));
previousRemaining.removeComponents(0, previousRemaining.numComponents());
previousRemaining.release();
previousRemaining = null;
}

boolean advanceByteBufUntilNewline(ByteBuf bb) {
while (bb.isReadable()) { // sonar lint doesn't like if the while statement has an empty body
if (bb.readByte() == '\n') { return true; }
Expand All @@ -81,16 +90,36 @@ CompositeByteBuf addSliceToComposite(ChannelHandlerContext ctx, CompositeByteBuf

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (!(msg instanceof ByteBuf)) {
if (!(msg instanceof ByteBuf) || requestPosition == MessagePosition.AFTER_HEADERS) {
super.channelRead(ctx, msg);
return;
}

var sourceBuf = (ByteBuf) msg;
var currentSourceSegmentStart = (previousRemaining != null || dropUntilNewline) ? -1 : sourceBuf.readerIndex();
var currentSourceSegmentStart =
(previousRemaining != null || dropUntilNewline || requestPosition == MessagePosition.ONE_NEW_LINE)
? -1 : sourceBuf.readerIndex();
CompositeByteBuf cleanedIncomingBuf = null;
sourceBuf.markReaderIndex();

while (sourceBuf.isReadable()) {
if (requestPosition == MessagePosition.ONE_NEW_LINE) {
final var nextByte = sourceBuf.readByte();
if (nextByte == '\n' || nextByte == '\r') {
requestPosition = MessagePosition.AFTER_HEADERS;
if (currentSourceSegmentStart == -1) {
currentSourceSegmentStart = sourceBuf.readerIndex() - 1;
}
sourceBuf.readerIndex(sourceBuf.writerIndex());
break;
} else {
previousRemaining = ctx.alloc().compositeBuffer(16);
requestPosition = MessagePosition.IN_HEADER;
sourceBuf.resetReaderIndex();
continue;
}
}

if (previousRemaining != null) {
final var sourceReaderIdx = sourceBuf.readerIndex();
if (matchNextBytes(ctx, sourceBuf)) {
Expand All @@ -106,7 +135,8 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
}
} else {
if (advanceByteBufUntilNewline(sourceBuf)) {
previousRemaining = ctx.alloc().compositeBuffer(16);
sourceBuf.markReaderIndex();
requestPosition = MessagePosition.ONE_NEW_LINE;
} else {
break;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.opensearch.migrations.trafficcapture.proxyserver;

import java.io.ByteArrayInputStream;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
Expand All @@ -22,8 +23,15 @@

@Slf4j
public class TestHeaderRewrites {

public static final String ONLY_FOR_HEADERS_VALUE = "this is only for headers";
public static final String BODY_WITH_HEADERS_CONTENTS = "\n" +
"body: should stay\n" +
"body: untouched\n" +
"body:\n";

@Test
public void testRewrites() throws Exception {
public void testHeaderRewrites() throws Exception {
final var payloadBytes = "Success".getBytes(StandardCharsets.UTF_8);
final var headers = Map.of(
"Content-Type",
Expand All @@ -44,7 +52,7 @@ public void testRewrites() throws Exception {
Duration.ofMinutes(10),
fl -> {
capturedRequestList.add(fl);
log.error("headers: " + fl.getHeaders().stream().map(kvp->kvp.getKey()+": "+kvp.getValue())
log.trace("headers: " + fl.getHeaders().stream().map(kvp->kvp.getKey()+": "+kvp.getValue())
.collect(Collectors.joining()));
return new SimpleHttpResponse(headers, payloadBytes, "OK", 200);
});
Expand All @@ -55,16 +63,68 @@ public void testRewrites() throws Exception {
proxy.start();
final var proxyEndpoint = CaptureProxyContainer.getUriFromContainer(proxy);


var allHeaders = new LinkedHashMap<String, String>();
allHeaders.put("Host", "localhost");
allHeaders.put("User-Agent", "UnitTest");
var response = client.makeGetRequest(new URI(proxyEndpoint), allHeaders.entrySet().stream());
log.error("response=" + response);
var capturedRequest = capturedRequestList.get(capturedRequestList.size()-1).getHeaders().stream()
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
Assertions.assertEquals("localhost", capturedRequest.get("host"));
Assertions.assertEquals("insignificant value", capturedRequest.get("X-new-header"));
}
}

@Test
public void testBodyDoesntRewrite() throws Exception {
final var payloadBytes = "Success".getBytes(StandardCharsets.UTF_8);
final var headers = Map.of(
"Content-Type",
"text/plain",
"Content-Length",
"" + payloadBytes.length
);
var rewriteArgs = List.of(
"--setHeader",
"host",
"localhost",
"--setHeader",
"body",
ONLY_FOR_HEADERS_VALUE
);
var capturedRequestList = new ArrayList<HttpRequest>();
var capturedBodies = new ArrayList<String>();
try (var destinationServer = SimpleNettyHttpServer.makeNettyServer(false,
Duration.ofMinutes(10),
fullRequest -> {
var request = new SimpleNettyHttpServer.RequestToAdapter(fullRequest);
capturedRequestList.add(request);
log.atTrace().setMessage(() -> "headers: " +
request.getHeaders().stream().map(kvp->kvp.getKey()+": "+kvp.getValue())
.collect(Collectors.joining())).log();
capturedBodies.add(fullRequest.content().toString(StandardCharsets.UTF_8));
return new SimpleHttpResponse(headers, payloadBytes, "OK", 200);
});
var proxy = new CaptureProxyContainer(() -> destinationServer.localhostEndpoint().toString(), null,
rewriteArgs.stream());
var client = new SimpleHttpClientForTesting();
var bodyStream = new ByteArrayInputStream(BODY_WITH_HEADERS_CONTENTS.getBytes(StandardCharsets.UTF_8)))
{
proxy.start();
final var proxyEndpoint = CaptureProxyContainer.getUriFromContainer(proxy);

var allHeaders = new LinkedHashMap<String, String>();
allHeaders.put("Host", "localhost");
allHeaders.put("User-Agent", "UnitTest");
var response = client.makePutRequest(new URI(proxyEndpoint), allHeaders.entrySet().stream(),
new SimpleHttpClientForTesting.PayloadAndContentType(bodyStream, "text/plain"));
log.error("response=" + response);
var capturedRequest = capturedRequestList.get(capturedRequestList.size()-1).getHeaders().stream()
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
Assertions.assertEquals("localhost", capturedRequest.get("host"));
Assertions.assertEquals(ONLY_FOR_HEADERS_VALUE, capturedRequest.get("body"));

var lastBody = capturedBodies.get(capturedBodies.size()-1);
Assertions.assertEquals(BODY_WITH_HEADERS_CONTENTS, lastBody);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ public void runTestWithSize(Function<Boolean,String> messageMaker, IntStream siz
outputBuf.release();
}

@Test
public void newlinesArePreserved() {
runTestsWithSize((b,s) -> "GET / HTTP/1.1\r\n" + (b ? "host: localhost\r\n" : "") + "\r\n",
() -> IntStream.of(Integer.MAX_VALUE));
}

@Test
public void throwsOnHostFormatError() {
Assertions.assertThrows(IllegalArgumentException.class, () -> new HeaderRemoverHandler("host"));
Expand Down Expand Up @@ -87,6 +93,7 @@ public void randomFragmentedCheckInterlaced() {
final var bound = getBound(HeaderRemoverHandlerTest::makeInterlacedMessage);
for (int i=0; i<NUM_RANDOM_RUNS; ++i) {
Random r = new Random(i);
log.atDebug().setMessage(() -> "random run={}").addArgument(i).log();
runTestsWithSize(HeaderRemoverHandlerTest::makeInterlacedMessage,
() -> IntStream.generate(() -> r.nextInt(bound)));
}
Expand All @@ -107,6 +114,7 @@ public void randomFragmentedCheckConsecutive() {
final var bound = getBound(HeaderRemoverHandlerTest::makeConsecutiveMessage);
for (int i=0; i<NUM_RANDOM_RUNS; ++i) {
Random r = new Random(i);
log.atDebug().setMessage(() -> "random run={}").addArgument(i).log();
runTestsWithSize(HeaderRemoverHandlerTest::makeConsecutiveMessage,
() -> IntStream.generate(() -> r.nextInt(bound)));
}
Expand Down
Loading

0 comments on commit 74dd2f7

Please sign in to comment.