Skip to content

Commit

Permalink
Experimenting with handlers to add and remove headers from http reque…
Browse files Browse the repository at this point in the history
…sts.

Signed-off-by: Greg Schohn <[email protected]>
  • Loading branch information
gregschohn committed Sep 13, 2024
1 parent bce5d3d commit d8acd47
Show file tree
Hide file tree
Showing 5 changed files with 330 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package org.opensearch.migrations.trafficcapture.proxyserver.netty;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;

public class HeaderAdderHandler extends ChannelInboundHandlerAdapter {
boolean insertedHeader = false;
private final ByteBuf headerLineToAdd;

public HeaderAdderHandler(ByteBuf headerLineToAdd) {
this.headerLineToAdd = headerLineToAdd.retain();
}

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (!(msg instanceof ByteBuf) || insertedHeader) {
super.channelRead(ctx, msg);
return;
}
var buf = (ByteBuf) msg;
buf.markReaderIndex();
while (buf.isReadable()) {
if (buf.readByte() == '\n') {
final var upToIndex = buf.readerIndex();
var composite = Unpooled.compositeBuffer(3);
buf.resetReaderIndex();
composite.addComponent(true, buf.retainedSlice(0, upToIndex));
composite.addComponent(true, headerLineToAdd.duplicate());
composite.addComponent(true, buf.retainedSlice(upToIndex, buf.readableBytes()-upToIndex));
buf.release();
super.channelRead(ctx, composite);
insertedHeader = true;
return;
}
}
buf.resetReaderIndex();
super.channelRead(ctx, msg);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package org.opensearch.migrations.trafficcapture.proxyserver.netty;

import java.util.List;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import lombok.SneakyThrows;

public class HeaderRemoverHandler extends ChannelInboundHandlerAdapter {
final String headerToRemove;
CompositeByteBuf previousRemaining;
// This handler has 3 states - copying, dropping, or testing. when previousRemaining != null, we're testing.
// when dropUntilNewline == true, we're dropping, otherwise, we're copying (when previousRemaining==null)
// The starting state is previousRemaining == null and dropUntilNewline = false
boolean dropUntilNewline;

public HeaderRemoverHandler(String headerToRemove) {
if (!headerToRemove.endsWith(":")) {
throw new IllegalArgumentException("The headerToRemove must end with a ':'");
}
this.headerToRemove = headerToRemove;
}

@SneakyThrows
void lambdaSafeSuperChannelRead(ChannelHandlerContext ctx, ByteBuf bb) {
super.channelRead(ctx, bb);
}

/**
* @return true if there's a discongruity in the incoming buf and the contents that preceded this call will
* need to be buffered by the caller
*/
boolean matchNextBytes(ChannelHandlerContext ctx, ByteBuf buf) {
if (!buf.isReadable()) {
return false;
}
buf.markReaderIndex();
for (int i=previousRemaining.readerIndex(); ; ++i) {
if (!buf.isReadable()) { // partial match
previousRemaining.addComponent(true, buf);
return true;
}
if (i == headerToRemove.length()) { // match!
previousRemaining.release(); // drop those in limbo ...
previousRemaining = null;
dropUntilNewline = true; // ... plus other bytes until we reset
return true;
}
if (Character.toLowerCase(headerToRemove.charAt(i)) != Character.toLowerCase(buf.readByte())) { // no match
previousRemaining.forEach(bb -> lambdaSafeSuperChannelRead(ctx, bb));
previousRemaining = null;
dropUntilNewline = false;
return false;
}
}

}

boolean advanceByteBufUntilNewline(ByteBuf bb) {
while (bb.isReadable()) { // sonar lint doesn't like if the while statement has an empty body
if (bb.readByte() != '\n') { return true; }
}
return false;
}

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (!(msg instanceof ByteBuf)) {
super.channelRead(ctx, msg);
return;
}

var sourceBuf = (ByteBuf) msg;
var startForNextSourceSegment = sourceBuf.readerIndex();
var cleanedIncomingBuf = ctx.alloc().compositeBuffer(4);

while (true) {
if (previousRemaining != null) {
final var sourceReaderIdx = sourceBuf.readerIndex();
if (matchNextBytes(ctx, sourceBuf.slice(sourceReaderIdx, sourceBuf.readableBytes())) &&
sourceReaderIdx != startForNextSourceSegment) // would be 0-length
{
cleanedIncomingBuf.addComponent(true,
sourceBuf.retainedSlice(startForNextSourceSegment, sourceReaderIdx));
startForNextSourceSegment = -1;
}
} else {
var foundNewline = advanceByteBufUntilNewline(sourceBuf);
if (dropUntilNewline) {
if (foundNewline) {
// took care of previous bytes in the source buffer in the previousRemaining != null branch
startForNextSourceSegment = sourceBuf.readerIndex();
}
}
if (foundNewline) {
previousRemaining = ctx.alloc().compositeBuffer(16);
} else {
break;
}
}
}
if (startForNextSourceSegment >= 0) {
cleanedIncomingBuf.addComponent(true,
sourceBuf.retainedSlice(startForNextSourceSegment, sourceBuf.readerIndex()-startForNextSourceSegment));
}
super.channelRead(ctx, cleanedIncomingBuf);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package org.opensearch.migrations.trafficcapture.proxyserver.netty;

import java.nio.charset.StandardCharsets;
import java.util.stream.Stream;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import static org.junit.jupiter.api.Assertions.*;

class HeaderAdderHandlerTest {

public static Stream<Arguments> makeArgs() {
return Stream.of(
Arguments.of("\n"),
Arguments.of("\r\n"));
}

@ParameterizedTest
@MethodSource("makeArgs")
public void simpleCheck(String lineEnding) {
var extraHeader = "host: my.host\n";
var newHeader = Unpooled.wrappedBuffer(extraHeader.getBytes(StandardCharsets.UTF_8));
final var msg = makeMessage(lineEnding, "");

var channel = new EmbeddedChannel(new HeaderAdderHandler(newHeader));
channel.writeInbound(Unpooled.wrappedBuffer(msg.getBytes(StandardCharsets.UTF_8)));
var output = Unpooled.compositeBuffer();
channel.inboundMessages().forEach(v -> output.addComponent(true, (ByteBuf) v));

Assertions.assertEquals(makeMessage(lineEnding, extraHeader), output.toString(StandardCharsets.UTF_8));
}

@ParameterizedTest
@MethodSource("makeArgs")
public void fragmentedCheck(String lineEnding) {
var extraHeader = "host: my.host\n";
var newHeader = Unpooled.wrappedBuffer(extraHeader.getBytes(StandardCharsets.UTF_8));
final var msg = makeMessage(lineEnding, "");

var channel = new EmbeddedChannel(new HeaderAdderHandler(newHeader));
msg.chars().forEach(c -> channel.writeInbound(Unpooled.wrappedBuffer(new byte[]{(byte) c})));
var output = Unpooled.compositeBuffer();
channel.inboundMessages().forEach(v -> output.addComponent(true, (ByteBuf) v));

Assertions.assertEquals(makeMessage(lineEnding, extraHeader), output.toString(StandardCharsets.UTF_8));
}

String makeMessage(String lineEnding, String extraHeader) {
return "GET / HTTP/1.1" + lineEnding +
extraHeader +
"NICEHeader: v1" + lineEnding +
"silLYHeader: yyy" + lineEnding +
lineEnding;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package org.opensearch.migrations.trafficcapture.proxyserver.netty;

import java.nio.charset.StandardCharsets;
import java.util.stream.Stream;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

class HeaderRemoverHandlerTest {

public static Stream<Arguments> makeArgs() {
return Stream.of(
Arguments.of("\n"),
Arguments.of("\r\n"));
}

@ParameterizedTest
@MethodSource("makeArgs")
public void simpleCheck(String lineEnding) {
var extraHeader = "host: my.host\n";
var newHeader = Unpooled.wrappedBuffer(extraHeader.getBytes(StandardCharsets.UTF_8));
final var msg = makeMessage(lineEnding, true);

var channel = new EmbeddedChannel(new HeaderAdderHandler(newHeader));
channel.writeInbound(Unpooled.wrappedBuffer(msg.getBytes(StandardCharsets.UTF_8)));
var output = Unpooled.compositeBuffer();
channel.inboundMessages().forEach(v -> output.addComponent(true, (ByteBuf) v));

Assertions.assertEquals(makeMessage(lineEnding, false), output.toString(StandardCharsets.UTF_8));
}

@ParameterizedTest
@MethodSource("makeArgs")
public void fragmentedCheck(String lineEnding) {
var headerToRemove = "host";
final var msg = makeMessage(lineEnding, true);

var channel = new EmbeddedChannel(new HeaderRemoverHandler(headerToRemove));
msg.chars().forEach(c -> channel.writeInbound(Unpooled.wrappedBuffer(new byte[]{(byte) c})));
var output = Unpooled.compositeBuffer();
channel.inboundMessages().forEach(v -> output.addComponent(true, (ByteBuf) v));

Assertions.assertEquals(makeMessage(lineEnding, false), output.toString(StandardCharsets.UTF_8));
}

String makeMessage(String lineEnding, boolean withHosts) {
return "GET / HTTP/1.1" + lineEnding +
"hoststays: v1" + lineEnding +
(withHosts ? ("HOST: begone" + lineEnding) : "") +
"different: v1" + lineEnding +
(withHosts ? ("HosT: begone" + lineEnding) : "") +
lineEnding;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package org.opensearch.migrations.trafficcapture.proxyserver.netty;

import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.regex.Pattern;

import com.google.common.base.Strings;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

@Slf4j
public class MatcherTest {

public static final ByteBuf BIG_BUF =
Unpooled.wrappedBuffer(Strings.repeat("ha", 100_000).getBytes(StandardCharsets.UTF_8));
public static final ByteBuf SMALL_BUF =
Unpooled.wrappedBuffer(Strings.repeat("ha", 1).getBytes(StandardCharsets.UTF_8));

@Test
public void test() {
var p = Pattern.compile("^host:.*", Pattern.CASE_INSENSITIVE);

Assertions.assertTrue(
bufMatches(p, Unpooled.wrappedBuffer("host: MYHOST".getBytes(StandardCharsets.UTF_8))));

getMatchTime(p, BIG_BUF, 1000);
getMatchTime(p, BIG_BUF, 1000);

for (int i=0; i<1; ++i) {
final var MATCH_REPS = 100_000_000;
var smallTime = getMatchTime(p, SMALL_BUF, MATCH_REPS);
var bigTime = getMatchTime(p, BIG_BUF, MATCH_REPS);
log.info("smallTime = "+smallTime);
log.info("bigTime = "+bigTime);
}
}

private static Duration getMatchTime(Pattern p, ByteBuf input, int i) {
final var start = System.nanoTime();
boolean didMatch = false;
for (; i > 0; --i) {
didMatch |= bufMatches(p, input);
}
try {
return Duration.ofNanos(System.nanoTime() - start);
} finally {
Assertions.assertFalse(didMatch);
}
}

public static boolean bufMatches(Pattern p, ByteBuf b) {
return p.matcher(b.getCharSequence(0, b.readableBytes(),StandardCharsets.UTF_8)).matches();
}
}

0 comments on commit d8acd47

Please sign in to comment.