diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java
deleted file mode 100644
index d808d4f4ab58..000000000000
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java
+++ /dev/null
@@ -1,355 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.dataflow.worker;
-
-import com.google.auto.value.AutoBuilder;
-import java.io.PrintWriter;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.atomic.AtomicInteger;
-import javax.annotation.concurrent.GuardedBy;
-import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor;
-import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
-import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
-import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
-import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
-import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool;
-import org.apache.beam.sdk.annotations.Internal;
-import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.SettableFuture;
-import org.checkerframework.checker.nullness.qual.Nullable;
-import org.joda.time.Duration;
-
-/**
- * Wrapper around a {@link WindmillServerStub} that tracks metrics for the number of in-flight
- * requests and throttles requests when memory pressure is high.
- *
- *
External API: individual worker threads request state for their computation via {@link
- * #getStateData}. However, requests are either issued using a pool of streaming rpcs or possibly
- * batched requests.
- */
-@SuppressWarnings({
- "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public class MetricTrackingWindmillServerStub {
-
- private static final int MAX_READS_PER_BATCH = 60;
- private static final int MAX_ACTIVE_READS = 10;
- private static final Duration STREAM_TIMEOUT = Duration.standardSeconds(30);
- private final AtomicInteger activeSideInputs = new AtomicInteger();
- private final AtomicInteger activeStateReads = new AtomicInteger();
- private final AtomicInteger activeHeartbeats = new AtomicInteger();
- private final WindmillServerStub server;
- private final MemoryMonitor gcThrashingMonitor;
- private final boolean useStreamingRequests;
-
- private final WindmillStreamPool getDataStreamPool;
-
- // This may be the same instance as getDataStreamPool based upon options.
- private final WindmillStreamPool heartbeatStreamPool;
-
- @GuardedBy("this")
- private final List pendingReadBatches;
-
- @GuardedBy("this")
- private int activeReadThreads = 0;
-
- @Internal
- @AutoBuilder(ofClass = MetricTrackingWindmillServerStub.class)
- public abstract static class Builder {
-
- abstract Builder setServer(WindmillServerStub server);
-
- abstract Builder setGcThrashingMonitor(MemoryMonitor gcThrashingMonitor);
-
- abstract Builder setUseStreamingRequests(boolean useStreamingRequests);
-
- abstract Builder setUseSeparateHeartbeatStreams(boolean useSeparateHeartbeatStreams);
-
- abstract Builder setNumGetDataStreams(int numGetDataStreams);
-
- abstract MetricTrackingWindmillServerStub build();
- }
-
- public static Builder builder(WindmillServerStub server, MemoryMonitor gcThrashingMonitor) {
- return new AutoBuilder_MetricTrackingWindmillServerStub_Builder()
- .setServer(server)
- .setGcThrashingMonitor(gcThrashingMonitor)
- .setUseStreamingRequests(false)
- .setUseSeparateHeartbeatStreams(false)
- .setNumGetDataStreams(1);
- }
-
- MetricTrackingWindmillServerStub(
- WindmillServerStub server,
- MemoryMonitor gcThrashingMonitor,
- boolean useStreamingRequests,
- boolean useSeparateHeartbeatStreams,
- int numGetDataStreams) {
- this.server = server;
- this.gcThrashingMonitor = gcThrashingMonitor;
- this.useStreamingRequests = useStreamingRequests;
- if (useStreamingRequests) {
- getDataStreamPool =
- WindmillStreamPool.create(
- Math.max(1, numGetDataStreams), STREAM_TIMEOUT, this.server::getDataStream);
- if (useSeparateHeartbeatStreams) {
- heartbeatStreamPool =
- WindmillStreamPool.create(1, STREAM_TIMEOUT, this.server::getDataStream);
- } else {
- heartbeatStreamPool = getDataStreamPool;
- }
- } else {
- getDataStreamPool = heartbeatStreamPool = null;
- }
- // This is used as a queue but is expected to be less than 10 batches.
- this.pendingReadBatches = new ArrayList<>();
- }
-
- // Adds the entry to a read batch for sending to the windmill server. If a non-null batch is
- // returned, this thread will be responsible for sending the batch and should wait for the batch
- // startRead to be notified.
- // If null is returned, the entry was added to a read batch that will be issued by another thread.
- private @Nullable ReadBatch addToReadBatch(QueueEntry entry) {
- synchronized (this) {
- ReadBatch batch;
- if (activeReadThreads < MAX_ACTIVE_READS) {
- assert (pendingReadBatches.isEmpty());
- activeReadThreads += 1;
- // fall through to below synchronized block
- } else if (pendingReadBatches.isEmpty()
- || pendingReadBatches.get(pendingReadBatches.size() - 1).reads.size()
- >= MAX_READS_PER_BATCH) {
- // This is the first read of a batch, it will be responsible for sending the batch.
- batch = new ReadBatch();
- pendingReadBatches.add(batch);
- batch.reads.add(entry);
- return batch;
- } else {
- // This fits within an existing batch, it will be sent by the first blocking thread in the
- // batch.
- pendingReadBatches.get(pendingReadBatches.size() - 1).reads.add(entry);
- return null;
- }
- }
- ReadBatch batch = new ReadBatch();
- batch.reads.add(entry);
- batch.startRead.set(true);
- return batch;
- }
-
- private void issueReadBatch(ReadBatch batch) {
- try {
- boolean read = batch.startRead.get();
- assert (read);
- } catch (InterruptedException e) {
- // We don't expect this thread to be interrupted. To simplify handling, we just fall through
- // to issuing
- // the call.
- assert (false);
- Thread.currentThread().interrupt();
- } catch (ExecutionException e) {
- // startRead is a SettableFuture so this should never occur.
- throw new AssertionError("Should not have exception on startRead", e);
- }
- Map> pendingResponses =
- new HashMap<>(batch.reads.size());
- Map computationBuilders = new HashMap<>();
- for (QueueEntry entry : batch.reads) {
- Windmill.ComputationGetDataRequest.Builder computationBuilder =
- computationBuilders.computeIfAbsent(
- entry.computation,
- k -> Windmill.ComputationGetDataRequest.newBuilder().setComputationId(k));
-
- computationBuilder.addRequests(entry.request);
- pendingResponses.put(
- WindmillComputationKey.create(
- entry.computation, entry.request.getKey(), entry.request.getShardingKey()),
- entry.response);
- }
-
- // Build the full GetDataRequest from the KeyedGetDataRequests pulled from the queue.
- Windmill.GetDataRequest.Builder builder = Windmill.GetDataRequest.newBuilder();
- for (Windmill.ComputationGetDataRequest.Builder computationBuilder :
- computationBuilders.values()) {
- builder.addRequests(computationBuilder);
- }
-
- try {
- Windmill.GetDataResponse response = server.getData(builder.build());
-
- // Dispatch the per-key responses back to the waiting threads.
- for (Windmill.ComputationGetDataResponse computationResponse : response.getDataList()) {
- for (Windmill.KeyedGetDataResponse keyResponse : computationResponse.getDataList()) {
- pendingResponses
- .get(
- WindmillComputationKey.create(
- computationResponse.getComputationId(),
- keyResponse.getKey(),
- keyResponse.getShardingKey()))
- .set(keyResponse);
- }
- }
- } catch (RuntimeException e) {
- // Fan the exception out to the reads.
- for (QueueEntry entry : batch.reads) {
- entry.response.setException(e);
- }
- } finally {
- synchronized (this) {
- assert (activeReadThreads >= 1);
- if (pendingReadBatches.isEmpty()) {
- activeReadThreads--;
- } else {
- // Notify the thread responsible for issuing the next batch read.
- ReadBatch startBatch = pendingReadBatches.remove(0);
- startBatch.startRead.set(true);
- }
- }
- }
- }
-
- public Windmill.KeyedGetDataResponse getStateData(
- String computation, Windmill.KeyedGetDataRequest request) {
- gcThrashingMonitor.waitForResources("GetStateData");
- activeStateReads.getAndIncrement();
-
- try {
- if (useStreamingRequests) {
- GetDataStream stream = getDataStreamPool.getStream();
- try {
- return stream.requestKeyedData(computation, request);
- } finally {
- getDataStreamPool.releaseStream(stream);
- }
- } else {
- SettableFuture response = SettableFuture.create();
- ReadBatch batch = addToReadBatch(new QueueEntry(computation, request, response));
- if (batch != null) {
- issueReadBatch(batch);
- }
- return response.get();
- }
- } catch (Exception e) {
- throw new RuntimeException(e);
- } finally {
- activeStateReads.getAndDecrement();
- }
- }
-
- public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) {
- gcThrashingMonitor.waitForResources("GetSideInputData");
- activeSideInputs.getAndIncrement();
- try {
- if (useStreamingRequests) {
- GetDataStream stream = getDataStreamPool.getStream();
- try {
- return stream.requestGlobalData(request);
- } finally {
- getDataStreamPool.releaseStream(stream);
- }
- } else {
- return server
- .getData(
- Windmill.GetDataRequest.newBuilder().addGlobalDataFetchRequests(request).build())
- .getGlobalData(0);
- }
- } catch (Exception e) {
- throw new RuntimeException("Failed to get side input: ", e);
- } finally {
- activeSideInputs.getAndDecrement();
- }
- }
-
- /** Tells windmill processing is ongoing for the given keys. */
- public void refreshActiveWork(Map> heartbeats) {
- if (heartbeats.isEmpty()) {
- return;
- }
- activeHeartbeats.set(heartbeats.size());
- try {
- if (useStreamingRequests) {
- GetDataStream stream = heartbeatStreamPool.getStream();
- try {
- stream.refreshActiveWork(heartbeats);
- } finally {
- heartbeatStreamPool.releaseStream(stream);
- }
- } else {
- // This code path is only used by appliance which sends heartbeats (used to refresh active
- // work) as KeyedGetDataRequests. So we must translate the HeartbeatRequest to a
- // KeyedGetDataRequest here regardless of the value of sendKeyedGetDataRequests.
- Windmill.GetDataRequest.Builder builder = Windmill.GetDataRequest.newBuilder();
- for (Map.Entry> entry : heartbeats.entrySet()) {
- Windmill.ComputationGetDataRequest.Builder perComputationBuilder =
- Windmill.ComputationGetDataRequest.newBuilder();
- perComputationBuilder.setComputationId(entry.getKey());
- for (HeartbeatRequest request : entry.getValue()) {
- perComputationBuilder.addRequests(
- Windmill.KeyedGetDataRequest.newBuilder()
- .setShardingKey(request.getShardingKey())
- .setWorkToken(request.getWorkToken())
- .setCacheToken(request.getCacheToken())
- .addAllLatencyAttribution(request.getLatencyAttributionList())
- .build());
- }
- builder.addRequests(perComputationBuilder.build());
- }
- server.getData(builder.build());
- }
- } finally {
- activeHeartbeats.set(0);
- }
- }
-
- public void printHtml(PrintWriter writer) {
- writer.println("Active Fetches:");
- writer.println(" Side Inputs: " + activeSideInputs.get());
- writer.println(" State Reads: " + activeStateReads.get());
- if (!useStreamingRequests) {
- synchronized (this) {
- writer.println(" Read threads: " + activeReadThreads);
- writer.println(" Pending read batches: " + pendingReadBatches.size());
- }
- }
- writer.println("Heartbeat Keys Active: " + activeHeartbeats.get());
- }
-
- private static final class ReadBatch {
- ArrayList reads = new ArrayList<>();
- SettableFuture startRead = SettableFuture.create();
- }
-
- private static final class QueueEntry {
-
- final String computation;
- final Windmill.KeyedGetDataRequest request;
- final SettableFuture response;
-
- QueueEntry(
- String computation,
- Windmill.KeyedGetDataRequest request,
- SettableFuture response) {
- this.computation = computation;
- this.request = request;
- this.response = response;
- }
- }
-}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
index a07bbfa7f5f3..f196852b2253 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
@@ -66,12 +66,17 @@
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress;
import org.apache.beam.runners.dataflow.worker.windmill.appliance.JniWindmillApplianceServer;
+import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.CompleteCommit;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingApplianceWorkCommitter;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingEngineWorkCommitter;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.ApplianceGetDataClient;
+import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
+import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.StreamPoolGetDataClient;
+import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.ThrottlingGetDataMetricTracker;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.ChannelzServlet;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcDispatcherClient;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillServer;
@@ -87,7 +92,9 @@
import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.StreamingEngineFailureTracker;
import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.WorkFailureProcessor;
import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.ActiveWorkRefresher;
-import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.ActiveWorkRefreshers;
+import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.ApplianceHeartbeatSender;
+import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
+import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.StreamPoolHeartbeatSender;
import org.apache.beam.sdk.fn.IdGenerator;
import org.apache.beam.sdk.fn.IdGenerators;
import org.apache.beam.sdk.fn.JvmInitializers;
@@ -127,6 +134,7 @@ public class StreamingDataflowWorker {
static final int GET_WORK_STREAM_TIMEOUT_MINUTES = 3;
static final Duration COMMIT_STREAM_TIMEOUT = Duration.standardMinutes(1);
private static final Logger LOG = LoggerFactory.getLogger(StreamingDataflowWorker.class);
+ private static final Duration GET_DATA_STREAM_TIMEOUT = Duration.standardSeconds(30);
/** The idGenerator to generate unique id globally. */
private static final IdGenerator ID_GENERATOR = IdGenerators.decrementingLongs();
@@ -150,7 +158,7 @@ public class StreamingDataflowWorker {
private final AtomicBoolean running = new AtomicBoolean();
private final DataflowWorkerHarnessOptions options;
private final long clientId;
- private final MetricTrackingWindmillServerStub metricTrackingWindmillServer;
+ private final GetDataClient getDataClient;
private final MemoryMonitor memoryMonitor;
private final Thread memoryMonitorThread;
private final ReaderCache readerCache;
@@ -160,6 +168,7 @@ public class StreamingDataflowWorker {
private final StreamingWorkerStatusReporter workerStatusReporter;
private final StreamingCounters streamingCounters;
private final StreamingWorkScheduler streamingWorkScheduler;
+ private final HeartbeatSender heartbeatSender;
private StreamingDataflowWorker(
WindmillServerStub windmillServer,
@@ -181,6 +190,9 @@ private StreamingDataflowWorker(
GrpcWindmillStreamFactory windmillStreamFactory,
Function executorSupplier,
ConcurrentMap stageInfoMap) {
+ // Register standard file systems.
+ FileSystems.setDefaultPipelineOptions(options);
+
this.configFetcher = configFetcher;
this.computationStateCache = computationStateCache;
this.stateCache = windmillStateCache;
@@ -199,12 +211,16 @@ private StreamingDataflowWorker(
this.workCommitter =
windmillServiceEnabled
- ? StreamingEngineWorkCommitter.create(
- WindmillStreamPool.create(
- numCommitThreads, COMMIT_STREAM_TIMEOUT, windmillServer::commitWorkStream)
- ::getCloseableStream,
- numCommitThreads,
- this::onCompleteCommit)
+ ? StreamingEngineWorkCommitter.builder()
+ .setCommitWorkStreamFactory(
+ WindmillStreamPool.create(
+ numCommitThreads,
+ COMMIT_STREAM_TIMEOUT,
+ windmillServer::commitWorkStream)
+ ::getCloseableStream)
+ .setNumCommitSenders(numCommitThreads)
+ .setOnCommitComplete(this::onCompleteCommit)
+ .build()
: StreamingApplianceWorkCommitter.create(
windmillServer::commitWork, this::onCompleteCommit);
@@ -230,29 +246,41 @@ private StreamingDataflowWorker(
dispatchThread.setName("DispatchThread");
this.clientId = clientId;
this.windmillServer = windmillServer;
- this.metricTrackingWindmillServer =
- MetricTrackingWindmillServerStub.builder(windmillServer, memoryMonitor)
- .setUseStreamingRequests(windmillServiceEnabled)
- .setUseSeparateHeartbeatStreams(options.getUseSeparateWindmillHeartbeatStreams())
- .setNumGetDataStreams(options.getWindmillGetDataStreamCount())
- .build();
- // Register standard file systems.
- FileSystems.setDefaultPipelineOptions(options);
+ ThrottlingGetDataMetricTracker getDataMetricTracker =
+ new ThrottlingGetDataMetricTracker(memoryMonitor);
+
+ int stuckCommitDurationMillis;
+ if (windmillServiceEnabled) {
+ WindmillStreamPool getDataStreamPool =
+ WindmillStreamPool.create(
+ Math.max(1, options.getWindmillGetDataStreamCount()),
+ GET_DATA_STREAM_TIMEOUT,
+ windmillServer::getDataStream);
+ this.getDataClient = new StreamPoolGetDataClient(getDataMetricTracker, getDataStreamPool);
+ this.heartbeatSender =
+ new StreamPoolHeartbeatSender(
+ options.getUseSeparateWindmillHeartbeatStreams()
+ ? WindmillStreamPool.create(
+ 1, GET_DATA_STREAM_TIMEOUT, windmillServer::getDataStream)
+ : getDataStreamPool);
+ stuckCommitDurationMillis =
+ options.getStuckCommitDurationMillis() > 0 ? options.getStuckCommitDurationMillis() : 0;
+ } else {
+ this.getDataClient = new ApplianceGetDataClient(windmillServer, getDataMetricTracker);
+ this.heartbeatSender = new ApplianceHeartbeatSender(windmillServer::getData);
+ stuckCommitDurationMillis = 0;
+ }
- int stuckCommitDurationMillis =
- windmillServiceEnabled && options.getStuckCommitDurationMillis() > 0
- ? options.getStuckCommitDurationMillis()
- : 0;
this.activeWorkRefresher =
- ActiveWorkRefreshers.createDispatchedActiveWorkRefresher(
+ new ActiveWorkRefresher(
clock,
options.getActiveWorkRefreshPeriodMillis(),
stuckCommitDurationMillis,
computationStateCache::getAllPresentComputations,
sampler,
- metricTrackingWindmillServer::refreshActiveWork,
- executorSupplier.apply("RefreshWork"));
+ executorSupplier.apply("RefreshWork"),
+ getDataMetricTracker::trackHeartbeats);
WorkerStatusPages workerStatusPages =
WorkerStatusPages.create(DEFAULT_STATUS_PORT, memoryMonitor);
@@ -265,7 +293,7 @@ private StreamingDataflowWorker(
.setStateCache(stateCache)
.setComputationStateCache(computationStateCache)
.setCurrentActiveCommitBytes(workCommitter::currentActiveCommitBytes)
- .setGetDataStatusProvider(metricTrackingWindmillServer::printHtml)
+ .setGetDataStatusProvider(getDataClient::printHtml)
.setWorkUnitExecutor(workUnitExecutor);
this.statusPages =
@@ -281,7 +309,6 @@ private StreamingDataflowWorker(
this.workerStatusReporter = workerStatusReporter;
this.streamingCounters = streamingCounters;
this.memoryMonitor = memoryMonitor;
-
this.streamingWorkScheduler =
StreamingWorkScheduler.create(
options,
@@ -290,7 +317,6 @@ private StreamingDataflowWorker(
mapTaskExecutorFactory,
workUnitExecutor,
stateCache::forComputation,
- metricTrackingWindmillServer::getSideInputData,
failureTracker,
workFailureProcessor,
streamingCounters,
@@ -829,7 +855,7 @@ private void dispatchLoop() {
workItem,
watermarks.setOutputDataWatermark(workItem.getOutputDataWatermark()).build(),
Work.createProcessingContext(
- computationId, metricTrackingWindmillServer::getStateData, workCommitter::commit),
+ computationId, getDataClient, workCommitter::commit, heartbeatSender),
/* getWorkStreamLatencies= */ Collections.emptyList());
}
}
@@ -865,8 +891,9 @@ void streamingDispatchLoop() {
.build(),
Work.createProcessingContext(
computationState.getComputationId(),
- metricTrackingWindmillServer::getStateData,
- workCommitter::commit),
+ getDataClient,
+ workCommitter::commit,
+ heartbeatSender),
getWorkStreamLatencies);
}));
try {
@@ -874,7 +901,7 @@ void streamingDispatchLoop() {
// If at any point the server closes the stream, we will reconnect immediately; otherwise
// we half-close the stream after some time and create a new one.
if (!stream.awaitTermination(GET_WORK_STREAM_TIMEOUT_MINUTES, TimeUnit.MINUTES)) {
- stream.close();
+ stream.halfClose();
}
} catch (InterruptedException e) {
// Continue processing until !running.get()
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java
index 934977fe0985..ec5122a8732a 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java
@@ -26,6 +26,10 @@ public WorkItemCancelledException(long sharding_key) {
super("Work item cancelled for key " + sharding_key);
}
+ public WorkItemCancelledException(Throwable e) {
+ super(e);
+ }
+
/** Returns whether an exception was caused by a {@link WorkItemCancelledException}. */
public static boolean isWorkItemCancelledException(Throwable t) {
while (t != null) {
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java
index 3e226514d57e..c80c3a882e52 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java
@@ -18,6 +18,7 @@
package org.apache.beam.runners.dataflow.worker.streaming;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
+import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap.flatteningToImmutableListMultimap;
import java.io.PrintWriter;
import java.util.ArrayDeque;
@@ -31,13 +32,9 @@
import java.util.Queue;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
-import java.util.stream.Stream;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;
-import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler;
-import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
-import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem;
import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
@@ -45,6 +42,7 @@
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap;
import org.joda.time.Duration;
@@ -106,29 +104,6 @@ private static String elapsedString(Instant start, Instant end) {
return activeFor.toString().substring(2);
}
- private static Stream toHeartbeatRequestStream(
- Entry> shardedKeyAndWorkQueue,
- Instant refreshDeadline,
- DataflowExecutionStateSampler sampler) {
- ShardedKey shardedKey = shardedKeyAndWorkQueue.getKey();
- Deque workQueue = shardedKeyAndWorkQueue.getValue();
-
- return workQueue.stream()
- .map(ExecutableWork::work)
- .filter(work -> work.getStartTime().isBefore(refreshDeadline))
- // Don't send heartbeats for queued work we already know is failed.
- .filter(work -> !work.isFailed())
- .map(
- work ->
- Windmill.HeartbeatRequest.newBuilder()
- .setShardingKey(shardedKey.shardingKey())
- .setWorkToken(work.getWorkItem().getWorkToken())
- .setCacheToken(work.getWorkItem().getCacheToken())
- .addAllLatencyAttribution(
- work.getLatencyAttributions(/* isHeartbeat= */ true, sampler))
- .build());
- }
-
/**
* Activates {@link Work} for the {@link ShardedKey}. Outcome can be 1 of 4 {@link
* ActivateWorkResult}
@@ -219,6 +194,31 @@ synchronized void failWorkForKey(Multimap failedWork) {
}
}
+ /**
+ * Returns a read only view of current active work.
+ *
+ * @implNote Do not return a reference to the underlying workQueue as iterations over it will
+ * cause a {@link java.util.ConcurrentModificationException} as it is not a thread-safe data
+ * structure.
+ */
+ synchronized ImmutableListMultimap getReadOnlyActiveWork() {
+ return activeWork.entrySet().stream()
+ .collect(
+ flatteningToImmutableListMultimap(
+ Entry::getKey,
+ e ->
+ e.getValue().stream()
+ .map(executableWork -> (RefreshableWork) executableWork.work())));
+ }
+
+ synchronized ImmutableList getRefreshableWork(Instant refreshDeadline) {
+ return activeWork.values().stream()
+ .flatMap(Deque::stream)
+ .map(ExecutableWork::work)
+ .filter(work -> !work.isFailed() && work.getStartTime().isBefore(refreshDeadline))
+ .collect(toImmutableList());
+ }
+
private void incrementActiveWorkBudget(Work work) {
activeGetWorkBudget.updateAndGet(
getWorkBudget -> getWorkBudget.apply(1, work.getWorkItem().getSerializedSize()));
@@ -324,13 +324,6 @@ private synchronized ImmutableMap getStuckCommitsAt(
return stuckCommits.build();
}
- synchronized ImmutableList getKeyHeartbeats(
- Instant refreshDeadline, DataflowExecutionStateSampler sampler) {
- return activeWork.entrySet().stream()
- .flatMap(entry -> toHeartbeatRequestStream(entry, refreshDeadline, sampler))
- .collect(toImmutableList());
- }
-
/**
* Returns the current aggregate {@link GetWorkBudget} that is active on the user worker. Active
* means that the work is received from Windmill, being processed or queued to be processed in
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java
index 434e78484799..f3b0ba16fbe2 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java
@@ -23,14 +23,13 @@
import java.util.Optional;
import java.util.concurrent.ConcurrentLinkedQueue;
import javax.annotation.Nullable;
-import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler;
import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor;
-import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap;
import org.joda.time.Instant;
@@ -147,10 +146,12 @@ private void forceExecute(ExecutableWork executableWork) {
executor.forceExecute(executableWork, executableWork.work().getWorkItem().getSerializedSize());
}
- /** Gets HeartbeatRequests for any work started before refreshDeadline. */
- public ImmutableList getKeyHeartbeats(
- Instant refreshDeadline, DataflowExecutionStateSampler sampler) {
- return activeWorkState.getKeyHeartbeats(refreshDeadline, sampler);
+ public ImmutableListMultimap currentActiveWorkReadOnly() {
+ return activeWorkState.getReadOnlyActiveWork();
+ }
+
+ public ImmutableList getRefreshableWork(Instant refreshDeadline) {
+ return activeWorkState.getRefreshableWork(refreshDeadline);
}
public GetWorkBudget getActiveWorkBudget() {
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java
index bdf8a7814ea3..db279f066630 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java
@@ -31,7 +31,7 @@ public static ExecutableWork create(Work work, Consumer executeWorkFn) {
public abstract Work work();
- abstract Consumer executeWorkFn();
+ public abstract Consumer executeWorkFn();
@Override
public void run() {
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/RefreshableWork.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/RefreshableWork.java
new file mode 100644
index 000000000000..c51b04f23719
--- /dev/null
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/RefreshableWork.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.streaming;
+
+import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+
+/** View of {@link Work} that exposes an interface for work refreshing. */
+@Internal
+public interface RefreshableWork {
+
+ WorkId id();
+
+ ShardedKey getShardedKey();
+
+ HeartbeatSender heartbeatSender();
+
+ ImmutableList getHeartbeatLatencyAttributions(
+ DataflowExecutionStateSampler sampler);
+
+ void setFailed();
+}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
index ed3f2671b40c..e77823602eda 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
@@ -27,14 +27,14 @@
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
-import java.util.function.BiFunction;
import java.util.function.Consumer;
-import java.util.function.Function;
import java.util.function.Supplier;
import javax.annotation.concurrent.NotThreadSafe;
import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair;
import org.apache.beam.runners.dataflow.worker.ActiveMessageMetadata;
import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution;
@@ -45,7 +45,9 @@
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commit;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader;
+import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.joda.time.Duration;
@@ -58,7 +60,7 @@
*/
@NotThreadSafe
@Internal
-public final class Work {
+public final class Work implements RefreshableWork {
private final ShardedKey shardedKey;
private final WorkItem workItem;
private final ProcessingContext processingContext;
@@ -105,9 +107,10 @@ public static Work create(
public static ProcessingContext createProcessingContext(
String computationId,
- BiFunction getKeyedDataFn,
- Consumer workCommitter) {
- return ProcessingContext.create(computationId, getKeyedDataFn, workCommitter);
+ GetDataClient getDataClient,
+ Consumer workCommitter,
+ HeartbeatSender heartbeatSender) {
+ return ProcessingContext.create(computationId, getDataClient, workCommitter, heartbeatSender);
}
private static LatencyAttribution.Builder createLatencyAttributionWithActiveLatencyBreakdown(
@@ -151,12 +154,17 @@ public WorkItem getWorkItem() {
return workItem;
}
+ @Override
public ShardedKey getShardedKey() {
return shardedKey;
}
public Optional fetchKeyedState(KeyedGetDataRequest keyedGetDataRequest) {
- return processingContext.keyedDataFetcher().apply(keyedGetDataRequest);
+ return processingContext.fetchKeyedState(keyedGetDataRequest);
+ }
+
+ public GlobalData fetchSideInput(GlobalDataRequest request) {
+ return processingContext.getDataClient().getSideInputData(request);
}
public Watermarks watermarks() {
@@ -180,6 +188,7 @@ public void setState(State state) {
this.currentState = TimedState.create(state, now);
}
+ @Override
public void setFailed() {
this.isFailed = true;
}
@@ -196,6 +205,11 @@ public String getLatencyTrackingId() {
return latencyTrackingId;
}
+ @Override
+ public HeartbeatSender heartbeatSender() {
+ return processingContext.heartbeatSender();
+ }
+
public void queueCommit(WorkItemCommitRequest commitRequest, ComputationState computationState) {
setState(State.COMMIT_QUEUED);
processingContext.workCommitter().accept(Commit.create(commitRequest, computationState, this));
@@ -205,6 +219,7 @@ public WindmillStateReader createWindmillStateReader() {
return WindmillStateReader.forWork(this);
}
+ @Override
public WorkId id() {
return id;
}
@@ -216,7 +231,25 @@ private void recordGetWorkStreamLatencies(Collection getWork
}
}
+ @Override
+ public ImmutableList getHeartbeatLatencyAttributions(
+ DataflowExecutionStateSampler sampler) {
+ return getLatencyAttributions(/* isHeartbeat= */ true, sampler);
+ }
+
public ImmutableList getLatencyAttributions(
+ DataflowExecutionStateSampler sampler) {
+ return getLatencyAttributions(/* isHeartbeat= */ false, sampler);
+ }
+
+ private Duration getTotalDurationAtLatencyAttributionState(LatencyAttribution.State state) {
+ Duration duration = totalDurationPerState.getOrDefault(state, Duration.ZERO);
+ return state == this.currentState.state().toLatencyAttributionState()
+ ? duration.plus(new Duration(this.currentState.startTime(), clock.get()))
+ : duration;
+ }
+
+ private ImmutableList getLatencyAttributions(
boolean isHeartbeat, DataflowExecutionStateSampler sampler) {
return Arrays.stream(LatencyAttribution.State.values())
.map(state -> Pair.of(state, getTotalDurationAtLatencyAttributionState(state)))
@@ -233,13 +266,6 @@ public ImmutableList getLatencyAttributions(
.collect(toImmutableList());
}
- private Duration getTotalDurationAtLatencyAttributionState(LatencyAttribution.State state) {
- Duration duration = totalDurationPerState.getOrDefault(state, Duration.ZERO);
- return state == this.currentState.state().toLatencyAttributionState()
- ? duration.plus(new Duration(this.currentState.startTime(), clock.get()))
- : duration;
- }
-
private LatencyAttribution createLatencyAttribution(
LatencyAttribution.State state,
boolean isHeartbeat,
@@ -314,25 +340,29 @@ public abstract static class ProcessingContext {
private static ProcessingContext create(
String computationId,
- BiFunction getKeyedDataFn,
- Consumer workCommitter) {
+ GetDataClient getDataClient,
+ Consumer workCommitter,
+ HeartbeatSender heartbeatSender) {
return new AutoValue_Work_ProcessingContext(
- computationId,
- request -> Optional.ofNullable(getKeyedDataFn.apply(computationId, request)),
- workCommitter);
+ computationId, getDataClient, heartbeatSender, workCommitter);
}
/** Computation that the {@link Work} belongs to. */
public abstract String computationId();
/** Handles GetData requests to streaming backend. */
- public abstract Function>
- keyedDataFetcher();
+ public abstract GetDataClient getDataClient();
+
+ public abstract HeartbeatSender heartbeatSender();
/**
* {@link WorkCommitter} that commits completed work to the backend Windmill worker handling the
* {@link WorkItem}.
*/
public abstract Consumer workCommitter();
+
+ private Optional fetchKeyedState(KeyedGetDataRequest request) {
+ return Optional.ofNullable(getDataClient().getStateData(computationId(), request));
+ }
}
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkId.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkId.java
index f8f8d1901914..d4e7f05d255f 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkId.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkId.java
@@ -41,9 +41,9 @@ public static WorkId of(Windmill.WorkItem workItem) {
.build();
}
- abstract long cacheToken();
+ public abstract long cacheToken();
- abstract long workToken();
+ public abstract long workToken();
@AutoValue.Builder
public abstract static class Builder {
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java
index 7fd2487575c2..303cdeb94f8c 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java
@@ -30,11 +30,11 @@
import java.util.function.Function;
import javax.annotation.concurrent.NotThreadSafe;
import org.apache.beam.runners.core.InMemoryMultimapSideInputView;
-import org.apache.beam.runners.dataflow.options.DataflowStreamingPipelineOptions;
import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest;
+import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
@@ -46,14 +46,14 @@
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
-import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-/** Class responsible for fetching state from the windmill server. */
+/** Class responsible for fetching side input state from the streaming backend. */
@NotThreadSafe
+@Internal
public class SideInputStateFetcher {
private static final Logger LOG = LoggerFactory.getLogger(SideInputStateFetcher.class);
@@ -64,13 +64,6 @@ public class SideInputStateFetcher {
private final Function fetchGlobalDataFn;
private long bytesRead = 0L;
- public SideInputStateFetcher(
- Function fetchGlobalDataFn,
- DataflowStreamingPipelineOptions options) {
- this(fetchGlobalDataFn, SideInputCache.create(options));
- }
-
- @VisibleForTesting
SideInputStateFetcher(
Function fetchGlobalDataFn, SideInputCache sideInputCache) {
this.fetchGlobalDataFn = fetchGlobalDataFn;
@@ -103,12 +96,56 @@ private static Coder> getCoder(PCollectionView view) {
return view.getCoderInternal();
}
- /** Returns a view of the underlying cache that keeps track of bytes read separately. */
- public SideInputStateFetcher byteTrackingView() {
- return new SideInputStateFetcher(fetchGlobalDataFn, sideInputCache);
+ private static SideInput createSideInputCacheEntry(
+ PCollectionView view, GlobalData data) throws IOException {
+ Iterable> rawData = decodeRawData(view, data);
+ switch (getViewFn(view).getMaterialization().getUrn()) {
+ case ITERABLE_MATERIALIZATION_URN:
+ {
+ @SuppressWarnings({
+ "unchecked", // ITERABLE_MATERIALIZATION_URN has ViewFn.
+ "rawtypes" // TODO(https://github.com/apache/beam/issues/20447)
+ })
+ ViewFn viewFn = (ViewFn) getViewFn(view);
+ return SideInput.ready(viewFn.apply(() -> rawData), data.getData().size());
+ }
+ case MULTIMAP_MATERIALIZATION_URN:
+ {
+ @SuppressWarnings({
+ "unchecked", // MULTIMAP_MATERIALIZATION_URN has ViewFn.
+ "rawtypes" // TODO(https://github.com/apache/beam/issues/20447)
+ })
+ ViewFn viewFn = (ViewFn) getViewFn(view);
+ Coder> keyCoder = ((KvCoder, ?>) getCoder(view)).getKeyCoder();
+
+ @SuppressWarnings({
+ "unchecked", // Safe since multimap rawData is of type Iterable>
+ "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
+ })
+ T multimapSideInputValue =
+ viewFn.apply(
+ InMemoryMultimapSideInputView.fromIterable(keyCoder, (Iterable) rawData));
+ return SideInput.ready(multimapSideInputValue, data.getData().size());
+ }
+ default:
+ {
+ throw new IllegalStateException(
+ "Unknown side input materialization format requested: "
+ + getViewFn(view).getMaterialization().getUrn());
+ }
+ }
}
- public long getBytesRead() {
+ private static void validateViewMaterialization(PCollectionView view) {
+ String materializationUrn = getViewFn(view).getMaterialization().getUrn();
+ checkState(
+ SUPPORTED_MATERIALIZATIONS.contains(materializationUrn),
+ "Only materialization's of type %s supported, received %s",
+ SUPPORTED_MATERIALIZATIONS,
+ materializationUrn);
+ }
+
+ public final long getBytesRead() {
return bytesRead;
}
@@ -200,53 +237,4 @@ private SideInput loadSideInputFromWindmill(
bytesRead += data.getSerializedSize();
return data.getIsReady() ? createSideInputCacheEntry(view, data) : SideInput.notReady();
}
-
- private void validateViewMaterialization(PCollectionView view) {
- String materializationUrn = getViewFn(view).getMaterialization().getUrn();
- checkState(
- SUPPORTED_MATERIALIZATIONS.contains(materializationUrn),
- "Only materialization's of type %s supported, received %s",
- SUPPORTED_MATERIALIZATIONS,
- materializationUrn);
- }
-
- private SideInput createSideInputCacheEntry(PCollectionView view, GlobalData data)
- throws IOException {
- Iterable> rawData = decodeRawData(view, data);
- switch (getViewFn(view).getMaterialization().getUrn()) {
- case ITERABLE_MATERIALIZATION_URN:
- {
- @SuppressWarnings({
- "unchecked", // ITERABLE_MATERIALIZATION_URN has ViewFn.
- "rawtypes" // TODO(https://github.com/apache/beam/issues/20447)
- })
- ViewFn viewFn = (ViewFn) getViewFn(view);
- return SideInput.ready(viewFn.apply(() -> rawData), data.getData().size());
- }
- case MULTIMAP_MATERIALIZATION_URN:
- {
- @SuppressWarnings({
- "unchecked", // MULTIMAP_MATERIALIZATION_URN has ViewFn.
- "rawtypes" // TODO(https://github.com/apache/beam/issues/20447)
- })
- ViewFn viewFn = (ViewFn) getViewFn(view);
- Coder> keyCoder = ((KvCoder, ?>) getCoder(view)).getKeyCoder();
-
- @SuppressWarnings({
- "unchecked", // Safe since multimap rawData is of type Iterable>
- "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
- })
- T multimapSideInputValue =
- viewFn.apply(
- InMemoryMultimapSideInputView.fromIterable(keyCoder, (Iterable) rawData));
- return SideInput.ready(multimapSideInputValue, data.getData().size());
- }
- default:
- {
- throw new IllegalStateException(
- "Unknown side input materialization format requested: "
- + getViewFn(view).getMaterialization().getUrn());
- }
- }
- }
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcherFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcherFactory.java
new file mode 100644
index 000000000000..fd42b9ff1801
--- /dev/null
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcherFactory.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.streaming.sideinput;
+
+import java.util.function.Function;
+import org.apache.beam.runners.dataflow.options.DataflowStreamingPipelineOptions;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest;
+import org.apache.beam.sdk.annotations.Internal;
+
+/**
+ * Factory class for generating {@link SideInputStateFetcher} instances that share a {@link
+ * SideInputCache}.
+ */
+@Internal
+public final class SideInputStateFetcherFactory {
+ private final SideInputCache globalSideInputCache;
+
+ private SideInputStateFetcherFactory(SideInputCache globalSideInputCache) {
+ this.globalSideInputCache = globalSideInputCache;
+ }
+
+ public static SideInputStateFetcherFactory fromOptions(DataflowStreamingPipelineOptions options) {
+ return new SideInputStateFetcherFactory(SideInputCache.create(options));
+ }
+
+ public SideInputStateFetcher createSideInputStateFetcher(
+ Function fetchGlobalDataFn) {
+ return new SideInputStateFetcher(fetchGlobalDataFn, globalSideInputCache);
+ }
+}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/ApplianceWindmillClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/ApplianceWindmillClient.java
new file mode 100644
index 000000000000..2cd3748eb31b
--- /dev/null
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/ApplianceWindmillClient.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill;
+
+import org.apache.beam.sdk.annotations.Internal;
+
+/** Client for WindmillService via Streaming Appliance. */
+@Internal
+public interface ApplianceWindmillClient {
+ /** Get a batch of work to process. */
+ Windmill.GetWorkResponse getWork(Windmill.GetWorkRequest request);
+
+ /** Get additional data such as state needed to process work. */
+ Windmill.GetDataResponse getData(Windmill.GetDataRequest request);
+
+ /** Commit the work, issuing any output productions, state modifications etc. */
+ Windmill.CommitWorkResponse commitWork(Windmill.CommitWorkRequest request);
+
+ /** Get configuration data from the server. */
+ Windmill.GetConfigResponse getConfig(Windmill.GetConfigRequest request);
+
+ /** Report execution information to the server. */
+ Windmill.ReportStatsResponse reportStats(Windmill.ReportStatsRequest request);
+}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/StreamingEngineWindmillClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/StreamingEngineWindmillClient.java
new file mode 100644
index 000000000000..e02e6c112358
--- /dev/null
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/StreamingEngineWindmillClient.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill;
+
+import java.util.Set;
+import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemReceiver;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
+
+/** Client for WindmillService via Streaming Engine. */
+@Internal
+public interface StreamingEngineWindmillClient {
+ /** Returns the windmill service endpoints set by setWindmillServiceEndpoints */
+ ImmutableSet getWindmillServiceEndpoints();
+
+ /**
+ * Sets the new endpoints used to talk to windmill. Upon first call, the stubs are initialized. On
+ * subsequent calls, if endpoints are different from previous values new stubs are created,
+ * replacing the previous ones.
+ */
+ void setWindmillServiceEndpoints(Set endpoints);
+
+ /**
+ * Gets work to process, returned as a stream.
+ *
+ * Each time a WorkItem is received, it will be passed to the given receiver. The returned
+ * GetWorkStream object can be used to control the lifetime of the stream.
+ */
+ WindmillStream.GetWorkStream getWorkStream(
+ Windmill.GetWorkRequest request, WorkItemReceiver receiver);
+
+ /** Get additional data such as state needed to process work, returned as a stream. */
+ WindmillStream.GetDataStream getDataStream();
+
+ /** Returns a stream allowing individual WorkItemCommitRequests to be streamed to Windmill. */
+ WindmillStream.CommitWorkStream commitWorkStream();
+}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillConnection.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillConnection.java
index a20c2f02b269..7d199afc0861 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillConnection.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillConnection.java
@@ -27,6 +27,8 @@
@AutoValue
@Internal
public abstract class WindmillConnection {
+ private static final String NO_BACKEND_WORKER_TOKEN = "";
+
public static WindmillConnection from(
Endpoint windmillEndpoint,
Function endpointToStubFn) {
@@ -40,23 +42,24 @@ public static WindmillConnection from(
}
public static Builder builder() {
- return new AutoValue_WindmillConnection.Builder();
+ return new AutoValue_WindmillConnection.Builder()
+ .setBackendWorkerToken(NO_BACKEND_WORKER_TOKEN);
}
- public abstract Optional backendWorkerToken();
+ public abstract String backendWorkerToken();
public abstract Optional directEndpoint();
public abstract CloudWindmillServiceV1Alpha1Stub stub();
@AutoValue.Builder
- abstract static class Builder {
+ public abstract static class Builder {
abstract Builder setBackendWorkerToken(String backendWorkerToken);
public abstract Builder setDirectEndpoint(WindmillServiceAddress value);
- abstract Builder setStub(CloudWindmillServiceV1Alpha1Stub stub);
+ public abstract Builder setStub(CloudWindmillServiceV1Alpha1Stub stub);
- abstract WindmillConnection build();
+ public abstract WindmillConnection build();
}
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerBase.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerBase.java
index 0785ae96626e..5f7fd6da9d4b 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerBase.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerBase.java
@@ -59,11 +59,6 @@ public ImmutableSet getWindmillServiceEndpoints() {
return ImmutableSet.of();
}
- @Override
- public boolean isReady() {
- return true;
- }
-
@Override
public Windmill.GetWorkResponse getWork(Windmill.GetWorkRequest workRequest) {
try {
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java
index 7d0c4f5aba32..cd753cb8ec91 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java
@@ -18,65 +18,11 @@
package org.apache.beam.runners.dataflow.worker.windmill;
import java.io.PrintWriter;
-import java.util.Set;
import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider;
-import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
-import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
-import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream;
-import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemReceiver;
-import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
-import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
/** Stub for communicating with a Windmill server. */
-@SuppressWarnings({
- "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class WindmillServerStub implements StatusDataProvider {
-
- /**
- * Sets the new endpoints used to talk to windmill. Upon first call, the stubs are initialized. On
- * subsequent calls, if endpoints are different from previous values new stubs are created,
- * replacing the previous ones.
- */
- public abstract void setWindmillServiceEndpoints(Set endpoints);
-
- /*
- * Returns the windmill service endpoints set by setWindmillServiceEndpoints
- */
- public abstract ImmutableSet getWindmillServiceEndpoints();
-
- /** Returns true iff this WindmillServerStub is ready for making API calls. */
- public abstract boolean isReady();
-
- /** Get a batch of work to process. */
- public abstract Windmill.GetWorkResponse getWork(Windmill.GetWorkRequest request);
-
- /** Get additional data such as state needed to process work. */
- public abstract Windmill.GetDataResponse getData(Windmill.GetDataRequest request);
-
- /** Commit the work, issuing any output productions, state modifications etc. */
- public abstract Windmill.CommitWorkResponse commitWork(Windmill.CommitWorkRequest request);
-
- /** Get configuration data from the server. */
- public abstract Windmill.GetConfigResponse getConfig(Windmill.GetConfigRequest request);
-
- /** Report execution information to the server. */
- public abstract Windmill.ReportStatsResponse reportStats(Windmill.ReportStatsRequest request);
-
- /**
- * Gets work to process, returned as a stream.
- *
- * Each time a WorkItem is received, it will be passed to the given receiver. The returned
- * GetWorkStream object can be used to control the lifetime of the stream.
- */
- public abstract GetWorkStream getWorkStream(
- Windmill.GetWorkRequest request, WorkItemReceiver receiver);
-
- /** Get additional data such as state needed to process work, returned as a stream. */
- public abstract GetDataStream getDataStream();
-
- /** Returns a stream allowing individual WorkItemCommitRequests to be streamed to Windmill. */
- public abstract CommitWorkStream commitWorkStream();
+public abstract class WindmillServerStub
+ implements ApplianceWindmillClient, StreamingEngineWindmillClient, StatusDataProvider {
/** Returns the amount of time the server has been throttled and resets the time to 0. */
public abstract long getAndResetThrottleTime();
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java
index 028a5c2e1d4b..58aecfc71e00 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java
@@ -69,6 +69,7 @@ public abstract class AbstractWindmillStream implements Win
protected static final int RPC_STREAM_CHUNK_SIZE = 2 << 20;
private static final Logger LOG = LoggerFactory.getLogger(AbstractWindmillStream.class);
protected final AtomicBoolean clientClosed;
+ private final AtomicBoolean isShutdown;
private final AtomicLong lastSendTimeMs;
private final Executor executor;
private final BackOff backoff;
@@ -84,19 +85,23 @@ public abstract class AbstractWindmillStream implements Win
private final Supplier> requestObserverSupplier;
// Indicates if the current stream in requestObserver is closed by calling close() method
private final AtomicBoolean streamClosed;
+ private final String backendWorkerToken;
private @Nullable StreamObserver requestObserver;
protected AbstractWindmillStream(
+ String debugStreamType,
Function, StreamObserver> clientFactory,
BackOff backoff,
StreamObserverFactory streamObserverFactory,
Set> streamRegistry,
- int logEveryNStreamFailures) {
+ int logEveryNStreamFailures,
+ String backendWorkerToken) {
+ this.backendWorkerToken = backendWorkerToken;
this.executor =
Executors.newSingleThreadExecutor(
new ThreadFactoryBuilder()
.setDaemon(true)
- .setNameFormat("WindmillStream-thread")
+ .setNameFormat(createThreadName(debugStreamType, backendWorkerToken))
.build());
this.backoff = backoff;
this.streamRegistry = streamRegistry;
@@ -111,12 +116,19 @@ protected AbstractWindmillStream(
this.lastErrorTime = new AtomicReference<>();
this.sleepUntil = new AtomicLong();
this.finishLatch = new CountDownLatch(1);
+ this.isShutdown = new AtomicBoolean(false);
this.requestObserverSupplier =
() ->
streamObserverFactory.from(
clientFactory, new AbstractWindmillStream.ResponseObserver());
}
+ private static String createThreadName(String streamType, String backendWorkerToken) {
+ return !backendWorkerToken.isEmpty()
+ ? String.format("%s-%s-WindmillStream-thread", streamType, backendWorkerToken)
+ : String.format("%s-WindmillStream-thread", streamType);
+ }
+
private static long debugDuration(long nowMs, long startMs) {
if (startMs <= 0) {
return -1;
@@ -140,6 +152,11 @@ private static long debugDuration(long nowMs, long startMs) {
*/
protected abstract void startThrottleTimer();
+ /** Reflects that {@link #shutdown()} was explicitly called. */
+ protected boolean isShutdown() {
+ return isShutdown.get();
+ }
+
private StreamObserver requestObserver() {
if (requestObserver == null) {
throw new NullPointerException(
@@ -175,7 +192,7 @@ protected final void startStream() {
requestObserver = requestObserverSupplier.get();
onNewStream();
if (clientClosed.get()) {
- close();
+ halfClose();
}
return;
}
@@ -238,7 +255,7 @@ public final void appendSummaryHtml(PrintWriter writer) {
protected abstract void appendSpecificHtml(PrintWriter writer);
@Override
- public final synchronized void close() {
+ public final synchronized void halfClose() {
// Synchronization of close and onCompleted necessary for correct retry logic in onNewStream.
clientClosed.set(true);
requestObserver().onCompleted();
@@ -255,6 +272,30 @@ public final Instant startTime() {
return new Instant(startTimeMs.get());
}
+ @Override
+ public String backendWorkerToken() {
+ return backendWorkerToken;
+ }
+
+ @Override
+ public void shutdown() {
+ if (isShutdown.compareAndSet(false, true)) {
+ requestObserver()
+ .onError(new WindmillStreamShutdownException("Explicit call to shutdown stream."));
+ }
+ }
+
+ private void setLastError(String error) {
+ lastError.set(error);
+ lastErrorTime.set(DateTime.now());
+ }
+
+ public static class WindmillStreamShutdownException extends RuntimeException {
+ public WindmillStreamShutdownException(String message) {
+ super(message);
+ }
+ }
+
private class ResponseObserver implements StreamObserver {
@Override
@@ -280,7 +321,7 @@ public void onCompleted() {
private void onStreamFinished(@Nullable Throwable t) {
synchronized (this) {
- if (clientClosed.get() && !hasPendingRequests()) {
+ if (isShutdown.get() || (clientClosed.get() && !hasPendingRequests())) {
streamRegistry.remove(AbstractWindmillStream.this);
finishLatch.countDown();
return;
@@ -337,9 +378,4 @@ private void onStreamFinished(@Nullable Throwable t) {
executor.execute(AbstractWindmillStream.this::startStream);
}
}
-
- private void setLastError(String error) {
- lastError.set(error);
- lastErrorTime.set(DateTime.now());
- }
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java
index d044e9300790..31bd4e146a78 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java
@@ -18,6 +18,7 @@
package org.apache.beam.runners.dataflow.worker.windmill.client;
import java.io.Closeable;
+import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
@@ -32,8 +33,12 @@
/** Superclass for streams returned by streaming Windmill methods. */
@ThreadSafe
public interface WindmillStream {
+
+ /** An identifier for the backend worker where the stream is sending/receiving RPCs. */
+ String backendWorkerToken();
+
/** Indicates that no more requests will be sent. */
- void close();
+ void halfClose();
/** Waits for the server to close its end of the connection, with timeout. */
boolean awaitTermination(int time, TimeUnit unit) throws InterruptedException;
@@ -41,6 +46,12 @@ public interface WindmillStream {
/** Returns when the stream was opened. */
Instant startTime();
+ /**
+ * Shutdown the stream. There should be no further interactions with the stream once this has been
+ * called.
+ */
+ void shutdown();
+
/** Handle representing a stream of GetWork responses. */
@ThreadSafe
interface GetWorkStream extends WindmillStream {
@@ -62,7 +73,7 @@ Windmill.KeyedGetDataResponse requestKeyedData(
Windmill.GlobalData requestGlobalData(Windmill.GlobalDataRequest request);
/** Tells windmill processing is ongoing for the given keys. */
- void refreshActiveWork(Map> heartbeats);
+ void refreshActiveWork(Map> heartbeats);
void onHeartbeatResponse(List responses);
}
@@ -70,6 +81,12 @@ Windmill.KeyedGetDataResponse requestKeyedData(
/** Interface for streaming CommitWorkRequests to Windmill. */
@ThreadSafe
interface CommitWorkStream extends WindmillStream {
+ /**
+ * Returns a builder that can be used for sending requests. Each builder is not thread-safe but
+ * different builders for the same stream may be used simultaneously.
+ */
+ CommitWorkStream.RequestBatcher batcher();
+
@NotThreadSafe
interface RequestBatcher extends Closeable {
/**
@@ -92,12 +109,6 @@ default void close() {
flush();
}
}
-
- /**
- * Returns a builder that can be used for sending requests. Each builder is not thread-safe but
- * different builders for the same stream may be used simultaneously.
- */
- RequestBatcher batcher();
}
/** Interface for streaming GetWorkerMetadata requests to Windmill. */
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPool.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPool.java
index 0e4e085c066c..f14fc40fdfdf 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPool.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPool.java
@@ -128,7 +128,7 @@ public StreamT getStream() {
return resultStream;
} finally {
if (closeThisStream != null) {
- closeThisStream.close();
+ closeThisStream.halfClose();
}
}
}
@@ -166,7 +166,7 @@ public void releaseStream(StreamT stream) {
}
if (closeStream) {
- stream.close();
+ stream.halfClose();
}
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java
index ed4dcfa212f1..bf1007bc4bfb 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java
@@ -17,9 +17,11 @@
*/
package org.apache.beam.runners.dataflow.worker.windmill.client.commits;
+import com.google.auto.value.AutoBuilder;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.function.Supplier;
@@ -45,6 +47,7 @@ public final class StreamingEngineWorkCommitter implements WorkCommitter {
private static final Logger LOG = LoggerFactory.getLogger(StreamingEngineWorkCommitter.class);
private static final int TARGET_COMMIT_BATCH_KEYS = 5;
private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB
+ private static final String NO_BACKEND_WORKER_TOKEN = "";
private final Supplier> commitWorkStreamFactory;
private final WeightedBoundedQueue commitQueue;
@@ -52,11 +55,13 @@ public final class StreamingEngineWorkCommitter implements WorkCommitter {
private final AtomicLong activeCommitBytes;
private final Consumer onCommitComplete;
private final int numCommitSenders;
+ private final AtomicBoolean isRunning;
- private StreamingEngineWorkCommitter(
+ StreamingEngineWorkCommitter(
Supplier> commitWorkStreamFactory,
int numCommitSenders,
- Consumer onCommitComplete) {
+ Consumer onCommitComplete,
+ String backendWorkerToken) {
this.commitWorkStreamFactory = commitWorkStreamFactory;
this.commitQueue =
WeightedBoundedQueue.create(
@@ -67,34 +72,48 @@ private StreamingEngineWorkCommitter(
new ThreadFactoryBuilder()
.setDaemon(true)
.setPriority(Thread.MAX_PRIORITY)
- .setNameFormat("CommitThread-%d")
+ .setNameFormat(
+ backendWorkerToken.isEmpty()
+ ? "CommitThread-%d"
+ : "CommitThread-" + backendWorkerToken + "-%d")
.build());
this.activeCommitBytes = new AtomicLong();
this.onCommitComplete = onCommitComplete;
this.numCommitSenders = numCommitSenders;
+ this.isRunning = new AtomicBoolean(false);
}
- public static StreamingEngineWorkCommitter create(
- Supplier> commitWorkStreamFactory,
- int numCommitSenders,
- Consumer onCommitComplete) {
- return new StreamingEngineWorkCommitter(
- commitWorkStreamFactory, numCommitSenders, onCommitComplete);
+ public static Builder builder() {
+ return new AutoBuilder_StreamingEngineWorkCommitter_Builder()
+ .setBackendWorkerToken(NO_BACKEND_WORKER_TOKEN)
+ .setNumCommitSenders(1);
}
@Override
@SuppressWarnings("FutureReturnValueIgnored")
public void start() {
- if (!commitSenders.isShutdown()) {
- for (int i = 0; i < numCommitSenders; i++) {
- commitSenders.submit(this::streamingCommitLoop);
- }
+ Preconditions.checkState(
+ isRunning.compareAndSet(false, true), "Multiple calls to WorkCommitter.start().");
+ for (int i = 0; i < numCommitSenders; i++) {
+ commitSenders.submit(this::streamingCommitLoop);
}
}
@Override
public void commit(Commit commit) {
- commitQueue.put(commit);
+ boolean isShutdown = !this.isRunning.get();
+ if (commit.work().isFailed() || isShutdown) {
+ if (isShutdown) {
+ LOG.debug(
+ "Trying to queue commit on shutdown, failing commit=[computationId={}, shardingKey={}, workId={} ].",
+ commit.computationId(),
+ commit.work().getShardedKey(),
+ commit.work().id());
+ }
+ failCommit(commit);
+ } else {
+ commitQueue.put(commit);
+ }
}
@Override
@@ -104,15 +123,14 @@ public long currentActiveCommitBytes() {
@Override
public void stop() {
- if (!commitSenders.isTerminated()) {
- commitSenders.shutdownNow();
- try {
- commitSenders.awaitTermination(10, TimeUnit.SECONDS);
- } catch (InterruptedException e) {
- LOG.warn(
- "Commit senders didn't complete shutdown within 10 seconds, continuing to drain queue",
- e);
- }
+ Preconditions.checkState(isRunning.compareAndSet(true, false));
+ commitSenders.shutdownNow();
+ try {
+ commitSenders.awaitTermination(10, TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ LOG.warn(
+ "Commit senders didn't complete shutdown within 10 seconds, continuing to drain queue.",
+ e);
}
drainCommitQueue();
}
@@ -138,12 +156,13 @@ public int parallelism() {
private void streamingCommitLoop() {
@Nullable Commit initialCommit = null;
try {
- while (true) {
+ while (isRunning.get()) {
if (initialCommit == null) {
try {
// Block until we have a commit or are shutting down.
initialCommit = commitQueue.take();
} catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
return;
}
}
@@ -156,17 +175,14 @@ private void streamingCommitLoop() {
}
try (CloseableStream closeableCommitStream =
- commitWorkStreamFactory.get()) {
- CommitWorkStream commitStream = closeableCommitStream.stream();
- try (CommitWorkStream.RequestBatcher batcher = commitStream.batcher()) {
- if (!tryAddToCommitBatch(initialCommit, batcher)) {
- throw new AssertionError(
- "Initial commit on flushed stream should always be accepted.");
- }
- // Batch additional commits to the stream and possibly make an un-batched commit the
- // next initial commit.
- initialCommit = expandBatch(batcher);
+ commitWorkStreamFactory.get();
+ CommitWorkStream.RequestBatcher batcher = closeableCommitStream.stream().batcher()) {
+ if (!tryAddToCommitBatch(initialCommit, batcher)) {
+ throw new AssertionError("Initial commit on flushed stream should always be accepted.");
}
+ // Batch additional commits to the stream and possibly make an un-batched commit the
+ // next initial commit.
+ initialCommit = expandBatch(batcher);
} catch (Exception e) {
LOG.error("Error occurred sending commits.", e);
}
@@ -187,7 +203,7 @@ private boolean tryAddToCommitBatch(Commit commit, CommitWorkStream.RequestBatch
batcher.commitWorkItem(
commit.computationId(),
commit.request(),
- (commitStatus) -> {
+ commitStatus -> {
onCommitComplete.accept(CompleteCommit.create(commit, commitStatus));
activeCommitBytes.addAndGet(-commit.getSize());
});
@@ -201,9 +217,11 @@ private boolean tryAddToCommitBatch(Commit commit, CommitWorkStream.RequestBatch
return isCommitAccepted;
}
- // Helper to batch additional commits into the commit batch as long as they fit.
- // Returns a commit that was removed from the queue but not consumed or null.
- private Commit expandBatch(CommitWorkStream.RequestBatcher batcher) {
+ /**
+ * Helper to batch additional commits into the commit batch as long as they fit. Returns a commit
+ * that was removed from the queue but not consumed or null.
+ */
+ private @Nullable Commit expandBatch(CommitWorkStream.RequestBatcher batcher) {
int commits = 1;
while (true) {
Commit commit;
@@ -214,6 +232,7 @@ private Commit expandBatch(CommitWorkStream.RequestBatcher batcher) {
commit = commitQueue.poll();
}
} catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
return null;
}
@@ -233,4 +252,22 @@ private Commit expandBatch(CommitWorkStream.RequestBatcher batcher) {
commits++;
}
}
+
+ @AutoBuilder
+ public interface Builder {
+ Builder setCommitWorkStreamFactory(
+ Supplier> commitWorkStreamFactory);
+
+ Builder setNumCommitSenders(int numCommitSenders);
+
+ Builder setOnCommitComplete(Consumer onCommitComplete);
+
+ Builder setBackendWorkerToken(String backendWorkerToken);
+
+ StreamingEngineWorkCommitter autoBuild();
+
+ default WorkCommitter build() {
+ return autoBuild();
+ }
+ }
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ApplianceGetDataClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ApplianceGetDataClient.java
new file mode 100644
index 000000000000..e0500dde0c53
--- /dev/null
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ApplianceGetDataClient.java
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.getdata;
+
+import java.io.PrintWriter;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ExecutionException;
+import javax.annotation.concurrent.GuardedBy;
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.runners.dataflow.worker.WindmillComputationKey;
+import org.apache.beam.runners.dataflow.worker.windmill.ApplianceWindmillClient;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.SettableFuture;
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+/** Appliance implementation of {@link GetDataClient}. */
+@Internal
+@ThreadSafe
+public final class ApplianceGetDataClient implements GetDataClient {
+ private static final int MAX_READS_PER_BATCH = 60;
+ private static final int MAX_ACTIVE_READS = 10;
+
+ private final ApplianceWindmillClient windmillClient;
+ private final ThrottlingGetDataMetricTracker getDataMetricTracker;
+
+ @GuardedBy("this")
+ private final List pendingReadBatches;
+
+ @GuardedBy("this")
+ private int activeReadThreads;
+
+ public ApplianceGetDataClient(
+ ApplianceWindmillClient windmillClient, ThrottlingGetDataMetricTracker getDataMetricTracker) {
+ this.windmillClient = windmillClient;
+ this.getDataMetricTracker = getDataMetricTracker;
+ this.pendingReadBatches = new ArrayList<>();
+ this.activeReadThreads = 0;
+ }
+
+ @Override
+ public Windmill.KeyedGetDataResponse getStateData(
+ String computationId, Windmill.KeyedGetDataRequest request) {
+ try (AutoCloseable ignored = getDataMetricTracker.trackStateDataFetchWithThrottling()) {
+ SettableFuture response = SettableFuture.create();
+ ReadBatch batch = addToReadBatch(new QueueEntry(computationId, request, response));
+ if (batch != null) {
+ issueReadBatch(batch);
+ }
+ return response.get();
+ } catch (Exception e) {
+ throw new GetDataException(
+ "Error occurred fetching state for computation="
+ + computationId
+ + ", key="
+ + request.getShardingKey(),
+ e);
+ }
+ }
+
+ @Override
+ public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) {
+ try (AutoCloseable ignored = getDataMetricTracker.trackSideInputFetchWithThrottling()) {
+ return windmillClient
+ .getData(Windmill.GetDataRequest.newBuilder().addGlobalDataFetchRequests(request).build())
+ .getGlobalData(0);
+ } catch (Exception e) {
+ throw new GetDataException(
+ "Error occurred fetching side input for tag=" + request.getDataId(), e);
+ }
+ }
+
+ @Override
+ public synchronized void printHtml(PrintWriter writer) {
+ getDataMetricTracker.printHtml(writer);
+ writer.println(" Read threads: " + activeReadThreads);
+ writer.println(" Pending read batches: " + pendingReadBatches.size());
+ }
+
+ private void issueReadBatch(ReadBatch batch) {
+ try {
+ // Possibly block until the batch is allowed to start.
+ batch.startRead.get();
+ } catch (InterruptedException e) {
+ // We don't expect this thread to be interrupted. To simplify handling, we just fall through
+ // to issuing the call.
+ assert (false);
+ Thread.currentThread().interrupt();
+ } catch (ExecutionException e) {
+ // startRead is a SettableFuture so this should never occur.
+ throw new AssertionError("Should not have exception on startRead", e);
+ }
+ Map> pendingResponses =
+ new HashMap<>(batch.reads.size());
+ Map computationBuilders = new HashMap<>();
+ for (QueueEntry entry : batch.reads) {
+ ComputationGetDataRequest.Builder computationBuilder =
+ computationBuilders.computeIfAbsent(
+ entry.computation, k -> ComputationGetDataRequest.newBuilder().setComputationId(k));
+
+ computationBuilder.addRequests(entry.request);
+ pendingResponses.put(
+ WindmillComputationKey.create(
+ entry.computation, entry.request.getKey(), entry.request.getShardingKey()),
+ entry.response);
+ }
+
+ // Build the full GetDataRequest from the KeyedGetDataRequests pulled from the queue.
+ Windmill.GetDataRequest.Builder builder = Windmill.GetDataRequest.newBuilder();
+ for (ComputationGetDataRequest.Builder computationBuilder : computationBuilders.values()) {
+ builder.addRequests(computationBuilder);
+ }
+
+ try {
+ Windmill.GetDataResponse response = windmillClient.getData(builder.build());
+ // Dispatch the per-key responses back to the waiting threads.
+ for (Windmill.ComputationGetDataResponse computationResponse : response.getDataList()) {
+ for (Windmill.KeyedGetDataResponse keyResponse : computationResponse.getDataList()) {
+ pendingResponses
+ .get(
+ WindmillComputationKey.create(
+ computationResponse.getComputationId(),
+ keyResponse.getKey(),
+ keyResponse.getShardingKey()))
+ .set(keyResponse);
+ }
+ }
+ } catch (RuntimeException e) {
+ // Fan the exception out to the reads.
+ for (QueueEntry entry : batch.reads) {
+ entry.response.setException(e);
+ }
+ } finally {
+ synchronized (this) {
+ Preconditions.checkState(activeReadThreads >= 1);
+ if (pendingReadBatches.isEmpty()) {
+ activeReadThreads--;
+ } else {
+ // Notify the thread responsible for issuing the next batch read.
+ ReadBatch startBatch = pendingReadBatches.remove(0);
+ startBatch.startRead.set(null);
+ }
+ }
+ }
+ }
+
+ /**
+ * Adds the entry to a read batch for sending to the windmill server. If a non-null batch is
+ * returned, this thread will be responsible for sending the batch and should wait for the batch
+ * startRead to be notified. If null is returned, the entry was added to a read batch that will be
+ * issued by another thread.
+ */
+ private @Nullable ReadBatch addToReadBatch(QueueEntry entry) {
+ synchronized (this) {
+ ReadBatch batch;
+ if (activeReadThreads < MAX_ACTIVE_READS) {
+ assert (pendingReadBatches.isEmpty());
+ activeReadThreads += 1;
+ // fall through to below synchronized block
+ } else if (pendingReadBatches.isEmpty()
+ || pendingReadBatches.get(pendingReadBatches.size() - 1).reads.size()
+ >= MAX_READS_PER_BATCH) {
+ // This is the first read of a batch, it will be responsible for sending the batch.
+ batch = new ReadBatch();
+ pendingReadBatches.add(batch);
+ batch.reads.add(entry);
+ return batch;
+ } else {
+ // This fits within an existing batch, it will be sent by the first blocking thread in the
+ // batch.
+ pendingReadBatches.get(pendingReadBatches.size() - 1).reads.add(entry);
+ return null;
+ }
+ }
+ ReadBatch batch = new ReadBatch();
+ batch.reads.add(entry);
+ batch.startRead.set(null);
+ return batch;
+ }
+
+ private static final class ReadBatch {
+ ArrayList reads = new ArrayList<>();
+ SettableFuture startRead = SettableFuture.create();
+ }
+
+ private static final class QueueEntry {
+ final String computation;
+ final Windmill.KeyedGetDataRequest request;
+ final SettableFuture response;
+
+ QueueEntry(
+ String computation,
+ Windmill.KeyedGetDataRequest request,
+ SettableFuture response) {
+ this.computation = computation;
+ this.request = request;
+ this.response = response;
+ }
+ }
+}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/GetDataClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/GetDataClient.java
new file mode 100644
index 000000000000..c732591bf12d
--- /dev/null
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/GetDataClient.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.getdata;
+
+import java.io.PrintWriter;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse;
+import org.apache.beam.sdk.annotations.Internal;
+
+/** Client for streaming backend GetData API. */
+@Internal
+public interface GetDataClient {
+ /**
+ * Issues a blocking call to fetch state data for a specific computation and {@link
+ * org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem}.
+ *
+ * @throws GetDataException when there was an unexpected error during the attempted fetch.
+ */
+ KeyedGetDataResponse getStateData(String computationId, KeyedGetDataRequest request)
+ throws GetDataException;
+
+ /**
+ * Issues a blocking call to fetch side input data.
+ *
+ * @throws GetDataException when there was an unexpected error during the attempted fetch.
+ */
+ GlobalData getSideInputData(GlobalDataRequest request) throws GetDataException;
+
+ void printHtml(PrintWriter writer);
+
+ final class GetDataException extends RuntimeException {
+ GetDataException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+ GetDataException(String message) {
+ super(message);
+ }
+ }
+}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java
new file mode 100644
index 000000000000..c8e058e7e230
--- /dev/null
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.getdata;
+
+import java.io.PrintWriter;
+import java.util.function.Function;
+import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream;
+import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
+import org.apache.beam.sdk.annotations.Internal;
+
+/** {@link GetDataClient} that fetches data directly from a specific {@link GetDataStream}. */
+@Internal
+public final class StreamGetDataClient implements GetDataClient {
+
+ private final GetDataStream getDataStream;
+ private final Function sideInputGetDataStreamFactory;
+ private final ThrottlingGetDataMetricTracker getDataMetricTracker;
+
+ private StreamGetDataClient(
+ GetDataStream getDataStream,
+ Function sideInputGetDataStreamFactory,
+ ThrottlingGetDataMetricTracker getDataMetricTracker) {
+ this.getDataStream = getDataStream;
+ this.sideInputGetDataStreamFactory = sideInputGetDataStreamFactory;
+ this.getDataMetricTracker = getDataMetricTracker;
+ }
+
+ public static GetDataClient create(
+ GetDataStream getDataStream,
+ Function sideInputGetDataStreamFactory,
+ ThrottlingGetDataMetricTracker getDataMetricTracker) {
+ return new StreamGetDataClient(
+ getDataStream, sideInputGetDataStreamFactory, getDataMetricTracker);
+ }
+
+ /**
+ * @throws WorkItemCancelledException when the fetch fails due to the stream being shutdown,
+ * indicating that the {@link
+ * org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem} that triggered the
+ * fetch has been cancelled.
+ */
+ @Override
+ public Windmill.KeyedGetDataResponse getStateData(
+ String computationId, Windmill.KeyedGetDataRequest request) throws GetDataException {
+ try (AutoCloseable ignored = getDataMetricTracker.trackStateDataFetchWithThrottling()) {
+ return getDataStream.requestKeyedData(computationId, request);
+ } catch (AbstractWindmillStream.WindmillStreamShutdownException e) {
+ throw new WorkItemCancelledException(request.getShardingKey());
+ } catch (Exception e) {
+ throw new GetDataException(
+ "Error occurred fetching state for computation="
+ + computationId
+ + ", key="
+ + request.getShardingKey(),
+ e);
+ }
+ }
+
+ /**
+ * @throws WorkItemCancelledException when the fetch fails due to the stream being shutdown,
+ * indicating that the {@link
+ * org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem} that triggered the
+ * fetch has been cancelled.
+ */
+ @Override
+ public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request)
+ throws GetDataException {
+ GetDataStream sideInputGetDataStream =
+ sideInputGetDataStreamFactory.apply(request.getDataId().getTag());
+ try (AutoCloseable ignored = getDataMetricTracker.trackSideInputFetchWithThrottling()) {
+ return sideInputGetDataStream.requestGlobalData(request);
+ } catch (AbstractWindmillStream.WindmillStreamShutdownException e) {
+ throw new WorkItemCancelledException(e);
+ } catch (Exception e) {
+ throw new GetDataException(
+ "Error occurred fetching side input for tag=" + request.getDataId(), e);
+ }
+ }
+
+ @Override
+ public void printHtml(PrintWriter writer) {
+ getDataMetricTracker.printHtml(writer);
+ }
+}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamPoolGetDataClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamPoolGetDataClient.java
new file mode 100644
index 000000000000..49fe3e4bdc15
--- /dev/null
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamPoolGetDataClient.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.getdata;
+
+import java.io.PrintWriter;
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
+import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream;
+import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
+import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool;
+import org.apache.beam.sdk.annotations.Internal;
+
+/**
+ * StreamingEngine implementation of {@link GetDataClient}.
+ *
+ * @implNote Uses {@link WindmillStreamPool} to send requests.
+ */
+@Internal
+@ThreadSafe
+public final class StreamPoolGetDataClient implements GetDataClient {
+
+ private final WindmillStreamPool getDataStreamPool;
+ private final ThrottlingGetDataMetricTracker getDataMetricTracker;
+
+ public StreamPoolGetDataClient(
+ ThrottlingGetDataMetricTracker getDataMetricTracker,
+ WindmillStreamPool getDataStreamPool) {
+ this.getDataMetricTracker = getDataMetricTracker;
+ this.getDataStreamPool = getDataStreamPool;
+ }
+
+ @Override
+ public Windmill.KeyedGetDataResponse getStateData(
+ String computationId, KeyedGetDataRequest request) {
+ try (AutoCloseable ignored = getDataMetricTracker.trackStateDataFetchWithThrottling();
+ CloseableStream closeableStream = getDataStreamPool.getCloseableStream()) {
+ return closeableStream.stream().requestKeyedData(computationId, request);
+ } catch (Exception e) {
+ throw new GetDataException(
+ "Error occurred fetching state for computation="
+ + computationId
+ + ", key="
+ + request.getShardingKey(),
+ e);
+ }
+ }
+
+ @Override
+ public Windmill.GlobalData getSideInputData(GlobalDataRequest request) {
+ try (AutoCloseable ignored = getDataMetricTracker.trackSideInputFetchWithThrottling();
+ CloseableStream closeableStream = getDataStreamPool.getCloseableStream()) {
+ return closeableStream.stream().requestGlobalData(request);
+ } catch (Exception e) {
+ throw new GetDataException(
+ "Error occurred fetching side input for tag=" + request.getDataId(), e);
+ }
+ }
+
+ @Override
+ public void printHtml(PrintWriter writer) {
+ getDataMetricTracker.printHtml(writer);
+ }
+}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ThrottlingGetDataMetricTracker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ThrottlingGetDataMetricTracker.java
new file mode 100644
index 000000000000..6bb00292e29a
--- /dev/null
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ThrottlingGetDataMetricTracker.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.getdata;
+
+import com.google.auto.value.AutoValue;
+import java.io.PrintWriter;
+import java.util.concurrent.atomic.AtomicInteger;
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+
+/**
+ * Wraps GetData calls to track metrics for the number of in-flight requests and throttles requests
+ * when memory pressure is high.
+ */
+@Internal
+@ThreadSafe
+public final class ThrottlingGetDataMetricTracker {
+ private static final String GET_STATE_DATA_RESOURCE_CONTEXT = "GetStateData";
+ private static final String GET_SIDE_INPUT_RESOURCE_CONTEXT = "GetSideInputData";
+
+ private final MemoryMonitor gcThrashingMonitor;
+ private final AtomicInteger activeStateReads;
+ private final AtomicInteger activeSideInputs;
+ private final AtomicInteger activeHeartbeats;
+
+ public ThrottlingGetDataMetricTracker(MemoryMonitor gcThrashingMonitor) {
+ this.gcThrashingMonitor = gcThrashingMonitor;
+ this.activeStateReads = new AtomicInteger();
+ this.activeSideInputs = new AtomicInteger();
+ this.activeHeartbeats = new AtomicInteger();
+ }
+
+ /**
+ * Tracks a state data fetch. If there is memory pressure, may throttle requests. Returns an
+ * {@link AutoCloseable} that will decrement the metric after the call is finished.
+ */
+ AutoCloseable trackStateDataFetchWithThrottling() {
+ gcThrashingMonitor.waitForResources(GET_STATE_DATA_RESOURCE_CONTEXT);
+ activeStateReads.getAndIncrement();
+ return activeStateReads::getAndDecrement;
+ }
+
+ /**
+ * Tracks a side input fetch. If there is memory pressure, may throttle requests. Returns an
+ * {@link AutoCloseable} that will decrement the metric after the call is finished.
+ */
+ AutoCloseable trackSideInputFetchWithThrottling() {
+ gcThrashingMonitor.waitForResources(GET_SIDE_INPUT_RESOURCE_CONTEXT);
+ activeSideInputs.getAndIncrement();
+ return activeSideInputs::getAndDecrement;
+ }
+
+ /**
+ * Tracks heartbeat request metrics. Returns an {@link AutoCloseable} that will decrement the
+ * metric after the call is finished.
+ */
+ public AutoCloseable trackHeartbeats(int numHeartbeats) {
+ activeHeartbeats.getAndAdd(numHeartbeats);
+ return () -> activeHeartbeats.getAndAdd(-numHeartbeats);
+ }
+
+ public void printHtml(PrintWriter writer) {
+ writer.println("Active Fetches:");
+ writer.println(" Side Inputs: " + activeSideInputs.get());
+ writer.println(" State Reads: " + activeStateReads.get());
+ writer.println("Heartbeat Keys Active: " + activeHeartbeats.get());
+ }
+
+ @VisibleForTesting
+ ReadOnlySnapshot getMetricsSnapshot() {
+ return ReadOnlySnapshot.create(
+ activeSideInputs.get(), activeStateReads.get(), activeHeartbeats.get());
+ }
+
+ @VisibleForTesting
+ @AutoValue
+ abstract static class ReadOnlySnapshot {
+
+ private static ReadOnlySnapshot create(
+ int activeSideInputs, int activeStateReads, int activeHeartbeats) {
+ return new AutoValue_ThrottlingGetDataMetricTracker_ReadOnlySnapshot(
+ activeSideInputs, activeStateReads, activeHeartbeats);
+ }
+
+ abstract int activeSideInputs();
+
+ abstract int activeStateReads();
+
+ abstract int activeHeartbeats();
+ }
+}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java
index f9f579119d61..053843a8af25 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java
@@ -57,6 +57,7 @@ public final class GrpcCommitWorkStream
private final int streamingRpcBatchLimit;
private GrpcCommitWorkStream(
+ String backendWorkerToken,
Function, StreamObserver>
startCommitWorkRpcFn,
BackOff backoff,
@@ -68,11 +69,13 @@ private GrpcCommitWorkStream(
AtomicLong idGenerator,
int streamingRpcBatchLimit) {
super(
+ "CommitWorkStream",
startCommitWorkRpcFn,
backoff,
streamObserverFactory,
streamRegistry,
- logEveryNStreamFailures);
+ logEveryNStreamFailures,
+ backendWorkerToken);
pending = new ConcurrentHashMap<>();
this.idGenerator = idGenerator;
this.jobHeader = jobHeader;
@@ -81,6 +84,7 @@ private GrpcCommitWorkStream(
}
public static GrpcCommitWorkStream create(
+ String backendWorkerToken,
Function, StreamObserver>
startCommitWorkRpcFn,
BackOff backoff,
@@ -93,6 +97,7 @@ public static GrpcCommitWorkStream create(
int streamingRpcBatchLimit) {
GrpcCommitWorkStream commitWorkStream =
new GrpcCommitWorkStream(
+ backendWorkerToken,
startCommitWorkRpcFn,
backoff,
streamObserverFactory,
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java
index 6f4b5b7b33fb..58f72610e2d3 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java
@@ -39,10 +39,12 @@
import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory;
import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.sdk.util.BackOff;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
@@ -80,8 +82,9 @@ public final class GrpcDirectGetWorkStream
private final GetWorkRequest request;
private final WorkItemScheduler workItemScheduler;
private final ThrottleTimer getWorkThrottleTimer;
- private final Supplier getDataStream;
+ private final Supplier heartbeatSender;
private final Supplier workCommitter;
+ private final Supplier getDataClient;
/**
* Map of stream IDs to their buffers. Used to aggregate streaming gRPC response chunks as they
@@ -91,6 +94,7 @@ public final class GrpcDirectGetWorkStream
private final ConcurrentMap workItemBuffers;
private GrpcDirectGetWorkStream(
+ String backendWorkerToken,
Function<
StreamObserver,
StreamObserver>
@@ -101,25 +105,32 @@ private GrpcDirectGetWorkStream(
Set> streamRegistry,
int logEveryNStreamFailures,
ThrottleTimer getWorkThrottleTimer,
- Supplier getDataStream,
+ Supplier heartbeatSender,
+ Supplier getDataClient,
Supplier workCommitter,
WorkItemScheduler workItemScheduler) {
super(
- startGetWorkRpcFn, backoff, streamObserverFactory, streamRegistry, logEveryNStreamFailures);
+ "GetWorkStream",
+ startGetWorkRpcFn,
+ backoff,
+ streamObserverFactory,
+ streamRegistry,
+ logEveryNStreamFailures,
+ backendWorkerToken);
this.request = request;
this.getWorkThrottleTimer = getWorkThrottleTimer;
this.workItemScheduler = workItemScheduler;
this.workItemBuffers = new ConcurrentHashMap<>();
- // Use the same GetDataStream and CommitWorkStream instances to process all the work in this
- // stream.
- this.getDataStream = Suppliers.memoize(getDataStream::get);
+ this.heartbeatSender = Suppliers.memoize(heartbeatSender::get);
this.workCommitter = Suppliers.memoize(workCommitter::get);
+ this.getDataClient = Suppliers.memoize(getDataClient::get);
this.inFlightBudget = new AtomicReference<>(GetWorkBudget.noBudget());
this.nextBudgetAdjustment = new AtomicReference<>(GetWorkBudget.noBudget());
this.pendingResponseBudget = new AtomicReference<>(GetWorkBudget.noBudget());
}
public static GrpcDirectGetWorkStream create(
+ String backendWorkerToken,
Function<
StreamObserver,
StreamObserver>
@@ -130,11 +141,13 @@ public static GrpcDirectGetWorkStream create(
Set> streamRegistry,
int logEveryNStreamFailures,
ThrottleTimer getWorkThrottleTimer,
- Supplier getDataStream,
+ Supplier heartbeatSender,
+ Supplier getDataClient,
Supplier workCommitter,
WorkItemScheduler workItemScheduler) {
GrpcDirectGetWorkStream getWorkStream =
new GrpcDirectGetWorkStream(
+ backendWorkerToken,
startGetWorkRpcFn,
request,
backoff,
@@ -142,7 +155,8 @@ public static GrpcDirectGetWorkStream create(
streamRegistry,
logEveryNStreamFailures,
getWorkThrottleTimer,
- getDataStream,
+ heartbeatSender,
+ getDataClient,
workCommitter,
workItemScheduler);
getWorkStream.startStream();
@@ -327,7 +341,7 @@ private void runAndReset() {
private Work.ProcessingContext createProcessingContext(String computationId) {
return Work.createProcessingContext(
- computationId, getDataStream.get()::requestKeyedData, workCommitter.get()::commit);
+ computationId, getDataClient.get(), workCommitter.get()::commit, heartbeatSender.get());
}
}
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
index feb15c2ac83c..0e9a0c6316ee 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
@@ -23,6 +23,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
+import java.util.Collection;
import java.util.Deque;
import java.util.List;
import java.util.Map;
@@ -75,6 +76,7 @@ public final class GrpcGetDataStream
private final Consumer> processHeartbeatResponses;
private GrpcGetDataStream(
+ String backendWorkerToken,
Function, StreamObserver>
startGetDataRpcFn,
BackOff backoff,
@@ -88,7 +90,13 @@ private GrpcGetDataStream(
boolean sendKeyedGetDataRequests,
Consumer> processHeartbeatResponses) {
super(
- startGetDataRpcFn, backoff, streamObserverFactory, streamRegistry, logEveryNStreamFailures);
+ "GetDataStream",
+ startGetDataRpcFn,
+ backoff,
+ streamObserverFactory,
+ streamRegistry,
+ logEveryNStreamFailures,
+ backendWorkerToken);
this.idGenerator = idGenerator;
this.getDataThrottleTimer = getDataThrottleTimer;
this.jobHeader = jobHeader;
@@ -100,6 +108,7 @@ private GrpcGetDataStream(
}
public static GrpcGetDataStream create(
+ String backendWorkerToken,
Function, StreamObserver>
startGetDataRpcFn,
BackOff backoff,
@@ -114,6 +123,7 @@ public static GrpcGetDataStream create(
Consumer> processHeartbeatResponses) {
GrpcGetDataStream getDataStream =
new GrpcGetDataStream(
+ backendWorkerToken,
startGetDataRpcFn,
backoff,
streamObserverFactory,
@@ -189,11 +199,15 @@ public GlobalData requestGlobalData(GlobalDataRequest request) {
}
@Override
- public void refreshActiveWork(Map> heartbeats) {
+ public void refreshActiveWork(Map> heartbeats) {
+ if (isShutdown()) {
+ throw new WindmillStreamShutdownException("Unable to refresh work for shutdown stream.");
+ }
+
StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder();
if (sendKeyedGetDataRequests) {
long builderBytes = 0;
- for (Map.Entry> entry : heartbeats.entrySet()) {
+ for (Map.Entry> entry : heartbeats.entrySet()) {
for (HeartbeatRequest request : entry.getValue()) {
// Calculate the bytes with some overhead for proto encoding.
long bytes = (long) entry.getKey().length() + request.getSerializedSize() + 10;
@@ -224,7 +238,7 @@ public void refreshActiveWork(Map> heartbeats) {
} else {
// No translation necessary, but we must still respect `RPC_STREAM_CHUNK_SIZE`.
long builderBytes = 0;
- for (Map.Entry> entry : heartbeats.entrySet()) {
+ for (Map.Entry> entry : heartbeats.entrySet()) {
ComputationHeartbeatRequest.Builder computationHeartbeatBuilder =
ComputationHeartbeatRequest.newBuilder().setComputationId(entry.getKey());
for (HeartbeatRequest request : entry.getValue()) {
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java
index 867180fb0d31..4b392e9190ed 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java
@@ -60,6 +60,7 @@ public final class GrpcGetWorkStream
private final AtomicLong inflightBytes;
private GrpcGetWorkStream(
+ String backendWorkerToken,
Function<
StreamObserver,
StreamObserver>
@@ -72,7 +73,13 @@ private GrpcGetWorkStream(
ThrottleTimer getWorkThrottleTimer,
WorkItemReceiver receiver) {
super(
- startGetWorkRpcFn, backoff, streamObserverFactory, streamRegistry, logEveryNStreamFailures);
+ "GetWorkStream",
+ startGetWorkRpcFn,
+ backoff,
+ streamObserverFactory,
+ streamRegistry,
+ logEveryNStreamFailures,
+ backendWorkerToken);
this.request = request;
this.getWorkThrottleTimer = getWorkThrottleTimer;
this.receiver = receiver;
@@ -82,6 +89,7 @@ private GrpcGetWorkStream(
}
public static GrpcGetWorkStream create(
+ String backendWorkerToken,
Function<
StreamObserver,
StreamObserver>
@@ -95,6 +103,7 @@ public static GrpcGetWorkStream create(
WorkItemReceiver receiver) {
GrpcGetWorkStream getWorkStream =
new GrpcGetWorkStream(
+ backendWorkerToken,
startGetWorkRpcFn,
request,
backoff,
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java
index 3672f02c813f..44e21a9b18ed 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java
@@ -65,11 +65,13 @@ private GrpcGetWorkerMetadataStream(
ThrottleTimer getWorkerMetadataThrottleTimer,
Consumer serverMappingConsumer) {
super(
+ "GetWorkerMetadataStream",
startGetWorkerMetadataRpcFn,
backoff,
streamObserverFactory,
streamRegistry,
- logEveryNStreamFailures);
+ logEveryNStreamFailures,
+ "");
this.workerMetadataRequest = WorkerMetadataRequest.newBuilder().setHeader(jobHeader).build();
this.metadataVersion = metadataVersion;
this.getWorkerMetadataThrottleTimer = getWorkerMetadataThrottleTimer;
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java
index 0ab03a803180..1fce4d238b2e 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java
@@ -254,11 +254,6 @@ public void setWindmillServiceEndpoints(Set endpoints) {
dispatcherClient.consumeWindmillDispatcherEndpoints(ImmutableSet.copyOf(endpoints));
}
- @Override
- public boolean isReady() {
- return dispatcherClient.hasInitializedEndpoints();
- }
-
private synchronized void initializeLocalHost(int port) {
this.maxBackoff = Duration.millis(500);
if (options.isEnableStreamingEngine()) {
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
index 14866f3f586b..92f031db9972 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
@@ -37,6 +37,7 @@
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints;
import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
@@ -44,10 +45,12 @@
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkerMetadataStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory;
import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemReceiver;
import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
+import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.sdk.util.BackOff;
import org.apache.beam.sdk.util.FluentBackoff;
@@ -69,6 +72,7 @@ public class GrpcWindmillStreamFactory implements StatusDataProvider {
private static final int DEFAULT_STREAMING_RPC_BATCH_LIMIT = Integer.MAX_VALUE;
private static final int DEFAULT_WINDMILL_MESSAGES_BETWEEN_IS_READY_CHECKS = 1;
private static final int NO_HEALTH_CHECKS = -1;
+ private static final String NO_BACKEND_WORKER_TOKEN = "";
private final JobHeader jobHeader;
private final int logEveryNStreamFailures;
@@ -179,6 +183,7 @@ public GetWorkStream createGetWorkStream(
ThrottleTimer getWorkThrottleTimer,
WorkItemReceiver processWorkItem) {
return GrpcGetWorkStream.create(
+ NO_BACKEND_WORKER_TOKEN,
responseObserver -> withDefaultDeadline(stub).getWorkStream(responseObserver),
request,
grpcBackOff.get(),
@@ -190,21 +195,24 @@ public GetWorkStream createGetWorkStream(
}
public GetWorkStream createDirectGetWorkStream(
- CloudWindmillServiceV1Alpha1Stub stub,
+ WindmillConnection connection,
GetWorkRequest request,
ThrottleTimer getWorkThrottleTimer,
- Supplier getDataStream,
+ Supplier heartbeatSender,
+ Supplier getDataClient,
Supplier workCommitter,
WorkItemScheduler workItemScheduler) {
return GrpcDirectGetWorkStream.create(
- responseObserver -> withDefaultDeadline(stub).getWorkStream(responseObserver),
+ connection.backendWorkerToken(),
+ responseObserver -> withDefaultDeadline(connection.stub()).getWorkStream(responseObserver),
request,
grpcBackOff.get(),
newStreamObserverFactory(),
streamRegistry,
logEveryNStreamFailures,
getWorkThrottleTimer,
- getDataStream,
+ heartbeatSender,
+ getDataClient,
workCommitter,
workItemScheduler);
}
@@ -212,6 +220,7 @@ public GetWorkStream createDirectGetWorkStream(
public GetDataStream createGetDataStream(
CloudWindmillServiceV1Alpha1Stub stub, ThrottleTimer getDataThrottleTimer) {
return GrpcGetDataStream.create(
+ NO_BACKEND_WORKER_TOKEN,
responseObserver -> withDefaultDeadline(stub).getDataStream(responseObserver),
grpcBackOff.get(),
newStreamObserverFactory(),
@@ -228,6 +237,7 @@ public GetDataStream createGetDataStream(
public CommitWorkStream createCommitWorkStream(
CloudWindmillServiceV1Alpha1Stub stub, ThrottleTimer commitWorkThrottleTimer) {
return GrpcCommitWorkStream.create(
+ NO_BACKEND_WORKER_TOKEN,
responseObserver -> withDefaultDeadline(stub).commitWorkStream(responseObserver),
grpcBackOff.get(),
newStreamObserverFactory(),
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java
index 4760062c5754..b9573ff94cc9 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java
@@ -45,6 +45,8 @@
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkerMetadataStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.StreamGetDataClient;
+import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.ThrottlingGetDataMetricTracker;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.ChannelCachingStubFactory;
import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
@@ -91,6 +93,7 @@ public final class StreamingEngineClient {
private final Supplier getWorkerMetadataStream;
private final Queue newWindmillEndpoints;
private final Function workCommitterFactory;
+ private final ThrottlingGetDataMetricTracker getDataMetricTracker;
/** Writes are guarded by synchronization, reads are lock free. */
private final AtomicReference connections;
@@ -107,8 +110,10 @@ private StreamingEngineClient(
GetWorkBudgetDistributor getWorkBudgetDistributor,
GrpcDispatcherClient dispatcherClient,
long clientId,
- Function workCommitterFactory) {
+ Function workCommitterFactory,
+ ThrottlingGetDataMetricTracker getDataMetricTracker) {
this.jobHeader = jobHeader;
+ this.getDataMetricTracker = getDataMetricTracker;
this.started = false;
this.streamFactory = streamFactory;
this.workItemScheduler = workItemScheduler;
@@ -171,7 +176,8 @@ public static StreamingEngineClient create(
ChannelCachingStubFactory channelCachingStubFactory,
GetWorkBudgetDistributor getWorkBudgetDistributor,
GrpcDispatcherClient dispatcherClient,
- Function workCommitterFactory) {
+ Function workCommitterFactory,
+ ThrottlingGetDataMetricTracker getDataMetricTracker) {
return new StreamingEngineClient(
jobHeader,
totalGetWorkBudget,
@@ -181,7 +187,8 @@ public static StreamingEngineClient create(
getWorkBudgetDistributor,
dispatcherClient,
/* clientId= */ new Random().nextLong(),
- workCommitterFactory);
+ workCommitterFactory,
+ getDataMetricTracker);
}
@VisibleForTesting
@@ -194,7 +201,8 @@ static StreamingEngineClient forTesting(
GetWorkBudgetDistributor getWorkBudgetDistributor,
GrpcDispatcherClient dispatcherClient,
long clientId,
- Function workCommitterFactory) {
+ Function workCommitterFactory,
+ ThrottlingGetDataMetricTracker getDataMetricTracker) {
StreamingEngineClient streamingEngineClient =
new StreamingEngineClient(
jobHeader,
@@ -205,7 +213,8 @@ static StreamingEngineClient forTesting(
getWorkBudgetDistributor,
dispatcherClient,
clientId,
- workCommitterFactory);
+ workCommitterFactory,
+ getDataMetricTracker);
streamingEngineClient.start();
return streamingEngineClient;
}
@@ -240,7 +249,7 @@ public ImmutableSet currentWindmillEndpoints() {
* Fetches {@link GetDataStream} mapped to globalDataKey if one exists, or defaults to {@link
* GetDataStream} pointing to dispatcher.
*/
- public GetDataStream getGlobalDataStream(String globalDataKey) {
+ private GetDataStream getGlobalDataStream(String globalDataKey) {
return Optional.ofNullable(connections.get().globalDataStreams().get(globalDataKey))
.map(Supplier::get)
.orElseGet(
@@ -263,7 +272,7 @@ private void startWorkerMetadataConsumer() {
@VisibleForTesting
public synchronized void finish() {
Preconditions.checkState(started, "StreamingEngineClient never started.");
- getWorkerMetadataStream.get().close();
+ getWorkerMetadataStream.get().halfClose();
getWorkBudgetRefresher.stop();
newWorkerMetadataPublisher.shutdownNow();
newWorkerMetadataConsumer.shutdownNow();
@@ -390,7 +399,7 @@ private WindmillStreamSender createAndStartWindmillStreamSenderFor(
// GetWorkBudgetDistributor.
WindmillStreamSender windmillStreamSender =
WindmillStreamSender.create(
- connection.stub(),
+ connection,
GetWorkRequest.newBuilder()
.setClientId(clientId)
.setJobId(jobHeader.getJobId())
@@ -400,6 +409,9 @@ private WindmillStreamSender createAndStartWindmillStreamSenderFor(
GetWorkBudget.noBudget(),
streamFactory,
workItemScheduler,
+ getDataStream ->
+ StreamGetDataClient.create(
+ getDataStream, this::getGlobalDataStream, getDataMetricTracker),
workCommitterFactory);
windmillStreamSender.startStreams();
return windmillStreamSender;
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSender.java
index e9f008eb522e..7d09726e4b28 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSender.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSender.java
@@ -22,15 +22,17 @@
import java.util.function.Function;
import java.util.function.Supplier;
import javax.annotation.concurrent.ThreadSafe;
-import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.StreamingEngineThrottleTimers;
import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.FixedStreamHeartbeatSender;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;
@@ -65,11 +67,12 @@ public class WindmillStreamSender {
private final StreamingEngineThrottleTimers streamingEngineThrottleTimers;
private WindmillStreamSender(
- CloudWindmillServiceV1Alpha1Stub stub,
+ WindmillConnection connection,
GetWorkRequest getWorkRequest,
AtomicReference getWorkBudget,
GrpcWindmillStreamFactory streamingEngineStreamFactory,
WorkItemScheduler workItemScheduler,
+ Function getDataClientFactory,
Function workCommitterFactory) {
this.started = new AtomicBoolean(false);
this.getWorkBudget = getWorkBudget;
@@ -83,39 +86,42 @@ private WindmillStreamSender(
Suppliers.memoize(
() ->
streamingEngineStreamFactory.createGetDataStream(
- stub, streamingEngineThrottleTimers.getDataThrottleTimer()));
+ connection.stub(), streamingEngineThrottleTimers.getDataThrottleTimer()));
this.commitWorkStream =
Suppliers.memoize(
() ->
streamingEngineStreamFactory.createCommitWorkStream(
- stub, streamingEngineThrottleTimers.commitWorkThrottleTimer()));
+ connection.stub(), streamingEngineThrottleTimers.commitWorkThrottleTimer()));
this.workCommitter =
Suppliers.memoize(() -> workCommitterFactory.apply(commitWorkStream.get()));
this.getWorkStream =
Suppliers.memoize(
() ->
streamingEngineStreamFactory.createDirectGetWorkStream(
- stub,
+ connection,
withRequestBudget(getWorkRequest, getWorkBudget.get()),
streamingEngineThrottleTimers.getWorkThrottleTimer(),
- getDataStream,
+ () -> FixedStreamHeartbeatSender.create(getDataStream.get()),
+ () -> getDataClientFactory.apply(getDataStream.get()),
workCommitter,
workItemScheduler));
}
public static WindmillStreamSender create(
- CloudWindmillServiceV1Alpha1Stub stub,
+ WindmillConnection connection,
GetWorkRequest getWorkRequest,
GetWorkBudget getWorkBudget,
GrpcWindmillStreamFactory streamingEngineStreamFactory,
WorkItemScheduler workItemScheduler,
+ Function getDataClientFactory,
Function workCommitterFactory) {
return new WindmillStreamSender(
- stub,
+ connection,
getWorkRequest,
new AtomicReference<>(getWorkBudget),
streamingEngineStreamFactory,
workItemScheduler,
+ getDataClientFactory,
workCommitterFactory);
}
@@ -138,10 +144,10 @@ void closeAllStreams() {
// streaming RPCs by possibly making calls over the network. Do not close the streams unless
// they have already been started.
if (started.get()) {
- getWorkStream.get().close();
- getDataStream.get().close();
+ getWorkStream.get().shutdown();
+ getDataStream.get().shutdown();
workCommitter.get().stop();
- commitWorkStream.get().close();
+ commitWorkStream.get().shutdown();
}
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverCancelledException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverCancelledException.java
new file mode 100644
index 000000000000..4ea209f31b1d
--- /dev/null
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverCancelledException.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers;
+
+import org.apache.beam.sdk.annotations.Internal;
+
+@Internal
+public final class StreamObserverCancelledException extends RuntimeException {
+ public StreamObserverCancelledException(Throwable cause) {
+ super(cause);
+ }
+
+ public StreamObserverCancelledException(String message, Throwable cause) {
+ super(message, cause);
+ }
+}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java
index e9ffa982925b..b0b6377dd8b1 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java
@@ -44,6 +44,7 @@
import org.apache.beam.runners.dataflow.worker.streaming.Work;
import org.apache.beam.runners.dataflow.worker.streaming.harness.StreamingCounters;
import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher;
+import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory;
import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commit;
@@ -74,7 +75,7 @@ public final class StreamingWorkScheduler {
private final DataflowWorkerHarnessOptions options;
private final Supplier clock;
private final ComputationWorkExecutorFactory computationWorkExecutorFactory;
- private final SideInputStateFetcher sideInputStateFetcher;
+ private final SideInputStateFetcherFactory sideInputStateFetcherFactory;
private final FailureTracker failureTracker;
private final WorkFailureProcessor workFailureProcessor;
private final StreamingCommitFinalizer commitFinalizer;
@@ -88,7 +89,7 @@ public StreamingWorkScheduler(
DataflowWorkerHarnessOptions options,
Supplier clock,
ComputationWorkExecutorFactory computationWorkExecutorFactory,
- SideInputStateFetcher sideInputStateFetcher,
+ SideInputStateFetcherFactory sideInputStateFetcherFactory,
FailureTracker failureTracker,
WorkFailureProcessor workFailureProcessor,
StreamingCommitFinalizer commitFinalizer,
@@ -100,7 +101,7 @@ public StreamingWorkScheduler(
this.options = options;
this.clock = clock;
this.computationWorkExecutorFactory = computationWorkExecutorFactory;
- this.sideInputStateFetcher = sideInputStateFetcher;
+ this.sideInputStateFetcherFactory = sideInputStateFetcherFactory;
this.failureTracker = failureTracker;
this.workFailureProcessor = workFailureProcessor;
this.commitFinalizer = commitFinalizer;
@@ -118,7 +119,6 @@ public static StreamingWorkScheduler create(
DataflowMapTaskExecutorFactory mapTaskExecutorFactory,
BoundedQueueExecutor workExecutor,
Function stateCacheFactory,
- Function fetchGlobalDataFn,
FailureTracker failureTracker,
WorkFailureProcessor workFailureProcessor,
StreamingCounters streamingCounters,
@@ -141,7 +141,7 @@ public static StreamingWorkScheduler create(
options,
clock,
computationWorkExecutorFactory,
- new SideInputStateFetcher(fetchGlobalDataFn, options),
+ SideInputStateFetcherFactory.fromOptions(options),
failureTracker,
workFailureProcessor,
StreamingCommitFinalizer.create(workExecutor),
@@ -348,7 +348,8 @@ private ExecuteWorkResult executeWork(
try {
WindmillStateReader stateReader = work.createWindmillStateReader();
- SideInputStateFetcher localSideInputStateFetcher = sideInputStateFetcher.byteTrackingView();
+ SideInputStateFetcher localSideInputStateFetcher =
+ sideInputStateFetcherFactory.createSideInputStateFetcher(work::fetchSideInput);
// If the read output KVs, then we can decode Windmill's byte key into userland
// key object and provide it to the execution context for use with per-key state.
@@ -403,8 +404,7 @@ private ExecuteWorkResult executeWork(
computationState.releaseComputationWorkExecutor(computationWorkExecutor);
work.setState(Work.State.COMMIT_QUEUED);
- outputBuilder.addAllPerWorkItemLatencyAttributions(
- work.getLatencyAttributions(false, sampler));
+ outputBuilder.addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions(sampler));
return ExecuteWorkResult.create(
outputBuilder, stateReader.getBytesRead() + localSideInputStateFetcher.getBytesRead());
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresher.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresher.java
index 96a6feec1da0..499d2e5b6943 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresher.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresher.java
@@ -17,13 +17,26 @@
*/
package org.apache.beam.runners.dataflow.worker.windmill.work.refresh;
+import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
+
+import java.util.ArrayList;
import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
+import javax.annotation.Nullable;
import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler;
import org.apache.beam.runners.dataflow.worker.streaming.ComputationState;
+import org.apache.beam.runners.dataflow.worker.streaming.RefreshableWork;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.slf4j.Logger;
@@ -37,29 +50,39 @@
* threshold is determined by {@link #activeWorkRefreshPeriodMillis}
*/
@ThreadSafe
-public abstract class ActiveWorkRefresher {
+@Internal
+public final class ActiveWorkRefresher {
private static final Logger LOG = LoggerFactory.getLogger(ActiveWorkRefresher.class);
+ private static final String FAN_OUT_REFRESH_WORK_EXECUTOR_NAME =
+ "FanOutActiveWorkRefreshExecutor-%d";
- protected final Supplier clock;
- protected final int activeWorkRefreshPeriodMillis;
- protected final Supplier> computations;
- protected final DataflowExecutionStateSampler sampler;
+ private final Supplier clock;
+ private final int activeWorkRefreshPeriodMillis;
+ private final Supplier> computations;
+ private final DataflowExecutionStateSampler sampler;
private final int stuckCommitDurationMillis;
+ private final HeartbeatTracker heartbeatTracker;
private final ScheduledExecutorService activeWorkRefreshExecutor;
+ private final ExecutorService fanOutActiveWorkRefreshExecutor;
- protected ActiveWorkRefresher(
+ public ActiveWorkRefresher(
Supplier clock,
int activeWorkRefreshPeriodMillis,
int stuckCommitDurationMillis,
Supplier> computations,
DataflowExecutionStateSampler sampler,
- ScheduledExecutorService activeWorkRefreshExecutor) {
+ ScheduledExecutorService activeWorkRefreshExecutor,
+ HeartbeatTracker heartbeatTracker) {
this.clock = clock;
this.activeWorkRefreshPeriodMillis = activeWorkRefreshPeriodMillis;
this.stuckCommitDurationMillis = stuckCommitDurationMillis;
this.computations = computations;
this.sampler = sampler;
this.activeWorkRefreshExecutor = activeWorkRefreshExecutor;
+ this.heartbeatTracker = heartbeatTracker;
+ this.fanOutActiveWorkRefreshExecutor =
+ Executors.newCachedThreadPool(
+ new ThreadFactoryBuilder().setNameFormat(FAN_OUT_REFRESH_WORK_EXECUTOR_NAME).build());
}
@SuppressWarnings("FutureReturnValueIgnored")
@@ -103,5 +126,67 @@ private void invalidateStuckCommits() {
}
}
- protected abstract void refreshActiveWork();
+ private void refreshActiveWork() {
+ Instant refreshDeadline = clock.get().minus(Duration.millis(activeWorkRefreshPeriodMillis));
+ Map heartbeatsBySender =
+ aggregateHeartbeatsBySender(refreshDeadline);
+
+ List> fanOutRefreshActiveWork = new ArrayList<>();
+
+ // Send the first heartbeat on the calling thread, and fan out the rest via the
+ // fanOutActiveWorkRefreshExecutor.
+ @Nullable Map.Entry firstHeartbeat = null;
+ for (Map.Entry heartbeat : heartbeatsBySender.entrySet()) {
+ if (firstHeartbeat == null) {
+ firstHeartbeat = heartbeat;
+ } else {
+ fanOutRefreshActiveWork.add(
+ CompletableFuture.runAsync(
+ () -> sendHeartbeatSafely(heartbeat), fanOutActiveWorkRefreshExecutor));
+ }
+ }
+
+ sendHeartbeatSafely(firstHeartbeat);
+ fanOutRefreshActiveWork.forEach(CompletableFuture::join);
+ }
+
+ /** Aggregate the heartbeats across computations by HeartbeatSender for correct fan out. */
+ private Map aggregateHeartbeatsBySender(Instant refreshDeadline) {
+ Map heartbeatsBySender = new HashMap<>();
+
+ // Aggregate the heartbeats across computations by HeartbeatSender for correct fan out.
+ for (ComputationState computationState : computations.get()) {
+ for (RefreshableWork work : computationState.getRefreshableWork(refreshDeadline)) {
+ heartbeatsBySender
+ .computeIfAbsent(work.heartbeatSender(), ignored -> Heartbeats.builder())
+ .add(computationState.getComputationId(), work, sampler);
+ }
+ }
+
+ return heartbeatsBySender.entrySet().stream()
+ .collect(toImmutableMap(Map.Entry::getKey, e -> e.getValue().build()));
+ }
+
+ /**
+ * Send the {@link Heartbeats} using the {@link HeartbeatSender}. Safe since exceptions are caught
+ * and logged.
+ */
+ private void sendHeartbeatSafely(Map.Entry heartbeat) {
+ try (AutoCloseable ignored = heartbeatTracker.trackHeartbeats(heartbeat.getValue().size())) {
+ HeartbeatSender sender = heartbeat.getKey();
+ Heartbeats heartbeats = heartbeat.getValue();
+ sender.sendHeartbeats(heartbeats);
+ } catch (Exception e) {
+ LOG.error(
+ "Unable to send {} heartbeats to {}.",
+ heartbeat.getValue().size(),
+ heartbeat.getKey(),
+ e);
+ }
+ }
+
+ @FunctionalInterface
+ public interface HeartbeatTracker {
+ AutoCloseable trackHeartbeats(int numHeartbeats);
+ }
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefreshers.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefreshers.java
deleted file mode 100644
index 5a59a7f1ae01..000000000000
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefreshers.java
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.dataflow.worker.windmill.work.refresh;
-
-import java.util.Collection;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.ScheduledExecutorService;
-import java.util.function.Consumer;
-import java.util.function.Supplier;
-import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler;
-import org.apache.beam.runners.dataflow.worker.streaming.ComputationState;
-import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
-import org.joda.time.Instant;
-
-/** Utility class for {@link ActiveWorkRefresher}. */
-public final class ActiveWorkRefreshers {
- public static ActiveWorkRefresher createDispatchedActiveWorkRefresher(
- Supplier clock,
- int activeWorkRefreshPeriodMillis,
- int stuckCommitDurationMillis,
- Supplier> computations,
- DataflowExecutionStateSampler sampler,
- Consumer