diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java deleted file mode 100644 index d808d4f4ab58..000000000000 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java +++ /dev/null @@ -1,355 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.dataflow.worker; - -import com.google.auto.value.AutoBuilder; -import java.io.PrintWriter; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.atomic.AtomicInteger; -import javax.annotation.concurrent.GuardedBy; -import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest; -import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub; -import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; -import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool; -import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.SettableFuture; -import org.checkerframework.checker.nullness.qual.Nullable; -import org.joda.time.Duration; - -/** - * Wrapper around a {@link WindmillServerStub} that tracks metrics for the number of in-flight - * requests and throttles requests when memory pressure is high. - * - *

External API: individual worker threads request state for their computation via {@link - * #getStateData}. However, requests are either issued using a pool of streaming rpcs or possibly - * batched requests. - */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -public class MetricTrackingWindmillServerStub { - - private static final int MAX_READS_PER_BATCH = 60; - private static final int MAX_ACTIVE_READS = 10; - private static final Duration STREAM_TIMEOUT = Duration.standardSeconds(30); - private final AtomicInteger activeSideInputs = new AtomicInteger(); - private final AtomicInteger activeStateReads = new AtomicInteger(); - private final AtomicInteger activeHeartbeats = new AtomicInteger(); - private final WindmillServerStub server; - private final MemoryMonitor gcThrashingMonitor; - private final boolean useStreamingRequests; - - private final WindmillStreamPool getDataStreamPool; - - // This may be the same instance as getDataStreamPool based upon options. - private final WindmillStreamPool heartbeatStreamPool; - - @GuardedBy("this") - private final List pendingReadBatches; - - @GuardedBy("this") - private int activeReadThreads = 0; - - @Internal - @AutoBuilder(ofClass = MetricTrackingWindmillServerStub.class) - public abstract static class Builder { - - abstract Builder setServer(WindmillServerStub server); - - abstract Builder setGcThrashingMonitor(MemoryMonitor gcThrashingMonitor); - - abstract Builder setUseStreamingRequests(boolean useStreamingRequests); - - abstract Builder setUseSeparateHeartbeatStreams(boolean useSeparateHeartbeatStreams); - - abstract Builder setNumGetDataStreams(int numGetDataStreams); - - abstract MetricTrackingWindmillServerStub build(); - } - - public static Builder builder(WindmillServerStub server, MemoryMonitor gcThrashingMonitor) { - return new AutoBuilder_MetricTrackingWindmillServerStub_Builder() - .setServer(server) - .setGcThrashingMonitor(gcThrashingMonitor) - .setUseStreamingRequests(false) - .setUseSeparateHeartbeatStreams(false) - .setNumGetDataStreams(1); - } - - MetricTrackingWindmillServerStub( - WindmillServerStub server, - MemoryMonitor gcThrashingMonitor, - boolean useStreamingRequests, - boolean useSeparateHeartbeatStreams, - int numGetDataStreams) { - this.server = server; - this.gcThrashingMonitor = gcThrashingMonitor; - this.useStreamingRequests = useStreamingRequests; - if (useStreamingRequests) { - getDataStreamPool = - WindmillStreamPool.create( - Math.max(1, numGetDataStreams), STREAM_TIMEOUT, this.server::getDataStream); - if (useSeparateHeartbeatStreams) { - heartbeatStreamPool = - WindmillStreamPool.create(1, STREAM_TIMEOUT, this.server::getDataStream); - } else { - heartbeatStreamPool = getDataStreamPool; - } - } else { - getDataStreamPool = heartbeatStreamPool = null; - } - // This is used as a queue but is expected to be less than 10 batches. - this.pendingReadBatches = new ArrayList<>(); - } - - // Adds the entry to a read batch for sending to the windmill server. If a non-null batch is - // returned, this thread will be responsible for sending the batch and should wait for the batch - // startRead to be notified. - // If null is returned, the entry was added to a read batch that will be issued by another thread. - private @Nullable ReadBatch addToReadBatch(QueueEntry entry) { - synchronized (this) { - ReadBatch batch; - if (activeReadThreads < MAX_ACTIVE_READS) { - assert (pendingReadBatches.isEmpty()); - activeReadThreads += 1; - // fall through to below synchronized block - } else if (pendingReadBatches.isEmpty() - || pendingReadBatches.get(pendingReadBatches.size() - 1).reads.size() - >= MAX_READS_PER_BATCH) { - // This is the first read of a batch, it will be responsible for sending the batch. - batch = new ReadBatch(); - pendingReadBatches.add(batch); - batch.reads.add(entry); - return batch; - } else { - // This fits within an existing batch, it will be sent by the first blocking thread in the - // batch. - pendingReadBatches.get(pendingReadBatches.size() - 1).reads.add(entry); - return null; - } - } - ReadBatch batch = new ReadBatch(); - batch.reads.add(entry); - batch.startRead.set(true); - return batch; - } - - private void issueReadBatch(ReadBatch batch) { - try { - boolean read = batch.startRead.get(); - assert (read); - } catch (InterruptedException e) { - // We don't expect this thread to be interrupted. To simplify handling, we just fall through - // to issuing - // the call. - assert (false); - Thread.currentThread().interrupt(); - } catch (ExecutionException e) { - // startRead is a SettableFuture so this should never occur. - throw new AssertionError("Should not have exception on startRead", e); - } - Map> pendingResponses = - new HashMap<>(batch.reads.size()); - Map computationBuilders = new HashMap<>(); - for (QueueEntry entry : batch.reads) { - Windmill.ComputationGetDataRequest.Builder computationBuilder = - computationBuilders.computeIfAbsent( - entry.computation, - k -> Windmill.ComputationGetDataRequest.newBuilder().setComputationId(k)); - - computationBuilder.addRequests(entry.request); - pendingResponses.put( - WindmillComputationKey.create( - entry.computation, entry.request.getKey(), entry.request.getShardingKey()), - entry.response); - } - - // Build the full GetDataRequest from the KeyedGetDataRequests pulled from the queue. - Windmill.GetDataRequest.Builder builder = Windmill.GetDataRequest.newBuilder(); - for (Windmill.ComputationGetDataRequest.Builder computationBuilder : - computationBuilders.values()) { - builder.addRequests(computationBuilder); - } - - try { - Windmill.GetDataResponse response = server.getData(builder.build()); - - // Dispatch the per-key responses back to the waiting threads. - for (Windmill.ComputationGetDataResponse computationResponse : response.getDataList()) { - for (Windmill.KeyedGetDataResponse keyResponse : computationResponse.getDataList()) { - pendingResponses - .get( - WindmillComputationKey.create( - computationResponse.getComputationId(), - keyResponse.getKey(), - keyResponse.getShardingKey())) - .set(keyResponse); - } - } - } catch (RuntimeException e) { - // Fan the exception out to the reads. - for (QueueEntry entry : batch.reads) { - entry.response.setException(e); - } - } finally { - synchronized (this) { - assert (activeReadThreads >= 1); - if (pendingReadBatches.isEmpty()) { - activeReadThreads--; - } else { - // Notify the thread responsible for issuing the next batch read. - ReadBatch startBatch = pendingReadBatches.remove(0); - startBatch.startRead.set(true); - } - } - } - } - - public Windmill.KeyedGetDataResponse getStateData( - String computation, Windmill.KeyedGetDataRequest request) { - gcThrashingMonitor.waitForResources("GetStateData"); - activeStateReads.getAndIncrement(); - - try { - if (useStreamingRequests) { - GetDataStream stream = getDataStreamPool.getStream(); - try { - return stream.requestKeyedData(computation, request); - } finally { - getDataStreamPool.releaseStream(stream); - } - } else { - SettableFuture response = SettableFuture.create(); - ReadBatch batch = addToReadBatch(new QueueEntry(computation, request, response)); - if (batch != null) { - issueReadBatch(batch); - } - return response.get(); - } - } catch (Exception e) { - throw new RuntimeException(e); - } finally { - activeStateReads.getAndDecrement(); - } - } - - public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) { - gcThrashingMonitor.waitForResources("GetSideInputData"); - activeSideInputs.getAndIncrement(); - try { - if (useStreamingRequests) { - GetDataStream stream = getDataStreamPool.getStream(); - try { - return stream.requestGlobalData(request); - } finally { - getDataStreamPool.releaseStream(stream); - } - } else { - return server - .getData( - Windmill.GetDataRequest.newBuilder().addGlobalDataFetchRequests(request).build()) - .getGlobalData(0); - } - } catch (Exception e) { - throw new RuntimeException("Failed to get side input: ", e); - } finally { - activeSideInputs.getAndDecrement(); - } - } - - /** Tells windmill processing is ongoing for the given keys. */ - public void refreshActiveWork(Map> heartbeats) { - if (heartbeats.isEmpty()) { - return; - } - activeHeartbeats.set(heartbeats.size()); - try { - if (useStreamingRequests) { - GetDataStream stream = heartbeatStreamPool.getStream(); - try { - stream.refreshActiveWork(heartbeats); - } finally { - heartbeatStreamPool.releaseStream(stream); - } - } else { - // This code path is only used by appliance which sends heartbeats (used to refresh active - // work) as KeyedGetDataRequests. So we must translate the HeartbeatRequest to a - // KeyedGetDataRequest here regardless of the value of sendKeyedGetDataRequests. - Windmill.GetDataRequest.Builder builder = Windmill.GetDataRequest.newBuilder(); - for (Map.Entry> entry : heartbeats.entrySet()) { - Windmill.ComputationGetDataRequest.Builder perComputationBuilder = - Windmill.ComputationGetDataRequest.newBuilder(); - perComputationBuilder.setComputationId(entry.getKey()); - for (HeartbeatRequest request : entry.getValue()) { - perComputationBuilder.addRequests( - Windmill.KeyedGetDataRequest.newBuilder() - .setShardingKey(request.getShardingKey()) - .setWorkToken(request.getWorkToken()) - .setCacheToken(request.getCacheToken()) - .addAllLatencyAttribution(request.getLatencyAttributionList()) - .build()); - } - builder.addRequests(perComputationBuilder.build()); - } - server.getData(builder.build()); - } - } finally { - activeHeartbeats.set(0); - } - } - - public void printHtml(PrintWriter writer) { - writer.println("Active Fetches:"); - writer.println(" Side Inputs: " + activeSideInputs.get()); - writer.println(" State Reads: " + activeStateReads.get()); - if (!useStreamingRequests) { - synchronized (this) { - writer.println(" Read threads: " + activeReadThreads); - writer.println(" Pending read batches: " + pendingReadBatches.size()); - } - } - writer.println("Heartbeat Keys Active: " + activeHeartbeats.get()); - } - - private static final class ReadBatch { - ArrayList reads = new ArrayList<>(); - SettableFuture startRead = SettableFuture.create(); - } - - private static final class QueueEntry { - - final String computation; - final Windmill.KeyedGetDataRequest request; - final SettableFuture response; - - QueueEntry( - String computation, - Windmill.KeyedGetDataRequest request, - SettableFuture response) { - this.computation = computation; - this.request = request; - this.response = response; - } - } -} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index a07bbfa7f5f3..f196852b2253 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -66,12 +66,17 @@ import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress; import org.apache.beam.runners.dataflow.worker.windmill.appliance.JniWindmillApplianceServer; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.CompleteCommit; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingApplianceWorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingEngineWorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.ApplianceGetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.StreamPoolGetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.ThrottlingGetDataMetricTracker; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.ChannelzServlet; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcDispatcherClient; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillServer; @@ -87,7 +92,9 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.StreamingEngineFailureTracker; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.WorkFailureProcessor; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.ActiveWorkRefresher; -import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.ActiveWorkRefreshers; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.ApplianceHeartbeatSender; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.StreamPoolHeartbeatSender; import org.apache.beam.sdk.fn.IdGenerator; import org.apache.beam.sdk.fn.IdGenerators; import org.apache.beam.sdk.fn.JvmInitializers; @@ -127,6 +134,7 @@ public class StreamingDataflowWorker { static final int GET_WORK_STREAM_TIMEOUT_MINUTES = 3; static final Duration COMMIT_STREAM_TIMEOUT = Duration.standardMinutes(1); private static final Logger LOG = LoggerFactory.getLogger(StreamingDataflowWorker.class); + private static final Duration GET_DATA_STREAM_TIMEOUT = Duration.standardSeconds(30); /** The idGenerator to generate unique id globally. */ private static final IdGenerator ID_GENERATOR = IdGenerators.decrementingLongs(); @@ -150,7 +158,7 @@ public class StreamingDataflowWorker { private final AtomicBoolean running = new AtomicBoolean(); private final DataflowWorkerHarnessOptions options; private final long clientId; - private final MetricTrackingWindmillServerStub metricTrackingWindmillServer; + private final GetDataClient getDataClient; private final MemoryMonitor memoryMonitor; private final Thread memoryMonitorThread; private final ReaderCache readerCache; @@ -160,6 +168,7 @@ public class StreamingDataflowWorker { private final StreamingWorkerStatusReporter workerStatusReporter; private final StreamingCounters streamingCounters; private final StreamingWorkScheduler streamingWorkScheduler; + private final HeartbeatSender heartbeatSender; private StreamingDataflowWorker( WindmillServerStub windmillServer, @@ -181,6 +190,9 @@ private StreamingDataflowWorker( GrpcWindmillStreamFactory windmillStreamFactory, Function executorSupplier, ConcurrentMap stageInfoMap) { + // Register standard file systems. + FileSystems.setDefaultPipelineOptions(options); + this.configFetcher = configFetcher; this.computationStateCache = computationStateCache; this.stateCache = windmillStateCache; @@ -199,12 +211,16 @@ private StreamingDataflowWorker( this.workCommitter = windmillServiceEnabled - ? StreamingEngineWorkCommitter.create( - WindmillStreamPool.create( - numCommitThreads, COMMIT_STREAM_TIMEOUT, windmillServer::commitWorkStream) - ::getCloseableStream, - numCommitThreads, - this::onCompleteCommit) + ? StreamingEngineWorkCommitter.builder() + .setCommitWorkStreamFactory( + WindmillStreamPool.create( + numCommitThreads, + COMMIT_STREAM_TIMEOUT, + windmillServer::commitWorkStream) + ::getCloseableStream) + .setNumCommitSenders(numCommitThreads) + .setOnCommitComplete(this::onCompleteCommit) + .build() : StreamingApplianceWorkCommitter.create( windmillServer::commitWork, this::onCompleteCommit); @@ -230,29 +246,41 @@ private StreamingDataflowWorker( dispatchThread.setName("DispatchThread"); this.clientId = clientId; this.windmillServer = windmillServer; - this.metricTrackingWindmillServer = - MetricTrackingWindmillServerStub.builder(windmillServer, memoryMonitor) - .setUseStreamingRequests(windmillServiceEnabled) - .setUseSeparateHeartbeatStreams(options.getUseSeparateWindmillHeartbeatStreams()) - .setNumGetDataStreams(options.getWindmillGetDataStreamCount()) - .build(); - // Register standard file systems. - FileSystems.setDefaultPipelineOptions(options); + ThrottlingGetDataMetricTracker getDataMetricTracker = + new ThrottlingGetDataMetricTracker(memoryMonitor); + + int stuckCommitDurationMillis; + if (windmillServiceEnabled) { + WindmillStreamPool getDataStreamPool = + WindmillStreamPool.create( + Math.max(1, options.getWindmillGetDataStreamCount()), + GET_DATA_STREAM_TIMEOUT, + windmillServer::getDataStream); + this.getDataClient = new StreamPoolGetDataClient(getDataMetricTracker, getDataStreamPool); + this.heartbeatSender = + new StreamPoolHeartbeatSender( + options.getUseSeparateWindmillHeartbeatStreams() + ? WindmillStreamPool.create( + 1, GET_DATA_STREAM_TIMEOUT, windmillServer::getDataStream) + : getDataStreamPool); + stuckCommitDurationMillis = + options.getStuckCommitDurationMillis() > 0 ? options.getStuckCommitDurationMillis() : 0; + } else { + this.getDataClient = new ApplianceGetDataClient(windmillServer, getDataMetricTracker); + this.heartbeatSender = new ApplianceHeartbeatSender(windmillServer::getData); + stuckCommitDurationMillis = 0; + } - int stuckCommitDurationMillis = - windmillServiceEnabled && options.getStuckCommitDurationMillis() > 0 - ? options.getStuckCommitDurationMillis() - : 0; this.activeWorkRefresher = - ActiveWorkRefreshers.createDispatchedActiveWorkRefresher( + new ActiveWorkRefresher( clock, options.getActiveWorkRefreshPeriodMillis(), stuckCommitDurationMillis, computationStateCache::getAllPresentComputations, sampler, - metricTrackingWindmillServer::refreshActiveWork, - executorSupplier.apply("RefreshWork")); + executorSupplier.apply("RefreshWork"), + getDataMetricTracker::trackHeartbeats); WorkerStatusPages workerStatusPages = WorkerStatusPages.create(DEFAULT_STATUS_PORT, memoryMonitor); @@ -265,7 +293,7 @@ private StreamingDataflowWorker( .setStateCache(stateCache) .setComputationStateCache(computationStateCache) .setCurrentActiveCommitBytes(workCommitter::currentActiveCommitBytes) - .setGetDataStatusProvider(metricTrackingWindmillServer::printHtml) + .setGetDataStatusProvider(getDataClient::printHtml) .setWorkUnitExecutor(workUnitExecutor); this.statusPages = @@ -281,7 +309,6 @@ private StreamingDataflowWorker( this.workerStatusReporter = workerStatusReporter; this.streamingCounters = streamingCounters; this.memoryMonitor = memoryMonitor; - this.streamingWorkScheduler = StreamingWorkScheduler.create( options, @@ -290,7 +317,6 @@ private StreamingDataflowWorker( mapTaskExecutorFactory, workUnitExecutor, stateCache::forComputation, - metricTrackingWindmillServer::getSideInputData, failureTracker, workFailureProcessor, streamingCounters, @@ -829,7 +855,7 @@ private void dispatchLoop() { workItem, watermarks.setOutputDataWatermark(workItem.getOutputDataWatermark()).build(), Work.createProcessingContext( - computationId, metricTrackingWindmillServer::getStateData, workCommitter::commit), + computationId, getDataClient, workCommitter::commit, heartbeatSender), /* getWorkStreamLatencies= */ Collections.emptyList()); } } @@ -865,8 +891,9 @@ void streamingDispatchLoop() { .build(), Work.createProcessingContext( computationState.getComputationId(), - metricTrackingWindmillServer::getStateData, - workCommitter::commit), + getDataClient, + workCommitter::commit, + heartbeatSender), getWorkStreamLatencies); })); try { @@ -874,7 +901,7 @@ void streamingDispatchLoop() { // If at any point the server closes the stream, we will reconnect immediately; otherwise // we half-close the stream after some time and create a new one. if (!stream.awaitTermination(GET_WORK_STREAM_TIMEOUT_MINUTES, TimeUnit.MINUTES)) { - stream.close(); + stream.halfClose(); } } catch (InterruptedException e) { // Continue processing until !running.get() diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java index 934977fe0985..ec5122a8732a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java @@ -26,6 +26,10 @@ public WorkItemCancelledException(long sharding_key) { super("Work item cancelled for key " + sharding_key); } + public WorkItemCancelledException(Throwable e) { + super(e); + } + /** Returns whether an exception was caused by a {@link WorkItemCancelledException}. */ public static boolean isWorkItemCancelledException(Throwable t) { while (t != null) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java index 3e226514d57e..c80c3a882e52 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.dataflow.worker.streaming; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap.flatteningToImmutableListMultimap; import java.io.PrintWriter; import java.util.ArrayDeque; @@ -31,13 +32,9 @@ import java.util.Queue; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; -import java.util.stream.Stream; import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; -import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; @@ -45,6 +42,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap; import org.joda.time.Duration; @@ -106,29 +104,6 @@ private static String elapsedString(Instant start, Instant end) { return activeFor.toString().substring(2); } - private static Stream toHeartbeatRequestStream( - Entry> shardedKeyAndWorkQueue, - Instant refreshDeadline, - DataflowExecutionStateSampler sampler) { - ShardedKey shardedKey = shardedKeyAndWorkQueue.getKey(); - Deque workQueue = shardedKeyAndWorkQueue.getValue(); - - return workQueue.stream() - .map(ExecutableWork::work) - .filter(work -> work.getStartTime().isBefore(refreshDeadline)) - // Don't send heartbeats for queued work we already know is failed. - .filter(work -> !work.isFailed()) - .map( - work -> - Windmill.HeartbeatRequest.newBuilder() - .setShardingKey(shardedKey.shardingKey()) - .setWorkToken(work.getWorkItem().getWorkToken()) - .setCacheToken(work.getWorkItem().getCacheToken()) - .addAllLatencyAttribution( - work.getLatencyAttributions(/* isHeartbeat= */ true, sampler)) - .build()); - } - /** * Activates {@link Work} for the {@link ShardedKey}. Outcome can be 1 of 4 {@link * ActivateWorkResult} @@ -219,6 +194,31 @@ synchronized void failWorkForKey(Multimap failedWork) { } } + /** + * Returns a read only view of current active work. + * + * @implNote Do not return a reference to the underlying workQueue as iterations over it will + * cause a {@link java.util.ConcurrentModificationException} as it is not a thread-safe data + * structure. + */ + synchronized ImmutableListMultimap getReadOnlyActiveWork() { + return activeWork.entrySet().stream() + .collect( + flatteningToImmutableListMultimap( + Entry::getKey, + e -> + e.getValue().stream() + .map(executableWork -> (RefreshableWork) executableWork.work()))); + } + + synchronized ImmutableList getRefreshableWork(Instant refreshDeadline) { + return activeWork.values().stream() + .flatMap(Deque::stream) + .map(ExecutableWork::work) + .filter(work -> !work.isFailed() && work.getStartTime().isBefore(refreshDeadline)) + .collect(toImmutableList()); + } + private void incrementActiveWorkBudget(Work work) { activeGetWorkBudget.updateAndGet( getWorkBudget -> getWorkBudget.apply(1, work.getWorkItem().getSerializedSize())); @@ -324,13 +324,6 @@ private synchronized ImmutableMap getStuckCommitsAt( return stuckCommits.build(); } - synchronized ImmutableList getKeyHeartbeats( - Instant refreshDeadline, DataflowExecutionStateSampler sampler) { - return activeWork.entrySet().stream() - .flatMap(entry -> toHeartbeatRequestStream(entry, refreshDeadline, sampler)) - .collect(toImmutableList()); - } - /** * Returns the current aggregate {@link GetWorkBudget} that is active on the user worker. Active * means that the work is received from Windmill, being processed or queued to be processed in diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java index 434e78484799..f3b0ba16fbe2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java @@ -23,14 +23,13 @@ import java.util.Optional; import java.util.concurrent.ConcurrentLinkedQueue; import javax.annotation.Nullable; -import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap; import org.joda.time.Instant; @@ -147,10 +146,12 @@ private void forceExecute(ExecutableWork executableWork) { executor.forceExecute(executableWork, executableWork.work().getWorkItem().getSerializedSize()); } - /** Gets HeartbeatRequests for any work started before refreshDeadline. */ - public ImmutableList getKeyHeartbeats( - Instant refreshDeadline, DataflowExecutionStateSampler sampler) { - return activeWorkState.getKeyHeartbeats(refreshDeadline, sampler); + public ImmutableListMultimap currentActiveWorkReadOnly() { + return activeWorkState.getReadOnlyActiveWork(); + } + + public ImmutableList getRefreshableWork(Instant refreshDeadline) { + return activeWorkState.getRefreshableWork(refreshDeadline); } public GetWorkBudget getActiveWorkBudget() { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java index bdf8a7814ea3..db279f066630 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java @@ -31,7 +31,7 @@ public static ExecutableWork create(Work work, Consumer executeWorkFn) { public abstract Work work(); - abstract Consumer executeWorkFn(); + public abstract Consumer executeWorkFn(); @Override public void run() { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/RefreshableWork.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/RefreshableWork.java new file mode 100644 index 000000000000..c51b04f23719 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/RefreshableWork.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; + +/** View of {@link Work} that exposes an interface for work refreshing. */ +@Internal +public interface RefreshableWork { + + WorkId id(); + + ShardedKey getShardedKey(); + + HeartbeatSender heartbeatSender(); + + ImmutableList getHeartbeatLatencyAttributions( + DataflowExecutionStateSampler sampler); + + void setFailed(); +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java index ed3f2671b40c..e77823602eda 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java @@ -27,14 +27,14 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Optional; -import java.util.function.BiFunction; import java.util.function.Consumer; -import java.util.function.Function; import java.util.function.Supplier; import javax.annotation.concurrent.NotThreadSafe; import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; import org.apache.beam.runners.dataflow.worker.ActiveMessageMetadata; import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution; @@ -45,7 +45,9 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commit; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.joda.time.Duration; @@ -58,7 +60,7 @@ */ @NotThreadSafe @Internal -public final class Work { +public final class Work implements RefreshableWork { private final ShardedKey shardedKey; private final WorkItem workItem; private final ProcessingContext processingContext; @@ -105,9 +107,10 @@ public static Work create( public static ProcessingContext createProcessingContext( String computationId, - BiFunction getKeyedDataFn, - Consumer workCommitter) { - return ProcessingContext.create(computationId, getKeyedDataFn, workCommitter); + GetDataClient getDataClient, + Consumer workCommitter, + HeartbeatSender heartbeatSender) { + return ProcessingContext.create(computationId, getDataClient, workCommitter, heartbeatSender); } private static LatencyAttribution.Builder createLatencyAttributionWithActiveLatencyBreakdown( @@ -151,12 +154,17 @@ public WorkItem getWorkItem() { return workItem; } + @Override public ShardedKey getShardedKey() { return shardedKey; } public Optional fetchKeyedState(KeyedGetDataRequest keyedGetDataRequest) { - return processingContext.keyedDataFetcher().apply(keyedGetDataRequest); + return processingContext.fetchKeyedState(keyedGetDataRequest); + } + + public GlobalData fetchSideInput(GlobalDataRequest request) { + return processingContext.getDataClient().getSideInputData(request); } public Watermarks watermarks() { @@ -180,6 +188,7 @@ public void setState(State state) { this.currentState = TimedState.create(state, now); } + @Override public void setFailed() { this.isFailed = true; } @@ -196,6 +205,11 @@ public String getLatencyTrackingId() { return latencyTrackingId; } + @Override + public HeartbeatSender heartbeatSender() { + return processingContext.heartbeatSender(); + } + public void queueCommit(WorkItemCommitRequest commitRequest, ComputationState computationState) { setState(State.COMMIT_QUEUED); processingContext.workCommitter().accept(Commit.create(commitRequest, computationState, this)); @@ -205,6 +219,7 @@ public WindmillStateReader createWindmillStateReader() { return WindmillStateReader.forWork(this); } + @Override public WorkId id() { return id; } @@ -216,7 +231,25 @@ private void recordGetWorkStreamLatencies(Collection getWork } } + @Override + public ImmutableList getHeartbeatLatencyAttributions( + DataflowExecutionStateSampler sampler) { + return getLatencyAttributions(/* isHeartbeat= */ true, sampler); + } + public ImmutableList getLatencyAttributions( + DataflowExecutionStateSampler sampler) { + return getLatencyAttributions(/* isHeartbeat= */ false, sampler); + } + + private Duration getTotalDurationAtLatencyAttributionState(LatencyAttribution.State state) { + Duration duration = totalDurationPerState.getOrDefault(state, Duration.ZERO); + return state == this.currentState.state().toLatencyAttributionState() + ? duration.plus(new Duration(this.currentState.startTime(), clock.get())) + : duration; + } + + private ImmutableList getLatencyAttributions( boolean isHeartbeat, DataflowExecutionStateSampler sampler) { return Arrays.stream(LatencyAttribution.State.values()) .map(state -> Pair.of(state, getTotalDurationAtLatencyAttributionState(state))) @@ -233,13 +266,6 @@ public ImmutableList getLatencyAttributions( .collect(toImmutableList()); } - private Duration getTotalDurationAtLatencyAttributionState(LatencyAttribution.State state) { - Duration duration = totalDurationPerState.getOrDefault(state, Duration.ZERO); - return state == this.currentState.state().toLatencyAttributionState() - ? duration.plus(new Duration(this.currentState.startTime(), clock.get())) - : duration; - } - private LatencyAttribution createLatencyAttribution( LatencyAttribution.State state, boolean isHeartbeat, @@ -314,25 +340,29 @@ public abstract static class ProcessingContext { private static ProcessingContext create( String computationId, - BiFunction getKeyedDataFn, - Consumer workCommitter) { + GetDataClient getDataClient, + Consumer workCommitter, + HeartbeatSender heartbeatSender) { return new AutoValue_Work_ProcessingContext( - computationId, - request -> Optional.ofNullable(getKeyedDataFn.apply(computationId, request)), - workCommitter); + computationId, getDataClient, heartbeatSender, workCommitter); } /** Computation that the {@link Work} belongs to. */ public abstract String computationId(); /** Handles GetData requests to streaming backend. */ - public abstract Function> - keyedDataFetcher(); + public abstract GetDataClient getDataClient(); + + public abstract HeartbeatSender heartbeatSender(); /** * {@link WorkCommitter} that commits completed work to the backend Windmill worker handling the * {@link WorkItem}. */ public abstract Consumer workCommitter(); + + private Optional fetchKeyedState(KeyedGetDataRequest request) { + return Optional.ofNullable(getDataClient().getStateData(computationId(), request)); + } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkId.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkId.java index f8f8d1901914..d4e7f05d255f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkId.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkId.java @@ -41,9 +41,9 @@ public static WorkId of(Windmill.WorkItem workItem) { .build(); } - abstract long cacheToken(); + public abstract long cacheToken(); - abstract long workToken(); + public abstract long workToken(); @AutoValue.Builder public abstract static class Builder { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java index 7fd2487575c2..303cdeb94f8c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java @@ -30,11 +30,11 @@ import java.util.function.Function; import javax.annotation.concurrent.NotThreadSafe; import org.apache.beam.runners.core.InMemoryMultimapSideInputView; -import org.apache.beam.runners.dataflow.options.DataflowStreamingPipelineOptions; import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; @@ -46,14 +46,14 @@ import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** Class responsible for fetching state from the windmill server. */ +/** Class responsible for fetching side input state from the streaming backend. */ @NotThreadSafe +@Internal public class SideInputStateFetcher { private static final Logger LOG = LoggerFactory.getLogger(SideInputStateFetcher.class); @@ -64,13 +64,6 @@ public class SideInputStateFetcher { private final Function fetchGlobalDataFn; private long bytesRead = 0L; - public SideInputStateFetcher( - Function fetchGlobalDataFn, - DataflowStreamingPipelineOptions options) { - this(fetchGlobalDataFn, SideInputCache.create(options)); - } - - @VisibleForTesting SideInputStateFetcher( Function fetchGlobalDataFn, SideInputCache sideInputCache) { this.fetchGlobalDataFn = fetchGlobalDataFn; @@ -103,12 +96,56 @@ private static Coder getCoder(PCollectionView view) { return view.getCoderInternal(); } - /** Returns a view of the underlying cache that keeps track of bytes read separately. */ - public SideInputStateFetcher byteTrackingView() { - return new SideInputStateFetcher(fetchGlobalDataFn, sideInputCache); + private static SideInput createSideInputCacheEntry( + PCollectionView view, GlobalData data) throws IOException { + Iterable rawData = decodeRawData(view, data); + switch (getViewFn(view).getMaterialization().getUrn()) { + case ITERABLE_MATERIALIZATION_URN: + { + @SuppressWarnings({ + "unchecked", // ITERABLE_MATERIALIZATION_URN has ViewFn. + "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) + }) + ViewFn viewFn = (ViewFn) getViewFn(view); + return SideInput.ready(viewFn.apply(() -> rawData), data.getData().size()); + } + case MULTIMAP_MATERIALIZATION_URN: + { + @SuppressWarnings({ + "unchecked", // MULTIMAP_MATERIALIZATION_URN has ViewFn. + "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) + }) + ViewFn viewFn = (ViewFn) getViewFn(view); + Coder keyCoder = ((KvCoder) getCoder(view)).getKeyCoder(); + + @SuppressWarnings({ + "unchecked", // Safe since multimap rawData is of type Iterable> + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) + }) + T multimapSideInputValue = + viewFn.apply( + InMemoryMultimapSideInputView.fromIterable(keyCoder, (Iterable) rawData)); + return SideInput.ready(multimapSideInputValue, data.getData().size()); + } + default: + { + throw new IllegalStateException( + "Unknown side input materialization format requested: " + + getViewFn(view).getMaterialization().getUrn()); + } + } } - public long getBytesRead() { + private static void validateViewMaterialization(PCollectionView view) { + String materializationUrn = getViewFn(view).getMaterialization().getUrn(); + checkState( + SUPPORTED_MATERIALIZATIONS.contains(materializationUrn), + "Only materialization's of type %s supported, received %s", + SUPPORTED_MATERIALIZATIONS, + materializationUrn); + } + + public final long getBytesRead() { return bytesRead; } @@ -200,53 +237,4 @@ private SideInput loadSideInputFromWindmill( bytesRead += data.getSerializedSize(); return data.getIsReady() ? createSideInputCacheEntry(view, data) : SideInput.notReady(); } - - private void validateViewMaterialization(PCollectionView view) { - String materializationUrn = getViewFn(view).getMaterialization().getUrn(); - checkState( - SUPPORTED_MATERIALIZATIONS.contains(materializationUrn), - "Only materialization's of type %s supported, received %s", - SUPPORTED_MATERIALIZATIONS, - materializationUrn); - } - - private SideInput createSideInputCacheEntry(PCollectionView view, GlobalData data) - throws IOException { - Iterable rawData = decodeRawData(view, data); - switch (getViewFn(view).getMaterialization().getUrn()) { - case ITERABLE_MATERIALIZATION_URN: - { - @SuppressWarnings({ - "unchecked", // ITERABLE_MATERIALIZATION_URN has ViewFn. - "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) - }) - ViewFn viewFn = (ViewFn) getViewFn(view); - return SideInput.ready(viewFn.apply(() -> rawData), data.getData().size()); - } - case MULTIMAP_MATERIALIZATION_URN: - { - @SuppressWarnings({ - "unchecked", // MULTIMAP_MATERIALIZATION_URN has ViewFn. - "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) - }) - ViewFn viewFn = (ViewFn) getViewFn(view); - Coder keyCoder = ((KvCoder) getCoder(view)).getKeyCoder(); - - @SuppressWarnings({ - "unchecked", // Safe since multimap rawData is of type Iterable> - "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) - }) - T multimapSideInputValue = - viewFn.apply( - InMemoryMultimapSideInputView.fromIterable(keyCoder, (Iterable) rawData)); - return SideInput.ready(multimapSideInputValue, data.getData().size()); - } - default: - { - throw new IllegalStateException( - "Unknown side input materialization format requested: " - + getViewFn(view).getMaterialization().getUrn()); - } - } - } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcherFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcherFactory.java new file mode 100644 index 000000000000..fd42b9ff1801 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcherFactory.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming.sideinput; + +import java.util.function.Function; +import org.apache.beam.runners.dataflow.options.DataflowStreamingPipelineOptions; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; +import org.apache.beam.sdk.annotations.Internal; + +/** + * Factory class for generating {@link SideInputStateFetcher} instances that share a {@link + * SideInputCache}. + */ +@Internal +public final class SideInputStateFetcherFactory { + private final SideInputCache globalSideInputCache; + + private SideInputStateFetcherFactory(SideInputCache globalSideInputCache) { + this.globalSideInputCache = globalSideInputCache; + } + + public static SideInputStateFetcherFactory fromOptions(DataflowStreamingPipelineOptions options) { + return new SideInputStateFetcherFactory(SideInputCache.create(options)); + } + + public SideInputStateFetcher createSideInputStateFetcher( + Function fetchGlobalDataFn) { + return new SideInputStateFetcher(fetchGlobalDataFn, globalSideInputCache); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/ApplianceWindmillClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/ApplianceWindmillClient.java new file mode 100644 index 000000000000..2cd3748eb31b --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/ApplianceWindmillClient.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill; + +import org.apache.beam.sdk.annotations.Internal; + +/** Client for WindmillService via Streaming Appliance. */ +@Internal +public interface ApplianceWindmillClient { + /** Get a batch of work to process. */ + Windmill.GetWorkResponse getWork(Windmill.GetWorkRequest request); + + /** Get additional data such as state needed to process work. */ + Windmill.GetDataResponse getData(Windmill.GetDataRequest request); + + /** Commit the work, issuing any output productions, state modifications etc. */ + Windmill.CommitWorkResponse commitWork(Windmill.CommitWorkRequest request); + + /** Get configuration data from the server. */ + Windmill.GetConfigResponse getConfig(Windmill.GetConfigRequest request); + + /** Report execution information to the server. */ + Windmill.ReportStatsResponse reportStats(Windmill.ReportStatsRequest request); +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/StreamingEngineWindmillClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/StreamingEngineWindmillClient.java new file mode 100644 index 000000000000..e02e6c112358 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/StreamingEngineWindmillClient.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill; + +import java.util.Set; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemReceiver; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; + +/** Client for WindmillService via Streaming Engine. */ +@Internal +public interface StreamingEngineWindmillClient { + /** Returns the windmill service endpoints set by setWindmillServiceEndpoints */ + ImmutableSet getWindmillServiceEndpoints(); + + /** + * Sets the new endpoints used to talk to windmill. Upon first call, the stubs are initialized. On + * subsequent calls, if endpoints are different from previous values new stubs are created, + * replacing the previous ones. + */ + void setWindmillServiceEndpoints(Set endpoints); + + /** + * Gets work to process, returned as a stream. + * + *

Each time a WorkItem is received, it will be passed to the given receiver. The returned + * GetWorkStream object can be used to control the lifetime of the stream. + */ + WindmillStream.GetWorkStream getWorkStream( + Windmill.GetWorkRequest request, WorkItemReceiver receiver); + + /** Get additional data such as state needed to process work, returned as a stream. */ + WindmillStream.GetDataStream getDataStream(); + + /** Returns a stream allowing individual WorkItemCommitRequests to be streamed to Windmill. */ + WindmillStream.CommitWorkStream commitWorkStream(); +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillConnection.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillConnection.java index a20c2f02b269..7d199afc0861 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillConnection.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillConnection.java @@ -27,6 +27,8 @@ @AutoValue @Internal public abstract class WindmillConnection { + private static final String NO_BACKEND_WORKER_TOKEN = ""; + public static WindmillConnection from( Endpoint windmillEndpoint, Function endpointToStubFn) { @@ -40,23 +42,24 @@ public static WindmillConnection from( } public static Builder builder() { - return new AutoValue_WindmillConnection.Builder(); + return new AutoValue_WindmillConnection.Builder() + .setBackendWorkerToken(NO_BACKEND_WORKER_TOKEN); } - public abstract Optional backendWorkerToken(); + public abstract String backendWorkerToken(); public abstract Optional directEndpoint(); public abstract CloudWindmillServiceV1Alpha1Stub stub(); @AutoValue.Builder - abstract static class Builder { + public abstract static class Builder { abstract Builder setBackendWorkerToken(String backendWorkerToken); public abstract Builder setDirectEndpoint(WindmillServiceAddress value); - abstract Builder setStub(CloudWindmillServiceV1Alpha1Stub stub); + public abstract Builder setStub(CloudWindmillServiceV1Alpha1Stub stub); - abstract WindmillConnection build(); + public abstract WindmillConnection build(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerBase.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerBase.java index 0785ae96626e..5f7fd6da9d4b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerBase.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerBase.java @@ -59,11 +59,6 @@ public ImmutableSet getWindmillServiceEndpoints() { return ImmutableSet.of(); } - @Override - public boolean isReady() { - return true; - } - @Override public Windmill.GetWorkResponse getWork(Windmill.GetWorkRequest workRequest) { try { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java index 7d0c4f5aba32..cd753cb8ec91 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java @@ -18,65 +18,11 @@ package org.apache.beam.runners.dataflow.worker.windmill; import java.io.PrintWriter; -import java.util.Set; import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider; -import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; -import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; -import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; -import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemReceiver; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; /** Stub for communicating with a Windmill server. */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -public abstract class WindmillServerStub implements StatusDataProvider { - - /** - * Sets the new endpoints used to talk to windmill. Upon first call, the stubs are initialized. On - * subsequent calls, if endpoints are different from previous values new stubs are created, - * replacing the previous ones. - */ - public abstract void setWindmillServiceEndpoints(Set endpoints); - - /* - * Returns the windmill service endpoints set by setWindmillServiceEndpoints - */ - public abstract ImmutableSet getWindmillServiceEndpoints(); - - /** Returns true iff this WindmillServerStub is ready for making API calls. */ - public abstract boolean isReady(); - - /** Get a batch of work to process. */ - public abstract Windmill.GetWorkResponse getWork(Windmill.GetWorkRequest request); - - /** Get additional data such as state needed to process work. */ - public abstract Windmill.GetDataResponse getData(Windmill.GetDataRequest request); - - /** Commit the work, issuing any output productions, state modifications etc. */ - public abstract Windmill.CommitWorkResponse commitWork(Windmill.CommitWorkRequest request); - - /** Get configuration data from the server. */ - public abstract Windmill.GetConfigResponse getConfig(Windmill.GetConfigRequest request); - - /** Report execution information to the server. */ - public abstract Windmill.ReportStatsResponse reportStats(Windmill.ReportStatsRequest request); - - /** - * Gets work to process, returned as a stream. - * - *

Each time a WorkItem is received, it will be passed to the given receiver. The returned - * GetWorkStream object can be used to control the lifetime of the stream. - */ - public abstract GetWorkStream getWorkStream( - Windmill.GetWorkRequest request, WorkItemReceiver receiver); - - /** Get additional data such as state needed to process work, returned as a stream. */ - public abstract GetDataStream getDataStream(); - - /** Returns a stream allowing individual WorkItemCommitRequests to be streamed to Windmill. */ - public abstract CommitWorkStream commitWorkStream(); +public abstract class WindmillServerStub + implements ApplianceWindmillClient, StreamingEngineWindmillClient, StatusDataProvider { /** Returns the amount of time the server has been throttled and resets the time to 0. */ public abstract long getAndResetThrottleTime(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java index 028a5c2e1d4b..58aecfc71e00 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java @@ -69,6 +69,7 @@ public abstract class AbstractWindmillStream implements Win protected static final int RPC_STREAM_CHUNK_SIZE = 2 << 20; private static final Logger LOG = LoggerFactory.getLogger(AbstractWindmillStream.class); protected final AtomicBoolean clientClosed; + private final AtomicBoolean isShutdown; private final AtomicLong lastSendTimeMs; private final Executor executor; private final BackOff backoff; @@ -84,19 +85,23 @@ public abstract class AbstractWindmillStream implements Win private final Supplier> requestObserverSupplier; // Indicates if the current stream in requestObserver is closed by calling close() method private final AtomicBoolean streamClosed; + private final String backendWorkerToken; private @Nullable StreamObserver requestObserver; protected AbstractWindmillStream( + String debugStreamType, Function, StreamObserver> clientFactory, BackOff backoff, StreamObserverFactory streamObserverFactory, Set> streamRegistry, - int logEveryNStreamFailures) { + int logEveryNStreamFailures, + String backendWorkerToken) { + this.backendWorkerToken = backendWorkerToken; this.executor = Executors.newSingleThreadExecutor( new ThreadFactoryBuilder() .setDaemon(true) - .setNameFormat("WindmillStream-thread") + .setNameFormat(createThreadName(debugStreamType, backendWorkerToken)) .build()); this.backoff = backoff; this.streamRegistry = streamRegistry; @@ -111,12 +116,19 @@ protected AbstractWindmillStream( this.lastErrorTime = new AtomicReference<>(); this.sleepUntil = new AtomicLong(); this.finishLatch = new CountDownLatch(1); + this.isShutdown = new AtomicBoolean(false); this.requestObserverSupplier = () -> streamObserverFactory.from( clientFactory, new AbstractWindmillStream.ResponseObserver()); } + private static String createThreadName(String streamType, String backendWorkerToken) { + return !backendWorkerToken.isEmpty() + ? String.format("%s-%s-WindmillStream-thread", streamType, backendWorkerToken) + : String.format("%s-WindmillStream-thread", streamType); + } + private static long debugDuration(long nowMs, long startMs) { if (startMs <= 0) { return -1; @@ -140,6 +152,11 @@ private static long debugDuration(long nowMs, long startMs) { */ protected abstract void startThrottleTimer(); + /** Reflects that {@link #shutdown()} was explicitly called. */ + protected boolean isShutdown() { + return isShutdown.get(); + } + private StreamObserver requestObserver() { if (requestObserver == null) { throw new NullPointerException( @@ -175,7 +192,7 @@ protected final void startStream() { requestObserver = requestObserverSupplier.get(); onNewStream(); if (clientClosed.get()) { - close(); + halfClose(); } return; } @@ -238,7 +255,7 @@ public final void appendSummaryHtml(PrintWriter writer) { protected abstract void appendSpecificHtml(PrintWriter writer); @Override - public final synchronized void close() { + public final synchronized void halfClose() { // Synchronization of close and onCompleted necessary for correct retry logic in onNewStream. clientClosed.set(true); requestObserver().onCompleted(); @@ -255,6 +272,30 @@ public final Instant startTime() { return new Instant(startTimeMs.get()); } + @Override + public String backendWorkerToken() { + return backendWorkerToken; + } + + @Override + public void shutdown() { + if (isShutdown.compareAndSet(false, true)) { + requestObserver() + .onError(new WindmillStreamShutdownException("Explicit call to shutdown stream.")); + } + } + + private void setLastError(String error) { + lastError.set(error); + lastErrorTime.set(DateTime.now()); + } + + public static class WindmillStreamShutdownException extends RuntimeException { + public WindmillStreamShutdownException(String message) { + super(message); + } + } + private class ResponseObserver implements StreamObserver { @Override @@ -280,7 +321,7 @@ public void onCompleted() { private void onStreamFinished(@Nullable Throwable t) { synchronized (this) { - if (clientClosed.get() && !hasPendingRequests()) { + if (isShutdown.get() || (clientClosed.get() && !hasPendingRequests())) { streamRegistry.remove(AbstractWindmillStream.this); finishLatch.countDown(); return; @@ -337,9 +378,4 @@ private void onStreamFinished(@Nullable Throwable t) { executor.execute(AbstractWindmillStream.this::startStream); } } - - private void setLastError(String error) { - lastError.set(error); - lastErrorTime.set(DateTime.now()); - } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java index d044e9300790..31bd4e146a78 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.dataflow.worker.windmill.client; import java.io.Closeable; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -32,8 +33,12 @@ /** Superclass for streams returned by streaming Windmill methods. */ @ThreadSafe public interface WindmillStream { + + /** An identifier for the backend worker where the stream is sending/receiving RPCs. */ + String backendWorkerToken(); + /** Indicates that no more requests will be sent. */ - void close(); + void halfClose(); /** Waits for the server to close its end of the connection, with timeout. */ boolean awaitTermination(int time, TimeUnit unit) throws InterruptedException; @@ -41,6 +46,12 @@ public interface WindmillStream { /** Returns when the stream was opened. */ Instant startTime(); + /** + * Shutdown the stream. There should be no further interactions with the stream once this has been + * called. + */ + void shutdown(); + /** Handle representing a stream of GetWork responses. */ @ThreadSafe interface GetWorkStream extends WindmillStream { @@ -62,7 +73,7 @@ Windmill.KeyedGetDataResponse requestKeyedData( Windmill.GlobalData requestGlobalData(Windmill.GlobalDataRequest request); /** Tells windmill processing is ongoing for the given keys. */ - void refreshActiveWork(Map> heartbeats); + void refreshActiveWork(Map> heartbeats); void onHeartbeatResponse(List responses); } @@ -70,6 +81,12 @@ Windmill.KeyedGetDataResponse requestKeyedData( /** Interface for streaming CommitWorkRequests to Windmill. */ @ThreadSafe interface CommitWorkStream extends WindmillStream { + /** + * Returns a builder that can be used for sending requests. Each builder is not thread-safe but + * different builders for the same stream may be used simultaneously. + */ + CommitWorkStream.RequestBatcher batcher(); + @NotThreadSafe interface RequestBatcher extends Closeable { /** @@ -92,12 +109,6 @@ default void close() { flush(); } } - - /** - * Returns a builder that can be used for sending requests. Each builder is not thread-safe but - * different builders for the same stream may be used simultaneously. - */ - RequestBatcher batcher(); } /** Interface for streaming GetWorkerMetadata requests to Windmill. */ diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPool.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPool.java index 0e4e085c066c..f14fc40fdfdf 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPool.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPool.java @@ -128,7 +128,7 @@ public StreamT getStream() { return resultStream; } finally { if (closeThisStream != null) { - closeThisStream.close(); + closeThisStream.halfClose(); } } } @@ -166,7 +166,7 @@ public void releaseStream(StreamT stream) { } if (closeStream) { - stream.close(); + stream.halfClose(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java index ed4dcfa212f1..bf1007bc4bfb 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java @@ -17,9 +17,11 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client.commits; +import com.google.auto.value.AutoBuilder; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import java.util.function.Supplier; @@ -45,6 +47,7 @@ public final class StreamingEngineWorkCommitter implements WorkCommitter { private static final Logger LOG = LoggerFactory.getLogger(StreamingEngineWorkCommitter.class); private static final int TARGET_COMMIT_BATCH_KEYS = 5; private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB + private static final String NO_BACKEND_WORKER_TOKEN = ""; private final Supplier> commitWorkStreamFactory; private final WeightedBoundedQueue commitQueue; @@ -52,11 +55,13 @@ public final class StreamingEngineWorkCommitter implements WorkCommitter { private final AtomicLong activeCommitBytes; private final Consumer onCommitComplete; private final int numCommitSenders; + private final AtomicBoolean isRunning; - private StreamingEngineWorkCommitter( + StreamingEngineWorkCommitter( Supplier> commitWorkStreamFactory, int numCommitSenders, - Consumer onCommitComplete) { + Consumer onCommitComplete, + String backendWorkerToken) { this.commitWorkStreamFactory = commitWorkStreamFactory; this.commitQueue = WeightedBoundedQueue.create( @@ -67,34 +72,48 @@ private StreamingEngineWorkCommitter( new ThreadFactoryBuilder() .setDaemon(true) .setPriority(Thread.MAX_PRIORITY) - .setNameFormat("CommitThread-%d") + .setNameFormat( + backendWorkerToken.isEmpty() + ? "CommitThread-%d" + : "CommitThread-" + backendWorkerToken + "-%d") .build()); this.activeCommitBytes = new AtomicLong(); this.onCommitComplete = onCommitComplete; this.numCommitSenders = numCommitSenders; + this.isRunning = new AtomicBoolean(false); } - public static StreamingEngineWorkCommitter create( - Supplier> commitWorkStreamFactory, - int numCommitSenders, - Consumer onCommitComplete) { - return new StreamingEngineWorkCommitter( - commitWorkStreamFactory, numCommitSenders, onCommitComplete); + public static Builder builder() { + return new AutoBuilder_StreamingEngineWorkCommitter_Builder() + .setBackendWorkerToken(NO_BACKEND_WORKER_TOKEN) + .setNumCommitSenders(1); } @Override @SuppressWarnings("FutureReturnValueIgnored") public void start() { - if (!commitSenders.isShutdown()) { - for (int i = 0; i < numCommitSenders; i++) { - commitSenders.submit(this::streamingCommitLoop); - } + Preconditions.checkState( + isRunning.compareAndSet(false, true), "Multiple calls to WorkCommitter.start()."); + for (int i = 0; i < numCommitSenders; i++) { + commitSenders.submit(this::streamingCommitLoop); } } @Override public void commit(Commit commit) { - commitQueue.put(commit); + boolean isShutdown = !this.isRunning.get(); + if (commit.work().isFailed() || isShutdown) { + if (isShutdown) { + LOG.debug( + "Trying to queue commit on shutdown, failing commit=[computationId={}, shardingKey={}, workId={} ].", + commit.computationId(), + commit.work().getShardedKey(), + commit.work().id()); + } + failCommit(commit); + } else { + commitQueue.put(commit); + } } @Override @@ -104,15 +123,14 @@ public long currentActiveCommitBytes() { @Override public void stop() { - if (!commitSenders.isTerminated()) { - commitSenders.shutdownNow(); - try { - commitSenders.awaitTermination(10, TimeUnit.SECONDS); - } catch (InterruptedException e) { - LOG.warn( - "Commit senders didn't complete shutdown within 10 seconds, continuing to drain queue", - e); - } + Preconditions.checkState(isRunning.compareAndSet(true, false)); + commitSenders.shutdownNow(); + try { + commitSenders.awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + LOG.warn( + "Commit senders didn't complete shutdown within 10 seconds, continuing to drain queue.", + e); } drainCommitQueue(); } @@ -138,12 +156,13 @@ public int parallelism() { private void streamingCommitLoop() { @Nullable Commit initialCommit = null; try { - while (true) { + while (isRunning.get()) { if (initialCommit == null) { try { // Block until we have a commit or are shutting down. initialCommit = commitQueue.take(); } catch (InterruptedException e) { + Thread.currentThread().interrupt(); return; } } @@ -156,17 +175,14 @@ private void streamingCommitLoop() { } try (CloseableStream closeableCommitStream = - commitWorkStreamFactory.get()) { - CommitWorkStream commitStream = closeableCommitStream.stream(); - try (CommitWorkStream.RequestBatcher batcher = commitStream.batcher()) { - if (!tryAddToCommitBatch(initialCommit, batcher)) { - throw new AssertionError( - "Initial commit on flushed stream should always be accepted."); - } - // Batch additional commits to the stream and possibly make an un-batched commit the - // next initial commit. - initialCommit = expandBatch(batcher); + commitWorkStreamFactory.get(); + CommitWorkStream.RequestBatcher batcher = closeableCommitStream.stream().batcher()) { + if (!tryAddToCommitBatch(initialCommit, batcher)) { + throw new AssertionError("Initial commit on flushed stream should always be accepted."); } + // Batch additional commits to the stream and possibly make an un-batched commit the + // next initial commit. + initialCommit = expandBatch(batcher); } catch (Exception e) { LOG.error("Error occurred sending commits.", e); } @@ -187,7 +203,7 @@ private boolean tryAddToCommitBatch(Commit commit, CommitWorkStream.RequestBatch batcher.commitWorkItem( commit.computationId(), commit.request(), - (commitStatus) -> { + commitStatus -> { onCommitComplete.accept(CompleteCommit.create(commit, commitStatus)); activeCommitBytes.addAndGet(-commit.getSize()); }); @@ -201,9 +217,11 @@ private boolean tryAddToCommitBatch(Commit commit, CommitWorkStream.RequestBatch return isCommitAccepted; } - // Helper to batch additional commits into the commit batch as long as they fit. - // Returns a commit that was removed from the queue but not consumed or null. - private Commit expandBatch(CommitWorkStream.RequestBatcher batcher) { + /** + * Helper to batch additional commits into the commit batch as long as they fit. Returns a commit + * that was removed from the queue but not consumed or null. + */ + private @Nullable Commit expandBatch(CommitWorkStream.RequestBatcher batcher) { int commits = 1; while (true) { Commit commit; @@ -214,6 +232,7 @@ private Commit expandBatch(CommitWorkStream.RequestBatcher batcher) { commit = commitQueue.poll(); } } catch (InterruptedException e) { + Thread.currentThread().interrupt(); return null; } @@ -233,4 +252,22 @@ private Commit expandBatch(CommitWorkStream.RequestBatcher batcher) { commits++; } } + + @AutoBuilder + public interface Builder { + Builder setCommitWorkStreamFactory( + Supplier> commitWorkStreamFactory); + + Builder setNumCommitSenders(int numCommitSenders); + + Builder setOnCommitComplete(Consumer onCommitComplete); + + Builder setBackendWorkerToken(String backendWorkerToken); + + StreamingEngineWorkCommitter autoBuild(); + + default WorkCommitter build() { + return autoBuild(); + } + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ApplianceGetDataClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ApplianceGetDataClient.java new file mode 100644 index 000000000000..e0500dde0c53 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ApplianceGetDataClient.java @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.getdata; + +import java.io.PrintWriter; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.worker.WindmillComputationKey; +import org.apache.beam.runners.dataflow.worker.windmill.ApplianceWindmillClient; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.SettableFuture; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** Appliance implementation of {@link GetDataClient}. */ +@Internal +@ThreadSafe +public final class ApplianceGetDataClient implements GetDataClient { + private static final int MAX_READS_PER_BATCH = 60; + private static final int MAX_ACTIVE_READS = 10; + + private final ApplianceWindmillClient windmillClient; + private final ThrottlingGetDataMetricTracker getDataMetricTracker; + + @GuardedBy("this") + private final List pendingReadBatches; + + @GuardedBy("this") + private int activeReadThreads; + + public ApplianceGetDataClient( + ApplianceWindmillClient windmillClient, ThrottlingGetDataMetricTracker getDataMetricTracker) { + this.windmillClient = windmillClient; + this.getDataMetricTracker = getDataMetricTracker; + this.pendingReadBatches = new ArrayList<>(); + this.activeReadThreads = 0; + } + + @Override + public Windmill.KeyedGetDataResponse getStateData( + String computationId, Windmill.KeyedGetDataRequest request) { + try (AutoCloseable ignored = getDataMetricTracker.trackStateDataFetchWithThrottling()) { + SettableFuture response = SettableFuture.create(); + ReadBatch batch = addToReadBatch(new QueueEntry(computationId, request, response)); + if (batch != null) { + issueReadBatch(batch); + } + return response.get(); + } catch (Exception e) { + throw new GetDataException( + "Error occurred fetching state for computation=" + + computationId + + ", key=" + + request.getShardingKey(), + e); + } + } + + @Override + public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) { + try (AutoCloseable ignored = getDataMetricTracker.trackSideInputFetchWithThrottling()) { + return windmillClient + .getData(Windmill.GetDataRequest.newBuilder().addGlobalDataFetchRequests(request).build()) + .getGlobalData(0); + } catch (Exception e) { + throw new GetDataException( + "Error occurred fetching side input for tag=" + request.getDataId(), e); + } + } + + @Override + public synchronized void printHtml(PrintWriter writer) { + getDataMetricTracker.printHtml(writer); + writer.println(" Read threads: " + activeReadThreads); + writer.println(" Pending read batches: " + pendingReadBatches.size()); + } + + private void issueReadBatch(ReadBatch batch) { + try { + // Possibly block until the batch is allowed to start. + batch.startRead.get(); + } catch (InterruptedException e) { + // We don't expect this thread to be interrupted. To simplify handling, we just fall through + // to issuing the call. + assert (false); + Thread.currentThread().interrupt(); + } catch (ExecutionException e) { + // startRead is a SettableFuture so this should never occur. + throw new AssertionError("Should not have exception on startRead", e); + } + Map> pendingResponses = + new HashMap<>(batch.reads.size()); + Map computationBuilders = new HashMap<>(); + for (QueueEntry entry : batch.reads) { + ComputationGetDataRequest.Builder computationBuilder = + computationBuilders.computeIfAbsent( + entry.computation, k -> ComputationGetDataRequest.newBuilder().setComputationId(k)); + + computationBuilder.addRequests(entry.request); + pendingResponses.put( + WindmillComputationKey.create( + entry.computation, entry.request.getKey(), entry.request.getShardingKey()), + entry.response); + } + + // Build the full GetDataRequest from the KeyedGetDataRequests pulled from the queue. + Windmill.GetDataRequest.Builder builder = Windmill.GetDataRequest.newBuilder(); + for (ComputationGetDataRequest.Builder computationBuilder : computationBuilders.values()) { + builder.addRequests(computationBuilder); + } + + try { + Windmill.GetDataResponse response = windmillClient.getData(builder.build()); + // Dispatch the per-key responses back to the waiting threads. + for (Windmill.ComputationGetDataResponse computationResponse : response.getDataList()) { + for (Windmill.KeyedGetDataResponse keyResponse : computationResponse.getDataList()) { + pendingResponses + .get( + WindmillComputationKey.create( + computationResponse.getComputationId(), + keyResponse.getKey(), + keyResponse.getShardingKey())) + .set(keyResponse); + } + } + } catch (RuntimeException e) { + // Fan the exception out to the reads. + for (QueueEntry entry : batch.reads) { + entry.response.setException(e); + } + } finally { + synchronized (this) { + Preconditions.checkState(activeReadThreads >= 1); + if (pendingReadBatches.isEmpty()) { + activeReadThreads--; + } else { + // Notify the thread responsible for issuing the next batch read. + ReadBatch startBatch = pendingReadBatches.remove(0); + startBatch.startRead.set(null); + } + } + } + } + + /** + * Adds the entry to a read batch for sending to the windmill server. If a non-null batch is + * returned, this thread will be responsible for sending the batch and should wait for the batch + * startRead to be notified. If null is returned, the entry was added to a read batch that will be + * issued by another thread. + */ + private @Nullable ReadBatch addToReadBatch(QueueEntry entry) { + synchronized (this) { + ReadBatch batch; + if (activeReadThreads < MAX_ACTIVE_READS) { + assert (pendingReadBatches.isEmpty()); + activeReadThreads += 1; + // fall through to below synchronized block + } else if (pendingReadBatches.isEmpty() + || pendingReadBatches.get(pendingReadBatches.size() - 1).reads.size() + >= MAX_READS_PER_BATCH) { + // This is the first read of a batch, it will be responsible for sending the batch. + batch = new ReadBatch(); + pendingReadBatches.add(batch); + batch.reads.add(entry); + return batch; + } else { + // This fits within an existing batch, it will be sent by the first blocking thread in the + // batch. + pendingReadBatches.get(pendingReadBatches.size() - 1).reads.add(entry); + return null; + } + } + ReadBatch batch = new ReadBatch(); + batch.reads.add(entry); + batch.startRead.set(null); + return batch; + } + + private static final class ReadBatch { + ArrayList reads = new ArrayList<>(); + SettableFuture startRead = SettableFuture.create(); + } + + private static final class QueueEntry { + final String computation; + final Windmill.KeyedGetDataRequest request; + final SettableFuture response; + + QueueEntry( + String computation, + Windmill.KeyedGetDataRequest request, + SettableFuture response) { + this.computation = computation; + this.request = request; + this.response = response; + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/GetDataClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/GetDataClient.java new file mode 100644 index 000000000000..c732591bf12d --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/GetDataClient.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.getdata; + +import java.io.PrintWriter; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse; +import org.apache.beam.sdk.annotations.Internal; + +/** Client for streaming backend GetData API. */ +@Internal +public interface GetDataClient { + /** + * Issues a blocking call to fetch state data for a specific computation and {@link + * org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem}. + * + * @throws GetDataException when there was an unexpected error during the attempted fetch. + */ + KeyedGetDataResponse getStateData(String computationId, KeyedGetDataRequest request) + throws GetDataException; + + /** + * Issues a blocking call to fetch side input data. + * + * @throws GetDataException when there was an unexpected error during the attempted fetch. + */ + GlobalData getSideInputData(GlobalDataRequest request) throws GetDataException; + + void printHtml(PrintWriter writer); + + final class GetDataException extends RuntimeException { + GetDataException(String message, Throwable cause) { + super(message, cause); + } + + GetDataException(String message) { + super(message); + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java new file mode 100644 index 000000000000..c8e058e7e230 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.getdata; + +import java.io.PrintWriter; +import java.util.function.Function; +import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; +import org.apache.beam.sdk.annotations.Internal; + +/** {@link GetDataClient} that fetches data directly from a specific {@link GetDataStream}. */ +@Internal +public final class StreamGetDataClient implements GetDataClient { + + private final GetDataStream getDataStream; + private final Function sideInputGetDataStreamFactory; + private final ThrottlingGetDataMetricTracker getDataMetricTracker; + + private StreamGetDataClient( + GetDataStream getDataStream, + Function sideInputGetDataStreamFactory, + ThrottlingGetDataMetricTracker getDataMetricTracker) { + this.getDataStream = getDataStream; + this.sideInputGetDataStreamFactory = sideInputGetDataStreamFactory; + this.getDataMetricTracker = getDataMetricTracker; + } + + public static GetDataClient create( + GetDataStream getDataStream, + Function sideInputGetDataStreamFactory, + ThrottlingGetDataMetricTracker getDataMetricTracker) { + return new StreamGetDataClient( + getDataStream, sideInputGetDataStreamFactory, getDataMetricTracker); + } + + /** + * @throws WorkItemCancelledException when the fetch fails due to the stream being shutdown, + * indicating that the {@link + * org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem} that triggered the + * fetch has been cancelled. + */ + @Override + public Windmill.KeyedGetDataResponse getStateData( + String computationId, Windmill.KeyedGetDataRequest request) throws GetDataException { + try (AutoCloseable ignored = getDataMetricTracker.trackStateDataFetchWithThrottling()) { + return getDataStream.requestKeyedData(computationId, request); + } catch (AbstractWindmillStream.WindmillStreamShutdownException e) { + throw new WorkItemCancelledException(request.getShardingKey()); + } catch (Exception e) { + throw new GetDataException( + "Error occurred fetching state for computation=" + + computationId + + ", key=" + + request.getShardingKey(), + e); + } + } + + /** + * @throws WorkItemCancelledException when the fetch fails due to the stream being shutdown, + * indicating that the {@link + * org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem} that triggered the + * fetch has been cancelled. + */ + @Override + public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) + throws GetDataException { + GetDataStream sideInputGetDataStream = + sideInputGetDataStreamFactory.apply(request.getDataId().getTag()); + try (AutoCloseable ignored = getDataMetricTracker.trackSideInputFetchWithThrottling()) { + return sideInputGetDataStream.requestGlobalData(request); + } catch (AbstractWindmillStream.WindmillStreamShutdownException e) { + throw new WorkItemCancelledException(e); + } catch (Exception e) { + throw new GetDataException( + "Error occurred fetching side input for tag=" + request.getDataId(), e); + } + } + + @Override + public void printHtml(PrintWriter writer) { + getDataMetricTracker.printHtml(writer); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamPoolGetDataClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamPoolGetDataClient.java new file mode 100644 index 000000000000..49fe3e4bdc15 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamPoolGetDataClient.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.getdata; + +import java.io.PrintWriter; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool; +import org.apache.beam.sdk.annotations.Internal; + +/** + * StreamingEngine implementation of {@link GetDataClient}. + * + * @implNote Uses {@link WindmillStreamPool} to send requests. + */ +@Internal +@ThreadSafe +public final class StreamPoolGetDataClient implements GetDataClient { + + private final WindmillStreamPool getDataStreamPool; + private final ThrottlingGetDataMetricTracker getDataMetricTracker; + + public StreamPoolGetDataClient( + ThrottlingGetDataMetricTracker getDataMetricTracker, + WindmillStreamPool getDataStreamPool) { + this.getDataMetricTracker = getDataMetricTracker; + this.getDataStreamPool = getDataStreamPool; + } + + @Override + public Windmill.KeyedGetDataResponse getStateData( + String computationId, KeyedGetDataRequest request) { + try (AutoCloseable ignored = getDataMetricTracker.trackStateDataFetchWithThrottling(); + CloseableStream closeableStream = getDataStreamPool.getCloseableStream()) { + return closeableStream.stream().requestKeyedData(computationId, request); + } catch (Exception e) { + throw new GetDataException( + "Error occurred fetching state for computation=" + + computationId + + ", key=" + + request.getShardingKey(), + e); + } + } + + @Override + public Windmill.GlobalData getSideInputData(GlobalDataRequest request) { + try (AutoCloseable ignored = getDataMetricTracker.trackSideInputFetchWithThrottling(); + CloseableStream closeableStream = getDataStreamPool.getCloseableStream()) { + return closeableStream.stream().requestGlobalData(request); + } catch (Exception e) { + throw new GetDataException( + "Error occurred fetching side input for tag=" + request.getDataId(), e); + } + } + + @Override + public void printHtml(PrintWriter writer) { + getDataMetricTracker.printHtml(writer); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ThrottlingGetDataMetricTracker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ThrottlingGetDataMetricTracker.java new file mode 100644 index 000000000000..6bb00292e29a --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ThrottlingGetDataMetricTracker.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.getdata; + +import com.google.auto.value.AutoValue; +import java.io.PrintWriter; +import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; + +/** + * Wraps GetData calls to track metrics for the number of in-flight requests and throttles requests + * when memory pressure is high. + */ +@Internal +@ThreadSafe +public final class ThrottlingGetDataMetricTracker { + private static final String GET_STATE_DATA_RESOURCE_CONTEXT = "GetStateData"; + private static final String GET_SIDE_INPUT_RESOURCE_CONTEXT = "GetSideInputData"; + + private final MemoryMonitor gcThrashingMonitor; + private final AtomicInteger activeStateReads; + private final AtomicInteger activeSideInputs; + private final AtomicInteger activeHeartbeats; + + public ThrottlingGetDataMetricTracker(MemoryMonitor gcThrashingMonitor) { + this.gcThrashingMonitor = gcThrashingMonitor; + this.activeStateReads = new AtomicInteger(); + this.activeSideInputs = new AtomicInteger(); + this.activeHeartbeats = new AtomicInteger(); + } + + /** + * Tracks a state data fetch. If there is memory pressure, may throttle requests. Returns an + * {@link AutoCloseable} that will decrement the metric after the call is finished. + */ + AutoCloseable trackStateDataFetchWithThrottling() { + gcThrashingMonitor.waitForResources(GET_STATE_DATA_RESOURCE_CONTEXT); + activeStateReads.getAndIncrement(); + return activeStateReads::getAndDecrement; + } + + /** + * Tracks a side input fetch. If there is memory pressure, may throttle requests. Returns an + * {@link AutoCloseable} that will decrement the metric after the call is finished. + */ + AutoCloseable trackSideInputFetchWithThrottling() { + gcThrashingMonitor.waitForResources(GET_SIDE_INPUT_RESOURCE_CONTEXT); + activeSideInputs.getAndIncrement(); + return activeSideInputs::getAndDecrement; + } + + /** + * Tracks heartbeat request metrics. Returns an {@link AutoCloseable} that will decrement the + * metric after the call is finished. + */ + public AutoCloseable trackHeartbeats(int numHeartbeats) { + activeHeartbeats.getAndAdd(numHeartbeats); + return () -> activeHeartbeats.getAndAdd(-numHeartbeats); + } + + public void printHtml(PrintWriter writer) { + writer.println("Active Fetches:"); + writer.println(" Side Inputs: " + activeSideInputs.get()); + writer.println(" State Reads: " + activeStateReads.get()); + writer.println("Heartbeat Keys Active: " + activeHeartbeats.get()); + } + + @VisibleForTesting + ReadOnlySnapshot getMetricsSnapshot() { + return ReadOnlySnapshot.create( + activeSideInputs.get(), activeStateReads.get(), activeHeartbeats.get()); + } + + @VisibleForTesting + @AutoValue + abstract static class ReadOnlySnapshot { + + private static ReadOnlySnapshot create( + int activeSideInputs, int activeStateReads, int activeHeartbeats) { + return new AutoValue_ThrottlingGetDataMetricTracker_ReadOnlySnapshot( + activeSideInputs, activeStateReads, activeHeartbeats); + } + + abstract int activeSideInputs(); + + abstract int activeStateReads(); + + abstract int activeHeartbeats(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index f9f579119d61..053843a8af25 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -57,6 +57,7 @@ public final class GrpcCommitWorkStream private final int streamingRpcBatchLimit; private GrpcCommitWorkStream( + String backendWorkerToken, Function, StreamObserver> startCommitWorkRpcFn, BackOff backoff, @@ -68,11 +69,13 @@ private GrpcCommitWorkStream( AtomicLong idGenerator, int streamingRpcBatchLimit) { super( + "CommitWorkStream", startCommitWorkRpcFn, backoff, streamObserverFactory, streamRegistry, - logEveryNStreamFailures); + logEveryNStreamFailures, + backendWorkerToken); pending = new ConcurrentHashMap<>(); this.idGenerator = idGenerator; this.jobHeader = jobHeader; @@ -81,6 +84,7 @@ private GrpcCommitWorkStream( } public static GrpcCommitWorkStream create( + String backendWorkerToken, Function, StreamObserver> startCommitWorkRpcFn, BackOff backoff, @@ -93,6 +97,7 @@ public static GrpcCommitWorkStream create( int streamingRpcBatchLimit) { GrpcCommitWorkStream commitWorkStream = new GrpcCommitWorkStream( + backendWorkerToken, startCommitWorkRpcFn, backoff, streamObserverFactory, diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java index 6f4b5b7b33fb..58f72610e2d3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java @@ -39,10 +39,12 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; @@ -80,8 +82,9 @@ public final class GrpcDirectGetWorkStream private final GetWorkRequest request; private final WorkItemScheduler workItemScheduler; private final ThrottleTimer getWorkThrottleTimer; - private final Supplier getDataStream; + private final Supplier heartbeatSender; private final Supplier workCommitter; + private final Supplier getDataClient; /** * Map of stream IDs to their buffers. Used to aggregate streaming gRPC response chunks as they @@ -91,6 +94,7 @@ public final class GrpcDirectGetWorkStream private final ConcurrentMap workItemBuffers; private GrpcDirectGetWorkStream( + String backendWorkerToken, Function< StreamObserver, StreamObserver> @@ -101,25 +105,32 @@ private GrpcDirectGetWorkStream( Set> streamRegistry, int logEveryNStreamFailures, ThrottleTimer getWorkThrottleTimer, - Supplier getDataStream, + Supplier heartbeatSender, + Supplier getDataClient, Supplier workCommitter, WorkItemScheduler workItemScheduler) { super( - startGetWorkRpcFn, backoff, streamObserverFactory, streamRegistry, logEveryNStreamFailures); + "GetWorkStream", + startGetWorkRpcFn, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + backendWorkerToken); this.request = request; this.getWorkThrottleTimer = getWorkThrottleTimer; this.workItemScheduler = workItemScheduler; this.workItemBuffers = new ConcurrentHashMap<>(); - // Use the same GetDataStream and CommitWorkStream instances to process all the work in this - // stream. - this.getDataStream = Suppliers.memoize(getDataStream::get); + this.heartbeatSender = Suppliers.memoize(heartbeatSender::get); this.workCommitter = Suppliers.memoize(workCommitter::get); + this.getDataClient = Suppliers.memoize(getDataClient::get); this.inFlightBudget = new AtomicReference<>(GetWorkBudget.noBudget()); this.nextBudgetAdjustment = new AtomicReference<>(GetWorkBudget.noBudget()); this.pendingResponseBudget = new AtomicReference<>(GetWorkBudget.noBudget()); } public static GrpcDirectGetWorkStream create( + String backendWorkerToken, Function< StreamObserver, StreamObserver> @@ -130,11 +141,13 @@ public static GrpcDirectGetWorkStream create( Set> streamRegistry, int logEveryNStreamFailures, ThrottleTimer getWorkThrottleTimer, - Supplier getDataStream, + Supplier heartbeatSender, + Supplier getDataClient, Supplier workCommitter, WorkItemScheduler workItemScheduler) { GrpcDirectGetWorkStream getWorkStream = new GrpcDirectGetWorkStream( + backendWorkerToken, startGetWorkRpcFn, request, backoff, @@ -142,7 +155,8 @@ public static GrpcDirectGetWorkStream create( streamRegistry, logEveryNStreamFailures, getWorkThrottleTimer, - getDataStream, + heartbeatSender, + getDataClient, workCommitter, workItemScheduler); getWorkStream.startStream(); @@ -327,7 +341,7 @@ private void runAndReset() { private Work.ProcessingContext createProcessingContext(String computationId) { return Work.createProcessingContext( - computationId, getDataStream.get()::requestKeyedData, workCommitter.get()::commit); + computationId, getDataClient.get(), workCommitter.get()::commit, heartbeatSender.get()); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index feb15c2ac83c..0e9a0c6316ee 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.PrintWriter; +import java.util.Collection; import java.util.Deque; import java.util.List; import java.util.Map; @@ -75,6 +76,7 @@ public final class GrpcGetDataStream private final Consumer> processHeartbeatResponses; private GrpcGetDataStream( + String backendWorkerToken, Function, StreamObserver> startGetDataRpcFn, BackOff backoff, @@ -88,7 +90,13 @@ private GrpcGetDataStream( boolean sendKeyedGetDataRequests, Consumer> processHeartbeatResponses) { super( - startGetDataRpcFn, backoff, streamObserverFactory, streamRegistry, logEveryNStreamFailures); + "GetDataStream", + startGetDataRpcFn, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + backendWorkerToken); this.idGenerator = idGenerator; this.getDataThrottleTimer = getDataThrottleTimer; this.jobHeader = jobHeader; @@ -100,6 +108,7 @@ private GrpcGetDataStream( } public static GrpcGetDataStream create( + String backendWorkerToken, Function, StreamObserver> startGetDataRpcFn, BackOff backoff, @@ -114,6 +123,7 @@ public static GrpcGetDataStream create( Consumer> processHeartbeatResponses) { GrpcGetDataStream getDataStream = new GrpcGetDataStream( + backendWorkerToken, startGetDataRpcFn, backoff, streamObserverFactory, @@ -189,11 +199,15 @@ public GlobalData requestGlobalData(GlobalDataRequest request) { } @Override - public void refreshActiveWork(Map> heartbeats) { + public void refreshActiveWork(Map> heartbeats) { + if (isShutdown()) { + throw new WindmillStreamShutdownException("Unable to refresh work for shutdown stream."); + } + StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder(); if (sendKeyedGetDataRequests) { long builderBytes = 0; - for (Map.Entry> entry : heartbeats.entrySet()) { + for (Map.Entry> entry : heartbeats.entrySet()) { for (HeartbeatRequest request : entry.getValue()) { // Calculate the bytes with some overhead for proto encoding. long bytes = (long) entry.getKey().length() + request.getSerializedSize() + 10; @@ -224,7 +238,7 @@ public void refreshActiveWork(Map> heartbeats) { } else { // No translation necessary, but we must still respect `RPC_STREAM_CHUNK_SIZE`. long builderBytes = 0; - for (Map.Entry> entry : heartbeats.entrySet()) { + for (Map.Entry> entry : heartbeats.entrySet()) { ComputationHeartbeatRequest.Builder computationHeartbeatBuilder = ComputationHeartbeatRequest.newBuilder().setComputationId(entry.getKey()); for (HeartbeatRequest request : entry.getValue()) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java index 867180fb0d31..4b392e9190ed 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java @@ -60,6 +60,7 @@ public final class GrpcGetWorkStream private final AtomicLong inflightBytes; private GrpcGetWorkStream( + String backendWorkerToken, Function< StreamObserver, StreamObserver> @@ -72,7 +73,13 @@ private GrpcGetWorkStream( ThrottleTimer getWorkThrottleTimer, WorkItemReceiver receiver) { super( - startGetWorkRpcFn, backoff, streamObserverFactory, streamRegistry, logEveryNStreamFailures); + "GetWorkStream", + startGetWorkRpcFn, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + backendWorkerToken); this.request = request; this.getWorkThrottleTimer = getWorkThrottleTimer; this.receiver = receiver; @@ -82,6 +89,7 @@ private GrpcGetWorkStream( } public static GrpcGetWorkStream create( + String backendWorkerToken, Function< StreamObserver, StreamObserver> @@ -95,6 +103,7 @@ public static GrpcGetWorkStream create( WorkItemReceiver receiver) { GrpcGetWorkStream getWorkStream = new GrpcGetWorkStream( + backendWorkerToken, startGetWorkRpcFn, request, backoff, diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java index 3672f02c813f..44e21a9b18ed 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java @@ -65,11 +65,13 @@ private GrpcGetWorkerMetadataStream( ThrottleTimer getWorkerMetadataThrottleTimer, Consumer serverMappingConsumer) { super( + "GetWorkerMetadataStream", startGetWorkerMetadataRpcFn, backoff, streamObserverFactory, streamRegistry, - logEveryNStreamFailures); + logEveryNStreamFailures, + ""); this.workerMetadataRequest = WorkerMetadataRequest.newBuilder().setHeader(jobHeader).build(); this.metadataVersion = metadataVersion; this.getWorkerMetadataThrottleTimer = getWorkerMetadataThrottleTimer; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java index 0ab03a803180..1fce4d238b2e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java @@ -254,11 +254,6 @@ public void setWindmillServiceEndpoints(Set endpoints) { dispatcherClient.consumeWindmillDispatcherEndpoints(ImmutableSet.copyOf(endpoints)); } - @Override - public boolean isReady() { - return dispatcherClient.hasInitializedEndpoints(); - } - private synchronized void initializeLocalHost(int port) { this.maxBackoff = Duration.millis(500); if (options.isEnableStreamingEngine()) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java index 14866f3f586b..92f031db9972 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java @@ -37,6 +37,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; @@ -44,10 +45,12 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkerMetadataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemReceiver; import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.sdk.util.FluentBackoff; @@ -69,6 +72,7 @@ public class GrpcWindmillStreamFactory implements StatusDataProvider { private static final int DEFAULT_STREAMING_RPC_BATCH_LIMIT = Integer.MAX_VALUE; private static final int DEFAULT_WINDMILL_MESSAGES_BETWEEN_IS_READY_CHECKS = 1; private static final int NO_HEALTH_CHECKS = -1; + private static final String NO_BACKEND_WORKER_TOKEN = ""; private final JobHeader jobHeader; private final int logEveryNStreamFailures; @@ -179,6 +183,7 @@ public GetWorkStream createGetWorkStream( ThrottleTimer getWorkThrottleTimer, WorkItemReceiver processWorkItem) { return GrpcGetWorkStream.create( + NO_BACKEND_WORKER_TOKEN, responseObserver -> withDefaultDeadline(stub).getWorkStream(responseObserver), request, grpcBackOff.get(), @@ -190,21 +195,24 @@ public GetWorkStream createGetWorkStream( } public GetWorkStream createDirectGetWorkStream( - CloudWindmillServiceV1Alpha1Stub stub, + WindmillConnection connection, GetWorkRequest request, ThrottleTimer getWorkThrottleTimer, - Supplier getDataStream, + Supplier heartbeatSender, + Supplier getDataClient, Supplier workCommitter, WorkItemScheduler workItemScheduler) { return GrpcDirectGetWorkStream.create( - responseObserver -> withDefaultDeadline(stub).getWorkStream(responseObserver), + connection.backendWorkerToken(), + responseObserver -> withDefaultDeadline(connection.stub()).getWorkStream(responseObserver), request, grpcBackOff.get(), newStreamObserverFactory(), streamRegistry, logEveryNStreamFailures, getWorkThrottleTimer, - getDataStream, + heartbeatSender, + getDataClient, workCommitter, workItemScheduler); } @@ -212,6 +220,7 @@ public GetWorkStream createDirectGetWorkStream( public GetDataStream createGetDataStream( CloudWindmillServiceV1Alpha1Stub stub, ThrottleTimer getDataThrottleTimer) { return GrpcGetDataStream.create( + NO_BACKEND_WORKER_TOKEN, responseObserver -> withDefaultDeadline(stub).getDataStream(responseObserver), grpcBackOff.get(), newStreamObserverFactory(), @@ -228,6 +237,7 @@ public GetDataStream createGetDataStream( public CommitWorkStream createCommitWorkStream( CloudWindmillServiceV1Alpha1Stub stub, ThrottleTimer commitWorkThrottleTimer) { return GrpcCommitWorkStream.create( + NO_BACKEND_WORKER_TOKEN, responseObserver -> withDefaultDeadline(stub).commitWorkStream(responseObserver), grpcBackOff.get(), newStreamObserverFactory(), diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java index 4760062c5754..b9573ff94cc9 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java @@ -45,6 +45,8 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkerMetadataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.StreamGetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.ThrottlingGetDataMetricTracker; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.ChannelCachingStubFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; @@ -91,6 +93,7 @@ public final class StreamingEngineClient { private final Supplier getWorkerMetadataStream; private final Queue newWindmillEndpoints; private final Function workCommitterFactory; + private final ThrottlingGetDataMetricTracker getDataMetricTracker; /** Writes are guarded by synchronization, reads are lock free. */ private final AtomicReference connections; @@ -107,8 +110,10 @@ private StreamingEngineClient( GetWorkBudgetDistributor getWorkBudgetDistributor, GrpcDispatcherClient dispatcherClient, long clientId, - Function workCommitterFactory) { + Function workCommitterFactory, + ThrottlingGetDataMetricTracker getDataMetricTracker) { this.jobHeader = jobHeader; + this.getDataMetricTracker = getDataMetricTracker; this.started = false; this.streamFactory = streamFactory; this.workItemScheduler = workItemScheduler; @@ -171,7 +176,8 @@ public static StreamingEngineClient create( ChannelCachingStubFactory channelCachingStubFactory, GetWorkBudgetDistributor getWorkBudgetDistributor, GrpcDispatcherClient dispatcherClient, - Function workCommitterFactory) { + Function workCommitterFactory, + ThrottlingGetDataMetricTracker getDataMetricTracker) { return new StreamingEngineClient( jobHeader, totalGetWorkBudget, @@ -181,7 +187,8 @@ public static StreamingEngineClient create( getWorkBudgetDistributor, dispatcherClient, /* clientId= */ new Random().nextLong(), - workCommitterFactory); + workCommitterFactory, + getDataMetricTracker); } @VisibleForTesting @@ -194,7 +201,8 @@ static StreamingEngineClient forTesting( GetWorkBudgetDistributor getWorkBudgetDistributor, GrpcDispatcherClient dispatcherClient, long clientId, - Function workCommitterFactory) { + Function workCommitterFactory, + ThrottlingGetDataMetricTracker getDataMetricTracker) { StreamingEngineClient streamingEngineClient = new StreamingEngineClient( jobHeader, @@ -205,7 +213,8 @@ static StreamingEngineClient forTesting( getWorkBudgetDistributor, dispatcherClient, clientId, - workCommitterFactory); + workCommitterFactory, + getDataMetricTracker); streamingEngineClient.start(); return streamingEngineClient; } @@ -240,7 +249,7 @@ public ImmutableSet currentWindmillEndpoints() { * Fetches {@link GetDataStream} mapped to globalDataKey if one exists, or defaults to {@link * GetDataStream} pointing to dispatcher. */ - public GetDataStream getGlobalDataStream(String globalDataKey) { + private GetDataStream getGlobalDataStream(String globalDataKey) { return Optional.ofNullable(connections.get().globalDataStreams().get(globalDataKey)) .map(Supplier::get) .orElseGet( @@ -263,7 +272,7 @@ private void startWorkerMetadataConsumer() { @VisibleForTesting public synchronized void finish() { Preconditions.checkState(started, "StreamingEngineClient never started."); - getWorkerMetadataStream.get().close(); + getWorkerMetadataStream.get().halfClose(); getWorkBudgetRefresher.stop(); newWorkerMetadataPublisher.shutdownNow(); newWorkerMetadataConsumer.shutdownNow(); @@ -390,7 +399,7 @@ private WindmillStreamSender createAndStartWindmillStreamSenderFor( // GetWorkBudgetDistributor. WindmillStreamSender windmillStreamSender = WindmillStreamSender.create( - connection.stub(), + connection, GetWorkRequest.newBuilder() .setClientId(clientId) .setJobId(jobHeader.getJobId()) @@ -400,6 +409,9 @@ private WindmillStreamSender createAndStartWindmillStreamSenderFor( GetWorkBudget.noBudget(), streamFactory, workItemScheduler, + getDataStream -> + StreamGetDataClient.create( + getDataStream, this::getGlobalDataStream, getDataMetricTracker), workCommitterFactory); windmillStreamSender.startStreams(); return windmillStreamSender; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSender.java index e9f008eb522e..7d09726e4b28 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSender.java @@ -22,15 +22,17 @@ import java.util.function.Function; import java.util.function.Supplier; import javax.annotation.concurrent.ThreadSafe; -import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.StreamingEngineThrottleTimers; import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.FixedStreamHeartbeatSender; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; @@ -65,11 +67,12 @@ public class WindmillStreamSender { private final StreamingEngineThrottleTimers streamingEngineThrottleTimers; private WindmillStreamSender( - CloudWindmillServiceV1Alpha1Stub stub, + WindmillConnection connection, GetWorkRequest getWorkRequest, AtomicReference getWorkBudget, GrpcWindmillStreamFactory streamingEngineStreamFactory, WorkItemScheduler workItemScheduler, + Function getDataClientFactory, Function workCommitterFactory) { this.started = new AtomicBoolean(false); this.getWorkBudget = getWorkBudget; @@ -83,39 +86,42 @@ private WindmillStreamSender( Suppliers.memoize( () -> streamingEngineStreamFactory.createGetDataStream( - stub, streamingEngineThrottleTimers.getDataThrottleTimer())); + connection.stub(), streamingEngineThrottleTimers.getDataThrottleTimer())); this.commitWorkStream = Suppliers.memoize( () -> streamingEngineStreamFactory.createCommitWorkStream( - stub, streamingEngineThrottleTimers.commitWorkThrottleTimer())); + connection.stub(), streamingEngineThrottleTimers.commitWorkThrottleTimer())); this.workCommitter = Suppliers.memoize(() -> workCommitterFactory.apply(commitWorkStream.get())); this.getWorkStream = Suppliers.memoize( () -> streamingEngineStreamFactory.createDirectGetWorkStream( - stub, + connection, withRequestBudget(getWorkRequest, getWorkBudget.get()), streamingEngineThrottleTimers.getWorkThrottleTimer(), - getDataStream, + () -> FixedStreamHeartbeatSender.create(getDataStream.get()), + () -> getDataClientFactory.apply(getDataStream.get()), workCommitter, workItemScheduler)); } public static WindmillStreamSender create( - CloudWindmillServiceV1Alpha1Stub stub, + WindmillConnection connection, GetWorkRequest getWorkRequest, GetWorkBudget getWorkBudget, GrpcWindmillStreamFactory streamingEngineStreamFactory, WorkItemScheduler workItemScheduler, + Function getDataClientFactory, Function workCommitterFactory) { return new WindmillStreamSender( - stub, + connection, getWorkRequest, new AtomicReference<>(getWorkBudget), streamingEngineStreamFactory, workItemScheduler, + getDataClientFactory, workCommitterFactory); } @@ -138,10 +144,10 @@ void closeAllStreams() { // streaming RPCs by possibly making calls over the network. Do not close the streams unless // they have already been started. if (started.get()) { - getWorkStream.get().close(); - getDataStream.get().close(); + getWorkStream.get().shutdown(); + getDataStream.get().shutdown(); workCommitter.get().stop(); - commitWorkStream.get().close(); + commitWorkStream.get().shutdown(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverCancelledException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverCancelledException.java new file mode 100644 index 000000000000..4ea209f31b1d --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverCancelledException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers; + +import org.apache.beam.sdk.annotations.Internal; + +@Internal +public final class StreamObserverCancelledException extends RuntimeException { + public StreamObserverCancelledException(Throwable cause) { + super(cause); + } + + public StreamObserverCancelledException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java index e9ffa982925b..b0b6377dd8b1 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java @@ -44,6 +44,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.harness.StreamingCounters; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commit; @@ -74,7 +75,7 @@ public final class StreamingWorkScheduler { private final DataflowWorkerHarnessOptions options; private final Supplier clock; private final ComputationWorkExecutorFactory computationWorkExecutorFactory; - private final SideInputStateFetcher sideInputStateFetcher; + private final SideInputStateFetcherFactory sideInputStateFetcherFactory; private final FailureTracker failureTracker; private final WorkFailureProcessor workFailureProcessor; private final StreamingCommitFinalizer commitFinalizer; @@ -88,7 +89,7 @@ public StreamingWorkScheduler( DataflowWorkerHarnessOptions options, Supplier clock, ComputationWorkExecutorFactory computationWorkExecutorFactory, - SideInputStateFetcher sideInputStateFetcher, + SideInputStateFetcherFactory sideInputStateFetcherFactory, FailureTracker failureTracker, WorkFailureProcessor workFailureProcessor, StreamingCommitFinalizer commitFinalizer, @@ -100,7 +101,7 @@ public StreamingWorkScheduler( this.options = options; this.clock = clock; this.computationWorkExecutorFactory = computationWorkExecutorFactory; - this.sideInputStateFetcher = sideInputStateFetcher; + this.sideInputStateFetcherFactory = sideInputStateFetcherFactory; this.failureTracker = failureTracker; this.workFailureProcessor = workFailureProcessor; this.commitFinalizer = commitFinalizer; @@ -118,7 +119,6 @@ public static StreamingWorkScheduler create( DataflowMapTaskExecutorFactory mapTaskExecutorFactory, BoundedQueueExecutor workExecutor, Function stateCacheFactory, - Function fetchGlobalDataFn, FailureTracker failureTracker, WorkFailureProcessor workFailureProcessor, StreamingCounters streamingCounters, @@ -141,7 +141,7 @@ public static StreamingWorkScheduler create( options, clock, computationWorkExecutorFactory, - new SideInputStateFetcher(fetchGlobalDataFn, options), + SideInputStateFetcherFactory.fromOptions(options), failureTracker, workFailureProcessor, StreamingCommitFinalizer.create(workExecutor), @@ -348,7 +348,8 @@ private ExecuteWorkResult executeWork( try { WindmillStateReader stateReader = work.createWindmillStateReader(); - SideInputStateFetcher localSideInputStateFetcher = sideInputStateFetcher.byteTrackingView(); + SideInputStateFetcher localSideInputStateFetcher = + sideInputStateFetcherFactory.createSideInputStateFetcher(work::fetchSideInput); // If the read output KVs, then we can decode Windmill's byte key into userland // key object and provide it to the execution context for use with per-key state. @@ -403,8 +404,7 @@ private ExecuteWorkResult executeWork( computationState.releaseComputationWorkExecutor(computationWorkExecutor); work.setState(Work.State.COMMIT_QUEUED); - outputBuilder.addAllPerWorkItemLatencyAttributions( - work.getLatencyAttributions(false, sampler)); + outputBuilder.addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions(sampler)); return ExecuteWorkResult.create( outputBuilder, stateReader.getBytesRead() + localSideInputStateFetcher.getBytesRead()); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresher.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresher.java index 96a6feec1da0..499d2e5b6943 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresher.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresher.java @@ -17,13 +17,26 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.refresh; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; + +import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; +import javax.annotation.Nullable; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; +import org.apache.beam.runners.dataflow.worker.streaming.RefreshableWork; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.joda.time.Duration; import org.joda.time.Instant; import org.slf4j.Logger; @@ -37,29 +50,39 @@ * threshold is determined by {@link #activeWorkRefreshPeriodMillis} */ @ThreadSafe -public abstract class ActiveWorkRefresher { +@Internal +public final class ActiveWorkRefresher { private static final Logger LOG = LoggerFactory.getLogger(ActiveWorkRefresher.class); + private static final String FAN_OUT_REFRESH_WORK_EXECUTOR_NAME = + "FanOutActiveWorkRefreshExecutor-%d"; - protected final Supplier clock; - protected final int activeWorkRefreshPeriodMillis; - protected final Supplier> computations; - protected final DataflowExecutionStateSampler sampler; + private final Supplier clock; + private final int activeWorkRefreshPeriodMillis; + private final Supplier> computations; + private final DataflowExecutionStateSampler sampler; private final int stuckCommitDurationMillis; + private final HeartbeatTracker heartbeatTracker; private final ScheduledExecutorService activeWorkRefreshExecutor; + private final ExecutorService fanOutActiveWorkRefreshExecutor; - protected ActiveWorkRefresher( + public ActiveWorkRefresher( Supplier clock, int activeWorkRefreshPeriodMillis, int stuckCommitDurationMillis, Supplier> computations, DataflowExecutionStateSampler sampler, - ScheduledExecutorService activeWorkRefreshExecutor) { + ScheduledExecutorService activeWorkRefreshExecutor, + HeartbeatTracker heartbeatTracker) { this.clock = clock; this.activeWorkRefreshPeriodMillis = activeWorkRefreshPeriodMillis; this.stuckCommitDurationMillis = stuckCommitDurationMillis; this.computations = computations; this.sampler = sampler; this.activeWorkRefreshExecutor = activeWorkRefreshExecutor; + this.heartbeatTracker = heartbeatTracker; + this.fanOutActiveWorkRefreshExecutor = + Executors.newCachedThreadPool( + new ThreadFactoryBuilder().setNameFormat(FAN_OUT_REFRESH_WORK_EXECUTOR_NAME).build()); } @SuppressWarnings("FutureReturnValueIgnored") @@ -103,5 +126,67 @@ private void invalidateStuckCommits() { } } - protected abstract void refreshActiveWork(); + private void refreshActiveWork() { + Instant refreshDeadline = clock.get().minus(Duration.millis(activeWorkRefreshPeriodMillis)); + Map heartbeatsBySender = + aggregateHeartbeatsBySender(refreshDeadline); + + List> fanOutRefreshActiveWork = new ArrayList<>(); + + // Send the first heartbeat on the calling thread, and fan out the rest via the + // fanOutActiveWorkRefreshExecutor. + @Nullable Map.Entry firstHeartbeat = null; + for (Map.Entry heartbeat : heartbeatsBySender.entrySet()) { + if (firstHeartbeat == null) { + firstHeartbeat = heartbeat; + } else { + fanOutRefreshActiveWork.add( + CompletableFuture.runAsync( + () -> sendHeartbeatSafely(heartbeat), fanOutActiveWorkRefreshExecutor)); + } + } + + sendHeartbeatSafely(firstHeartbeat); + fanOutRefreshActiveWork.forEach(CompletableFuture::join); + } + + /** Aggregate the heartbeats across computations by HeartbeatSender for correct fan out. */ + private Map aggregateHeartbeatsBySender(Instant refreshDeadline) { + Map heartbeatsBySender = new HashMap<>(); + + // Aggregate the heartbeats across computations by HeartbeatSender for correct fan out. + for (ComputationState computationState : computations.get()) { + for (RefreshableWork work : computationState.getRefreshableWork(refreshDeadline)) { + heartbeatsBySender + .computeIfAbsent(work.heartbeatSender(), ignored -> Heartbeats.builder()) + .add(computationState.getComputationId(), work, sampler); + } + } + + return heartbeatsBySender.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, e -> e.getValue().build())); + } + + /** + * Send the {@link Heartbeats} using the {@link HeartbeatSender}. Safe since exceptions are caught + * and logged. + */ + private void sendHeartbeatSafely(Map.Entry heartbeat) { + try (AutoCloseable ignored = heartbeatTracker.trackHeartbeats(heartbeat.getValue().size())) { + HeartbeatSender sender = heartbeat.getKey(); + Heartbeats heartbeats = heartbeat.getValue(); + sender.sendHeartbeats(heartbeats); + } catch (Exception e) { + LOG.error( + "Unable to send {} heartbeats to {}.", + heartbeat.getValue().size(), + heartbeat.getKey(), + e); + } + } + + @FunctionalInterface + public interface HeartbeatTracker { + AutoCloseable trackHeartbeats(int numHeartbeats); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefreshers.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefreshers.java deleted file mode 100644 index 5a59a7f1ae01..000000000000 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefreshers.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.dataflow.worker.windmill.work.refresh; - -import java.util.Collection; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ScheduledExecutorService; -import java.util.function.Consumer; -import java.util.function.Supplier; -import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; -import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest; -import org.joda.time.Instant; - -/** Utility class for {@link ActiveWorkRefresher}. */ -public final class ActiveWorkRefreshers { - public static ActiveWorkRefresher createDispatchedActiveWorkRefresher( - Supplier clock, - int activeWorkRefreshPeriodMillis, - int stuckCommitDurationMillis, - Supplier> computations, - DataflowExecutionStateSampler sampler, - Consumer>> activeWorkRefresherFn, - ScheduledExecutorService scheduledExecutorService) { - return new DispatchedActiveWorkRefresher( - clock, - activeWorkRefreshPeriodMillis, - stuckCommitDurationMillis, - computations, - sampler, - activeWorkRefresherFn, - scheduledExecutorService); - } -} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ApplianceHeartbeatSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ApplianceHeartbeatSender.java new file mode 100644 index 000000000000..b0f714433805 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ApplianceHeartbeatSender.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.work.refresh; + +import java.util.Collection; +import java.util.Map; +import java.util.function.Consumer; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.sdk.annotations.Internal; + +/** Streaming appliance implementation of {@link HeartbeatSender}. */ +@Internal +public final class ApplianceHeartbeatSender implements HeartbeatSender { + private final Consumer sendHeartbeatFn; + + public ApplianceHeartbeatSender(Consumer sendHeartbeatFn) { + this.sendHeartbeatFn = sendHeartbeatFn; + } + + /** + * Appliance which sends heartbeats (used to refresh active work) as KeyedGetDataRequests. So we + * must translate the HeartbeatRequest to a KeyedGetDataRequest here. + */ + @Override + public void sendHeartbeats(Heartbeats heartbeats) { + Windmill.GetDataRequest.Builder builder = Windmill.GetDataRequest.newBuilder(); + + for (Map.Entry> entry : + heartbeats.heartbeatRequests().asMap().entrySet()) { + Windmill.ComputationGetDataRequest.Builder perComputationBuilder = + Windmill.ComputationGetDataRequest.newBuilder(); + perComputationBuilder.setComputationId(entry.getKey()); + for (Windmill.HeartbeatRequest request : entry.getValue()) { + perComputationBuilder.addRequests( + Windmill.KeyedGetDataRequest.newBuilder() + .setShardingKey(request.getShardingKey()) + .setWorkToken(request.getWorkToken()) + .setCacheToken(request.getCacheToken()) + .addAllLatencyAttribution(request.getLatencyAttributionList()) + .build()); + } + builder.addRequests(perComputationBuilder.build()); + } + + sendHeartbeatFn.accept(builder.build()); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresher.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresher.java deleted file mode 100644 index f81233498fe3..000000000000 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresher.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.dataflow.worker.windmill.work.refresh; - -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ScheduledExecutorService; -import java.util.function.Consumer; -import java.util.function.Supplier; -import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; -import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.joda.time.Duration; -import org.joda.time.Instant; - -final class DispatchedActiveWorkRefresher extends ActiveWorkRefresher { - - private final Consumer>> activeWorkRefresherFn; - - DispatchedActiveWorkRefresher( - Supplier clock, - int activeWorkRefreshPeriodMillis, - int stuckCommitDurationMillis, - Supplier> computations, - DataflowExecutionStateSampler sampler, - Consumer>> activeWorkRefresherFn, - ScheduledExecutorService scheduledExecutorService) { - super( - clock, - activeWorkRefreshPeriodMillis, - stuckCommitDurationMillis, - computations, - sampler, - scheduledExecutorService); - this.activeWorkRefresherFn = activeWorkRefresherFn; - } - - @Override - protected void refreshActiveWork() { - Map> heartbeats = new HashMap<>(); - Instant refreshDeadline = clock.get().minus(Duration.millis(activeWorkRefreshPeriodMillis)); - - for (ComputationState computationState : computations.get()) { - heartbeats.put( - computationState.getComputationId(), - computationState.getKeyHeartbeats(refreshDeadline, sampler)); - } - - activeWorkRefresherFn.accept(heartbeats); - } -} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/FixedStreamHeartbeatSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/FixedStreamHeartbeatSender.java new file mode 100644 index 000000000000..33a55d1927f8 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/FixedStreamHeartbeatSender.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.work.refresh; + +import java.util.Objects; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.streaming.RefreshableWork; +import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; +import org.apache.beam.sdk.annotations.Internal; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * {@link HeartbeatSender} implementation that sends heartbeats directly on the underlying stream if + * the stream is not closed. + * + * @implNote + *

{@link #equals(Object)} and {@link #hashCode()} implementations delegate to internal + * {@link GetDataStream} implementations so that requests can be grouped and sent on the same + * stream instance. + *

This class is a stateless decorator to the underlying stream. + */ +@Internal +public final class FixedStreamHeartbeatSender implements HeartbeatSender { + private static final Logger LOG = LoggerFactory.getLogger(FixedStreamHeartbeatSender.class); + private final GetDataStream getDataStream; + + private FixedStreamHeartbeatSender(GetDataStream getDataStream) { + this.getDataStream = getDataStream; + } + + public static FixedStreamHeartbeatSender create(GetDataStream getDataStream) { + return new FixedStreamHeartbeatSender(getDataStream); + } + + @Override + public void sendHeartbeats(Heartbeats heartbeats) { + @Nullable String originalThreadName = null; + try { + String backendWorkerToken = getDataStream.backendWorkerToken(); + if (!backendWorkerToken.isEmpty()) { + // Decorate the thread name w/ the backendWorkerToken for debugging. Resets the thread's + // name after sending the heartbeats succeeds or fails. + originalThreadName = Thread.currentThread().getName(); + Thread.currentThread().setName(originalThreadName + "-" + backendWorkerToken); + } + getDataStream.refreshActiveWork(heartbeats.heartbeatRequests().asMap()); + } catch (AbstractWindmillStream.WindmillStreamShutdownException e) { + LOG.warn( + "Trying to refresh work w/ {} heartbeats on stream={} after work has moved off of worker." + + " heartbeats", + getDataStream.backendWorkerToken(), + heartbeats.heartbeatRequests().size()); + heartbeats.work().forEach(RefreshableWork::setFailed); + } finally { + if (originalThreadName != null) { + Thread.currentThread().setName(originalThreadName); + } + } + } + + @Override + public int hashCode() { + return Objects.hash(FixedStreamHeartbeatSender.class, getDataStream); + } + + @Override + public boolean equals(Object obj) { + return obj instanceof FixedStreamHeartbeatSender + && getDataStream.equals(((FixedStreamHeartbeatSender) obj).getDataStream); + } + + @Override + public String toString() { + return "HeartbeatSender-" + getDataStream.backendWorkerToken(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/HeartbeatSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/HeartbeatSender.java new file mode 100644 index 000000000000..06559344332c --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/HeartbeatSender.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.work.refresh; + +/** + * Interface for sending heartbeats. + * + * @implNote Batching/grouping of heartbeats is performed by HeartbeatSender equality. + */ +@FunctionalInterface +public interface HeartbeatSender { + /** + * Send heartbeats. Heartbeats represent WorkItem that is actively being processed belonging to + * the computation. + */ + void sendHeartbeats(Heartbeats heartbeats); +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/Heartbeats.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/Heartbeats.java new file mode 100644 index 000000000000..071bf7fa3d43 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/Heartbeats.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.work.refresh; + +import com.google.auto.value.AutoValue; +import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; +import org.apache.beam.runners.dataflow.worker.streaming.RefreshableWork; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap; + +/** Heartbeat requests and the work that was used to generate the heartbeat requests. */ +@AutoValue +abstract class Heartbeats { + + static Heartbeats.Builder builder() { + return new AutoValue_Heartbeats.Builder(); + } + + abstract ImmutableList work(); + + abstract ImmutableListMultimap heartbeatRequests(); + + final int size() { + return heartbeatRequests().asMap().size(); + } + + @AutoValue.Builder + abstract static class Builder { + + abstract ImmutableList.Builder workBuilder(); + + abstract ImmutableListMultimap.Builder + heartbeatRequestsBuilder(); + + final Builder add( + String computationId, RefreshableWork work, DataflowExecutionStateSampler sampler) { + workBuilder().add(work); + heartbeatRequestsBuilder().put(computationId, createHeartbeatRequest(work, sampler)); + return this; + } + + private Windmill.HeartbeatRequest createHeartbeatRequest( + RefreshableWork work, DataflowExecutionStateSampler sampler) { + return Windmill.HeartbeatRequest.newBuilder() + .setShardingKey(work.getShardedKey().shardingKey()) + .setWorkToken(work.id().workToken()) + .setCacheToken(work.id().cacheToken()) + .addAllLatencyAttribution(work.getHeartbeatLatencyAttributions(sampler)) + .build(); + } + + abstract Heartbeats build(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/StreamPoolHeartbeatSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/StreamPoolHeartbeatSender.java new file mode 100644 index 000000000000..e571f89f142c --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/StreamPoolHeartbeatSender.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.work.refresh; + +import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool; +import org.apache.beam.sdk.annotations.Internal; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** StreamingEngine stream pool based implementation of {@link HeartbeatSender}. */ +@Internal +public final class StreamPoolHeartbeatSender implements HeartbeatSender { + private static final Logger LOG = LoggerFactory.getLogger(StreamPoolHeartbeatSender.class); + + private final WindmillStreamPool heartbeatStreamPool; + + public StreamPoolHeartbeatSender( + WindmillStreamPool heartbeatStreamPool) { + this.heartbeatStreamPool = heartbeatStreamPool; + } + + @Override + public void sendHeartbeats(Heartbeats heartbeats) { + try (CloseableStream closeableStream = + heartbeatStreamPool.getCloseableStream()) { + closeableStream.stream().refreshActiveWork(heartbeats.heartbeatRequests().asMap()); + } catch (Exception e) { + LOG.warn("Error occurred sending heartbeats=[{}].", heartbeats, e); + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java index 127d46b7caf6..b3f7467cdbd3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java @@ -28,6 +28,7 @@ import static org.junit.Assert.assertFalse; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -89,11 +90,10 @@ public final class FakeWindmillServer extends WindmillServerStub { private final AtomicInteger expectedExceptionCount; private final ErrorCollector errorCollector; private final ConcurrentHashMap> droppedStreamingCommits; - private int commitsRequested = 0; private final List getDataRequests = new ArrayList<>(); - private boolean isReady = true; - private boolean dropStreamingCommits = false; private final Consumer> processHeartbeatResponses; + private int commitsRequested = 0; + private boolean dropStreamingCommits = false; @GuardedBy("this") private ImmutableSet dispatcherEndpoints; @@ -232,7 +232,15 @@ public GetWorkStream getWorkStream(Windmill.GetWorkRequest request, WorkItemRece final CountDownLatch done = new CountDownLatch(1); return new GetWorkStream() { @Override - public void close() { + public String backendWorkerToken() { + return ""; + } + + @Override + public void shutdown() {} + + @Override + public void halfClose() { done.countDown(); } @@ -257,7 +265,7 @@ public boolean awaitTermination(int time, TimeUnit unit) throws InterruptedExcep try { sleepMillis(500); } catch (InterruptedException e) { - close(); + halfClose(); Thread.currentThread().interrupt(); } continue; @@ -294,6 +302,14 @@ public Instant startTime() { public GetDataStream getDataStream() { Instant startTime = Instant.now(); return new GetDataStream() { + @Override + public String backendWorkerToken() { + return ""; + } + + @Override + public void shutdown() {} + @Override public Windmill.KeyedGetDataResponse requestKeyedData( String computation, KeyedGetDataRequest request) { @@ -330,9 +346,9 @@ public Windmill.GlobalData requestGlobalData(Windmill.GlobalDataRequest request) } @Override - public void refreshActiveWork(Map> heartbeats) { + public void refreshActiveWork(Map> heartbeats) { Windmill.GetDataRequest.Builder builder = Windmill.GetDataRequest.newBuilder(); - for (Map.Entry> entry : heartbeats.entrySet()) { + for (Map.Entry> entry : heartbeats.entrySet()) { builder.addComputationHeartbeatRequest( ComputationHeartbeatRequest.newBuilder() .setComputationId(entry.getKey()) @@ -348,7 +364,7 @@ public void onHeartbeatResponse(List responses) { } @Override - public void close() {} + public void halfClose() {} @Override public boolean awaitTermination(int time, TimeUnit unit) { @@ -368,18 +384,16 @@ public CommitWorkStream commitWorkStream() { return new CommitWorkStream() { @Override - public RequestBatcher batcher() { - return new RequestBatcher() { - class RequestAndDone { - final Consumer onDone; - final WorkItemCommitRequest request; + public String backendWorkerToken() { + return ""; + } - RequestAndDone(WorkItemCommitRequest request, Consumer onDone) { - this.request = request; - this.onDone = onDone; - } - } + @Override + public void shutdown() {} + @Override + public RequestBatcher batcher() { + return new RequestBatcher() { final List requests = new ArrayList<>(); @Override @@ -427,11 +441,21 @@ public void flush() { } requests.clear(); } + + class RequestAndDone { + final Consumer onDone; + final WorkItemCommitRequest request; + + RequestAndDone(WorkItemCommitRequest request, Consumer onDone) { + this.request = request; + this.onDone = onDone; + } + } }; } @Override - public void close() {} + public void halfClose() {} @Override public boolean awaitTermination(int time, TimeUnit unit) { @@ -523,27 +547,13 @@ public ArrayList getStatsReceived() { } @Override - public void setWindmillServiceEndpoints(Set endpoints) { - synchronized (this) { - this.dispatcherEndpoints = ImmutableSet.copyOf(endpoints); - isReady = true; - } + public synchronized ImmutableSet getWindmillServiceEndpoints() { + return dispatcherEndpoints; } @Override - public ImmutableSet getWindmillServiceEndpoints() { - synchronized (this) { - return dispatcherEndpoints; - } - } - - @Override - public boolean isReady() { - return isReady; - } - - public void setIsReady(boolean ready) { - this.isReady = ready; + public synchronized void setWindmillServiceEndpoints(Set endpoints) { + this.dispatcherEndpoints = ImmutableSet.copyOf(endpoints); } public static class ResponseQueue { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index 8a4369fdbd8d..5855057c4210 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -126,6 +126,8 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.Timer.Type; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WatermarkHold; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.Coder.Context; import org.apache.beam.sdk.coders.CollectionCoder; @@ -330,9 +332,7 @@ private static ExecutableWork createMockWork( .build(), Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build(), Work.createProcessingContext( - computationId, - (a, b) -> Windmill.KeyedGetDataResponse.getDefaultInstance(), - ignored -> {}), + computationId, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)), Instant::now, Collections.emptyList()), processWorkFn); @@ -887,7 +887,6 @@ private void runTestBasic(int numCommitThreads) throws Exception { makeSourceInstruction(StringUtf8Coder.of()), makeSinkInstruction(StringUtf8Coder.of(), 0)); - server.setIsReady(false); StreamingConfigTask streamingConfig = new StreamingConfigTask(); streamingConfig.setStreamingComputationConfigs( ImmutableList.of(makeDefaultStreamingComputationConfig(instructions))); @@ -935,8 +934,6 @@ public void testHotKeyLogging() throws Exception { makeSourceInstruction(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())), makeSinkInstruction(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()), 0)); - server.setIsReady(false); - StreamingConfigTask streamingConfig = new StreamingConfigTask(); streamingConfig.setStreamingComputationConfigs( ImmutableList.of(makeDefaultStreamingComputationConfig(instructions))); @@ -974,8 +971,6 @@ public void testHotKeyLoggingNotEnabled() throws Exception { makeSourceInstruction(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())), makeSinkInstruction(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()), 0)); - server.setIsReady(false); - StreamingConfigTask streamingConfig = new StreamingConfigTask(); streamingConfig.setStreamingComputationConfigs( ImmutableList.of(makeDefaultStreamingComputationConfig(instructions))); @@ -3484,8 +3479,9 @@ public void testLatencyAttributionProtobufsPopulated() { Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build(), Work.createProcessingContext( "computationId", - (a, b) -> Windmill.KeyedGetDataResponse.getDefaultInstance(), - ignored -> {}), + new FakeGetDataClient(), + ignored -> {}, + mock(HeartbeatSender.class)), clock, Collections.emptyList()); @@ -3502,7 +3498,7 @@ public void testLatencyAttributionProtobufsPopulated() { clock.sleep(Duration.millis(60)); Iterator it = - work.getLatencyAttributions(false, DataflowExecutionStateSampler.instance()).iterator(); + work.getLatencyAttributions(DataflowExecutionStateSampler.instance()).iterator(); assertTrue(it.hasNext()); LatencyAttribution lat = it.next(); assertSame(State.QUEUED, lat.getState()); @@ -3787,7 +3783,7 @@ public void testDoFnActiveMessageMetadataReportedOnHeartbeat() throws Exception Map result = server.waitForAndGetCommits(1); assertThat(server.numGetDataRequests(), greaterThan(0)); - Windmill.GetDataRequest heartbeat = server.getGetDataRequests().get(2); + Windmill.GetDataRequest heartbeat = server.getGetDataRequests().get(1); for (LatencyAttribution la : heartbeat diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index 7988212efde0..8445e8ede852 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -24,6 +24,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; import com.google.api.services.dataflow.model.CounterMetadata; import com.google.api.services.dataflow.model.CounterStructuredName; @@ -60,8 +61,10 @@ import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.metrics.MetricsContainer; import org.apache.beam.sdk.options.PipelineOptionsFactory; @@ -82,7 +85,6 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.MockitoAnnotations; /** Tests for {@link StreamingModeExecutionContext}. */ @@ -133,9 +135,7 @@ private static Work createMockWork(Windmill.WorkItem workItem, Watermarks waterm workItem, watermarks, Work.createProcessingContext( - COMPUTATION_ID, - (a, b) -> Windmill.KeyedGetDataResponse.getDefaultInstance(), - ignored -> {}), + COMPUTATION_ID, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)), Instant::now, Collections.emptyList()); } @@ -243,8 +243,8 @@ public void testSideInputReaderReconstituted() { @Test public void extractMsecCounters() { - MetricsContainer metricsContainer = Mockito.mock(MetricsContainer.class); - ProfileScope profileScope = Mockito.mock(ProfileScope.class); + MetricsContainer metricsContainer = mock(MetricsContainer.class); + ProfileScope profileScope = mock(ProfileScope.class); ExecutionState start1 = executionContext.executionStateRegistry.getState( NameContext.create("stage", "original-1", "system-1", "user-1"), diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java index c79d947ca227..5c149a65f4ce 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java @@ -95,8 +95,10 @@ import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader; import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader.NativeReaderIterator; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.BigEndianIntegerCoder; import org.apache.beam.sdk.coders.Coder; @@ -197,9 +199,7 @@ private static Work createMockWork(Windmill.WorkItem workItem, Watermarks waterm workItem, watermarks, Work.createProcessingContext( - COMPUTATION_ID, - (a, b) -> Windmill.KeyedGetDataResponse.getDefaultInstance(), - ignored -> {}), + COMPUTATION_ID, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)), Instant::now, Collections.emptyList()); } @@ -1000,8 +1000,9 @@ public void testFailedWorkItemsAbort() throws Exception { Watermarks.builder().setInputDataWatermark(new Instant(0)).build(), Work.createProcessingContext( COMPUTATION_ID, - (a, b) -> Windmill.KeyedGetDataResponse.getDefaultInstance(), - ignored -> {}), + new FakeGetDataClient(), + ignored -> {}, + mock(HeartbeatSender.class)), Instant::now, Collections.emptyList()); context.start( diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java index 3a3e0a34c217..a373dffd1dc4 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java @@ -18,7 +18,6 @@ package org.apache.beam.runners.dataflow.worker.streaming; import static com.google.common.truth.Truth.assertThat; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertSame; @@ -26,20 +25,18 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import com.google.auto.value.AutoValue; import java.util.Collections; import java.util.Deque; import java.util.HashMap; import java.util.Map; import java.util.Optional; -import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; import org.apache.beam.runners.dataflow.worker.streaming.ActiveWorkState.ActivateWorkResult; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.joda.time.Instant; import org.junit.Before; import org.junit.Rule; @@ -85,9 +82,7 @@ private static ExecutableWork expiredWork(Windmill.WorkItem workItem) { private static Work.ProcessingContext createWorkProcessingContext() { return Work.createProcessingContext( - "computationId", - (a, b) -> Windmill.KeyedGetDataResponse.getDefaultInstance(), - ignored -> {}); + "computationId", new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)); } private static WorkId workId(long workToken, long cacheToken) { @@ -447,70 +442,4 @@ public void testActivateWorkForKey_matchingCacheTokens_newWorkTokenLesser_STALE( assertFalse(readOnlyActiveWork.get(shardedKey).contains(newWork)); assertEquals(queuedWork, readOnlyActiveWork.get(shardedKey).peek()); } - - @Test - public void testGetKeyHeartbeats() { - Instant refreshDeadline = Instant.now(); - ShardedKey shardedKey1 = shardedKey("someKey", 1L); - ShardedKey shardedKey2 = shardedKey("anotherKey", 2L); - - ExecutableWork freshWork = createWork(createWorkItem(3L, 3L, shardedKey1)); - ExecutableWork refreshableWork1 = expiredWork(createWorkItem(1L, 1L, shardedKey1)); - refreshableWork1.work().setState(Work.State.COMMITTING); - ExecutableWork refreshableWork2 = expiredWork(createWorkItem(2L, 2L, shardedKey2)); - refreshableWork2.work().setState(Work.State.COMMITTING); - - activeWorkState.activateWorkForKey(refreshableWork1); - activeWorkState.activateWorkForKey(freshWork); - activeWorkState.activateWorkForKey(refreshableWork2); - - ImmutableList requests = - activeWorkState.getKeyHeartbeats(refreshDeadline, DataflowExecutionStateSampler.instance()); - - ImmutableList expected = - ImmutableList.of( - HeartbeatRequestShardingKeyWorkTokenAndCacheToken.from( - shardedKey1, refreshableWork1.work()), - HeartbeatRequestShardingKeyWorkTokenAndCacheToken.from( - shardedKey2, refreshableWork2.work())); - - ImmutableList actual = - requests.stream() - .map(HeartbeatRequestShardingKeyWorkTokenAndCacheToken::from) - .collect(toImmutableList()); - - assertThat(actual).containsExactlyElementsIn(expected); - } - - @AutoValue - abstract static class HeartbeatRequestShardingKeyWorkTokenAndCacheToken { - - private static HeartbeatRequestShardingKeyWorkTokenAndCacheToken create( - long shardingKey, long workToken, long cacheToken) { - return new AutoValue_ActiveWorkStateTest_HeartbeatRequestShardingKeyWorkTokenAndCacheToken( - shardingKey, workToken, cacheToken); - } - - private static HeartbeatRequestShardingKeyWorkTokenAndCacheToken from( - HeartbeatRequest heartbeatRequest) { - return create( - heartbeatRequest.getShardingKey(), - heartbeatRequest.getWorkToken(), - heartbeatRequest.getCacheToken()); - } - - private static HeartbeatRequestShardingKeyWorkTokenAndCacheToken from( - ShardedKey shardedKey, Work work) { - return create( - shardedKey.shardingKey(), - work.getWorkItem().getWorkToken(), - work.getWorkItem().getCacheToken()); - } - - abstract long shardingKey(); - - abstract long workToken(); - - abstract long cacheToken(); - } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCacheTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCacheTest.java index 3c1683ecf436..1f70c2476325 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCacheTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCacheTest.java @@ -36,8 +36,10 @@ import org.apache.beam.runners.dataflow.worker.streaming.config.ComputationConfig; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.sdk.fn.IdGenerators; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; @@ -68,8 +70,9 @@ private static ExecutableWork createWork(ShardedKey shardedKey, long workToken, Watermarks.builder().setInputDataWatermark(Instant.now()).build(), Work.createProcessingContext( "computationId", - (a, b) -> Windmill.KeyedGetDataResponse.getDefaultInstance(), - ignored -> {}), + new FakeGetDataClient(), + ignored -> {}, + mock(HeartbeatSender.class)), Instant::now, Collections.emptyList()), ignored -> {}); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcherTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcherTest.java index ad2ac6baeabb..24a93f58b12a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcherTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcherTest.java @@ -33,8 +33,8 @@ import java.util.List; import java.util.concurrent.TimeUnit; import org.apache.beam.runners.dataflow.options.DataflowStreamingPipelineOptions; -import org.apache.beam.runners.dataflow.worker.MetricTrackingWindmillServerStub; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.ListCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; @@ -67,13 +67,46 @@ @SuppressWarnings("deprecation") @RunWith(JUnit4.class) public class SideInputStateFetcherTest { - @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private static final String STATE_FAMILY = "state"; - - @Mock private MetricTrackingWindmillServerStub server; + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + @Mock private GetDataClient server; @Mock private Supplier readStateSupplier; + private static Windmill.GlobalData buildGlobalDataResponse( + String tag, boolean isReady, ByteString data) { + Windmill.GlobalData.Builder builder = + Windmill.GlobalData.newBuilder() + .setDataId( + Windmill.GlobalDataId.newBuilder() + .setTag(tag) + .setVersion(ByteString.EMPTY) + .build()); + + if (isReady) { + builder.setIsReady(true).setData(data); + } else { + builder.setIsReady(false); + } + return builder.build(); + } + + private static Windmill.GlobalDataRequest buildGlobalDataRequest(String tag, ByteString version) { + Windmill.GlobalDataId id = + Windmill.GlobalDataId.newBuilder().setTag(tag).setVersion(version).build(); + + return Windmill.GlobalDataRequest.newBuilder() + .setDataId(id) + .setStateFamily(STATE_FAMILY) + .setExistenceWatermarkDeadline( + TimeUnit.MILLISECONDS.toMicros(GlobalWindow.INSTANCE.maxTimestamp().getMillis())) + .build(); + } + + private static Windmill.GlobalDataRequest buildGlobalDataRequest(String tag) { + return buildGlobalDataRequest(tag, ByteString.EMPTY); + } + @Before public void setUp() { MockitoAnnotations.initMocks(this); @@ -81,10 +114,10 @@ public void setUp() { @Test public void testFetchGlobalDataBasic() throws Exception { - SideInputStateFetcher fetcher = - new SideInputStateFetcher( - server::getSideInputData, + SideInputStateFetcherFactory factory = + SideInputStateFetcherFactory.fromOptions( PipelineOptionsFactory.as(DataflowStreamingPipelineOptions.class)); + SideInputStateFetcher fetcher = factory.createSideInputStateFetcher(server::getSideInputData); ByteStringOutputStream stream = new ByteStringOutputStream(); ListCoder.of(StringUtf8Coder.of()) @@ -152,10 +185,10 @@ public void testFetchGlobalDataBasic() throws Exception { @Test public void testFetchGlobalDataNull() throws Exception { - SideInputStateFetcher fetcher = - new SideInputStateFetcher( - server::getSideInputData, + SideInputStateFetcherFactory factory = + SideInputStateFetcherFactory.fromOptions( PipelineOptionsFactory.as(DataflowStreamingPipelineOptions.class)); + SideInputStateFetcher fetcher = factory.createSideInputStateFetcher(server::getSideInputData); ByteStringOutputStream stream = new ByteStringOutputStream(); ListCoder.of(VoidCoder.of()) @@ -311,10 +344,10 @@ public void testFetchGlobalDataCacheOverflow() throws Exception { @Test public void testEmptyFetchGlobalData() { - SideInputStateFetcher fetcher = - new SideInputStateFetcher( - server::getSideInputData, + SideInputStateFetcherFactory factory = + SideInputStateFetcherFactory.fromOptions( PipelineOptionsFactory.as(DataflowStreamingPipelineOptions.class)); + SideInputStateFetcher fetcher = factory.createSideInputStateFetcher(server::getSideInputData); ByteString encodedIterable = ByteString.EMPTY; @@ -346,38 +379,4 @@ public void testEmptyFetchGlobalData() { verify(server).getSideInputData(buildGlobalDataRequest(tag)); verifyNoMoreInteractions(server); } - - private static Windmill.GlobalData buildGlobalDataResponse( - String tag, boolean isReady, ByteString data) { - Windmill.GlobalData.Builder builder = - Windmill.GlobalData.newBuilder() - .setDataId( - Windmill.GlobalDataId.newBuilder() - .setTag(tag) - .setVersion(ByteString.EMPTY) - .build()); - - if (isReady) { - builder.setIsReady(true).setData(data); - } else { - builder.setIsReady(false); - } - return builder.build(); - } - - private static Windmill.GlobalDataRequest buildGlobalDataRequest(String tag, ByteString version) { - Windmill.GlobalDataId id = - Windmill.GlobalDataId.newBuilder().setTag(tag).setVersion(version).build(); - - return Windmill.GlobalDataRequest.newBuilder() - .setDataId(id) - .setStateFamily(STATE_FAMILY) - .setExistenceWatermarkDeadline( - TimeUnit.MILLISECONDS.toMicros(GlobalWindow.INSTANCE.maxTimestamp().getMillis())) - .build(); - } - - private static Windmill.GlobalDataRequest buildGlobalDataRequest(String tag) { - return buildGlobalDataRequest(tag, ByteString.EMPTY); - } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java index e08c951975fa..ad77958837a1 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; import java.util.Collections; import java.util.concurrent.CountDownLatch; @@ -31,6 +32,8 @@ import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.joda.time.Instant; @@ -65,24 +68,23 @@ private static ExecutableWork createWork(Consumer executeWorkFn) { Watermarks.builder().setInputDataWatermark(Instant.now()).build(), Work.createProcessingContext( "computationId", - (a, b) -> Windmill.KeyedGetDataResponse.getDefaultInstance(), - ignored -> {}), + new FakeGetDataClient(), + ignored -> {}, + mock(HeartbeatSender.class)), Instant::now, Collections.emptyList()), executeWorkFn); } private Runnable createSleepProcessWorkFn(CountDownLatch start, CountDownLatch stop) { - Runnable runnable = - () -> { - start.countDown(); - try { - stop.await(); - } catch (Exception e) { - throw new RuntimeException(e); - } - }; - return runnable; + return () -> { + start.countDown(); + try { + stop.await(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }; } @Before diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java index a2f5e71d04c3..bdad382c9af2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java @@ -38,12 +38,12 @@ @RunWith(JUnit4.class) public class WindmillStreamPoolTest { - @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private static final int DEFAULT_NUM_STREAMS = 10; private static final int NEW_STREAM_HOLDS = 2; private final ConcurrentHashMap< TestWindmillStream, WindmillStreamPool.StreamData> holds = new ConcurrentHashMap<>(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private List> streams; @Before @@ -237,7 +237,7 @@ private TestWindmillStream(Instant startTime) { } @Override - public void close() { + public void halfClose() { closed = true; } @@ -250,5 +250,15 @@ public boolean awaitTermination(int time, TimeUnit unit) { public Instant startTime() { return startTime; } + + @Override + public String backendWorkerToken() { + return ""; + } + + @Override + public void shutdown() { + halfClose(); + } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java index 85e07c3bd797..51cd83d17fab 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java @@ -19,6 +19,7 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertNotNull; +import static org.mockito.Mockito.mock; import com.google.api.services.dataflow.model.MapTask; import com.google.common.truth.Correspondence; @@ -35,6 +36,8 @@ import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.joda.time.Instant; @@ -45,7 +48,6 @@ import org.junit.rules.ErrorCollector; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.mockito.Mockito; @RunWith(JUnit4.class) public class StreamingApplianceWorkCommitterTest { @@ -64,10 +66,11 @@ private static Work createMockWork(long workToken) { Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build(), Work.createProcessingContext( "computationId", - (a, b) -> Windmill.KeyedGetDataResponse.getDefaultInstance(), + new FakeGetDataClient(), ignored -> { throw new UnsupportedOperationException(); - }), + }, + mock(HeartbeatSender.class)), Instant::now, Collections.emptyList()); } @@ -76,7 +79,7 @@ private static ComputationState createComputationState(String computationId) { return new ComputationState( computationId, new MapTask().setSystemName("system").setStageName("stage"), - Mockito.mock(BoundedQueueExecutor.class), + mock(BoundedQueueExecutor.class), ImmutableMap.of(), null); } @@ -90,7 +93,7 @@ private StreamingApplianceWorkCommitter createWorkCommitter( public void setUp() { fakeWindmillServer = new FakeWindmillServer( - errorCollector, ignored -> Optional.of(Mockito.mock(ComputationState.class))); + errorCollector, ignored -> Optional.of(mock(ComputationState.class))); } @After diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java index d53690938aef..546a2883e3b2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java @@ -21,6 +21,7 @@ import static org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus.OK; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; import com.google.api.services.dataflow.model.MapTask; import java.io.IOException; @@ -49,24 +50,24 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.joda.time.Duration; import org.joda.time.Instant; -import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ErrorCollector; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.mockito.Mockito; @RunWith(JUnit4.class) public class StreamingEngineWorkCommitterTest { @Rule public ErrorCollector errorCollector = new ErrorCollector(); - private StreamingEngineWorkCommitter workCommitter; + private WorkCommitter workCommitter; private FakeWindmillServer fakeWindmillServer; private Supplier> commitWorkStreamFactory; @@ -81,10 +82,11 @@ private static Work createMockWork(long workToken) { Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build(), Work.createProcessingContext( "computationId", - (a, b) -> Windmill.KeyedGetDataResponse.getDefaultInstance(), + new FakeGetDataClient(), ignored -> { throw new UnsupportedOperationException(); - }), + }, + mock(HeartbeatSender.class)), Instant::now, Collections.emptyList()); } @@ -93,7 +95,7 @@ private static ComputationState createComputationState(String computationId) { return new ComputationState( computationId, new MapTask().setSystemName("system").setStageName("stage"), - Mockito.mock(BoundedQueueExecutor.class), + mock(BoundedQueueExecutor.class), ImmutableMap.of(), null); } @@ -110,21 +112,18 @@ private static CompleteCommit asCompleteCommit(Commit commit, Windmill.CommitSta public void setUp() throws IOException { fakeWindmillServer = new FakeWindmillServer( - errorCollector, ignored -> Optional.of(Mockito.mock(ComputationState.class))); + errorCollector, ignored -> Optional.of(mock(ComputationState.class))); commitWorkStreamFactory = WindmillStreamPool.create( 1, Duration.standardMinutes(1), fakeWindmillServer::commitWorkStream) ::getCloseableStream; } - @After - public void cleanUp() { - workCommitter.stop(); - } - - private StreamingEngineWorkCommitter createWorkCommitter( - Consumer onCommitComplete) { - return StreamingEngineWorkCommitter.create(commitWorkStreamFactory, 1, onCommitComplete); + private WorkCommitter createWorkCommitter(Consumer onCommitComplete) { + return StreamingEngineWorkCommitter.builder() + .setCommitWorkStreamFactory(commitWorkStreamFactory) + .setOnCommitComplete(onCommitComplete) + .build(); } @Test @@ -156,6 +155,8 @@ public void testCommit_sendsCommitsToStreamingEngine() { assertThat(request).isEqualTo(commit.request()); assertThat(completeCommits).contains(asCompleteCommit(commit, Windmill.CommitStatus.OK)); } + + workCommitter.stop(); } @Test @@ -196,6 +197,8 @@ public void testCommit_handlesFailedCommits() { .containsEntry(commit.work().getWorkItem().getWorkToken(), commit.request()); } } + + workCommitter.stop(); } @Test @@ -248,6 +251,8 @@ public void testCommit_handlesCompleteCommits_commitStatusNotOK() { .contains(asCompleteCommit(commit, expectedCommitStatus.get(commit.work().id()))); } assertThat(completeCommits.size()).isEqualTo(commits.size()); + + workCommitter.stop(); } @Test @@ -273,7 +278,7 @@ public void flush() {} } @Override - public void close() {} + public void halfClose() {} @Override public boolean awaitTermination(int time, TimeUnit unit) { @@ -284,6 +289,14 @@ public boolean awaitTermination(int time, TimeUnit unit) { public Instant startTime() { return Instant.now(); } + + @Override + public String backendWorkerToken() { + return ""; + } + + @Override + public void shutdown() {} }; commitWorkStreamFactory = @@ -328,7 +341,12 @@ public void testMultipleCommitSendersSingleStream() { ::getCloseableStream; Set completeCommits = Collections.newSetFromMap(new ConcurrentHashMap<>()); workCommitter = - StreamingEngineWorkCommitter.create(commitWorkStreamFactory, 5, completeCommits::add); + StreamingEngineWorkCommitter.builder() + .setCommitWorkStreamFactory(commitWorkStreamFactory) + .setNumCommitSenders(5) + .setOnCommitComplete(completeCommits::add) + .build(); + List commits = new ArrayList<>(); for (int i = 1; i <= 500; i++) { Work work = createMockWork(i); @@ -353,5 +371,7 @@ public void testMultipleCommitSendersSingleStream() { assertThat(request).isEqualTo(commit.request()); assertThat(completeCommits).contains(asCompleteCommit(commit, Windmill.CommitStatus.OK)); } + + workCommitter.stop(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/FakeGetDataClient.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/FakeGetDataClient.java new file mode 100644 index 000000000000..ca89e9647153 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/FakeGetDataClient.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.getdata; + +import java.io.PrintWriter; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; + +/** Fake {@link GetDataClient} implementation for testing. */ +public final class FakeGetDataClient implements GetDataClient { + @Override + public Windmill.KeyedGetDataResponse getStateData( + String computationId, Windmill.KeyedGetDataRequest request) throws GetDataException { + return Windmill.KeyedGetDataResponse.getDefaultInstance(); + } + + @Override + public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) + throws GetDataException { + return Windmill.GlobalData.getDefaultInstance(); + } + + @Override + public void printHtml(PrintWriter writer) {} +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ThrottlingGetDataMetricTrackerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ThrottlingGetDataMetricTrackerTest.java new file mode 100644 index 000000000000..d687434edff4 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ThrottlingGetDataMetricTrackerTest.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.getdata; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertFalse; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +@SuppressWarnings("FutureReturnValueIgnored") +public class ThrottlingGetDataMetricTrackerTest { + + private final MemoryMonitor memoryMonitor = mock(MemoryMonitor.class); + private final ThrottlingGetDataMetricTracker getDataMetricTracker = + new ThrottlingGetDataMetricTracker(memoryMonitor); + private final ExecutorService getDataProcessor = Executors.newCachedThreadPool(); + + @Test + public void testTrackFetchStateDataWithThrottling() throws InterruptedException { + doNothing().when(memoryMonitor).waitForResources(anyString()); + CountDownLatch processCall = new CountDownLatch(1); + CountDownLatch callProcessing = new CountDownLatch(1); + CountDownLatch processingDone = new CountDownLatch(1); + getDataProcessor.submit( + () -> { + try (AutoCloseable ignored = getDataMetricTracker.trackStateDataFetchWithThrottling()) { + callProcessing.countDown(); + processCall.await(); + } catch (Exception e) { + // Do nothing. + } + processingDone.countDown(); + }); + + callProcessing.await(); + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsWhileProcessing = + getDataMetricTracker.getMetricsSnapshot(); + + assertThat(metricsWhileProcessing.activeStateReads()).isEqualTo(1); + assertThat(metricsWhileProcessing.activeHeartbeats()).isEqualTo(0); + assertThat(metricsWhileProcessing.activeSideInputs()).isEqualTo(0); + + // Free the thread inside the AutoCloseable, wait for processingDone and check that metrics gets + // decremented + processCall.countDown(); + processingDone.await(); + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsAfterProcessing = + getDataMetricTracker.getMetricsSnapshot(); + assertThat(metricsAfterProcessing.activeStateReads()).isEqualTo(0); + assertThat(metricsAfterProcessing.activeHeartbeats()).isEqualTo(0); + assertThat(metricsAfterProcessing.activeSideInputs()).isEqualTo(0); + } + + @Test + public void testTrackSideInputFetchWithThrottling() throws InterruptedException { + doNothing().when(memoryMonitor).waitForResources(anyString()); + CountDownLatch processCall = new CountDownLatch(1); + CountDownLatch callProcessing = new CountDownLatch(1); + CountDownLatch processingDone = new CountDownLatch(1); + getDataProcessor.submit( + () -> { + try (AutoCloseable ignored = getDataMetricTracker.trackSideInputFetchWithThrottling()) { + callProcessing.countDown(); + processCall.await(); + } catch (Exception e) { + // Do nothing. + } + processingDone.countDown(); + }); + + callProcessing.await(); + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsWhileProcessing = + getDataMetricTracker.getMetricsSnapshot(); + + assertThat(metricsWhileProcessing.activeStateReads()).isEqualTo(0); + assertThat(metricsWhileProcessing.activeHeartbeats()).isEqualTo(0); + assertThat(metricsWhileProcessing.activeSideInputs()).isEqualTo(1); + + // Free the thread inside the AutoCloseable, wait for processingDone and check that metrics gets + // decremented + processCall.countDown(); + processingDone.await(); + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsAfterProcessing = + getDataMetricTracker.getMetricsSnapshot(); + assertThat(metricsAfterProcessing.activeStateReads()).isEqualTo(0); + assertThat(metricsAfterProcessing.activeHeartbeats()).isEqualTo(0); + assertThat(metricsAfterProcessing.activeSideInputs()).isEqualTo(0); + } + + @Test + public void testThrottledTrackSingleCallWithThrottling() throws InterruptedException { + CountDownLatch mockThrottler = simulateMemoryPressure(); + CountDownLatch processCall = new CountDownLatch(1); + CountDownLatch callProcessing = new CountDownLatch(1); + CountDownLatch processingDone = new CountDownLatch(1); + getDataProcessor.submit( + () -> { + try (AutoCloseable ignored = getDataMetricTracker.trackStateDataFetchWithThrottling()) { + callProcessing.countDown(); + processCall.await(); + } catch (Exception e) { + // Do nothing. + } + processingDone.countDown(); + }); + + assertFalse(callProcessing.await(10, TimeUnit.MILLISECONDS)); + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsBeforeProcessing = + getDataMetricTracker.getMetricsSnapshot(); + assertThat(metricsBeforeProcessing.activeStateReads()).isEqualTo(0); + assertThat(metricsBeforeProcessing.activeHeartbeats()).isEqualTo(0); + assertThat(metricsBeforeProcessing.activeSideInputs()).isEqualTo(0); + + // Stop throttling. + mockThrottler.countDown(); + callProcessing.await(); + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsWhileProcessing = + getDataMetricTracker.getMetricsSnapshot(); + + assertThat(metricsWhileProcessing.activeStateReads()).isEqualTo(1); + + // Free the thread inside the AutoCloseable, wait for processingDone and check that metrics gets + // decremented + processCall.countDown(); + processingDone.await(); + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsAfterProcessing = + getDataMetricTracker.getMetricsSnapshot(); + assertThat(metricsAfterProcessing.activeStateReads()).isEqualTo(0); + } + + @Test + public void testTrackSingleCall_exceptionThrown() throws InterruptedException { + doNothing().when(memoryMonitor).waitForResources(anyString()); + CountDownLatch callProcessing = new CountDownLatch(1); + CountDownLatch beforeException = new CountDownLatch(1); + CountDownLatch afterException = new CountDownLatch(1); + + // Catch the exception outside the try-with-resources block to ensure that + // AutoCloseable.closed() runs in the midst of an exception. + getDataProcessor.submit( + () -> { + try { + try (AutoCloseable ignored = getDataMetricTracker.trackStateDataFetchWithThrottling()) { + callProcessing.countDown(); + beforeException.await(); + throw new RuntimeException("something bad happened"); + } + } catch (RuntimeException e) { + afterException.countDown(); + throw e; + } + }); + + callProcessing.await(); + + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsWhileProcessing = + getDataMetricTracker.getMetricsSnapshot(); + + assertThat(metricsWhileProcessing.activeStateReads()).isEqualTo(1); + beforeException.countDown(); + + // In the midst of an exception, close() should still run. + afterException.await(); + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsAfterProcessing = + getDataMetricTracker.getMetricsSnapshot(); + assertThat(metricsAfterProcessing.activeStateReads()).isEqualTo(0); + } + + @Test + public void testTrackHeartbeats() throws InterruptedException { + CountDownLatch processCall = new CountDownLatch(1); + CountDownLatch callProcessing = new CountDownLatch(1); + CountDownLatch processingDone = new CountDownLatch(1); + int numHeartbeats = 5; + getDataProcessor.submit( + () -> { + try (AutoCloseable ignored = getDataMetricTracker.trackHeartbeats(numHeartbeats)) { + callProcessing.countDown(); + processCall.await(); + } catch (Exception e) { + // Do nothing. + } + processingDone.countDown(); + }); + + callProcessing.await(); + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsWhileProcessing = + getDataMetricTracker.getMetricsSnapshot(); + + assertThat(metricsWhileProcessing.activeHeartbeats()).isEqualTo(5); + + // Free the thread inside the AutoCloseable, wait for processingDone and check that metrics gets + // decremented + processCall.countDown(); + processingDone.await(); + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsAfterProcessing = + getDataMetricTracker.getMetricsSnapshot(); + assertThat(metricsAfterProcessing.activeHeartbeats()).isEqualTo(0); + } + + @Test + public void testTrackHeartbeats_exceptionThrown() throws InterruptedException { + CountDownLatch callProcessing = new CountDownLatch(1); + CountDownLatch beforeException = new CountDownLatch(1); + CountDownLatch afterException = new CountDownLatch(1); + int numHeartbeats = 10; + // Catch the exception outside the try-with-resources block to ensure that + // AutoCloseable.closed() runs in the midst of an exception. + getDataProcessor.submit( + () -> { + try { + try (AutoCloseable ignored = getDataMetricTracker.trackHeartbeats(numHeartbeats)) { + callProcessing.countDown(); + beforeException.await(); + throw new RuntimeException("something bad happened"); + } + } catch (RuntimeException e) { + afterException.countDown(); + throw e; + } + }); + + callProcessing.await(); + + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsWhileProcessing = + getDataMetricTracker.getMetricsSnapshot(); + + assertThat(metricsWhileProcessing.activeHeartbeats()).isEqualTo(numHeartbeats); + beforeException.countDown(); + + // In the midst of an exception, close() should still run. + afterException.await(); + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsAfterProcessing = + getDataMetricTracker.getMetricsSnapshot(); + assertThat(metricsAfterProcessing.activeHeartbeats()).isEqualTo(0); + } + + /** Have the memory monitor block when waitForResources is called simulating memory pressure. */ + private CountDownLatch simulateMemoryPressure() { + CountDownLatch mockThrottler = new CountDownLatch(1); + doAnswer( + invocationOnMock -> { + mockThrottler.await(); + return null; + }) + .when(memoryMonitor) + .waitForResources(anyString()); + return mockThrottler; + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java index 515beba0c88d..4439c409b32f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java @@ -261,7 +261,7 @@ public void testGetWorkerMetadata_correctlyAddsAndRemovesStreamFromRegistry() { .build()); assertTrue(streamRegistry.contains(stream)); - stream.close(); + stream.halfClose(); assertFalse(streamRegistry.contains(stream)); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java index 6473d5527a81..5cfc19ac07df 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java @@ -110,14 +110,13 @@ "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) }) public class GrpcWindmillServerTest { - @Rule public transient Timeout globalTimeout = Timeout.seconds(600); - @Rule public GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); - @Rule public ErrorCollector errorCollector = new ErrorCollector(); - private static final Logger LOG = LoggerFactory.getLogger(GrpcWindmillServerTest.class); private static final int STREAM_CHUNK_SIZE = 2 << 20; private final long clientId = 10L; private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + @Rule public GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + @Rule public ErrorCollector errorCollector = new ErrorCollector(); private Server server; private GrpcWindmillServer client; private int remainingErrors = 20; @@ -329,7 +328,7 @@ public void onCompleted() { }); assertTrue(latch.await(30, TimeUnit.SECONDS)); - stream.close(); + stream.halfClose(); assertTrue(stream.awaitTermination(30, TimeUnit.SECONDS)); } @@ -490,7 +489,7 @@ private void flushResponse() { }); } done.await(); - stream.close(); + stream.halfClose(); assertTrue(stream.awaitTermination(60, TimeUnit.SECONDS)); executor.shutdown(); } @@ -688,7 +687,7 @@ public StreamObserver commitWorkStream( // Make the commit requests, waiting for each of them to be verified and acknowledged. CommitWorkStream stream = client.commitWorkStream(); commitWorkTestHelper(stream, commitRequests, 0, 500); - stream.close(); + stream.halfClose(); assertTrue(stream.awaitTermination(30, TimeUnit.SECONDS)); } @@ -723,7 +722,7 @@ public StreamObserver commitWorkStream( for (Future f : futures) { f.get(); } - stream.close(); + stream.halfClose(); assertTrue(stream.awaitTermination(30, TimeUnit.SECONDS)); executor.shutdown(); } @@ -825,7 +824,7 @@ public void onCompleted() { } } - stream.close(); + stream.halfClose(); isClientClosed.set(true); deadline = System.currentTimeMillis() + 60_000; // 1 min @@ -957,13 +956,13 @@ public void onCompleted() { Map> expectedKeyedGetDataRequests = new HashMap<>(); expectedKeyedGetDataRequests.put("Computation1", makeGetDataHeartbeatRequest(computation1Keys)); expectedKeyedGetDataRequests.put("Computation2", makeGetDataHeartbeatRequest(computation2Keys)); - Map> heartbeatsToRefresh = new HashMap<>(); + Map> heartbeatsToRefresh = new HashMap<>(); heartbeatsToRefresh.put("Computation1", makeHeartbeatRequest(computation1Keys)); heartbeatsToRefresh.put("Computation2", makeHeartbeatRequest(computation2Keys)); GetDataStream stream = client.getDataStream(); stream.refreshActiveWork(heartbeatsToRefresh); - stream.close(); + stream.halfClose(); assertTrue(stream.awaitTermination(60, TimeUnit.SECONDS)); boolean receivedAllGetDataHeartbeats = false; @@ -1058,13 +1057,13 @@ public void onCompleted() { } expectedHeartbeats.add(comp1Builder.build()); expectedHeartbeats.add(comp2Builder.build()); - Map> heartbeatRequestMap = new HashMap<>(); + Map> heartbeatRequestMap = new HashMap<>(); heartbeatRequestMap.put("Computation1", makeHeartbeatRequest(computation1Keys)); heartbeatRequestMap.put("Computation2", makeHeartbeatRequest(computation2Keys)); GetDataStream stream = client.getDataStream(); stream.refreshActiveWork(heartbeatRequestMap); - stream.close(); + stream.halfClose(); assertTrue(stream.awaitTermination(60, TimeUnit.SECONDS)); boolean receivedAllHeartbeatRequests = false; @@ -1185,7 +1184,7 @@ public void onCompleted() { // actually report more due to backoff in restarting streams. assertTrue(this.client.getAndResetThrottleTime() > throttleTime); - stream.close(); + stream.halfClose(); assertTrue(stream.awaitTermination(30, TimeUnit.SECONDS)); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java index bc3afaff1b38..1999dbe31902 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java @@ -33,13 +33,13 @@ import java.util.Comparator; import java.util.HashSet; import java.util.List; -import java.util.Optional; import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; @@ -48,6 +48,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.ThrottlingGetDataMetricTracker; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.ChannelCachingStubFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory; import org.apache.beam.runners.dataflow.worker.windmill.testing.FakeWindmillStubFactory; @@ -97,7 +98,6 @@ public class StreamingEngineClientTest { .build(); @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); - @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); private final GrpcWindmillStreamFactory streamFactory = spy(GrpcWindmillStreamFactory.of(JOB_HEADER).build()); @@ -109,7 +109,7 @@ public class StreamingEngineClientTest { private final GrpcDispatcherClient dispatcherClient = GrpcDispatcherClient.forTesting( stubFactory, new ArrayList<>(), new ArrayList<>(), new HashSet<>()); - + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private Server fakeStreamingEngineServer; private CountDownLatch getWorkerMetadataReady; private GetWorkerMetadataTestStub fakeGetWorkerMetadataStub; @@ -181,7 +181,8 @@ private StreamingEngineClient newStreamingEngineClient( getWorkBudgetDistributor, dispatcherClient, CLIENT_ID, - ignored -> mock(WorkCommitter.class)); + ignored -> mock(WorkCommitter.class), + new ThrottlingGetDataMetricTracker(mock(MemoryMonitor.class))); } @Test @@ -222,8 +223,6 @@ public void testStreamsStartCorrectly() throws InterruptedException { Set workerTokens = currentConnections.windmillConnections().values().stream() .map(WindmillConnection::backendWorkerToken) - .filter(Optional::isPresent) - .map(Optional::get) .collect(Collectors.toSet()); assertTrue(workerTokens.contains(workerToken)); @@ -235,7 +234,13 @@ public void testStreamsStartCorrectly() throws InterruptedException { verify(streamFactory, times(2)) .createDirectGetWorkStream( - any(), eq(getWorkRequest(0, 0)), any(), any(), any(), eq(noOpProcessWorkItemFn())); + any(), + eq(getWorkRequest(0, 0)), + any(), + any(), + any(), + any(), + eq(noOpProcessWorkItemFn())); verify(streamFactory, times(2)).createGetDataStream(any(), any()); verify(streamFactory, times(2)).createCommitWorkStream(any(), any()); @@ -312,8 +317,6 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() Set workerTokens = streamingEngineClient.getCurrentConnections().windmillConnections().values().stream() .map(WindmillConnection::backendWorkerToken) - .filter(Optional::isPresent) - .map(Optional::get) .collect(Collectors.toSet()); assertFalse(workerTokens.contains(workerToken)); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSenderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSenderTest.java index 162c69509ae1..9d49c3ef3146 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSenderTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSenderTest.java @@ -27,13 +27,14 @@ import static org.mockito.Mockito.when; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; -import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; @@ -66,7 +67,7 @@ public class WindmillStreamSenderTest { (workItem, watermarks, processingContext, ackWorkItemQueued, getWorkStreamLatencies) -> {}; @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private ManagedChannel inProcessChannel; - private CloudWindmillServiceV1Alpha1Stub stub; + private WindmillConnection connection; @Before public void setUp() { @@ -74,7 +75,10 @@ public void setUp() { grpcCleanup.register( InProcessChannelBuilder.forName("WindmillStreamSenderTest").directExecutor().build()); grpcCleanup.register(inProcessChannel); - stub = CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel); + connection = + WindmillConnection.builder() + .setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)) + .build(); } @After @@ -95,7 +99,7 @@ public void testStartStream_startsAllStreams() { verify(streamFactory) .createDirectGetWorkStream( - eq(stub), + eq(connection), eq( GET_WORK_REQUEST .toBuilder() @@ -105,10 +109,11 @@ public void testStartStream_startsAllStreams() { any(ThrottleTimer.class), any(), any(), + any(), eq(workItemScheduler)); - verify(streamFactory).createGetDataStream(eq(stub), any(ThrottleTimer.class)); - verify(streamFactory).createCommitWorkStream(eq(stub), any(ThrottleTimer.class)); + verify(streamFactory).createGetDataStream(eq(connection.stub()), any(ThrottleTimer.class)); + verify(streamFactory).createCommitWorkStream(eq(connection.stub()), any(ThrottleTimer.class)); } @Test @@ -126,7 +131,7 @@ public void testStartStream_onlyStartsStreamsOnce() { verify(streamFactory, times(1)) .createDirectGetWorkStream( - eq(stub), + eq(connection), eq( GET_WORK_REQUEST .toBuilder() @@ -136,10 +141,13 @@ public void testStartStream_onlyStartsStreamsOnce() { any(ThrottleTimer.class), any(), any(), + any(), eq(workItemScheduler)); - verify(streamFactory, times(1)).createGetDataStream(eq(stub), any(ThrottleTimer.class)); - verify(streamFactory, times(1)).createCommitWorkStream(eq(stub), any(ThrottleTimer.class)); + verify(streamFactory, times(1)) + .createGetDataStream(eq(connection.stub()), any(ThrottleTimer.class)); + verify(streamFactory, times(1)) + .createCommitWorkStream(eq(connection.stub()), any(ThrottleTimer.class)); } @Test @@ -160,7 +168,7 @@ public void testStartStream_onlyStartsStreamsOnceConcurrent() throws Interrupted verify(streamFactory, times(1)) .createDirectGetWorkStream( - eq(stub), + eq(connection), eq( GET_WORK_REQUEST .toBuilder() @@ -170,10 +178,13 @@ public void testStartStream_onlyStartsStreamsOnceConcurrent() throws Interrupted any(ThrottleTimer.class), any(), any(), + any(), eq(workItemScheduler)); - verify(streamFactory, times(1)).createGetDataStream(eq(stub), any(ThrottleTimer.class)); - verify(streamFactory, times(1)).createCommitWorkStream(eq(stub), any(ThrottleTimer.class)); + verify(streamFactory, times(1)) + .createGetDataStream(eq(connection.stub()), any(ThrottleTimer.class)); + verify(streamFactory, times(1)) + .createCommitWorkStream(eq(connection.stub()), any(ThrottleTimer.class)); } @Test @@ -198,17 +209,18 @@ public void testCloseAllStreams_closesAllStreams() { CommitWorkStream mockCommitWorkStream = mock(CommitWorkStream.class); when(mockStreamFactory.createDirectGetWorkStream( - eq(stub), + eq(connection), eq(getWorkRequestWithBudget), any(ThrottleTimer.class), any(), any(), + any(), eq(workItemScheduler))) .thenReturn(mockGetWorkStream); - when(mockStreamFactory.createGetDataStream(eq(stub), any(ThrottleTimer.class))) + when(mockStreamFactory.createGetDataStream(eq(connection.stub()), any(ThrottleTimer.class))) .thenReturn(mockGetDataStream); - when(mockStreamFactory.createCommitWorkStream(eq(stub), any(ThrottleTimer.class))) + when(mockStreamFactory.createCommitWorkStream(eq(connection.stub()), any(ThrottleTimer.class))) .thenReturn(mockCommitWorkStream); WindmillStreamSender windmillStreamSender = @@ -219,9 +231,9 @@ public void testCloseAllStreams_closesAllStreams() { windmillStreamSender.startStreams(); windmillStreamSender.closeAllStreams(); - verify(mockGetWorkStream).close(); - verify(mockGetDataStream).close(); - verify(mockCommitWorkStream).close(); + verify(mockGetWorkStream).shutdown(); + verify(mockGetDataStream).shutdown(); + verify(mockCommitWorkStream).shutdown(); } private WindmillStreamSender newWindmillStreamSender(GetWorkBudget budget) { @@ -231,11 +243,12 @@ private WindmillStreamSender newWindmillStreamSender(GetWorkBudget budget) { private WindmillStreamSender newWindmillStreamSender( GetWorkBudget budget, GrpcWindmillStreamFactory streamFactory) { return WindmillStreamSender.create( - stub, + connection, GET_WORK_REQUEST, budget, streamFactory, workItemScheduler, + ignored -> mock(GetDataClient.class), ignored -> mock(WorkCommitter.class)); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReaderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReaderTest.java index 3460fc4cab92..8dbfc35192b7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReaderTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReaderTest.java @@ -35,13 +35,13 @@ import java.util.Optional; import java.util.concurrent.Future; import org.apache.beam.runners.dataflow.worker.KeyTokenInvalidException; -import org.apache.beam.runners.dataflow.worker.MetricTrackingWindmillServerStub; import org.apache.beam.runners.dataflow.worker.WindmillStateTestUtils; import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.SortedListEntry; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.SortedListRange; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; @@ -97,7 +97,7 @@ private static void assertNoReader(Object obj) throws Exception { WindmillStateTestUtils.assertNoReference(obj, WindmillStateReader.class); } - @Mock private MetricTrackingWindmillServerStub mockWindmill; + @Mock private GetDataClient mockWindmill; private WindmillStateReader underTest; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java index 83ae8aa22ce3..b0c305dc4ec4 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java @@ -31,7 +31,9 @@ import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillStreamFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.WindmillStreamSender; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; @@ -244,7 +246,7 @@ public void testDistributeBudget_distributesFairlyWhenNotEven() { private WindmillStreamSender createWindmillStreamSender(GetWorkBudget getWorkBudget) { return WindmillStreamSender.create( - stub, + WindmillConnection.builder().setStub(stub).build(), Windmill.GetWorkRequest.newBuilder() .setClientId(1L) .setJobId("job") @@ -259,6 +261,7 @@ private WindmillStreamSender createWindmillStreamSender(GetWorkBudget getWorkBud .build()) .build(), (workItem, watermarks, processingContext, ackWorkItemQueued, getWorkStreamLatencies) -> {}, + ignored -> mock(GetDataClient.class), ignored -> mock(WorkCommitter.class)); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java index bd55595da135..146b05bb7e35 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.mock; import java.util.ArrayList; import java.util.HashSet; @@ -34,6 +35,8 @@ import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.joda.time.Duration; @@ -86,8 +89,9 @@ private static ExecutableWork createWork(Supplier clock, Consumer Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build(), Work.createProcessingContext( "computationId", - (a, b) -> Windmill.KeyedGetDataResponse.getDefaultInstance(), - ignored -> {}), + new FakeGetDataClient(), + ignored -> {}, + mock(HeartbeatSender.class)), clock, new ArrayList<>()), processWorkFn); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresherTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java similarity index 84% rename from runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresherTest.java rename to runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java index 13019116767c..9dce3392c60c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresherTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java @@ -38,6 +38,7 @@ import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.function.Supplier; +import java.util.stream.Collectors; import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; @@ -46,7 +47,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; import org.apache.beam.runners.direct.Clock; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; @@ -59,13 +60,14 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; @RunWith(JUnit4.class) -public class DispatchedActiveWorkRefresherTest { - +public class ActiveWorkRefresherTest { private static final Supplier A_LONG_TIME_AGO = () -> Instant.parse("1998-09-04T00:00:00Z"); private static final String COMPUTATION_ID_PREFIX = "ComputationId-"; + private final HeartbeatSender heartbeatSender = mock(HeartbeatSender.class); private static BoundedQueueExecutor workExecutor() { return new BoundedQueueExecutor( @@ -97,15 +99,20 @@ private ActiveWorkRefresher createActiveWorkRefresher( int activeWorkRefreshPeriodMillis, int stuckCommitDurationMillis, Supplier> computations, - Consumer>> activeWorkRefresherFn) { - return new DispatchedActiveWorkRefresher( + ActiveWorkRefresher.HeartbeatTracker heartbeatTracker) { + return new ActiveWorkRefresher( clock, activeWorkRefreshPeriodMillis, stuckCommitDurationMillis, computations, DataflowExecutionStateSampler.instance(), - activeWorkRefresherFn, - Executors.newSingleThreadScheduledExecutor()); + Executors.newSingleThreadScheduledExecutor(), + heartbeatTracker); + } + + private ExecutableWork createOldWork(int workIds, Consumer processWork) { + ShardedKey shardedKey = ShardedKey.create(ByteString.EMPTY, workIds); + return createOldWork(shardedKey, workIds, processWork); } private ExecutableWork createOldWork( @@ -120,10 +127,8 @@ private ExecutableWork createOldWork( .build(), Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build(), Work.createProcessingContext( - "computationId", - (a, b) -> Windmill.KeyedGetDataResponse.getDefaultInstance(), - ignored -> {}), - DispatchedActiveWorkRefresherTest.A_LONG_TIME_AGO, + "computationId", new FakeGetDataClient(), ignored -> {}, heartbeatSender), + A_LONG_TIME_AGO, ImmutableList.of()), processWork); } @@ -147,8 +152,7 @@ public void testActiveWorkRefresh() throws InterruptedException { Map> computationsAndWork = new HashMap<>(); for (int i = 0; i < 5; i++) { ComputationState computationState = createComputationState(i); - ExecutableWork fakeWork = - createOldWork(ShardedKey.create(ByteString.EMPTY, i), i, processWork); + ExecutableWork fakeWork = createOldWork(i, processWork); computationState.activateWork(fakeWork); computations.add(computationState); @@ -158,38 +162,39 @@ public void testActiveWorkRefresh() throws InterruptedException { activeWorkForComputation.add(fakeWork); } - Map> expectedHeartbeats = new HashMap<>(); CountDownLatch heartbeatsSent = new CountDownLatch(1); TestClock fakeClock = new TestClock(Instant.now()); - ActiveWorkRefresher activeWorkRefresher = createActiveWorkRefresher( fakeClock::now, activeWorkRefreshPeriodMillis, 0, () -> computations, - heartbeats -> { - expectedHeartbeats.putAll(heartbeats); - heartbeatsSent.countDown(); - }); + heartbeats -> heartbeatsSent::countDown); + ArgumentCaptor heartbeatsCaptor = ArgumentCaptor.forClass(Heartbeats.class); activeWorkRefresher.start(); fakeClock.advance(Duration.millis(activeWorkRefreshPeriodMillis * 2)); heartbeatsSent.await(); activeWorkRefresher.stop(); + verify(heartbeatSender).sendHeartbeats(heartbeatsCaptor.capture()); + Heartbeats fanoutExpectedHeartbeats = heartbeatsCaptor.getValue(); + assertThat(computationsAndWork.size()) + .isEqualTo(fanoutExpectedHeartbeats.heartbeatRequests().size()); - assertThat(computationsAndWork.size()).isEqualTo(expectedHeartbeats.size()); - for (Map.Entry> expectedHeartbeat : - expectedHeartbeats.entrySet()) { + for (Map.Entry> expectedHeartbeat : + fanoutExpectedHeartbeats.heartbeatRequests().asMap().entrySet()) { String computationId = expectedHeartbeat.getKey(); - List heartbeatRequests = expectedHeartbeat.getValue(); - List work = computationsAndWork.get(computationId); - + Collection heartbeatRequests = expectedHeartbeat.getValue(); + List work = + computationsAndWork.get(computationId).stream() + .map(ExecutableWork::work) + .collect(Collectors.toList()); // Compare the heartbeatRequest's and Work's workTokens, cacheTokens, and shardingKeys. assertThat(heartbeatRequests) .comparingElementsUsing( Correspondence.from( - (HeartbeatRequest h, ExecutableWork w) -> + (Windmill.HeartbeatRequest h, Work w) -> h.getWorkToken() == w.getWorkItem().getWorkToken() && h.getCacheToken() == w.getWorkItem().getWorkToken() && h.getShardingKey() == w.getWorkItem().getShardingKey(), @@ -240,7 +245,7 @@ public void testInvalidateStuckCommits() throws InterruptedException { 0, stuckCommitDurationMillis, computations.rowMap()::keySet, - ignored -> {}); + ignored -> () -> {}); activeWorkRefresher.start(); fakeClock.advance(Duration.millis(stuckCommitDurationMillis));