Skip to content

Commit

Permalink
Create POOLED buffers in the proxy translation handler unit tests.
Browse files Browse the repository at this point in the history
That tweaked a couple other refCounting errors that are also fixed here.
  • Loading branch information
gregschohn committed Sep 13, 2024
1 parent d1d2be0 commit 030ed0c
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
package org.opensearch.migrations.trafficcapture.proxyserver.netty;

import java.nio.charset.StandardCharsets;
import java.util.List;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.ReferenceCountUtil;
Expand Down Expand Up @@ -53,7 +49,7 @@ boolean matchNextBytes(ChannelHandlerContext ctx, ByteBuf buf) {
}
buf.markReaderIndex();
if (Character.toLowerCase(headerToRemove.charAt(i)) != Character.toLowerCase(buf.readByte())) { // no match
previousRemaining.forEach(bb -> lambdaSafeSuperChannelRead(ctx, bb));
previousRemaining.forEach(bb -> lambdaSafeSuperChannelRead(ctx, bb.retain()));
previousRemaining.removeComponents(0, previousRemaining.numComponents());
previousRemaining.release();
previousRemaining = null;
Expand All @@ -71,6 +67,18 @@ boolean advanceByteBufUntilNewline(ByteBuf bb) {
return false;
}

CompositeByteBuf addSliceToComposite(ChannelHandlerContext ctx, CompositeByteBuf priorBuf, ByteBuf sourceBuf,
int start, int len) {
if (len == 0) {
return priorBuf;
}
if (priorBuf == null) {
priorBuf = ctx.alloc().compositeBuffer(4);
}
priorBuf.addComponent(true, sourceBuf.retainedSlice(start, len));
return priorBuf;
}

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (!(msg instanceof ByteBuf)) {
Expand All @@ -80,7 +88,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception

var sourceBuf = (ByteBuf) msg;
var currentSourceSegmentStart = (previousRemaining != null || dropUntilNewline) ? -1 : sourceBuf.readerIndex();
var cleanedIncomingBuf = ctx.alloc().compositeBuffer(4);
CompositeByteBuf cleanedIncomingBuf = null;

while (sourceBuf.isReadable()) {
if (previousRemaining != null) {
Expand All @@ -89,8 +97,8 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
if (currentSourceSegmentStart >= 0 &&
sourceReaderIdx != currentSourceSegmentStart) // would be 0-length
{
cleanedIncomingBuf.addComponent(true,
sourceBuf.retainedSlice(currentSourceSegmentStart, sourceReaderIdx-currentSourceSegmentStart));
cleanedIncomingBuf = addSliceToComposite(ctx, cleanedIncomingBuf, sourceBuf,
currentSourceSegmentStart, sourceReaderIdx-currentSourceSegmentStart);
currentSourceSegmentStart = -1;
}
} else if (currentSourceSegmentStart == -1) {
Expand All @@ -104,11 +112,15 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
}
}
}

if (currentSourceSegmentStart >= 0) {
cleanedIncomingBuf.addComponent(true,
sourceBuf.retainedSlice(currentSourceSegmentStart, sourceBuf.readerIndex()-currentSourceSegmentStart));
cleanedIncomingBuf = addSliceToComposite(ctx, cleanedIncomingBuf, sourceBuf,
currentSourceSegmentStart, sourceBuf.readerIndex()-currentSourceSegmentStart);
}
sourceBuf.release();
if (cleanedIncomingBuf != null) {
super.channelRead(ctx, cleanedIncomingBuf);
}
super.channelRead(ctx, cleanedIncomingBuf);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ private void runTestWithSize(String lineEnding, IntStream sizes) {
var channel = new EmbeddedChannel(new HeaderAdderHandler(newHeader));
sliceMessageIntoChannelWrites(channel, msg, sizes);
var output = Unpooled.compositeBuffer();
channel.inboundMessages().forEach(v -> output.addComponent(true, (ByteBuf) v));
channel.close();
channel.inboundMessages().forEach(v -> output.addComponent(true, ((ByteBuf) v).retain()));
channel.finishAndReleaseAll();

Assertions.assertEquals(makeMessage(lineEnding, extraHeader), output.toString(StandardCharsets.UTF_8));
output.release();
}

public static void sliceMessageIntoChannelWrites(EmbeddedChannel channel, String msg, IntStream sizes) {
Expand All @@ -73,7 +74,11 @@ public static void sliceMessageIntoChannelWrites(EmbeddedChannel channel, String
return substr;
})
.takeWhile(Objects::nonNull)
.forEach(substr -> channel.writeInbound(Unpooled.wrappedBuffer(substr.getBytes(StandardCharsets.UTF_8))));
.forEach(substr -> {
var bytes = substr.getBytes(StandardCharsets.UTF_8);
var buf = channel.alloc().buffer(bytes.length);
channel.writeInbound(buf.writeBytes(bytes));
});
}

String makeMessage(String lineEnding, String extraHeader) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.opensearch.migrations.testutils.WrapWithNettyLeakDetection;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.ReferenceCountUtil;
import lombok.extern.slf4j.Slf4j;
Expand All @@ -22,7 +21,7 @@
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

@WrapWithNettyLeakDetection()
@WrapWithNettyLeakDetection(repetitions = 1)
@Slf4j
class HeaderRemoverHandlerTest {

Expand All @@ -41,13 +40,14 @@ public void runTestWithSize(Function<Boolean,String> messageMaker, IntStream siz

var channel = new EmbeddedChannel(new HeaderRemoverHandler("host:"));
HeaderAdderHandlerTest.sliceMessageIntoChannelWrites(channel, sourceMsg, sizes);
var output = Unpooled.compositeBuffer();
channel.inboundMessages().forEach(v -> output.addComponent(true, ((ByteBuf) v)));
channel.close();
var outputBuf = channel.alloc().compositeBuffer();
channel.inboundMessages().forEach(v -> outputBuf.addComponent(true, ((ByteBuf) v).retain()));
channel.finishAndReleaseAll();

Assertions.assertEquals(messageMaker.apply(false), output.toString(StandardCharsets.UTF_8),
var outputString = outputBuf.toString(StandardCharsets.UTF_8);
Assertions.assertEquals(messageMaker.apply(false), outputString,
"Error converting source message: " + sourceMsg);
output.release();
outputBuf.release();
}

@Test
Expand Down

0 comments on commit 030ed0c

Please sign in to comment.