Skip to content

Commit

Permalink
address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
m-trieu committed Nov 2, 2024
1 parent ec3b662 commit 27a9758
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import org.apache.beam.runners.dataflow.worker.windmill.appliance.JniWindmillApplianceServer;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commits;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.CompleteCommit;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingApplianceWorkCommitter;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingEngineWorkCommitter;
Expand Down Expand Up @@ -119,17 +120,19 @@ public final class StreamingDataflowWorker {
*/
public static final int MAX_SINK_BYTES = 10_000_000;

public static final String STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_HEARTBEAT_POOL =
"streaming_engine_use_job_settings_for_heartbeat_pool";
private static final Logger LOG = LoggerFactory.getLogger(StreamingDataflowWorker.class);

/**
* Maximum number of threads for processing. Currently, each thread processes one key at a time.
*/
private static final int MAX_PROCESSING_THREADS = 300;

/** The idGenerator to generate unique id globally. */
private static final IdGenerator ID_GENERATOR = IdGenerators.decrementingLongs();

/** Maximum size of the result of a GetWork request. */
private static final long MAX_GET_WORK_FETCH_BYTES = 64L << 20; // 64m

/** Maximum number of failure stacktraces to report in each update sent to backend. */
private static final int MAX_FAILURES_TO_REPORT_IN_UPDATE = 1000;

Expand Down Expand Up @@ -197,6 +200,7 @@ private StreamingDataflowWorker(
this.workCommitter =
windmillServiceEnabled
? StreamingEngineWorkCommitter.builder()
.setCommitByteSemaphore(Commits.maxCommitByteSemaphore())
.setCommitWorkStreamFactory(
WindmillStreamPool.create(
numCommitThreads,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
public final class WeightedBoundedQueue<V> {

private final LinkedBlockingQueue<V> queue;
private final WeightedSemaphore<V> weigher;
private final WeightedSemaphore<V> weightedSemaphore;

private WeightedBoundedQueue(
LinkedBlockingQueue<V> linkedBlockingQueue, WeightedSemaphore<V> weigher) {
LinkedBlockingQueue<V> linkedBlockingQueue, WeightedSemaphore<V> weightedSemaphore) {
this.queue = linkedBlockingQueue;
this.weigher = weigher;
this.weightedSemaphore = weightedSemaphore;
}

public static <V> WeightedBoundedQueue<V> create(WeightedSemaphore<V> weigher) {
Expand All @@ -43,15 +43,15 @@ public static <V> WeightedBoundedQueue<V> create(WeightedSemaphore<V> weigher) {
* limit.
*/
public void put(V value) {
weigher.acquire(value);
weightedSemaphore.acquireUninterruptibly(value);
queue.add(value);
}

/** Returns and removes the next value, or null if there is no such value. */
public @Nullable V poll() {
V result = queue.poll();
@Nullable V result = queue.poll();
if (result != null) {
weigher.release(result);
weightedSemaphore.release(result);
}
return result;
}
Expand All @@ -67,26 +67,20 @@ public void put(V value) {
* @throws InterruptedException if interrupted while waiting
*/
public @Nullable V poll(long timeout, TimeUnit unit) throws InterruptedException {
V result = queue.poll(timeout, unit);
@Nullable V result = queue.poll(timeout, unit);
if (result != null) {
weigher.release(result);
weightedSemaphore.release(result);
}
return result;
}

/** Returns and removes the next value, or blocks until one is available. */
public @Nullable V take() throws InterruptedException {
public V take() throws InterruptedException {
V result = queue.take();
weigher.release(result);
weightedSemaphore.release(result);
return result;
}

/** Returns the current weight of the queue. */
@VisibleForTesting
int queuedElementsWeight() {
return weigher.currentWeight();
}

@VisibleForTesting
int size() {
return queue.size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ public static <V> WeightedSemaphore<V> create(int maxWeight, Function<V, Integer
return new WeightedSemaphore<>(maxWeight, new Semaphore(maxWeight, true), weigherFn);
}

void acquire(V value) {
public void acquireUninterruptibly(V value) {
limit.acquireUninterruptibly(weigher.apply(value));
}

void release(V value) {
public void release(V value) {
limit.release(weigher.apply(value));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker.windmill.client.commits;

import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;

/** Utility class for commits. */
@Internal
public final class Commits {

/** Max bytes of commits queued on the user worker. */
@VisibleForTesting static final int MAX_QUEUED_COMMITS_BYTES = 500 << 20; // 500MB

private Commits() {}

public static WeightedSemaphore<Commit> maxCommitByteSemaphore() {
return WeightedSemaphore.create(
MAX_QUEUED_COMMITS_BYTES, commit -> Math.min(MAX_QUEUED_COMMITS_BYTES, commit.getSize()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
public final class StreamingEngineWorkCommitter implements WorkCommitter {
private static final Logger LOG = LoggerFactory.getLogger(StreamingEngineWorkCommitter.class);
private static final int TARGET_COMMIT_BATCH_KEYS = 5;
private static final int MAX_QUEUED_COMMITS_BYTES = 500 << 20; // 500MB
private static final String NO_BACKEND_WORKER_TOKEN = "";

private final Supplier<CloseableStream<CommitWorkStream>> commitWorkStreamFactory;
Expand All @@ -63,9 +62,9 @@ public final class StreamingEngineWorkCommitter implements WorkCommitter {
int numCommitSenders,
Consumer<CompleteCommit> onCommitComplete,
String backendWorkerToken,
WeightedSemaphore<Commit> weigher) {
WeightedSemaphore<Commit> commitByteSemaphore) {
this.commitWorkStreamFactory = commitWorkStreamFactory;
this.commitQueue = WeightedBoundedQueue.create(weigher);
this.commitQueue = WeightedBoundedQueue.create(commitByteSemaphore);
this.commitSenders =
Executors.newFixedThreadPool(
numCommitSenders,
Expand All @@ -86,10 +85,6 @@ public final class StreamingEngineWorkCommitter implements WorkCommitter {
public static Builder builder() {
return new AutoBuilder_StreamingEngineWorkCommitter_Builder()
.setBackendWorkerToken(NO_BACKEND_WORKER_TOKEN)
.setWeigher(
WeightedSemaphore.create(
MAX_QUEUED_COMMITS_BYTES,
commit -> Math.min(MAX_QUEUED_COMMITS_BYTES, commit.getSize())))
.setNumCommitSenders(1);
}

Expand Down Expand Up @@ -169,6 +164,8 @@ private void streamingCommitLoop() {
return;
}
}

// take() blocks until a value is available in the commitQueue.
Preconditions.checkNotNull(initialCommit);

if (initialCommit.work().isFailed()) {
Expand Down Expand Up @@ -261,7 +258,7 @@ public interface Builder {
Builder setCommitWorkStreamFactory(
Supplier<CloseableStream<CommitWorkStream>> commitWorkStreamFactory);

Builder setWeigher(WeightedSemaphore<Commit> weigher);
Builder setCommitByteSemaphore(WeightedSemaphore<Commit> commitByteSemaphore);

Builder setNumCommitSenders(int numCommitSenders);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.Nullable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;
Expand All @@ -35,24 +36,24 @@ public class WeightBoundedQueueTest {

@Test
public void testPut_hasCapacity() {
WeightedBoundedQueue<Integer> queue =
WeightedBoundedQueue.create(
WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)));
WeightedSemaphore<Integer> weightedSemaphore =
WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
WeightedBoundedQueue<Integer> queue = WeightedBoundedQueue.create(weightedSemaphore);

int insertedValue = 1;

queue.put(insertedValue);

assertEquals(insertedValue, queue.queuedElementsWeight());
assertEquals(insertedValue, weightedSemaphore.currentWeight());
assertEquals(1, queue.size());
assertEquals(insertedValue, (int) queue.poll());
}

@Test
public void testPut_noCapacity() throws InterruptedException {
WeightedBoundedQueue<Integer> queue =
WeightedBoundedQueue.create(
WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)));
WeightedSemaphore<Integer> weightedSemaphore =
WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
WeightedBoundedQueue<Integer> queue = WeightedBoundedQueue.create(weightedSemaphore);

// Insert value that takes all the capacity into the queue.
queue.put(MAX_WEIGHT);
Expand All @@ -73,7 +74,7 @@ public void testPut_noCapacity() throws InterruptedException {

// Should only see the first value in the queue, since the queue is at capacity. thread2
// should be blocked.
assertEquals(MAX_WEIGHT, queue.queuedElementsWeight());
assertEquals(MAX_WEIGHT, weightedSemaphore.currentWeight());
assertEquals(1, queue.size());

// Poll the queue, pulling off the only value inside and freeing up the capacity in the queue.
Expand All @@ -82,23 +83,23 @@ public void testPut_noCapacity() throws InterruptedException {
// Wait for the putThread which was previously blocked due to the queue being at capacity.
putThread.join();

assertEquals(MAX_WEIGHT, queue.queuedElementsWeight());
assertEquals(MAX_WEIGHT, weightedSemaphore.currentWeight());
assertEquals(1, queue.size());
}

@Test
public void testPoll() {
WeightedBoundedQueue<Integer> queue =
WeightedBoundedQueue.create(
WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)));
WeightedSemaphore<Integer> weightedSemaphore =
WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
WeightedBoundedQueue<Integer> queue = WeightedBoundedQueue.create(weightedSemaphore);

int insertedValue1 = 1;
int insertedValue2 = 2;

queue.put(insertedValue1);
queue.put(insertedValue2);

assertEquals(insertedValue1 + insertedValue2, queue.queuedElementsWeight());
assertEquals(insertedValue1 + insertedValue2, weightedSemaphore.currentWeight());
assertEquals(2, queue.size());
assertEquals(insertedValue1, (int) queue.poll());
assertEquals(1, queue.size());
Expand Down Expand Up @@ -149,13 +150,17 @@ public void testPoll_withTimeout_timesOut() throws InterruptedException {
Thread pollThread =
new Thread(
() -> {
int polled;
@Nullable Integer polled;
try {
polled = queue.poll(pollWaitTimeMillis, TimeUnit.MILLISECONDS);
pollResult.set(polled);
if (polled != null) {
pollResult.set(polled);
}
} catch (InterruptedException e) {
throw new RuntimeException(e);
}

assertNull(polled);
});

pollThread.start();
Expand Down Expand Up @@ -214,24 +219,15 @@ public void testPut_sharedWeigher() throws InterruptedException {

// Try to insert a value into the queue2. This will block since there is no capacity in the
// weigher.
Thread putThread =
new Thread(
() -> {
try {
Thread.sleep(100);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
queue2.put(MAX_WEIGHT);
});
Thread putThread = new Thread(() -> queue2.put(MAX_WEIGHT));
putThread.start();

// Should only see the first value in the queue, since the queue is at capacity. putThread
// should be blocked. The weight should be the same however, since queue1 and queue2 are sharing
// the weigher.
assertEquals(MAX_WEIGHT, queue1.queuedElementsWeight());
Thread.sleep(100);
assertEquals(MAX_WEIGHT, weigher.currentWeight());
assertEquals(1, queue1.size());
assertEquals(MAX_WEIGHT, queue2.queuedElementsWeight());
assertEquals(MAX_WEIGHT, weigher.currentWeight());
assertEquals(0, queue2.size());

// Poll queue1, pulling off the only value inside and freeing up the capacity in the weigher.
Expand All @@ -240,7 +236,7 @@ public void testPut_sharedWeigher() throws InterruptedException {
// Wait for the putThread which was previously blocked due to the weigher being at capacity.
putThread.join();

assertEquals(MAX_WEIGHT, queue2.queuedElementsWeight());
assertEquals(MAX_WEIGHT, weigher.currentWeight());
assertEquals(1, queue2.size());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ public void setUp() throws IOException {

private WorkCommitter createWorkCommitter(Consumer<CompleteCommit> onCommitComplete) {
return StreamingEngineWorkCommitter.builder()
.setCommitByteSemaphore(Commits.maxCommitByteSemaphore())
.setCommitWorkStreamFactory(commitWorkStreamFactory)
.setOnCommitComplete(onCommitComplete)
.build();
Expand Down Expand Up @@ -342,6 +343,7 @@ public void testMultipleCommitSendersSingleStream() {
Set<CompleteCommit> completeCommits = Collections.newSetFromMap(new ConcurrentHashMap<>());
workCommitter =
StreamingEngineWorkCommitter.builder()
.setCommitByteSemaphore(Commits.maxCommitByteSemaphore())
.setCommitWorkStreamFactory(commitWorkStreamFactory)
.setNumCommitSenders(5)
.setOnCommitComplete(completeCommits::add)
Expand Down

0 comments on commit 27a9758

Please sign in to comment.