Skip to content

Commit

Permalink
Fix flaky test StreamingDataflowWorkerTest (#28173)
Browse files Browse the repository at this point in the history
Co-authored-by: Fei Xie <[email protected]>
  • Loading branch information
olalamichelle and Fei Xie authored Aug 28, 2023
1 parent e26735d commit 505f942
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.beam.runners.dataflow.worker.util;

import java.time.Clock;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
Expand Down Expand Up @@ -48,6 +49,24 @@ public BoundedQueueExecutor(
int maximumElementsOutstanding,
long maximumBytesOutstanding,
ThreadFactory threadFactory) {
this(
maximumPoolSize,
keepAliveTime,
unit,
maximumElementsOutstanding,
maximumBytesOutstanding,
threadFactory,
Clock.systemUTC());
}

public BoundedQueueExecutor(
int maximumPoolSize,
long keepAliveTime,
TimeUnit unit,
int maximumElementsOutstanding,
long maximumBytesOutstanding,
ThreadFactory threadFactory,
Clock clock) {
executor =
new ThreadPoolExecutor(
maximumPoolSize,
Expand All @@ -61,7 +80,7 @@ protected void beforeExecute(Thread t, Runnable r) {
super.beforeExecute(t, r);
synchronized (this) {
if (activeCount.getAndIncrement() >= maximumPoolSize - 1) {
startTimeMaxActiveThreadsUsed = System.currentTimeMillis();
startTimeMaxActiveThreadsUsed = clock.millis();
}
}
}
Expand All @@ -71,8 +90,7 @@ protected void afterExecute(Runnable r, Throwable t) {
super.afterExecute(r, t);
synchronized (this) {
if (activeCount.getAndDecrement() == maximumPoolSize) {
totalTimeMaxActiveThreadsUsed +=
(System.currentTimeMillis() - startTimeMaxActiveThreadsUsed);
totalTimeMaxActiveThreadsUsed += (clock.millis() - startTimeMaxActiveThreadsUsed);
startTimeMaxActiveThreadsUsed = 0;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,7 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.*;

import com.google.api.services.dataflow.model.CounterUpdate;
import com.google.api.services.dataflow.model.InstructionInput;
Expand All @@ -56,6 +52,7 @@
import com.google.api.services.dataflow.model.WriteInstruction;
import java.io.IOException;
import java.io.InputStream;
import java.time.Clock;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -170,7 +167,6 @@
import org.hamcrest.Matchers;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ErrorCollector;
Expand Down Expand Up @@ -2855,10 +2851,24 @@ public void testActiveWorkForShardedKeys() throws Exception {
}

@Test
@Ignore // Test is flaky on Jenkins (#27555)
public void testMaxThreadMetric() throws Exception {
int maxThreads = 2;
int threadExpiration = 60;

Clock mockClock = Mockito.mock(Clock.class);
CountDownLatch latch = new CountDownLatch(2);
doAnswer(
invocation -> {
latch.countDown();
// Return 0 until we are called once (reach max thread count).
if (latch.getCount() == 1) {
return 0L;
}
return 1000L;
})
.when(mockClock)
.millis();

// setting up actual implementation of executor instead of mocking to keep track of
// active thread count.
BoundedQueueExecutor executor =
Expand All @@ -2871,7 +2881,8 @@ public void testMaxThreadMetric() throws Exception {
new ThreadFactoryBuilder()
.setNameFormat("DataflowWorkUnits-%d")
.setDaemon(true)
.build());
.build(),
mockClock);

StreamingDataflowWorker.ComputationState computationState =
new StreamingDataflowWorker.ComputationState(
Expand All @@ -2883,15 +2894,17 @@ public void testMaxThreadMetric() throws Exception {

ShardedKey key1Shard1 = ShardedKey.create(ByteString.copyFromUtf8("key1"), 1);

// overriding definition of MockWork to add sleep, which will help us keep track of how
// long each work item takes to process and therefore let us manipulate how long the time
// at which we're at max threads is.
MockWork m2 =
new MockWork(2) {
@Override
public void run() {
try {
Thread.sleep(1000);
// Make sure we don't finish before both MockWork are executed, thus afterExecute must
// be called after
// beforeExecute.
while (latch.getCount() > 1) {
Thread.sleep(50);
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
Expand All @@ -2903,7 +2916,9 @@ public void run() {
@Override
public void run() {
try {
Thread.sleep(1000);
while (latch.getCount() > 1) {
Thread.sleep(50);
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
Expand All @@ -2913,13 +2928,11 @@ public void run() {
assertTrue(computationState.activateWork(key1Shard1, m2));
assertTrue(computationState.activateWork(key1Shard1, m3));
executor.execute(m2, m2.getWorkItem().getSerializedSize());

executor.execute(m3, m3.getWorkItem().getSerializedSize());
// Wait until the afterExecute is called.
latch.await();

// Will get close to 1000ms that both work items are processing (sleeping, really)
// give or take a few ms.
long i = 990L;
assertTrue(executor.allThreadsActiveTime() >= i);
assertEquals(1000L, executor.allThreadsActiveTime());
executor.shutdown();
}

Expand Down

0 comments on commit 505f942

Please sign in to comment.