Skip to content

Commit

Permalink
Bugfixes for proxy header translations and better tests
Browse files Browse the repository at this point in the history
Signed-off-by: Greg Schohn <[email protected]>
  • Loading branch information
gregschohn committed Sep 13, 2024
1 parent d8acd47 commit d1d2be0
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 79 deletions.
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
package org.opensearch.migrations.trafficcapture.proxyserver.netty;

import java.nio.charset.StandardCharsets;
import java.util.List;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.ReferenceCountUtil;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class HeaderRemoverHandler extends ChannelInboundHandlerAdapter {
final String headerToRemove;
CompositeByteBuf previousRemaining;
Expand All @@ -34,13 +38,11 @@ void lambdaSafeSuperChannelRead(ChannelHandlerContext ctx, ByteBuf bb) {
* need to be buffered by the caller
*/
boolean matchNextBytes(ChannelHandlerContext ctx, ByteBuf buf) {
if (!buf.isReadable()) {
return false;
}
buf.markReaderIndex();
for (int i=previousRemaining.readerIndex(); ; ++i) {
final var sourceReaderIdx = buf.readerIndex();
for (int i=previousRemaining.writerIndex(); ; ++i) {
if (!buf.isReadable()) { // partial match
previousRemaining.addComponent(true, buf);
previousRemaining.addComponent(true,
buf.retainedSlice(sourceReaderIdx, i-previousRemaining.writerIndex()));
return true;
}
if (i == headerToRemove.length()) { // match!
Expand All @@ -49,19 +51,22 @@ boolean matchNextBytes(ChannelHandlerContext ctx, ByteBuf buf) {
dropUntilNewline = true; // ... plus other bytes until we reset
return true;
}
buf.markReaderIndex();
if (Character.toLowerCase(headerToRemove.charAt(i)) != Character.toLowerCase(buf.readByte())) { // no match
previousRemaining.forEach(bb -> lambdaSafeSuperChannelRead(ctx, bb));
previousRemaining.removeComponents(0, previousRemaining.numComponents());
previousRemaining.release();
previousRemaining = null;
buf.resetReaderIndex();
dropUntilNewline = false;
return false;
}
}

}

boolean advanceByteBufUntilNewline(ByteBuf bb) {
while (bb.isReadable()) { // sonar lint doesn't like if the while statement has an empty body
if (bb.readByte() != '\n') { return true; }
if (bb.readByte() == '\n') { return true; }
}
return false;
}
Expand All @@ -74,38 +79,41 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
}

var sourceBuf = (ByteBuf) msg;
var startForNextSourceSegment = sourceBuf.readerIndex();
var currentSourceSegmentStart = (previousRemaining != null || dropUntilNewline) ? -1 : sourceBuf.readerIndex();
var cleanedIncomingBuf = ctx.alloc().compositeBuffer(4);

while (true) {
while (sourceBuf.isReadable()) {
if (previousRemaining != null) {
final var sourceReaderIdx = sourceBuf.readerIndex();
if (matchNextBytes(ctx, sourceBuf.slice(sourceReaderIdx, sourceBuf.readableBytes())) &&
sourceReaderIdx != startForNextSourceSegment) // would be 0-length
{
cleanedIncomingBuf.addComponent(true,
sourceBuf.retainedSlice(startForNextSourceSegment, sourceReaderIdx));
startForNextSourceSegment = -1;
}
} else {
var foundNewline = advanceByteBufUntilNewline(sourceBuf);
if (dropUntilNewline) {
if (foundNewline) {
// took care of previous bytes in the source buffer in the previousRemaining != null branch
startForNextSourceSegment = sourceBuf.readerIndex();
if (matchNextBytes(ctx, sourceBuf)) {
if (currentSourceSegmentStart >= 0 &&
sourceReaderIdx != currentSourceSegmentStart) // would be 0-length
{
cleanedIncomingBuf.addComponent(true,
sourceBuf.retainedSlice(currentSourceSegmentStart, sourceReaderIdx-currentSourceSegmentStart));
currentSourceSegmentStart = -1;
}
} else if (currentSourceSegmentStart == -1) {
currentSourceSegmentStart = sourceReaderIdx;
}
if (foundNewline) {
} else {
if (advanceByteBufUntilNewline(sourceBuf)) {
previousRemaining = ctx.alloc().compositeBuffer(16);
} else {
break;
}
}
}
if (startForNextSourceSegment >= 0) {
if (currentSourceSegmentStart >= 0) {
cleanedIncomingBuf.addComponent(true,
sourceBuf.retainedSlice(startForNextSourceSegment, sourceBuf.readerIndex()-startForNextSourceSegment));
sourceBuf.retainedSlice(currentSourceSegmentStart, sourceBuf.readerIndex()-currentSourceSegmentStart));
}
super.channelRead(ctx, cleanedIncomingBuf);
}

@Override
public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
ReferenceCountUtil.release(previousRemaining);
super.channelUnregistered(ctx);
}
}
Original file line number Diff line number Diff line change
@@ -1,57 +1,81 @@
package org.opensearch.migrations.trafficcapture.proxyserver.netty;

import java.nio.charset.StandardCharsets;
import java.util.stream.Stream;
import java.util.Arrays;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import java.util.stream.IntStream;

import org.opensearch.migrations.testutils.WrapWithNettyLeakDetection;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.params.provider.ValueSource;

@WrapWithNettyLeakDetection()
@Slf4j
class HeaderAdderHandlerTest {

public static Stream<Arguments> makeArgs() {
return Stream.of(
Arguments.of("\n"),
Arguments.of("\r\n"));
private void runTestsWithSize(Supplier<IntStream> sizesSupplier) {
runTestWithSize("\n", sizesSupplier.get());
runTestWithSize("\r\n", sizesSupplier.get());
}

@ParameterizedTest
@MethodSource("makeArgs")
public void simpleCheck(String lineEnding) {
var extraHeader = "host: my.host\n";
var newHeader = Unpooled.wrappedBuffer(extraHeader.getBytes(StandardCharsets.UTF_8));
final var msg = makeMessage(lineEnding, "");

var channel = new EmbeddedChannel(new HeaderAdderHandler(newHeader));
channel.writeInbound(Unpooled.wrappedBuffer(msg.getBytes(StandardCharsets.UTF_8)));
var output = Unpooled.compositeBuffer();
channel.inboundMessages().forEach(v -> output.addComponent(true, (ByteBuf) v));
@Test
public void simpleCheck() {
runTestsWithSize(() -> IntStream.of(Integer.MAX_VALUE));
}

Assertions.assertEquals(makeMessage(lineEnding, extraHeader), output.toString(StandardCharsets.UTF_8));
@Test
public void individualBytesCheck() {
runTestsWithSize(() -> IntStream.generate(()->1));
}

@ParameterizedTest
@MethodSource("makeArgs")
public void fragmentedCheck(String lineEnding) {
@ValueSource(strings = {
"8,27,9999",
"8,12,16,999"
})
public void fragmentedBytesCheck(String sizesStr) {
runTestsWithSize(() -> Arrays.stream(sizesStr.split(",")).mapToInt(Integer::parseInt));
}

private void runTestWithSize(String lineEnding, IntStream sizes) {
var extraHeader = "host: my.host\n";
var newHeader = Unpooled.wrappedBuffer(extraHeader.getBytes(StandardCharsets.UTF_8));
final var msg = makeMessage(lineEnding, "");

var channel = new EmbeddedChannel(new HeaderAdderHandler(newHeader));
msg.chars().forEach(c -> channel.writeInbound(Unpooled.wrappedBuffer(new byte[]{(byte) c})));
sliceMessageIntoChannelWrites(channel, msg, sizes);
var output = Unpooled.compositeBuffer();
channel.inboundMessages().forEach(v -> output.addComponent(true, (ByteBuf) v));
channel.close();

Assertions.assertEquals(makeMessage(lineEnding, extraHeader), output.toString(StandardCharsets.UTF_8));
}

public static void sliceMessageIntoChannelWrites(EmbeddedChannel channel, String msg, IntStream sizes) {
final var lastStart = new AtomicInteger();
sizes
.mapToObj(len -> {
var startIdx = lastStart.get();
if (startIdx >= msg.length()) { return null; }
var endIdx = startIdx + len;
var substr = msg.substring(lastStart.get(), Math.min(endIdx, msg.length()));
lastStart.set(endIdx);
log.atTrace().setMessage(() -> "s: " + substr).log();
return substr;
})
.takeWhile(Objects::nonNull)
.forEach(substr -> channel.writeInbound(Unpooled.wrappedBuffer(substr.getBytes(StandardCharsets.UTF_8))));
}

String makeMessage(String lineEnding, String extraHeader) {
return "GET / HTTP/1.1" + lineEnding +
extraHeader +
Expand Down
Loading

0 comments on commit d1d2be0

Please sign in to comment.