diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/BatchDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/BatchDataflowWorker.java index 9144729faca2..7407c97619b4 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/BatchDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/BatchDataflowWorker.java @@ -60,22 +60,8 @@ }) public class BatchDataflowWorker implements Closeable { private static final Logger LOG = LoggerFactory.getLogger(BatchDataflowWorker.class); - - /** A client to get and update work items. */ - private final WorkUnitClient workUnitClient; - - /** - * Pipeline options, initially provided via the constructor and partially provided via each work - * work unit. - */ - private final DataflowWorkerHarnessOptions options; - - /** The factory to create {@link DataflowMapTaskExecutor DataflowMapTaskExecutors}. */ - private final DataflowMapTaskExecutorFactory mapTaskExecutorFactory; - /** The idGenerator to generate unique id globally. */ private static final IdGenerator idGenerator = IdGenerators.decrementingLongs(); - /** * Function which converts map tasks to their network representation for execution. * @@ -90,30 +76,7 @@ public class BatchDataflowWorker implements Closeable { new FixMultiOutputInfosOnParDoInstructions(idGenerator) .andThen(new MapTaskToNetworkFunction(idGenerator)); - /** Registry of known {@link ReaderFactory ReaderFactories}. */ - private final ReaderRegistry readerRegistry = ReaderRegistry.defaultRegistry(); - - /** Registry of known {@link SinkFactory SinkFactories}. */ - private final SinkRegistry sinkRegistry = SinkRegistry.defaultRegistry(); - - /** A side input cache shared between all execution contexts. */ - private final Cache> sideInputDataCache; - - /** - * A side input cache shared between all execution contexts. This cache is meant to store values - * as weak references. This allows for insertion of logical keys with zero weight since they will - * only be scoped to the lifetime of the value being cached. - */ - private final Cache sideInputWeakReferenceCache; - private static final int DEFAULT_STATUS_PORT = 8081; - - /** Status pages returning health of worker. */ - private WorkerStatusPages statusPages; - - /** Periodic sender of debug information to the debug capture service. */ - private DebugCapture.Manager debugCaptureManager = null; - /** * A weight in "bytes" for the overhead of a {@link Weighted} wrapper in the cache. It is just an * approximation so it is OK for it to be fairly arbitrary as long as it is nonzero. @@ -121,33 +84,42 @@ public class BatchDataflowWorker implements Closeable { private static final int OVERHEAD_WEIGHT = 8; private static final long MEGABYTES = 1024 * 1024; - /** * Limit the number of logical references. Weak references may never be cleared if the object is * long lived irrespective if the user actually is interested in the key lookup anymore. */ private static final int MAX_LOGICAL_REFERENCES = 1_000_000; - /** How many concurrent write operations to a cache should we allow. */ private static final int CACHE_CONCURRENCY_LEVEL = 4 * Runtime.getRuntime().availableProcessors(); + /** A client to get and update work items. */ + private final WorkUnitClient workUnitClient; + /** + * Pipeline options, initially provided via the constructor and partially provided via each work + * work unit. + */ + private final DataflowWorkerHarnessOptions options; + /** The factory to create {@link DataflowMapTaskExecutor DataflowMapTaskExecutors}. */ + private final DataflowMapTaskExecutorFactory mapTaskExecutorFactory; + /** Registry of known {@link ReaderFactory ReaderFactories}. */ + private final ReaderRegistry readerRegistry = ReaderRegistry.defaultRegistry(); + /** Registry of known {@link SinkFactory SinkFactories}. */ + private final SinkRegistry sinkRegistry = SinkRegistry.defaultRegistry(); + /** A side input cache shared between all execution contexts. */ + private final Cache> sideInputDataCache; + /** + * A side input cache shared between all execution contexts. This cache is meant to store values + * as weak references. This allows for insertion of logical keys with zero weight since they will + * only be scoped to the lifetime of the value being cached. + */ + private final Cache sideInputWeakReferenceCache; private final Function> mapTaskToNetwork; - private final MemoryMonitor memoryMonitor; private final Thread memoryMonitorThread; - - /** - * Returns a {@link BatchDataflowWorker} configured to execute user functions via intrinsic Java - * execution. - * - *

This is also known as the "legacy" or "pre-portability" approach. It is not yet deprecated - * as there is not a compatible path forward for users. - */ - static BatchDataflowWorker forBatchIntrinsicWorkerHarness( - WorkUnitClient workUnitClient, DataflowWorkerHarnessOptions options) { - return new BatchDataflowWorker( - workUnitClient, IntrinsicMapTaskExecutorFactory.defaultFactory(), options); - } + /** Status pages returning health of worker. */ + private final WorkerStatusPages statusPages; + /** Periodic sender of debug information to the debug capture service. */ + private DebugCapture.Manager debugCaptureManager = null; protected BatchDataflowWorker( WorkUnitClient workUnitClient, @@ -188,6 +160,19 @@ protected BatchDataflowWorker( ExecutionStateSampler.instance().start(); } + /** + * Returns a {@link BatchDataflowWorker} configured to execute user functions via intrinsic Java + * execution. + * + *

This is also known as the "legacy" or "pre-portability" approach. It is not yet deprecated + * as there is not a compatible path forward for users. + */ + static BatchDataflowWorker forBatchIntrinsicWorkerHarness( + WorkUnitClient workUnitClient, DataflowWorkerHarnessOptions options) { + return new BatchDataflowWorker( + workUnitClient, IntrinsicMapTaskExecutorFactory.defaultFactory(), options); + } + private static DebugCapture.Manager initializeAndStartDebugCaptureManager( DataflowWorkerHarnessOptions options, Collection debugCapturePages) { DebugCapture.Manager result = new DebugCapture.Manager(options, debugCapturePages); @@ -215,7 +200,7 @@ private static Thread startMemoryMonitorThread(MemoryMonitor memoryMonitor) { */ public boolean getAndPerformWork() throws IOException { while (true) { - Optional work = workUnitClient.getWorkItem(); + Optional work = Optional.fromJavaUtil(workUnitClient.getWorkItem()); if (work.isPresent()) { WorkItemStatusClient statusProvider = new WorkItemStatusClient(workUnitClient, work.get()); return doWork(work.get(), statusProvider); @@ -243,7 +228,7 @@ boolean doWork(WorkItem workItem, WorkItemStatusClient workItemStatusClient) thr } else if (workItem.getSourceOperationTask() != null) { stageName = workItem.getSourceOperationTask().getStageName(); } else { - throw new RuntimeException("Unknown kind of work item: " + workItem.toString()); + throw new RuntimeException("Unknown kind of work item: " + workItem); } CounterSet counterSet = new CounterSet(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClient.java index bf809cfd0121..ffa377fd3f82 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClient.java @@ -39,13 +39,13 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Optional; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; import org.apache.beam.runners.dataflow.util.PropertyNames; import org.apache.beam.runners.dataflow.worker.logging.DataflowWorkerLoggingMDC; import org.apache.beam.runners.dataflow.worker.util.common.worker.WorkProgressUpdater; import org.apache.beam.sdk.extensions.gcp.util.Transport; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Optional; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.joda.time.DateTime; @@ -87,7 +87,7 @@ class DataflowWorkUnitClient implements WorkUnitClient { } /** - * Gets a {@link WorkItem} from the Dataflow service, or returns {@link Optional#absent()} if no + * Gets a {@link WorkItem} from the Dataflow service, or returns {@link Optional#empty()} if no * work was found. * *

If work is returned, the calling thread should call reportWorkItemStatus after completing it @@ -116,11 +116,11 @@ public Optional getWorkItem() throws IOException { if (!workItem.isPresent()) { // Normal case, this means that the response contained no work, i.e. no work is available // at this time. - return Optional.absent(); + return Optional.empty(); } - if (workItem.isPresent() && workItem.get().getId() == null) { - logger.debug("Discarding invalid work item {}", workItem.orNull()); - return Optional.absent(); + if (workItem.get().getId() == null) { + logger.debug("Discarding invalid work item {}", workItem.get()); + return Optional.empty(); } WorkItem work = workItem.get(); @@ -148,7 +148,7 @@ public Optional getWorkItem() throws IOException { /** * Gets a global streaming config {@link WorkItem} from the Dataflow service, or returns {@link - * Optional#absent()} if no work was found. + * Optional#empty()} if no work was found. */ @Override public Optional getGlobalStreamingConfigWorkItem() throws IOException { @@ -158,7 +158,7 @@ public Optional getGlobalStreamingConfigWorkItem() throws IOException /** * Gets a streaming config {@link WorkItem} for the given computation from the Dataflow service, - * or returns {@link Optional#absent()} if no work was found. + * or returns {@link Optional#empty()} if no work was found. */ @Override public Optional getStreamingConfigWorkItem(String computationId) throws IOException { @@ -197,7 +197,7 @@ private Optional getWorkItemInternal( List workItems = response.getWorkItems(); if (workItems == null || workItems.isEmpty()) { // We didn't lease any work. - return Optional.absent(); + return Optional.empty(); } else if (workItems.size() > 1) { throw new IOException( "This version of the SDK expects no more than one work item from the service: " diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 8629b7116973..7110fee29362 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -19,10 +19,8 @@ import static java.nio.charset.StandardCharsets.UTF_8; import static org.apache.beam.runners.dataflow.DataflowRunner.hasExperiment; -import static org.apache.beam.runners.dataflow.worker.DataflowSystemMetrics.THROTTLING_MSECS_METRIC_NAME; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; -import com.google.api.services.dataflow.model.CounterStructuredName; import com.google.api.services.dataflow.model.CounterUpdate; import com.google.api.services.dataflow.model.MapTask; import com.google.api.services.dataflow.model.Status; @@ -30,23 +28,19 @@ import com.google.api.services.dataflow.model.StreamingConfigTask; import com.google.api.services.dataflow.model.WorkItem; import com.google.api.services.dataflow.model.WorkItemStatus; -import com.google.auto.value.AutoValue; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import java.io.File; import java.io.IOException; import java.io.PrintWriter; -import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.Deque; -import java.util.EnumMap; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.Queue; +import java.util.Optional; import java.util.Random; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; @@ -54,9 +48,7 @@ import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; -import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.Semaphore; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -66,17 +58,13 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.apache.beam.runners.core.metrics.ExecutionStateSampler; -import org.apache.beam.runners.core.metrics.ExecutionStateTracker; import org.apache.beam.runners.core.metrics.MetricsLogger; import org.apache.beam.runners.dataflow.DataflowRunner; import org.apache.beam.runners.dataflow.internal.CustomSources; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; import org.apache.beam.runners.dataflow.util.CloudObject; import org.apache.beam.runners.dataflow.util.CloudObjects; -import org.apache.beam.runners.dataflow.worker.DataflowSystemMetrics.StreamingPerStageSystemCounterNames; import org.apache.beam.runners.dataflow.worker.DataflowSystemMetrics.StreamingSystemCounterNames; -import org.apache.beam.runners.dataflow.worker.StreamingDataflowWorker.Work.State; -import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.StreamingModeExecutionStateRegistry; import org.apache.beam.runners.dataflow.worker.apiary.FixMultiOutputInfosOnParDoInstructions; import org.apache.beam.runners.dataflow.worker.counters.Counter; import org.apache.beam.runners.dataflow.worker.counters.CounterSet; @@ -97,6 +85,15 @@ import org.apache.beam.runners.dataflow.worker.status.LastExceptionDataProvider; import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider; import org.apache.beam.runners.dataflow.worker.status.WorkerStatusPages; +import org.apache.beam.runners.dataflow.worker.streaming.Commit; +import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; +import org.apache.beam.runners.dataflow.worker.streaming.ExecutionState; +import org.apache.beam.runners.dataflow.worker.streaming.KeyCommitTooLargeException; +import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey; +import org.apache.beam.runners.dataflow.worker.streaming.StageInfo; +import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue; +import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.streaming.Work.State; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor; import org.apache.beam.runners.dataflow.worker.util.common.worker.ElementCounter; @@ -125,9 +122,7 @@ import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder; import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; -import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.TextFormat; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Optional; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Splitter; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache; @@ -154,8 +149,29 @@ }) public class StreamingDataflowWorker { - private static final Logger LOG = LoggerFactory.getLogger(StreamingDataflowWorker.class); + // TODO(https://github.com/apache/beam/issues/19632): Update throttling counters to use generic + // throttling-msecs metric. + public static final MetricName BIGQUERY_STREAMING_INSERT_THROTTLE_TIME = + MetricName.named( + "org.apache.beam.sdk.io.gcp.bigquery.BigQueryServicesImpl$DatasetServiceImpl", + "throttling-msecs"); + // Maximum number of threads for processing. Currently each thread processes one key at a time. + static final int MAX_PROCESSING_THREADS = 300; + static final long THREAD_EXPIRATION_TIME_SEC = 60; + static final long TARGET_COMMIT_BUNDLE_BYTES = 32 << 20; + static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB + static final int NUM_COMMIT_STREAMS = 1; + static final int GET_WORK_STREAM_TIMEOUT_MINUTES = 3; + static final Duration COMMIT_STREAM_TIMEOUT = Duration.standardMinutes(1); + /** + * Sinks are marked 'full' in {@link StreamingModeExecutionContext} once the amount of data sinked + * (across all the sinks, if there are more than one) reaches this limit. This serves as hint for + * readers to stop producing more. This can be disabled with 'disable_limiting_bundle_sink_bytes' + * experiment. + */ + static final int MAX_SINK_BYTES = 10_000_000; + private static final Logger LOG = LoggerFactory.getLogger(StreamingDataflowWorker.class); /** The idGenerator to generate unique id globally. */ private static final IdGenerator idGenerator = IdGenerators.decrementingLongs(); /** @@ -164,7 +180,6 @@ public class StreamingDataflowWorker { */ private static final Function fixMultiOutputInfos = new FixMultiOutputInfosOnParDoInstructions(idGenerator); - /** * Function which converts map tasks to their network representation for execution. * @@ -176,235 +191,35 @@ public class StreamingDataflowWorker { private static final Function> mapTaskToBaseNetwork = new MapTaskToNetworkFunction(idGenerator); - private static Random clientIdGenerator = new Random(); - - // Maximum number of threads for processing. Currently each thread processes one key at a time. - static final int MAX_PROCESSING_THREADS = 300; - static final long THREAD_EXPIRATION_TIME_SEC = 60; - static final long TARGET_COMMIT_BUNDLE_BYTES = 32 << 20; - static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB - static final int NUM_COMMIT_STREAMS = 1; - static final int GET_WORK_STREAM_TIMEOUT_MINUTES = 3; - static final Duration COMMIT_STREAM_TIMEOUT = Duration.standardMinutes(1); - private static final int DEFAULT_STATUS_PORT = 8081; - // Maximum size of the result of a GetWork request. private static final long MAX_GET_WORK_FETCH_BYTES = 64L << 20; // 64m - // Reserved ID for counter updates. // Matches kWindmillCounterUpdate in workflow_worker_service_multi_hubs.cc. private static final String WINDMILL_COUNTER_UPDATE_WORK_ID = "3"; - /** Maximum number of failure stacktraces to report in each update sent to backend. */ private static final int MAX_FAILURES_TO_REPORT_IN_UPDATE = 1000; - // TODO(https://github.com/apache/beam/issues/19632): Update throttling counters to use generic - // throttling-msecs metric. - public static final MetricName BIGQUERY_STREAMING_INSERT_THROTTLE_TIME = - MetricName.named( - "org.apache.beam.sdk.io.gcp.bigquery.BigQueryServicesImpl$DatasetServiceImpl", - "throttling-msecs"); - private static final Duration MAX_LOCAL_PROCESSING_RETRY_DURATION = Duration.standardMinutes(5); - - /** Returns whether an exception was caused by a {@link OutOfMemoryError}. */ - private static boolean isOutOfMemoryError(Throwable t) { - while (t != null) { - if (t instanceof OutOfMemoryError) { - return true; - } - t = t.getCause(); - } - return false; - } - - private static class KeyCommitTooLargeException extends Exception { - - public static KeyCommitTooLargeException causedBy( - String computationId, long byteLimit, WorkItemCommitRequest request) { - StringBuilder message = new StringBuilder(); - message.append("Commit request for stage "); - message.append(computationId); - message.append(" and key "); - message.append(request.getKey().toStringUtf8()); - if (request.getSerializedSize() > 0) { - message.append( - " has size " - + request.getSerializedSize() - + " which is more than the limit of " - + byteLimit); - } else { - message.append(" is larger than 2GB and cannot be processed"); - } - message.append( - ". This may be caused by grouping a very " - + "large amount of data in a single window without using Combine," - + " or by producing a large amount of data from a single input element."); - return new KeyCommitTooLargeException(message.toString()); - } - - private KeyCommitTooLargeException(String message) { - super(message); - } - } - - private static MapTask parseMapTask(String input) throws IOException { - return Transport.getJsonFactory().fromString(input, MapTask.class); - } - - public static void main(String[] args) throws Exception { - JvmInitializers.runOnStartup(); - - DataflowWorkerHarnessHelper.initializeLogging(StreamingDataflowWorker.class); - DataflowWorkerHarnessOptions options = - DataflowWorkerHarnessHelper.initializeGlobalStateAndPipelineOptions( - StreamingDataflowWorker.class); - DataflowWorkerHarnessHelper.configureLogging(options); - checkArgument( - options.isStreaming(), - "%s instantiated with options indicating batch use", - StreamingDataflowWorker.class.getName()); - - checkArgument( - !DataflowRunner.hasExperiment(options, "beam_fn_api"), - "%s cannot be main() class with beam_fn_api enabled", - StreamingDataflowWorker.class.getSimpleName()); - - StreamingDataflowWorker worker = - StreamingDataflowWorker.fromDataflowWorkerHarnessOptions(options); - - // Use the MetricsLogger container which is used by BigQueryIO to periodically log process-wide - // metrics. - MetricsEnvironment.setProcessWideContainer(new MetricsLogger(null)); - - JvmInitializers.runBeforeProcessing(options); - worker.startStatusPages(); - worker.start(); - } - - /** Bounded set of queues, with a maximum total weight. */ - private static class WeightedBoundedQueue { - - private final LinkedBlockingQueue queue = new LinkedBlockingQueue<>(); - private final int maxWeight; - private final Semaphore limit; - private final Function weigher; - - public WeightedBoundedQueue(int maxWeight, Function weigher) { - this.maxWeight = maxWeight; - this.limit = new Semaphore(maxWeight, true); - this.weigher = weigher; - } - - /** - * Adds the value to the queue, blocking if this would cause the overall weight to exceed the - * limit. - */ - public void put(V value) { - limit.acquireUninterruptibly(weigher.apply(value)); - queue.add(value); - } - - /** Returns and removes the next value, or null if there is no such value. */ - public @Nullable V poll() { - V result = queue.poll(); - if (result != null) { - limit.release(weigher.apply(result)); - } - return result; - } - - /** - * Retrieves and removes the head of this queue, waiting up to the specified wait time if - * necessary for an element to become available. - * - * @param timeout how long to wait before giving up, in units of {@code unit} - * @param unit a {@code TimeUnit} determining how to interpret the {@code timeout} parameter - * @return the head of this queue, or {@code null} if the specified waiting time elapses before - * an element is available - * @throws InterruptedException if interrupted while waiting - */ - public @Nullable V poll(long timeout, TimeUnit unit) throws InterruptedException { - V result = queue.poll(timeout, unit); - if (result != null) { - limit.release(weigher.apply(result)); - } - return result; - } - - /** Returns and removes the next value, or blocks until one is available. */ - public @Nullable V take() throws InterruptedException { - V result = queue.take(); - limit.release(weigher.apply(result)); - return result; - } - - /** Returns the current weight of the queue. */ - public int weight() { - return maxWeight - limit.availablePermits(); - } - - public int size() { - return queue.size(); - } - } - - // Value class for a queued commit. - static class Commit { - - private Windmill.WorkItemCommitRequest request; - private ComputationState computationState; - private Work work; - - public Commit( - Windmill.WorkItemCommitRequest request, ComputationState computationState, Work work) { - this.request = request; - assert request.getSerializedSize() > 0; - this.computationState = computationState; - this.work = work; - } - - public Windmill.WorkItemCommitRequest getRequest() { - return request; - } - - public ComputationState getComputationState() { - return computationState; - } - - public Work getWork() { - return work; - } - - public int getSize() { - return request.getSerializedSize(); - } - } - + private static final Random clientIdGenerator = new Random(); + final WindmillStateCache stateCache; // Maps from computation ids to per-computation state. private final ConcurrentMap computationMap = new ConcurrentHashMap<>(); private final WeightedBoundedQueue commitQueue = - new WeightedBoundedQueue<>( + WeightedBoundedQueue.create( MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize())); - // Cache of tokens to commit callbacks. // Using Cache with time eviction policy helps us to prevent memory leak when callback ids are // discarded by Dataflow service and calling commitCallback is best-effort. private final Cache commitCallbacks = CacheBuilder.newBuilder().expireAfterWrite(5L, TimeUnit.MINUTES).build(); - // Map of user state names to system state names. // TODO(drieber): obsolete stateNameMap. Use transformUserNameToStateFamily in // ComputationState instead. private final ConcurrentMap stateNameMap = new ConcurrentHashMap<>(); private final ConcurrentMap systemNameToComputationIdMap = new ConcurrentHashMap<>(); - - final WindmillStateCache stateCache; - private final ThreadFactory threadFactory; - private DataflowMapTaskExecutorFactory mapTaskExecutorFactory; private final BoundedQueueExecutor workUnitExecutor; private final WindmillServerStub windmillServer; private final Thread dispatchThread; @@ -415,16 +230,13 @@ public int getSize() { private final StreamingDataflowWorkerOptions options; private final boolean windmillServiceEnabled; private final long clientId; - private final MetricTrackingWindmillServerStub metricTrackingWindmillServer; private final CounterSet pendingDeltaCounters = new CounterSet(); private final CounterSet pendingCumulativeCounters = new CounterSet(); private final java.util.concurrent.ConcurrentLinkedQueue pendingMonitoringInfos = new ConcurrentLinkedQueue<>(); - // Map from stage name to StageInfo containing metrics container registry and per stage counters. private final ConcurrentMap stageInfoMap = new ConcurrentHashMap(); - // Built-in delta counters. private final Counter windmillShuffleBytesRead; private final Counter windmillStateBytesRead; @@ -436,134 +248,35 @@ public int getSize() { private final Counter timeAtMaxActiveThreads; private final Counter windmillMaxObservedWorkItemCommitBytes; private final Counter memoryThrashing; - private ScheduledExecutorService refreshWorkTimer; - private ScheduledExecutorService statusPageTimer; - private final boolean publishCounters; - private ScheduledExecutorService globalWorkerUpdatesTimer; - private int retryLocallyDelayMs = 10000; - - // Periodically fires a global config request to dataflow service. Only used when windmill service - // is enabled. - private ScheduledExecutorService globalConfigRefreshTimer; - private final MemoryMonitor memoryMonitor; private final Thread memoryMonitorThread; - private final WorkerStatusPages statusPages; - // Periodic sender of debug information to the debug capture service. - private DebugCapture.Manager debugCaptureManager = null; - // Limit on bytes sinked (committed) in a work item. private final long maxSinkBytes; // = MAX_SINK_BYTES unless disabled in options. - // Possibly overridden by streaming engine config. - private int maxWorkItemCommitBytes = Integer.MAX_VALUE; - private final EvictingQueue pendingFailuresToReport = - EvictingQueue.create(MAX_FAILURES_TO_REPORT_IN_UPDATE); - + EvictingQueue.create(MAX_FAILURES_TO_REPORT_IN_UPDATE); private final ReaderCache readerCache; - private final WorkUnitClient workUnitClient; private final CompletableFuture isDoneFuture; private final Function> mapTaskToNetwork; - - /** - * Sinks are marked 'full' in {@link StreamingModeExecutionContext} once the amount of data sinked - * (across all the sinks, if there are more than one) reaches this limit. This serves as hint for - * readers to stop producing more. This can be disabled with 'disable_limiting_bundle_sink_bytes' - * experiment. - */ - static final int MAX_SINK_BYTES = 10_000_000; - private final ReaderRegistry readerRegistry = ReaderRegistry.defaultRegistry(); private final SinkRegistry sinkRegistry = SinkRegistry.defaultRegistry(); - - private HotKeyLogger hotKeyLogger; - private final Supplier clock; private final Function executorSupplier; - - /** Contains a few of the stage specific fields. E.g. metrics container registry, counters etc. */ - private static class StageInfo { - - final String stageName; - final String systemName; - final MetricsContainerRegistry metricsContainerRegistry; - final StreamingModeExecutionStateRegistry executionStateRegistry; - final CounterSet deltaCounters; - final Counter throttledMsecs; - final Counter totalProcessingMsecs; - final Counter timerProcessingMsecs; - - StageInfo(String stageName, String systemName, StreamingDataflowWorker worker) { - this.stageName = stageName; - this.systemName = systemName; - metricsContainerRegistry = StreamingStepMetricsContainer.createRegistry(); - executionStateRegistry = new StreamingModeExecutionStateRegistry(worker); - NameContext nameContext = NameContext.create(stageName, null, systemName, null); - deltaCounters = new CounterSet(); - throttledMsecs = - deltaCounters.longSum( - StreamingPerStageSystemCounterNames.THROTTLED_MSECS.counterName(nameContext)); - totalProcessingMsecs = - deltaCounters.longSum( - StreamingPerStageSystemCounterNames.TOTAL_PROCESSING_MSECS.counterName(nameContext)); - timerProcessingMsecs = - deltaCounters.longSum( - StreamingPerStageSystemCounterNames.TIMER_PROCESSING_MSECS.counterName(nameContext)); - } - - List extractCounterUpdates() { - List counterUpdates = new ArrayList<>(); - Iterables.addAll( - counterUpdates, - StreamingStepMetricsContainer.extractMetricUpdates(metricsContainerRegistry)); - Iterables.addAll(counterUpdates, executionStateRegistry.extractUpdates(false)); - for (CounterUpdate counterUpdate : counterUpdates) { - translateKnownStepCounters(counterUpdate); - } - counterUpdates.addAll( - deltaCounters.extractModifiedDeltaUpdates(DataflowCounterUpdateExtractor.INSTANCE)); - return counterUpdates; - } - - // Checks if the step counter affects any per-stage counters. Currently 'throttled_millis' - // is the only counter updated. - private void translateKnownStepCounters(CounterUpdate stepCounterUpdate) { - CounterStructuredName structuredName = - stepCounterUpdate.getStructuredNameAndMetadata().getName(); - if ((THROTTLING_MSECS_METRIC_NAME.getNamespace().equals(structuredName.getOriginNamespace()) - && THROTTLING_MSECS_METRIC_NAME.getName().equals(structuredName.getName())) - || (BIGQUERY_STREAMING_INSERT_THROTTLE_TIME - .getNamespace() - .equals(structuredName.getOriginNamespace()) - && BIGQUERY_STREAMING_INSERT_THROTTLE_TIME - .getName() - .equals(structuredName.getName()))) { - long msecs = DataflowCounterUpdateExtractor.splitIntToLong(stepCounterUpdate.getInteger()); - if (msecs > 0) { - throttledMsecs.addValue(msecs); - } - } - } - } - - public static StreamingDataflowWorker fromDataflowWorkerHarnessOptions( - DataflowWorkerHarnessOptions options) throws IOException { - - return new StreamingDataflowWorker( - Collections.emptyList(), - IntrinsicMapTaskExecutorFactory.defaultFactory(), - new DataflowWorkUnitClient(options, LOG), - options.as(StreamingDataflowWorkerOptions.class), - true, - new HotKeyLogger(), - Instant::now, - (threadName) -> - Executors.newSingleThreadScheduledExecutor( - new ThreadFactoryBuilder().setNameFormat(threadName).build())); - } + private final DataflowMapTaskExecutorFactory mapTaskExecutorFactory; + private final HotKeyLogger hotKeyLogger; + // Periodic sender of debug information to the debug capture service. + private final DebugCapture.@Nullable Manager debugCaptureManager; + private ScheduledExecutorService refreshWorkTimer; + private ScheduledExecutorService statusPageTimer; + private ScheduledExecutorService globalWorkerUpdatesTimer; + private int retryLocallyDelayMs = 10000; + // Periodically fires a global config request to dataflow service. Only used when windmill service + // is enabled. + private ScheduledExecutorService globalConfigRefreshTimer; + // Possibly overridden by streaming engine config. + private int maxWorkItemCommitBytes = Integer.MAX_VALUE; @VisibleForTesting StreamingDataflowWorker( @@ -593,6 +306,8 @@ public static StreamingDataflowWorker fromDataflowWorkerHarnessOptions( if (windmillServiceEnabled) { this.debugCaptureManager = new DebugCapture.Manager(options, statusPages.getDebugCapturePages()); + } else { + this.debugCaptureManager = null; } this.windmillShuffleBytesRead = pendingDeltaCounters.longSum( @@ -701,6 +416,81 @@ public void run() { LOG.debug("maxWorkItemCommitBytes: {}", maxWorkItemCommitBytes); } + /** Returns whether an exception was caused by a {@link OutOfMemoryError}. */ + private static boolean isOutOfMemoryError(Throwable t) { + while (t != null) { + if (t instanceof OutOfMemoryError) { + return true; + } + t = t.getCause(); + } + return false; + } + + private static MapTask parseMapTask(String input) throws IOException { + return Transport.getJsonFactory().fromString(input, MapTask.class); + } + + public static void main(String[] args) throws Exception { + JvmInitializers.runOnStartup(); + + DataflowWorkerHarnessHelper.initializeLogging(StreamingDataflowWorker.class); + DataflowWorkerHarnessOptions options = + DataflowWorkerHarnessHelper.initializeGlobalStateAndPipelineOptions( + StreamingDataflowWorker.class); + DataflowWorkerHarnessHelper.configureLogging(options); + checkArgument( + options.isStreaming(), + "%s instantiated with options indicating batch use", + StreamingDataflowWorker.class.getName()); + + checkArgument( + !DataflowRunner.hasExperiment(options, "beam_fn_api"), + "%s cannot be main() class with beam_fn_api enabled", + StreamingDataflowWorker.class.getSimpleName()); + + StreamingDataflowWorker worker = + StreamingDataflowWorker.fromDataflowWorkerHarnessOptions(options); + + // Use the MetricsLogger container which is used by BigQueryIO to periodically log process-wide + // metrics. + MetricsEnvironment.setProcessWideContainer(new MetricsLogger(null)); + + JvmInitializers.runBeforeProcessing(options); + worker.startStatusPages(); + worker.start(); + } + + public static StreamingDataflowWorker fromDataflowWorkerHarnessOptions( + DataflowWorkerHarnessOptions options) throws IOException { + + return new StreamingDataflowWorker( + Collections.emptyList(), + IntrinsicMapTaskExecutorFactory.defaultFactory(), + new DataflowWorkUnitClient(options, LOG), + options.as(StreamingDataflowWorkerOptions.class), + true, + new HotKeyLogger(), + Instant::now, + (threadName) -> + Executors.newSingleThreadScheduledExecutor( + new ThreadFactoryBuilder().setNameFormat(threadName).build())); + } + + private static void sleep(int millis) { + Uninterruptibles.sleepUninterruptibly(millis, TimeUnit.MILLISECONDS); + } + + /** Sets the stage name and workId of the current Thread for logging. */ + private static void setUpWorkLoggingContext(Windmill.WorkItem workItem, String computationId) { + String workIdBuilder = + Long.toHexString(workItem.getShardingKey()) + + '-' + + Long.toHexString(workItem.getWorkToken()); + DataflowWorkerLoggingMDC.setWorkId(workIdBuilder); + DataflowWorkerLoggingMDC.setStageName(computationId); + } + private int chooseMaximumNumberOfThreads() { if (options.getNumberOfWorkerHarnessThreads() != 0) { return options.getNumberOfWorkerHarnessThreads(); @@ -810,7 +600,7 @@ public void run() { + options.getWorkerId() + "_" + page.pageName() - + timestamp.toString()) + + timestamp) .replaceAll("/", "_")); writer = new PrintWriter(outputFile, UTF_8.name()); page.captureData(writer); @@ -938,10 +728,6 @@ private synchronized void addComputation( } } - private static void sleep(int millis) { - Uninterruptibles.sleepUninterruptibly(millis, TimeUnit.MILLISECONDS); - } - /** * If the computation is not yet known about, configuration for it will be fetched. This can still * return null if there is no configuration fetched for the computation. @@ -996,7 +782,7 @@ private void dispatchLoop() { inputDataWatermark, synchronizedProcessingTime, workItem, - /*getWorkStreamLatencies=*/ Collections.emptyList()); + /* getWorkStreamLatencies= */ Collections.emptyList()); } } } @@ -1049,138 +835,20 @@ private void scheduleWorkItem( WindmillTimeUtils.windmillToHarnessWatermark(workItem.getOutputDataWatermark()); Preconditions.checkState( outputDataWatermark == null || !outputDataWatermark.isAfter(inputDataWatermark)); - Work work = - new Work(workItem, clock, getWorkStreamLatencies) { - @Override - public void run() { - process( - computationState, - inputDataWatermark, - outputDataWatermark, - synchronizedProcessingTime, - this); - } - }; + Work scheduledWork = + Work.create( + workItem, + clock, + getWorkStreamLatencies, + work -> + process( + computationState, + inputDataWatermark, + outputDataWatermark, + synchronizedProcessingTime, + work)); computationState.activateWork( - ShardedKey.create(workItem.getKey(), workItem.getShardingKey()), work); - } - - @AutoValue - abstract static class ShardedKey { - - public static ShardedKey create(ByteString key, long shardingKey) { - return new AutoValue_StreamingDataflowWorker_ShardedKey(key, shardingKey); - } - - public abstract ByteString key(); - - public abstract long shardingKey(); - - @Override - public final String toString() { - ByteString keyToDisplay = key(); - if (keyToDisplay.size() > 100) { - keyToDisplay = keyToDisplay.substring(0, 100); - } - return String.format("%016x-%s", shardingKey(), TextFormat.escapeBytes(keyToDisplay)); - } - } - - abstract static class Work implements Runnable { - - enum State { - QUEUED(Windmill.LatencyAttribution.State.QUEUED), - PROCESSING(Windmill.LatencyAttribution.State.ACTIVE), - READING(Windmill.LatencyAttribution.State.READING), - COMMIT_QUEUED(Windmill.LatencyAttribution.State.COMMITTING), - COMMITTING(Windmill.LatencyAttribution.State.COMMITTING), - GET_WORK_IN_WINDMILL_WORKER(Windmill.LatencyAttribution.State.GET_WORK_IN_WINDMILL_WORKER), - GET_WORK_IN_TRANSIT_TO_DISPATCHER( - Windmill.LatencyAttribution.State.GET_WORK_IN_TRANSIT_TO_DISPATCHER), - GET_WORK_IN_TRANSIT_TO_USER_WORKER( - Windmill.LatencyAttribution.State.GET_WORK_IN_TRANSIT_TO_USER_WORKER); - - private final Windmill.LatencyAttribution.State latencyAttributionState; - - private State(Windmill.LatencyAttribution.State latencyAttributionState) { - this.latencyAttributionState = latencyAttributionState; - } - - Windmill.LatencyAttribution.State toLatencyAttributionState() { - return latencyAttributionState; - } - } - - private final Windmill.WorkItem workItem; - private final Supplier clock; - private final Instant startTime; - private Instant stateStartTime; - private State state; - private final Map totalDurationPerState = - new EnumMap<>(Windmill.LatencyAttribution.State.class); - - public Work( - Windmill.WorkItem workItem, - Supplier clock, - Collection getWorkStreamLatencies) { - this.workItem = workItem; - this.clock = clock; - this.startTime = this.stateStartTime = clock.get(); - this.state = State.QUEUED; - recordGetWorkStreamLatencies(getWorkStreamLatencies); - } - - public Windmill.WorkItem getWorkItem() { - return workItem; - } - - public Instant getStartTime() { - return startTime; - } - - public State getState() { - return state; - } - - public void setState(State state) { - Instant now = clock.get(); - totalDurationPerState.compute( - this.state.toLatencyAttributionState(), - (s, d) -> new Duration(this.stateStartTime, now).plus(d == null ? Duration.ZERO : d)); - this.state = state; - this.stateStartTime = now; - } - - public Instant getStateStartTime() { - return stateStartTime; - } - - private void recordGetWorkStreamLatencies( - Collection getWorkStreamLatencies) { - for (LatencyAttribution latency : getWorkStreamLatencies) { - totalDurationPerState.put( - latency.getState(), Duration.millis(latency.getTotalDurationMillis())); - } - } - - public Collection getLatencyAttributions() { - List list = new ArrayList<>(); - for (Windmill.LatencyAttribution.State state : Windmill.LatencyAttribution.State.values()) { - Duration duration = totalDurationPerState.getOrDefault(state, Duration.ZERO); - if (state == this.state.toLatencyAttributionState()) { - duration = duration.plus(new Duration(this.stateStartTime, clock.get())); - } - if (duration.equals(Duration.ZERO)) { - continue; - } - list.add( - Windmill.LatencyAttribution.newBuilder() - .setState(state) - .setTotalDurationMillis(duration.getMillis()) - .build()); - } - return list; - } + ShardedKey.create(workItem.getKey(), workItem.getShardingKey()), scheduledWork); } /** @@ -1250,15 +918,9 @@ private void process( final String computationId = computationState.getComputationId(); final ByteString key = workItem.getKey(); work.setState(State.PROCESSING); - { - StringBuilder workIdBuilder = new StringBuilder(33); - workIdBuilder.append(Long.toHexString(workItem.getShardingKey())); - workIdBuilder.append('-'); - workIdBuilder.append(Long.toHexString(workItem.getWorkToken())); - DataflowWorkerLoggingMDC.setWorkId(workIdBuilder.toString()); - } - DataflowWorkerLoggingMDC.setStageName(computationId); + setUpWorkLoggingContext(workItem, computationId); + LOG.debug("Starting processing for {}:\n{}", computationId, work); Windmill.WorkItemCommitRequest.Builder outputBuilder = initializeOutputBuilder(key, workItem); @@ -1269,7 +931,7 @@ private void process( if (workItem.getSourceState().getOnlyFinalize()) { outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true)); work.setState(State.COMMIT_QUEUED); - commitQueue.put(new Commit(outputBuilder.build(), computationState, work)); + commitQueue.put(Commit.create(outputBuilder.build(), computationState, work)); return; } @@ -1279,7 +941,7 @@ private void process( StageInfo stageInfo = stageInfoMap.computeIfAbsent( - mapTask.getStageName(), s -> new StageInfo(s, mapTask.getSystemName(), this)); + mapTask.getStageName(), s -> StageInfo.create(s, mapTask.getSystemName(), this)); ExecutionState executionState = null; String counterName = "dataflow_source_bytes_processed-" + mapTask.getSystemName(); @@ -1304,12 +966,14 @@ private void process( DataflowExecutionContext.DataflowExecutionStateTracker executionStateTracker = new DataflowExecutionContext.DataflowExecutionStateTracker( ExecutionStateSampler.instance(), - stageInfo.executionStateRegistry.getState( - NameContext.forStage(mapTask.getStageName()), - "other", - null, - ScopedProfiler.INSTANCE.emptyScope()), - stageInfo.deltaCounters, + stageInfo + .executionStateRegistry() + .getState( + NameContext.forStage(mapTask.getStageName()), + "other", + null, + ScopedProfiler.INSTANCE.emptyScope()), + stageInfo.deltaCounters(), options, computationId); StreamingModeExecutionContext context = @@ -1321,9 +985,9 @@ private void process( ? computationState.getTransformUserNameToStateFamily() : stateNameMap, stateCache.forComputation(computationId), - stageInfo.metricsContainerRegistry, + stageInfo.metricsContainerRegistry(), executionStateTracker, - stageInfo.executionStateRegistry, + stageInfo.executionStateRegistry(), maxSinkBytes); DataflowMapTaskExecutor mapTaskExecutor = mapTaskExecutorFactory.create( @@ -1370,8 +1034,18 @@ private void process( .setSamplingPeriod(100) .countBytes(counterName)); } - executionState = - new ExecutionState(mapTaskExecutor, context, keyCoder, executionStateTracker); + + ExecutionState.Builder executionStateBuilder = + ExecutionState.builder() + .setWorkExecutor(mapTaskExecutor) + .setContext(context) + .setExecutionStateTracker(executionStateTracker); + + if (keyCoder != null) { + executionStateBuilder.setKeyCoder(keyCoder); + } + + executionState = executionStateBuilder.build(); } WindmillStateReader stateReader = @@ -1398,10 +1072,10 @@ public void close() { // // The coder type that will be present is: // WindowedValueCoder(TimerOrElementCoder(KvCoder)) - @Nullable Coder keyCoder = executionState.getKeyCoder(); + Optional> keyCoder = executionState.keyCoder(); @Nullable Object executionKey = - keyCoder == null ? null : keyCoder.decode(key.newInput(), Coder.Context.OUTER); + !keyCoder.isPresent() ? null : keyCoder.get().decode(key.newInput(), Coder.Context.OUTER); if (workItem.hasHotKeyInfo()) { Windmill.HotKeyInfo hotKeyInfo = workItem.getHotKeyInfo(); @@ -1410,7 +1084,7 @@ public void close() { // The MapTask instruction is ordered by dependencies, such that the first element is // always going to be the shuffle task. String stepName = computationState.getMapTask().getInstructions().get(0).getName(); - if (options.isHotKeyLoggingEnabled() && keyCoder != null) { + if (options.isHotKeyLoggingEnabled() && keyCoder.isPresent()) { hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge, executionKey); } else { hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge); @@ -1418,7 +1092,7 @@ public void close() { } executionState - .getContext() + .context() .start( executionKey, workItem, @@ -1430,13 +1104,13 @@ public void close() { outputBuilder); // Blocks while executing work. - executionState.getWorkExecutor().execute(); + executionState.workExecutor().execute(); // Reports source bytes processed to workitemcommitrequest if available. try { long sourceBytesProcessed = 0; HashMap counters = - ((DataflowMapTaskExecutor) executionState.getWorkExecutor()) + ((DataflowMapTaskExecutor) executionState.workExecutor()) .getReadOperation() .receivers[0] .getOutputCounters(); @@ -1450,9 +1124,9 @@ public void close() { } Iterables.addAll( - this.pendingMonitoringInfos, executionState.getWorkExecutor().extractMetricUpdates()); + this.pendingMonitoringInfos, executionState.workExecutor().extractMetricUpdates()); - commitCallbacks.putAll(executionState.getContext().flushState()); + commitCallbacks.putAll(executionState.context().flushState()); // Release the execution state for another thread to use. computationState.getExecutionStateQueue().offer(executionState); @@ -1481,7 +1155,7 @@ public void close() { commitRequest = buildWorkItemTruncationRequest(key, workItem, estimatedCommitSize); } - commitQueue.put(new Commit(commitRequest, computationState, work)); + commitQueue.put(Commit.create(commitRequest, computationState, work)); // Compute shuffle and state byte statistics these will be flushed asynchronously. long stateBytesWritten = @@ -1505,8 +1179,8 @@ public void close() { } catch (Throwable t) { if (executionState != null) { try { - executionState.getContext().invalidateCache(); - executionState.getWorkExecutor().close(); + executionState.context().invalidateCache(); + executionState.workExecutor().close(); } catch (Exception e) { LOG.warn("Failed to close map task executor: ", e); } finally { @@ -1571,7 +1245,7 @@ public void close() { } else { // Consider the item invalid. It will eventually be retried by Windmill if it still needs to // be processed. - computationState.completeWork( + computationState.completeWorkAndScheduleNextWorkForKey( ShardedKey.create(key, workItem.getShardingKey()), workItem.getWorkToken()); } } finally { @@ -1579,7 +1253,7 @@ public void close() { // work items causing exceptions are also accounted in time spent. long processingTimeMsecs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - processingStartTimeNanos); - stageInfo.totalProcessingMsecs.addValue(processingTimeMsecs); + stageInfo.totalProcessingMsecs().addValue(processingTimeMsecs); // Attribute all the processing to timers if the work item contains any timers. // Tests show that work items rarely contain both timers and message bundles. It should @@ -1587,7 +1261,7 @@ public void close() { // Another option: Derive time split between messages and timers based on recent totals. // either here or in DFE. if (work.getWorkItem().hasTimers()) { - stageInfo.timerProcessingMsecs.addValue(processingTimeMsecs); + stageInfo.timerProcessingMsecs().addValue(processingTimeMsecs); } DataflowWorkerLoggingMDC.setWorkId(null); @@ -1620,8 +1294,8 @@ private void commitLoop() { continue; } while (commit != null) { - ComputationState computationState = commit.getComputationState(); - commit.getWork().setState(State.COMMITTING); + ComputationState computationState = commit.computationState(); + commit.work().setState(Work.State.COMMITTING); Windmill.ComputationCommitWorkRequest.Builder computationRequestBuilder = computationRequestMap.get(computationState); if (computationRequestBuilder == null) { @@ -1629,7 +1303,7 @@ private void commitLoop() { computationRequestBuilder.setComputationId(computationState.getComputationId()); computationRequestMap.put(computationState, computationRequestBuilder); } - computationRequestBuilder.addRequests(commit.getRequest()); + computationRequestBuilder.addRequests(commit.request()); // Send the request if we've exceeded the bytes or there is no more // pending work. commitBytes is a long, so this cannot overflow. commitBytes += commit.getSize(); @@ -1647,7 +1321,7 @@ private void commitLoop() { computationRequestMap.entrySet()) { ComputationState computationState = entry.getKey(); for (Windmill.WorkItemCommitRequest workRequest : entry.getValue().getRequestsList()) { - computationState.completeWork( + computationState.completeWorkAndScheduleNextWorkForKey( ShardedKey.create(workRequest.getKey(), workRequest.getShardingKey()), workRequest.getWorkToken()); } @@ -1658,34 +1332,34 @@ private void commitLoop() { // Adds the commit to the commitStream if it fits, returning true iff it is consumed. private boolean addCommitToStream(Commit commit, CommitWorkStream commitStream) { Preconditions.checkNotNull(commit); - final ComputationState state = commit.getComputationState(); - final Windmill.WorkItemCommitRequest request = commit.getRequest(); + final ComputationState state = commit.computationState(); + final Windmill.WorkItemCommitRequest request = commit.request(); final int size = commit.getSize(); - commit.getWork().setState(State.COMMITTING); + commit.work().setState(Work.State.COMMITTING); activeCommitBytes.addAndGet(size); if (commitStream.commitWorkItem( - state.computationId, + state.getComputationId(), request, (Windmill.CommitStatus status) -> { if (status != Windmill.CommitStatus.OK) { readerCache.invalidateReader( WindmillComputationKey.create( - state.computationId, request.getKey(), request.getShardingKey())); + state.getComputationId(), request.getKey(), request.getShardingKey())); stateCache - .forComputation(state.computationId) + .forComputation(state.getComputationId()) .invalidate(request.getKey(), request.getShardingKey()); } activeCommitBytes.addAndGet(-size); // This may throw an exception if the commit was not active, which is possible if it // was deemed stuck. - state.completeWork( + state.completeWorkAndScheduleNextWorkForKey( ShardedKey.create(request.getKey(), request.getShardingKey()), request.getWorkToken()); })) { return true; } else { // Back out the stats changes since the commit wasn't consumed. - commit.getWork().setState(State.COMMIT_QUEUED); + commit.work().setState(Work.State.COMMIT_QUEUED); activeCommitBytes.addAndGet(-size); return false; } @@ -1699,7 +1373,7 @@ private Commit batchCommitsToStream(CommitWorkStream commitStream) { Commit commit; try { if (commits < 5) { - commit = commitQueue.poll(10 - 2 * commits, TimeUnit.MILLISECONDS); + commit = commitQueue.poll(10 - 2L * commits, TimeUnit.MILLISECONDS); } else { commit = commitQueue.poll(); } @@ -1786,7 +1460,8 @@ private void getConfigFromWindmill(String computation) { addComputation( computationId, mapTask, - transformUserNameToStateFamilyByComputationId.get(computationId)); + transformUserNameToStateFamilyByComputationId.getOrDefault( + computationId, ImmutableMap.of())); } catch (IOException e) { LOG.warn("Parsing MapTask failed: {}", serializedMapTask); LOG.warn("Error: ", e); @@ -1804,13 +1479,12 @@ private void getConfigFromWindmill(String computation) { * @throws IOException if the RPC fails. */ private void getConfigFromDataflowService(@Nullable String computation) throws IOException { - Optional workItem; - if (computation != null) { - workItem = workUnitClient.getStreamingConfigWorkItem(computation); - } else { - workItem = workUnitClient.getGlobalStreamingConfigWorkItem(); - } - if (workItem == null || !workItem.isPresent() || workItem.get() == null) { + Optional workItem = + computation != null + ? workUnitClient.getStreamingConfigWorkItem(computation) + : workUnitClient.getGlobalStreamingConfigWorkItem(); + + if (!workItem.isPresent()) { return; } StreamingConfigTask config = workItem.get().getStreamingConfigTask(); @@ -1837,7 +1511,8 @@ private void getConfigFromDataflowService(@Nullable String computation) throws I addComputation( computationConfig.getComputationId(), mapTask, - computationConfig.getTransformUserNameToStateFamily()); + Optional.ofNullable(computationConfig.getTransformUserNameToStateFamily()) + .orElseGet(ImmutableMap::of)); } } @@ -2173,277 +1848,6 @@ private void invalidateStuckCommits() { } } - /** - * Class representing the state of a computation. - * - *

This class is synchronized, but only used from the dispatch and commit threads, so should - * not be heavily contended. Still, blocking work should not be done by it. - */ - static class ComputationState implements AutoCloseable { - - private final String computationId; - private final MapTask mapTask; - private final ImmutableMap transformUserNameToStateFamily; - // Map from key to work for the key. The first item in the queue is - // actively processing. Synchronized by itself. - private final Map> activeWork = new HashMap<>(); - private final BoundedQueueExecutor executor; - private final ConcurrentLinkedQueue executionStateQueue = - new ConcurrentLinkedQueue<>(); - private final WindmillStateCache.ForComputation computationStateCache; - - public ComputationState( - String computationId, - MapTask mapTask, - BoundedQueueExecutor executor, - Map transformUserNameToStateFamily, - WindmillStateCache.ForComputation computationStateCache) { - this.computationId = computationId; - this.mapTask = mapTask; - this.executor = executor; - this.transformUserNameToStateFamily = - transformUserNameToStateFamily != null - ? ImmutableMap.copyOf(transformUserNameToStateFamily) - : ImmutableMap.of(); - this.computationStateCache = computationStateCache; - Preconditions.checkNotNull(mapTask.getStageName()); - Preconditions.checkNotNull(mapTask.getSystemName()); - } - - public String getComputationId() { - return computationId; - } - - public MapTask getMapTask() { - return mapTask; - } - - public ImmutableMap getTransformUserNameToStateFamily() { - return transformUserNameToStateFamily; - } - - public ConcurrentLinkedQueue getExecutionStateQueue() { - return executionStateQueue; - } - - /** Mark the given shardedKey and work as active. */ - public boolean activateWork(ShardedKey shardedKey, Work work) { - synchronized (activeWork) { - Deque queue = activeWork.get(shardedKey); - if (queue != null) { - Preconditions.checkState(!queue.isEmpty()); - // Ensure we don't already have this work token queueud. - for (Work queuedWork : queue) { - if (queuedWork.getWorkItem().getWorkToken() == work.getWorkItem().getWorkToken()) { - return false; - } - } - // Queue the work for later processing. - queue.addLast(work); - return true; - } else { - queue = new ArrayDeque<>(); - queue.addLast(work); - activeWork.put(shardedKey, queue); - // Fall through to execute without the lock held. - } - } - executor.execute(work, work.getWorkItem().getSerializedSize()); - return true; - } - - /** - * Marks the work for the given shardedKey as complete. Schedules queued work for the key if - * any. - */ - public void completeWork(ShardedKey shardedKey, long workToken) { - Work nextWork; - synchronized (activeWork) { - Queue queue = activeWork.get(shardedKey); - if (queue == null) { - // Work may have been completed due to clearing of stuck commits. - LOG.warn( - "Unable to complete inactive work for key {} and token {}.", shardedKey, workToken); - return; - } - Work completedWork = queue.peek(); - // avoid Preconditions.checkState here to prevent eagerly evaluating the - // format string parameters for the error message. - if (completedWork == null) { - throw new IllegalStateException( - String.format( - "Active key %s without work, expected token %d", shardedKey, workToken)); - } - if (completedWork.getWorkItem().getWorkToken() != workToken) { - // Work may have been completed due to clearing of stuck commits. - LOG.warn( - "Unable to complete due to token mismatch for key {} and token {}, actual token was {}.", - shardedKey, - workToken, - completedWork.getWorkItem().getWorkToken()); - return; - } - queue.remove(); // We consumed the matching work item. - nextWork = queue.peek(); - if (nextWork == null) { - Preconditions.checkState(queue == activeWork.remove(shardedKey)); - } - } - if (nextWork != null) { - executor.forceExecute(nextWork, nextWork.getWorkItem().getSerializedSize()); - } - } - - public void invalidateStuckCommits(Instant stuckCommitDeadline) { - synchronized (activeWork) { - // Determine the stuck commit keys but complete them outside of iterating over - // activeWork as completeWork may delete the entry from activeWork. - Map stuckCommits = new HashMap<>(); - for (Map.Entry> entry : activeWork.entrySet()) { - ShardedKey shardedKey = entry.getKey(); - Work work = entry.getValue().peek(); - if (work.getState() == State.COMMITTING - && work.getStateStartTime().isBefore(stuckCommitDeadline)) { - LOG.error( - "Detected key {} stuck in COMMITTING state since {}, completing it with error.", - shardedKey, - work.getStateStartTime()); - stuckCommits.put(shardedKey, work.getWorkItem().getWorkToken()); - } - } - for (Map.Entry stuckCommit : stuckCommits.entrySet()) { - computationStateCache.invalidate( - stuckCommit.getKey().key(), stuckCommit.getKey().shardingKey()); - completeWork(stuckCommit.getKey(), stuckCommit.getValue()); - } - } - } - - /** Adds any work started before the refreshDeadline to the GetDataRequest builder. */ - public List getKeysToRefresh(Instant refreshDeadline) { - List result = new ArrayList<>(); - synchronized (activeWork) { - for (Map.Entry> entry : activeWork.entrySet()) { - ShardedKey shardedKey = entry.getKey(); - for (Work work : entry.getValue()) { - if (work.getStartTime().isBefore(refreshDeadline)) { - result.add( - Windmill.KeyedGetDataRequest.newBuilder() - .setKey(shardedKey.key()) - .setShardingKey(shardedKey.shardingKey()) - .setWorkToken(work.getWorkItem().getWorkToken()) - .addAllLatencyAttribution(work.getLatencyAttributions()) - .build()); - } - } - } - } - return result; - } - - private String elapsedString(Instant start, Instant end) { - Duration activeFor = new Duration(start, end); - // Duration's toString always starts with "PT"; remove that here. - return activeFor.toString().substring(2); - } - - public void printActiveWork(PrintWriter writer) { - final Instant now = Instant.now(); - // The max number of keys in COMMITTING or COMMIT_QUEUED status to be shown. - final int maxCommitPending = 50; - int commitPendingCount = 0; - writer.println( - ""); - writer.println( - ""); - // We use a StringBuilder in the synchronized section to buffer writes since the provided - // PrintWriter may block when flushing. - StringBuilder builder = new StringBuilder(); - synchronized (activeWork) { - for (Map.Entry> entry : activeWork.entrySet()) { - Queue queue = entry.getValue(); - Preconditions.checkNotNull(queue); - Work work = queue.peek(); - Preconditions.checkNotNull(work); - Windmill.WorkItem workItem = work.getWorkItem(); - State state = work.getState(); - if (state == State.COMMITTING || state == State.COMMIT_QUEUED) { - if (++commitPendingCount >= maxCommitPending) { - continue; - } - } - builder.append(""); - builder.append("\n"); - } - } - writer.print(builder.toString()); - writer.println("
KeyTokenQueuedActive ForStateState Active For
"); - builder.append(String.format("%016x", workItem.getShardingKey())); - builder.append(""); - builder.append(String.format("%016x", workItem.getWorkToken())); - builder.append(""); - builder.append(queue.size() - 1); - builder.append(""); - builder.append(elapsedString(work.getStartTime(), now)); - builder.append(""); - builder.append(state); - builder.append(""); - builder.append(elapsedString(work.getStateStartTime(), now)); - builder.append("
"); - if (commitPendingCount >= maxCommitPending) { - writer.println("
"); - writer.print("Skipped keys in COMMITTING/COMMIT_QUEUED: "); - writer.println(commitPendingCount - maxCommitPending); - writer.println("
"); - } - } - - @Override - public void close() throws Exception { - ExecutionState executionState; - while ((executionState = executionStateQueue.poll()) != null) { - executionState.getWorkExecutor().close(); - } - executionStateQueue.clear(); - } - } - - private static class ExecutionState { - - public final DataflowWorkExecutor workExecutor; - public final StreamingModeExecutionContext context; - public final @Nullable Coder keyCoder; - private final ExecutionStateTracker executionStateTracker; - - public ExecutionState( - DataflowWorkExecutor workExecutor, - StreamingModeExecutionContext context, - Coder keyCoder, - ExecutionStateTracker executionStateTracker) { - this.workExecutor = workExecutor; - this.context = context; - this.keyCoder = keyCoder; - this.executionStateTracker = executionStateTracker; - } - - public DataflowWorkExecutor getWorkExecutor() { - return workExecutor; - } - - public StreamingModeExecutionContext getContext() { - return context; - } - - public ExecutionStateTracker getExecutionStateTracker() { - return executionStateTracker; - } - - public @Nullable Coder getKeyCoder() { - return keyCoder; - } - } - private class HarnessDataProvider implements StatusDataProvider { @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkUnitClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkUnitClient.java index c3c11716e4c7..82fbcd82c131 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkUnitClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkUnitClient.java @@ -21,7 +21,7 @@ import com.google.api.services.dataflow.model.WorkItemServiceState; import com.google.api.services.dataflow.model.WorkItemStatus; import java.io.IOException; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Optional; +import java.util.Optional; /** Abstract base class describing a client for WorkItem work units. */ interface WorkUnitClient { @@ -31,14 +31,14 @@ interface WorkUnitClient { Optional getWorkItem() throws IOException; /** - * Returns a new global streaming config WorkItem, or returns {@link Optional#absent()} if no work + * Returns a new global streaming config WorkItem, or returns {@link Optional#empty()} if no work * was found. */ Optional getGlobalStreamingConfigWorkItem() throws IOException; /** * Returns a streaming config WorkItem for the given computation, or returns {@link - * Optional#absent()} if no work was found. + * Optional#empty()} if no work was found. */ Optional getStreamingConfigWorkItem(String computationId) throws IOException; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/counters/NameContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/counters/NameContext.java index 6188386a4e67..4f4a1c3834e5 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/counters/NameContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/counters/NameContext.java @@ -31,7 +31,10 @@ public abstract class NameContext { * systemName} and a {@code userName}. */ public static NameContext create( - String stageName, String originalName, String systemName, String userName) { + String stageName, + @Nullable String originalName, + String systemName, + @Nullable String userName) { return new AutoValue_NameContext(stageName, originalName, systemName, userName); } @@ -44,7 +47,7 @@ public static NameContext forStage(String stageName) { } /** Returns the name of the stage this instruction is executing in. */ - public abstract @Nullable String stageName(); + public abstract String stageName(); /** * Returns the "original" name of this instruction. This name is a short name assigned by the SDK diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java new file mode 100644 index 000000000000..529bb0a41907 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList; + +import java.io.PrintWriter; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Queue; +import java.util.function.BiConsumer; +import java.util.stream.Stream; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.worker.WindmillStateCache; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Manages the active {@link Work} queues for their {@link ShardedKey}(s). Gives an interface to + * activate, queue, and complete {@link Work} (including invalidating stuck {@link Work}). + */ +@ThreadSafe +final class ActiveWorkState { + private static final Logger LOG = LoggerFactory.getLogger(ActiveWorkState.class); + + /* The max number of keys in COMMITTING or COMMIT_QUEUED status to be shown.*/ + private static final int MAX_PRINTABLE_COMMIT_PENDING_KEYS = 50; + + /** + * Map from {@link ShardedKey} to {@link Work} for the key. The first item in the {@link + * Queue} is actively processing. + */ + @GuardedBy("this") + private final Map> activeWork; + + @GuardedBy("this") + private final WindmillStateCache.ForComputation computationStateCache; + + private ActiveWorkState( + Map> activeWork, + WindmillStateCache.ForComputation computationStateCache) { + this.activeWork = activeWork; + this.computationStateCache = computationStateCache; + } + + static ActiveWorkState create(WindmillStateCache.ForComputation computationStateCache) { + return new ActiveWorkState(new HashMap<>(), computationStateCache); + } + + @VisibleForTesting + static ActiveWorkState forTesting( + Map> activeWork, + WindmillStateCache.ForComputation computationStateCache) { + return new ActiveWorkState(activeWork, computationStateCache); + } + + /** + * Activates {@link Work} for the {@link ShardedKey}. Outcome can be 1 of 3 {@link + * ActivateWorkResult} + * + *

1. EXECUTE: The {@link ShardedKey} has not been seen before, create a {@link Queue} + * for the key. The caller should execute the work. + * + *

2. DUPLICATE: A work queue for the {@link ShardedKey} exists, and the work already exists in + * the {@link ShardedKey}'s work queue, mark the {@link Work} as a duplicate. + * + *

3. QUEUED: A work queue for the {@link ShardedKey} exists, and the work is not in the key's + * work queue, queue the work for later processing. + */ + synchronized ActivateWorkResult activateWorkForKey(ShardedKey shardedKey, Work work) { + Deque workQueue = activeWork.getOrDefault(shardedKey, new ArrayDeque<>()); + + // This key does not have any work queued up on it. Create one, insert Work, and mark the work + // to be executed. + if (!activeWork.containsKey(shardedKey) || workQueue.isEmpty()) { + workQueue.addLast(work); + activeWork.put(shardedKey, workQueue); + return ActivateWorkResult.EXECUTE; + } + + // Ensure we don't already have this work token queued. + for (Work queuedWork : workQueue) { + if (queuedWork.getWorkItem().getWorkToken() == work.getWorkItem().getWorkToken()) { + return ActivateWorkResult.DUPLICATE; + } + } + + // Queue the work for later processing. + workQueue.addLast(work); + return ActivateWorkResult.QUEUED; + } + + /** + * Removes the complete work from the {@link Queue}. The {@link Work} is marked as completed + * if its workToken matches the one that is passed in. Returns the next {@link Work} in the {@link + * ShardedKey}'s work queue, if one exists else removes the {@link ShardedKey} from {@link + * #activeWork}. + */ + synchronized Optional completeWorkAndGetNextWorkForKey( + ShardedKey shardedKey, long workToken) { + @Nullable Queue workQueue = activeWork.get(shardedKey); + if (workQueue == null) { + // Work may have been completed due to clearing of stuck commits. + LOG.warn("Unable to complete inactive work for key {} and token {}.", shardedKey, workToken); + return Optional.empty(); + } + removeCompletedWorkFromQueue(workQueue, shardedKey, workToken); + return getNextWork(workQueue, shardedKey); + } + + private synchronized void removeCompletedWorkFromQueue( + Queue workQueue, ShardedKey shardedKey, long workToken) { + // avoid Preconditions.checkState here to prevent eagerly evaluating the + // format string parameters for the error message. + Work completedWork = + Optional.ofNullable(workQueue.peek()) + .orElseThrow( + () -> + new IllegalStateException( + String.format( + "Active key %s without work, expected token %d", + shardedKey, workToken))); + + if (completedWork.getWorkItem().getWorkToken() != workToken) { + // Work may have been completed due to clearing of stuck commits. + LOG.warn( + "Unable to complete due to token mismatch for key {} and token {}, actual token was {}.", + shardedKey, + workToken, + completedWork.getWorkItem().getWorkToken()); + return; + } + + // We consumed the matching work item. + workQueue.remove(); + } + + private synchronized Optional getNextWork(Queue workQueue, ShardedKey shardedKey) { + Optional nextWork = Optional.ofNullable(workQueue.peek()); + if (!nextWork.isPresent()) { + Preconditions.checkState(workQueue == activeWork.remove(shardedKey)); + } + + return nextWork; + } + + /** + * Invalidates all {@link Work} that is in the {@link Work.State#COMMITTING} state which started + * before the stuckCommitDeadline. + */ + synchronized void invalidateStuckCommits( + Instant stuckCommitDeadline, BiConsumer shardedKeyAndWorkTokenConsumer) { + for (Entry shardedKeyAndWorkToken : + getStuckCommitsAt(stuckCommitDeadline).entrySet()) { + ShardedKey shardedKey = shardedKeyAndWorkToken.getKey(); + long workToken = shardedKeyAndWorkToken.getValue(); + computationStateCache.invalidate(shardedKey.key(), shardedKey.shardingKey()); + shardedKeyAndWorkTokenConsumer.accept(shardedKey, workToken); + } + } + + private synchronized ImmutableMap getStuckCommitsAt( + Instant stuckCommitDeadline) { + // Determine the stuck commit keys but complete them outside the loop iterating over + // activeWork as completeWork may delete the entry from activeWork. + ImmutableMap.Builder stuckCommits = ImmutableMap.builder(); + for (Entry> entry : activeWork.entrySet()) { + ShardedKey shardedKey = entry.getKey(); + @Nullable Work work = entry.getValue().peek(); + if (work != null) { + if (work.isStuckCommittingAt(stuckCommitDeadline)) { + LOG.error( + "Detected key {} stuck in COMMITTING state since {}, completing it with error.", + shardedKey, + work.getStateStartTime()); + stuckCommits.put(shardedKey, work.getWorkItem().getWorkToken()); + } + } + } + + return stuckCommits.build(); + } + + synchronized ImmutableList getKeysToRefresh(Instant refreshDeadline) { + return activeWork.entrySet().stream() + .flatMap(entry -> toKeyedGetDataRequestStream(entry, refreshDeadline)) + .collect(toImmutableList()); + } + + private static Stream toKeyedGetDataRequestStream( + Entry> shardedKeyAndWorkQueue, Instant refreshDeadline) { + ShardedKey shardedKey = shardedKeyAndWorkQueue.getKey(); + Deque workQueue = shardedKeyAndWorkQueue.getValue(); + + return workQueue.stream() + .filter(work -> work.getStartTime().isBefore(refreshDeadline)) + .map( + work -> + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(shardedKey.key()) + .setShardingKey(shardedKey.shardingKey()) + .setWorkToken(work.getWorkItem().getWorkToken()) + .addAllLatencyAttribution(work.getLatencyAttributions()) + .build()); + } + + synchronized void printActiveWork(PrintWriter writer, Instant now) { + writer.println( + ""); + writer.println( + ""); + // Use StringBuilder because we are appending in loop. + StringBuilder activeWorkStatus = new StringBuilder(); + int commitsPendingCount = 0; + for (Map.Entry> entry : activeWork.entrySet()) { + Queue workQueue = Preconditions.checkNotNull(entry.getValue()); + Work activeWork = Preconditions.checkNotNull(workQueue.peek()); + Windmill.WorkItem workItem = activeWork.getWorkItem(); + if (activeWork.isCommitPending()) { + if (++commitsPendingCount >= MAX_PRINTABLE_COMMIT_PENDING_KEYS) { + continue; + } + } + activeWorkStatus.append(""); + activeWorkStatus.append("\n"); + } + + writer.print(activeWorkStatus); + writer.println("
KeyTokenQueuedActive ForStateState Active For
"); + activeWorkStatus.append(String.format("%016x", workItem.getShardingKey())); + activeWorkStatus.append(""); + activeWorkStatus.append(String.format("%016x", workItem.getWorkToken())); + activeWorkStatus.append(""); + activeWorkStatus.append(workQueue.size() - 1); + activeWorkStatus.append(""); + activeWorkStatus.append(elapsedString(activeWork.getStartTime(), now)); + activeWorkStatus.append(""); + activeWorkStatus.append(activeWork.getState()); + activeWorkStatus.append(""); + activeWorkStatus.append(elapsedString(activeWork.getStateStartTime(), now)); + activeWorkStatus.append("
"); + + if (commitsPendingCount >= MAX_PRINTABLE_COMMIT_PENDING_KEYS) { + writer.println("
"); + writer.print("Skipped keys in COMMITTING/COMMIT_QUEUED: "); + writer.println(commitsPendingCount - MAX_PRINTABLE_COMMIT_PENDING_KEYS); + writer.println("
"); + } + } + + private static String elapsedString(Instant start, Instant end) { + Duration activeFor = new Duration(start, end); + // Duration's toString always starts with "PT"; remove that here. + return activeFor.toString().substring(2); + } + + enum ActivateWorkResult { + QUEUED, + EXECUTE, + DUPLICATE + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Commit.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Commit.java new file mode 100644 index 000000000000..946897967561 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Commit.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import com.google.auto.value.AutoValue; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; + +/** Value class for a queued commit. */ +@AutoValue +public abstract class Commit { + + public static Commit create( + WorkItemCommitRequest request, ComputationState computationState, Work work) { + Preconditions.checkArgument(request.getSerializedSize() > 0); + return new AutoValue_Commit(request, computationState, work); + } + + public abstract WorkItemCommitRequest request(); + + public abstract ComputationState computationState(); + + public abstract Work work(); + + public final int getSize() { + return request().getSerializedSize(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java new file mode 100644 index 000000000000..a902d2b13a77 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import com.google.api.services.dataflow.model.MapTask; +import java.io.PrintWriter; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentLinkedQueue; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.WindmillStateCache; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.joda.time.Instant; + +/** + * Class representing the state of a computation. + * + *

This class is synchronized, but only used from the dispatch and commit threads, so should not + * be heavily contended. Still, blocking work should not be done by it. + */ +public class ComputationState implements AutoCloseable { + private final String computationId; + private final MapTask mapTask; + private final ImmutableMap transformUserNameToStateFamily; + private final ActiveWorkState activeWorkState; + private final BoundedQueueExecutor executor; + private final ConcurrentLinkedQueue executionStateQueue; + + public ComputationState( + String computationId, + MapTask mapTask, + BoundedQueueExecutor executor, + Map transformUserNameToStateFamily, + WindmillStateCache.ForComputation computationStateCache) { + Preconditions.checkNotNull(mapTask.getStageName()); + Preconditions.checkNotNull(mapTask.getSystemName()); + this.computationId = computationId; + this.mapTask = mapTask; + this.executor = executor; + this.transformUserNameToStateFamily = ImmutableMap.copyOf(transformUserNameToStateFamily); + this.executionStateQueue = new ConcurrentLinkedQueue<>(); + this.activeWorkState = ActiveWorkState.create(computationStateCache); + } + + public String getComputationId() { + return computationId; + } + + public MapTask getMapTask() { + return mapTask; + } + + public ImmutableMap getTransformUserNameToStateFamily() { + return transformUserNameToStateFamily; + } + + public ConcurrentLinkedQueue getExecutionStateQueue() { + return executionStateQueue; + } + + /** + * Mark the given {@link ShardedKey} and {@link Work} as active, and schedules execution of {@link + * Work} if there is no active {@link Work} for the {@link ShardedKey} already processing. + */ + public boolean activateWork(ShardedKey shardedKey, Work work) { + switch (activeWorkState.activateWorkForKey(shardedKey, work)) { + case DUPLICATE: + return false; + case QUEUED: + return true; + case EXECUTE: + { + execute(work); + return true; + } + default: + // This will never happen, the switch is exhaustive. + throw new IllegalStateException("Unrecognized ActivateWorkResult"); + } + } + + /** + * Marks the work for the given shardedKey as complete. Schedules queued work for the key if any. + */ + public void completeWorkAndScheduleNextWorkForKey(ShardedKey shardedKey, long workToken) { + activeWorkState + .completeWorkAndGetNextWorkForKey(shardedKey, workToken) + .ifPresent(this::forceExecute); + } + + public void invalidateStuckCommits(Instant stuckCommitDeadline) { + activeWorkState.invalidateStuckCommits( + stuckCommitDeadline, this::completeWorkAndScheduleNextWorkForKey); + } + + private void execute(Work work) { + executor.execute(work, work.getWorkItem().getSerializedSize()); + } + + private void forceExecute(Work work) { + executor.forceExecute(work, work.getWorkItem().getSerializedSize()); + } + + /** Adds any work started before the refreshDeadline to the GetDataRequest builder. */ + public List getKeysToRefresh(Instant refreshDeadline) { + return activeWorkState.getKeysToRefresh(refreshDeadline); + } + + public void printActiveWork(PrintWriter writer) { + activeWorkState.printActiveWork(writer, Instant.now()); + } + + @Override + public void close() throws Exception { + @Nullable ExecutionState executionState; + while ((executionState = executionStateQueue.poll()) != null) { + executionState.workExecutor().close(); + } + executionStateQueue.clear(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutionState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutionState.java new file mode 100644 index 000000000000..ba35179a75b3 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutionState.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import com.google.auto.value.AutoValue; +import java.util.Optional; +import org.apache.beam.runners.core.metrics.ExecutionStateTracker; +import org.apache.beam.runners.dataflow.worker.DataflowWorkExecutor; +import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext; +import org.apache.beam.sdk.coders.Coder; + +@AutoValue +public abstract class ExecutionState { + + public abstract DataflowWorkExecutor workExecutor(); + + public abstract StreamingModeExecutionContext context(); + + public abstract Optional> keyCoder(); + + public abstract ExecutionStateTracker executionStateTracker(); + + public static ExecutionState.Builder builder() { + return new AutoValue_ExecutionState.Builder(); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setWorkExecutor(DataflowWorkExecutor workExecutor); + + public abstract Builder setContext(StreamingModeExecutionContext context); + + public abstract Builder setKeyCoder(Coder keyCoder); + + public abstract Builder setExecutionStateTracker(ExecutionStateTracker executionStateTracker); + + public abstract ExecutionState build(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/KeyCommitTooLargeException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/KeyCommitTooLargeException.java new file mode 100644 index 000000000000..090d9981309e --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/KeyCommitTooLargeException.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; + +public final class KeyCommitTooLargeException extends Exception { + + public static KeyCommitTooLargeException causedBy( + String computationId, long byteLimit, Windmill.WorkItemCommitRequest request) { + StringBuilder message = new StringBuilder(); + message.append("Commit request for stage "); + message.append(computationId); + message.append(" and key "); + message.append(request.getKey().toStringUtf8()); + if (request.getSerializedSize() > 0) { + message.append( + " has size " + + request.getSerializedSize() + + " which is more than the limit of " + + byteLimit); + } else { + message.append(" is larger than 2GB and cannot be processed"); + } + message.append( + ". This may be caused by grouping a very " + + "large amount of data in a single window without using Combine," + + " or by producing a large amount of data from a single input element."); + return new KeyCommitTooLargeException(message.toString()); + } + + private KeyCommitTooLargeException(String message) { + super(message); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ShardedKey.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ShardedKey.java new file mode 100644 index 000000000000..86433d9e6752 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ShardedKey.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import com.google.auto.value.AutoValue; +import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; + +@AutoValue +public abstract class ShardedKey { + + public static ShardedKey create(ByteString key, long shardingKey) { + return new AutoValue_ShardedKey(key, shardingKey); + } + + public abstract ByteString key(); + + public abstract long shardingKey(); + + @Override + public final String toString() { + return String.format("%016x", shardingKey()); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/StageInfo.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/StageInfo.java new file mode 100644 index 000000000000..b514dfc84bb9 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/StageInfo.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import static org.apache.beam.runners.dataflow.worker.DataflowSystemMetrics.THROTTLING_MSECS_METRIC_NAME; + +import com.google.api.services.dataflow.model.CounterStructuredName; +import com.google.api.services.dataflow.model.CounterUpdate; +import com.google.auto.value.AutoValue; +import java.util.ArrayList; +import java.util.List; +import org.apache.beam.runners.dataflow.worker.DataflowSystemMetrics; +import org.apache.beam.runners.dataflow.worker.MetricsContainerRegistry; +import org.apache.beam.runners.dataflow.worker.StreamingDataflowWorker; +import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.StreamingModeExecutionStateRegistry; +import org.apache.beam.runners.dataflow.worker.StreamingStepMetricsContainer; +import org.apache.beam.runners.dataflow.worker.counters.Counter; +import org.apache.beam.runners.dataflow.worker.counters.CounterSet; +import org.apache.beam.runners.dataflow.worker.counters.DataflowCounterUpdateExtractor; +import org.apache.beam.runners.dataflow.worker.counters.NameContext; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; + +/** Contains a few of the stage specific fields. E.g. metrics container registry, counters etc. */ +@AutoValue +public abstract class StageInfo { + public static StageInfo create( + String stageName, String systemName, StreamingDataflowWorker worker) { + NameContext nameContext = NameContext.create(stageName, null, systemName, null); + CounterSet deltaCounters = new CounterSet(); + return new AutoValue_StageInfo( + stageName, + systemName, + StreamingStepMetricsContainer.createRegistry(), + new StreamingModeExecutionStateRegistry(worker), + deltaCounters, + deltaCounters.longSum( + DataflowSystemMetrics.StreamingPerStageSystemCounterNames.THROTTLED_MSECS.counterName( + nameContext)), + deltaCounters.longSum( + DataflowSystemMetrics.StreamingPerStageSystemCounterNames.TOTAL_PROCESSING_MSECS + .counterName(nameContext)), + deltaCounters.longSum( + DataflowSystemMetrics.StreamingPerStageSystemCounterNames.TIMER_PROCESSING_MSECS + .counterName(nameContext))); + } + + public abstract String stageName(); + + public abstract String systemName(); + + public abstract MetricsContainerRegistry + metricsContainerRegistry(); + + public abstract StreamingModeExecutionStateRegistry executionStateRegistry(); + + public abstract CounterSet deltaCounters(); + + public abstract Counter throttledMsecs(); + + public abstract Counter totalProcessingMsecs(); + + public abstract Counter timerProcessingMsecs(); + + public List extractCounterUpdates() { + List counterUpdates = new ArrayList<>(); + Iterables.addAll( + counterUpdates, + StreamingStepMetricsContainer.extractMetricUpdates(metricsContainerRegistry())); + Iterables.addAll(counterUpdates, executionStateRegistry().extractUpdates(false)); + for (CounterUpdate counterUpdate : counterUpdates) { + translateKnownStepCounters(counterUpdate); + } + counterUpdates.addAll( + deltaCounters().extractModifiedDeltaUpdates(DataflowCounterUpdateExtractor.INSTANCE)); + return counterUpdates; + } + + /** + * Checks if the step counter affects any per-stage counters. Currently 'throttled_millis' is the + * only counter updated. + */ + private void translateKnownStepCounters(CounterUpdate stepCounterUpdate) { + CounterStructuredName structuredName = + stepCounterUpdate.getStructuredNameAndMetadata().getName(); + if ((THROTTLING_MSECS_METRIC_NAME.getNamespace().equals(structuredName.getOriginNamespace()) + && THROTTLING_MSECS_METRIC_NAME.getName().equals(structuredName.getName())) + || (StreamingDataflowWorker.BIGQUERY_STREAMING_INSERT_THROTTLE_TIME + .getNamespace() + .equals(structuredName.getOriginNamespace()) + && StreamingDataflowWorker.BIGQUERY_STREAMING_INSERT_THROTTLE_TIME + .getName() + .equals(structuredName.getName()))) { + long msecs = DataflowCounterUpdateExtractor.splitIntToLong(stepCounterUpdate.getInteger()); + if (msecs > 0) { + throttledMsecs().addValue(msecs); + } + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java new file mode 100644 index 000000000000..f2893f3e7191 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** Bounded set of queues, with a maximum total weight. */ +public final class WeightedBoundedQueue { + + private final LinkedBlockingQueue queue; + private final int maxWeight; + private final Semaphore limit; + private final Function weigher; + + private WeightedBoundedQueue( + LinkedBlockingQueue linkedBlockingQueue, + int maxWeight, + Semaphore limit, + Function weigher) { + this.queue = linkedBlockingQueue; + this.maxWeight = maxWeight; + this.limit = limit; + this.weigher = weigher; + } + + public static WeightedBoundedQueue create(int maxWeight, Function weigherFn) { + return new WeightedBoundedQueue<>( + new LinkedBlockingQueue<>(), maxWeight, new Semaphore(maxWeight, true), weigherFn); + } + + /** + * Adds the value to the queue, blocking if this would cause the overall weight to exceed the + * limit. + */ + public void put(V value) { + limit.acquireUninterruptibly(weigher.apply(value)); + queue.add(value); + } + + /** Returns and removes the next value, or null if there is no such value. */ + public @Nullable V poll() { + V result = queue.poll(); + if (result != null) { + limit.release(weigher.apply(result)); + } + return result; + } + + /** + * Retrieves and removes the head of this queue, waiting up to the specified wait time if + * necessary for an element to become available. + * + * @param timeout how long to wait before giving up, in units of {@code unit} + * @param unit a {@code TimeUnit} determining how to interpret the {@code timeout} parameter + * @return the head of this queue, or {@code null} if the specified waiting time elapses before an + * element is available + * @throws InterruptedException if interrupted while waiting + */ + public @Nullable V poll(long timeout, TimeUnit unit) throws InterruptedException { + V result = queue.poll(timeout, unit); + if (result != null) { + limit.release(weigher.apply(result)); + } + return result; + } + + /** Returns and removes the next value, or blocks until one is available. */ + public @Nullable V take() throws InterruptedException { + V result = queue.take(); + limit.release(weigher.apply(result)); + return result; + } + + /** Returns the current weight of the queue. */ + public int queuedElementsWeight() { + return maxWeight - limit.availablePermits(); + } + + public int size() { + return queue.size(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java new file mode 100644 index 000000000000..cc3f6d1b12b2 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import com.google.auto.value.AutoValue; +import java.util.ArrayList; +import java.util.Collection; +import java.util.EnumMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Supplier; +import javax.annotation.concurrent.NotThreadSafe; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.joda.time.Duration; +import org.joda.time.Instant; + +@NotThreadSafe +public class Work implements Runnable { + + private final Windmill.WorkItem workItem; + private final Supplier clock; + private final Instant startTime; + private final Map totalDurationPerState; + private final Consumer processWorkFn; + private TimedState currentState; + + private Work(Windmill.WorkItem workItem, Supplier clock, Consumer processWorkFn) { + this.workItem = workItem; + this.clock = clock; + this.processWorkFn = processWorkFn; + this.startTime = clock.get(); + this.totalDurationPerState = new EnumMap<>(Windmill.LatencyAttribution.State.class); + this.currentState = TimedState.initialState(startTime); + } + + public static Work create( + Windmill.WorkItem workItem, + Supplier clock, + Collection getWorkStreamLatencies, + Consumer processWorkFn) { + Work work = new Work(workItem, clock, processWorkFn); + work.recordGetWorkStreamLatencies(getWorkStreamLatencies); + return work; + } + + @Override + public void run() { + processWorkFn.accept(this); + } + + public Windmill.WorkItem getWorkItem() { + return workItem; + } + + public Instant getStartTime() { + return startTime; + } + + public State getState() { + return currentState.state(); + } + + public void setState(State state) { + Instant now = clock.get(); + totalDurationPerState.compute( + this.currentState.state().toLatencyAttributionState(), + (s, d) -> + new Duration(this.currentState.startTime(), now).plus(d == null ? Duration.ZERO : d)); + this.currentState = TimedState.create(state, now); + } + + public boolean isCommitPending() { + return currentState.isCommitPending(); + } + + public Instant getStateStartTime() { + return currentState.startTime(); + } + + private void recordGetWorkStreamLatencies( + Collection getWorkStreamLatencies) { + for (Windmill.LatencyAttribution latency : getWorkStreamLatencies) { + totalDurationPerState.put( + latency.getState(), Duration.millis(latency.getTotalDurationMillis())); + } + } + + public Collection getLatencyAttributions() { + List list = new ArrayList<>(); + for (Windmill.LatencyAttribution.State state : Windmill.LatencyAttribution.State.values()) { + Duration duration = totalDurationPerState.getOrDefault(state, Duration.ZERO); + if (state == this.currentState.state().toLatencyAttributionState()) { + duration = duration.plus(new Duration(this.currentState.startTime(), clock.get())); + } + if (duration.equals(Duration.ZERO)) { + continue; + } + list.add( + Windmill.LatencyAttribution.newBuilder() + .setState(state) + .setTotalDurationMillis(duration.getMillis()) + .build()); + } + return list; + } + + boolean isStuckCommittingAt(Instant stuckCommitDeadline) { + return currentState.state() == Work.State.COMMITTING + && currentState.startTime().isBefore(stuckCommitDeadline); + } + + public enum State { + QUEUED(Windmill.LatencyAttribution.State.QUEUED), + PROCESSING(Windmill.LatencyAttribution.State.ACTIVE), + READING(Windmill.LatencyAttribution.State.READING), + COMMIT_QUEUED(Windmill.LatencyAttribution.State.COMMITTING), + COMMITTING(Windmill.LatencyAttribution.State.COMMITTING), + GET_WORK_IN_WINDMILL_WORKER(Windmill.LatencyAttribution.State.GET_WORK_IN_WINDMILL_WORKER), + GET_WORK_IN_TRANSIT_TO_DISPATCHER( + Windmill.LatencyAttribution.State.GET_WORK_IN_TRANSIT_TO_DISPATCHER), + GET_WORK_IN_TRANSIT_TO_USER_WORKER( + Windmill.LatencyAttribution.State.GET_WORK_IN_TRANSIT_TO_USER_WORKER); + + private final Windmill.LatencyAttribution.State latencyAttributionState; + + State(Windmill.LatencyAttribution.State latencyAttributionState) { + this.latencyAttributionState = latencyAttributionState; + } + + Windmill.LatencyAttribution.State toLatencyAttributionState() { + return latencyAttributionState; + } + } + + /** + * Represents the current state of an instance of {@link Work}. Contains the {@link State} and + * {@link Instant} when it started. + */ + @AutoValue + abstract static class TimedState { + private static TimedState create(State state, Instant startTime) { + return new AutoValue_Work_TimedState(state, startTime); + } + + private static TimedState initialState(Instant startTime) { + return create(State.QUEUED, startTime); + } + + private boolean isCommitPending() { + return state() == Work.State.COMMITTING || state() == Work.State.COMMIT_QUEUED; + } + + abstract State state(); + + abstract Instant startTime(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/BatchDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/BatchDataflowWorkerTest.java index 8d5660548e75..b4f544129db6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/BatchDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/BatchDataflowWorkerTest.java @@ -37,11 +37,11 @@ import com.google.api.services.dataflow.model.WorkItemStatus; import java.io.IOException; import java.util.ArrayList; +import java.util.Optional; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; import org.apache.beam.runners.dataflow.util.TimeUtil; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.util.FastNanoClockAndSleeper; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Optional; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.hamcrest.Description; import org.hamcrest.TypeSafeMatcher; @@ -61,16 +61,10 @@ @RunWith(JUnit4.class) public class BatchDataflowWorkerTest { - private static class WorkerException extends Exception {} - @Rule public FastNanoClockAndSleeper clockAndSleeper = new FastNanoClockAndSleeper(); - @Mock WorkUnitClient mockWorkUnitClient; - @Mock DataflowWorkProgressUpdater mockProgressUpdater; - @Mock DataflowWorkExecutor mockWorkExecutor; - DataflowWorkerHarnessOptions options; @Before @@ -98,7 +92,7 @@ public void testWhenNoWorkIsReturnedThatWeImmediatelyRetry() throws Exception { workItem.setReportStatusInterval(TimeUtil.toCloudDuration(Duration.standardMinutes(1))); when(mockWorkUnitClient.getWorkItem()) - .thenReturn(Optional.absent()) + .thenReturn(Optional.empty()) .thenReturn(Optional.of(workItem)); assertTrue(worker.getAndPerformWork()); @@ -138,7 +132,7 @@ public void testWhenProcessingWorkUnitFailsWeReportStatus() throws Exception { Throwable error = errorCaptor.getValue(); assertThat(error, notNullValue()); - assertThat(error.getMessage(), equalTo("Unknown kind of work item: " + workItem.toString())); + assertThat(error.getMessage(), equalTo("Unknown kind of work item: " + workItem)); } @Test @@ -168,8 +162,9 @@ public void testStopProgressReportInCaseOfFailure() throws Exception { @Test public void testIsSplitResponseTooLarge() throws IOException { SourceSplitResponse splitResponse = new SourceSplitResponse(); - splitResponse.setShards( - ImmutableList.of(new SourceSplitShard(), new SourceSplitShard())); + splitResponse.setShards(ImmutableList.of(new SourceSplitShard(), new SourceSplitShard())); assertThat(DataflowApiUtils.computeSerializedSizeBytes(splitResponse), greaterThan(0L)); } + + private static class WorkerException extends Exception {} } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClientTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClientTest.java index e8b5ce8d0df2..3c63f3cc19d2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClientTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClientTest.java @@ -34,6 +34,7 @@ import com.google.api.services.dataflow.model.SeqMapTask; import com.google.api.services.dataflow.model.WorkItem; import java.io.IOException; +import java.util.Optional; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; import org.apache.beam.runners.dataflow.worker.logging.DataflowWorkerLoggingMDC; import org.apache.beam.runners.dataflow.worker.testing.RestoreDataflowLoggingMDC; @@ -42,7 +43,6 @@ import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.RestoreSystemProperties; import org.apache.beam.sdk.util.FastNanoClockAndSleeper; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Optional; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.junit.Before; @@ -61,6 +61,9 @@ @RunWith(JUnit4.class) public class DataflowWorkUnitClientTest { private static final Logger LOG = LoggerFactory.getLogger(DataflowWorkUnitClientTest.class); + private static final String PROJECT_ID = "TEST_PROJECT_ID"; + private static final String JOB_ID = "TEST_JOB_ID"; + private static final String WORKER_ID = "TEST_WORKER_ID"; @Rule public TestRule restoreSystemProperties = new RestoreSystemProperties(); @Rule public TestRule restoreLogging = new RestoreDataflowLoggingMDC(); @Rule public ExpectedException expectedException = ExpectedException.none(); @@ -69,10 +72,6 @@ public class DataflowWorkUnitClientTest { @Mock private MockLowLevelHttpRequest request; private DataflowWorkerHarnessOptions pipelineOptions; - private static final String PROJECT_ID = "TEST_PROJECT_ID"; - private static final String JOB_ID = "TEST_JOB_ID"; - private static final String WORKER_ID = "TEST_WORKER_ID"; - @Before public void setUp() throws Exception { MockitoAnnotations.initMocks(this); @@ -104,10 +103,10 @@ public void testCloudServiceCall() throws Exception { .fromString(request.getContentAsString(), LeaseWorkItemRequest.class); assertEquals(WORKER_ID, actualRequest.getWorkerId()); assertEquals( - ImmutableList.of(WORKER_ID, "remote_source", "custom_source"), + ImmutableList.of(WORKER_ID, "remote_source", "custom_source"), actualRequest.getWorkerCapabilities()); assertEquals( - ImmutableList.of("map_task", "seq_map_task", "remote_source_task"), + ImmutableList.of("map_task", "seq_map_task", "remote_source_task"), actualRequest.getWorkItemTypes()); assertEquals("1234", DataflowWorkerLoggingMDC.getWorkId()); } @@ -151,17 +150,17 @@ public void testCloudServiceCallNoWorkPresent() throws Exception { WorkUnitClient client = new DataflowWorkUnitClient(pipelineOptions, LOG); - assertEquals(Optional.absent(), client.getWorkItem()); + assertEquals(Optional.empty(), client.getWorkItem()); LeaseWorkItemRequest actualRequest = Transport.getJsonFactory() .fromString(request.getContentAsString(), LeaseWorkItemRequest.class); assertEquals(WORKER_ID, actualRequest.getWorkerId()); assertEquals( - ImmutableList.of(WORKER_ID, "remote_source", "custom_source"), + ImmutableList.of(WORKER_ID, "remote_source", "custom_source"), actualRequest.getWorkerCapabilities()); assertEquals( - ImmutableList.of("map_task", "seq_map_task", "remote_source_task"), + ImmutableList.of("map_task", "seq_map_task", "remote_source_task"), actualRequest.getWorkItemTypes()); } @@ -175,17 +174,17 @@ public void testCloudServiceCallNoWorkId() throws Exception { WorkUnitClient client = new DataflowWorkUnitClient(pipelineOptions, LOG); - assertEquals(Optional.absent(), client.getWorkItem()); + assertEquals(Optional.empty(), client.getWorkItem()); LeaseWorkItemRequest actualRequest = Transport.getJsonFactory() .fromString(request.getContentAsString(), LeaseWorkItemRequest.class); assertEquals(WORKER_ID, actualRequest.getWorkerId()); assertEquals( - ImmutableList.of(WORKER_ID, "remote_source", "custom_source"), + ImmutableList.of(WORKER_ID, "remote_source", "custom_source"), actualRequest.getWorkerCapabilities()); assertEquals( - ImmutableList.of("map_task", "seq_map_task", "remote_source_task"), + ImmutableList.of("map_task", "seq_map_task", "remote_source_task"), actualRequest.getWorkItemTypes()); } @@ -195,17 +194,17 @@ public void testCloudServiceCallNoWorkItem() throws Exception { WorkUnitClient client = new DataflowWorkUnitClient(pipelineOptions, LOG); - assertEquals(Optional.absent(), client.getWorkItem()); + assertEquals(Optional.empty(), client.getWorkItem()); LeaseWorkItemRequest actualRequest = Transport.getJsonFactory() .fromString(request.getContentAsString(), LeaseWorkItemRequest.class); assertEquals(WORKER_ID, actualRequest.getWorkerId()); assertEquals( - ImmutableList.of(WORKER_ID, "remote_source", "custom_source"), + ImmutableList.of(WORKER_ID, "remote_source", "custom_source"), actualRequest.getWorkerCapabilities()); assertEquals( - ImmutableList.of("map_task", "seq_map_task", "remote_source_task"), + ImmutableList.of("map_task", "seq_map_task", "remote_source_task"), actualRequest.getWorkItemTypes()); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index 95b3a43ebf49..82fc38055a88 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -28,6 +28,8 @@ import static org.hamcrest.Matchers.lessThan; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.eq; @@ -65,6 +67,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.PriorityQueue; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; @@ -92,8 +95,10 @@ import org.apache.beam.runners.dataflow.util.CloudObjects; import org.apache.beam.runners.dataflow.util.PropertyNames; import org.apache.beam.runners.dataflow.util.Structs; -import org.apache.beam.runners.dataflow.worker.StreamingDataflowWorker.ShardedKey; import org.apache.beam.runners.dataflow.worker.options.StreamingDataflowWorkerOptions; +import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; +import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey; +import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.testing.RestoreDataflowLoggingMDC; import org.apache.beam.runners.dataflow.worker.testing.TestCountingSource; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; @@ -117,7 +122,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.Coder.Context; -import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.CollectionCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.ListCoder; @@ -158,7 +162,6 @@ import org.apache.beam.sdk.values.WindowingStrategy.AccumulationMode; import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.TextFormat; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Optional; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheStats; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; @@ -191,36 +194,16 @@ @SuppressWarnings("unused") public class StreamingDataflowWorkerTest { - private final boolean streamingEngine; - - @Parameterized.Parameters(name = "{index}: [streamingEngine={0}]") - public static Iterable data() { - return Arrays.asList(new Object[][] {{false}, {true}}); - } - - public StreamingDataflowWorkerTest(Boolean streamingEngine) { - this.streamingEngine = streamingEngine; - } - private static final Logger LOG = LoggerFactory.getLogger(StreamingDataflowWorkerTest.class); - private static final IntervalWindow DEFAULT_WINDOW = new IntervalWindow(new Instant(1234), Duration.millis(1000)); - private static final IntervalWindow WINDOW_AT_ZERO = new IntervalWindow(new Instant(0), new Instant(1000)); - private static final IntervalWindow WINDOW_AT_ONE_SECOND = new IntervalWindow(new Instant(1000), new Instant(2000)); - private static final Coder DEFAULT_WINDOW_CODER = IntervalWindow.getCoder(); private static final Coder> DEFAULT_WINDOW_COLLECTION_CODER = CollectionCoder.of(DEFAULT_WINDOW_CODER); - - private byte[] intervalWindowBytes(IntervalWindow window) throws Exception { - return CoderUtils.encodeToByteArray(DEFAULT_WINDOW_COLLECTION_CODER, Arrays.asList(window)); - } - // Default values that are unimportant for correctness, but must be consistent // between pieces of this test suite private static final String DEFAULT_COMPUTATION_ID = "computation"; @@ -242,14 +225,26 @@ private byte[] intervalWindowBytes(IntervalWindow window) throws Exception { private static final ByteString DEFAULT_KEY_BYTES = ByteString.copyFromUtf8(DEFAULT_KEY_STRING); private static final String DEFAULT_DATA_STRING = "data"; private static final String DEFAULT_DESTINATION_STREAM_ID = "out"; - - @Rule public BlockingFn blockingFn = new BlockingFn(); - @Rule public TestRule restoreMDC = new RestoreDataflowLoggingMDC(); - @Rule public ErrorCollector errorCollector = new ErrorCollector(); - - WorkUnitClient mockWorkUnitClient = mock(WorkUnitClient.class); - HotKeyLogger hotKeyLogger = mock(HotKeyLogger.class); - + private static final Function EMPTY_DATA_RESPONDER = + (GetDataRequest request) -> { + GetDataResponse.Builder builder = GetDataResponse.newBuilder(); + for (ComputationGetDataRequest compRequest : request.getRequestsList()) { + ComputationGetDataResponse.Builder compBuilder = + builder.addDataBuilder().setComputationId(compRequest.getComputationId()); + for (KeyedGetDataRequest keyRequest : compRequest.getRequestsList()) { + KeyedGetDataResponse.Builder keyBuilder = + compBuilder + .addDataBuilder() + .setKey(keyRequest.getKey()) + .setShardingKey(keyRequest.getShardingKey()); + keyBuilder.addAllValues(keyRequest.getValuesToFetchList()); + keyBuilder.addAllBags(keyRequest.getBagsToFetchList()); + keyBuilder.addAllWatermarkHolds(keyRequest.getWatermarkHoldsToFetchList()); + } + } + return builder.build(); + }; + private final boolean streamingEngine; private final Supplier idGenerator = new Supplier() { private final AtomicLong idGenerator = new AtomicLong(1L); @@ -259,6 +254,50 @@ public Long get() { return idGenerator.getAndIncrement(); } }; + @Rule public BlockingFn blockingFn = new BlockingFn(); + @Rule public TestRule restoreMDC = new RestoreDataflowLoggingMDC(); + @Rule public ErrorCollector errorCollector = new ErrorCollector(); + WorkUnitClient mockWorkUnitClient = mock(WorkUnitClient.class); + HotKeyLogger hotKeyLogger = mock(HotKeyLogger.class); + + public StreamingDataflowWorkerTest(Boolean streamingEngine) { + this.streamingEngine = streamingEngine; + } + + @Parameterized.Parameters(name = "{index}: [streamingEngine={0}]") + public static Iterable data() { + return Arrays.asList(new Object[][] {{false}, {true}}); + } + + private static CounterUpdate getCounter(Iterable counters, String name) { + for (CounterUpdate counter : counters) { + if (counter.getNameAndKind().getName().equals(name)) { + return counter; + } + } + return null; + } + + static Work createMockWork(long workToken) { + return Work.create( + Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(workToken).build(), + Instant::now, + Collections.emptyList(), + work -> {}); + } + + static Work createMockWork(long workToken, Consumer processWorkFn) { + return Work.create( + Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(workToken).build(), + Instant::now, + Collections.emptyList(), + work -> {}); + } + + private byte[] intervalWindowBytes(IntervalWindow window) throws Exception { + return CoderUtils.encodeToByteArray( + DEFAULT_WINDOW_COLLECTION_CODER, Collections.singletonList(window)); + } private String keyStringForIndex(int index) { return DEFAULT_KEY_STRING + index; @@ -273,7 +312,7 @@ private ParallelInstruction makeWindowingSourceInstruction(Coder coder) { CloudObject.forClassName( "com.google.cloud.dataflow.sdk.util.TimerOrElement$TimerOrElementCoder"); List component = - Collections.singletonList(CloudObjects.asCloudObject(coder, /*sdkComponents=*/ null)); + Collections.singletonList(CloudObjects.asCloudObject(coder, /* sdkComponents= */ null)); Structs.addList(timerCloudObject, PropertyNames.COMPONENT_ENCODINGS, component); CloudObject encodedCoder = CloudObject.forClassName("kind:windowed_value"); @@ -283,7 +322,7 @@ private ParallelInstruction makeWindowingSourceInstruction(Coder coder) { PropertyNames.COMPONENT_ENCODINGS, ImmutableList.of( timerCloudObject, - CloudObjects.asCloudObject(IntervalWindowCoder.of(), /*sdkComponents=*/ null))); + CloudObjects.asCloudObject(IntervalWindowCoder.of(), /* sdkComponents= */ null))); return new ParallelInstruction() .setSystemName(DEFAULT_SOURCE_SYSTEM_NAME) @@ -295,7 +334,7 @@ private ParallelInstruction makeWindowingSourceInstruction(Coder coder) { .setSpec(CloudObject.forClass(WindowingWindmillReader.class)) .setCodec(encodedCoder))) .setOutputs( - Arrays.asList( + Collections.singletonList( new InstructionOutput() .setName(Long.toString(idGenerator.get())) .setCodec(encodedCoder) @@ -315,9 +354,9 @@ private ParallelInstruction makeSourceInstruction(Coder coder) { .setCodec( CloudObjects.asCloudObject( WindowedValue.getFullCoder(coder, IntervalWindow.getCoder()), - /*sdkComponents=*/ null)))) + /* sdkComponents= */ null)))) .setOutputs( - Arrays.asList( + Collections.singletonList( new InstructionOutput() .setName(Long.toString(idGenerator.get())) .setOriginalName(DEFAULT_OUTPUT_ORIGINAL_NAME) @@ -325,7 +364,7 @@ private ParallelInstruction makeSourceInstruction(Coder coder) { .setCodec( CloudObjects.asCloudObject( WindowedValue.getFullCoder(coder, IntervalWindow.getCoder()), - /*sdkComponents=*/ null)))); + /* sdkComponents= */ null)))); } private ParallelInstruction makeDoFnInstruction( @@ -360,9 +399,9 @@ private ParallelInstruction makeDoFnInstruction( .setNumOutputs(1) .setUserFn(spec) .setMultiOutputInfos( - Arrays.asList(new MultiOutputInfo().setTag(PropertyNames.OUTPUT)))) + Collections.singletonList(new MultiOutputInfo().setTag(PropertyNames.OUTPUT)))) .setOutputs( - Arrays.asList( + Collections.singletonList( new InstructionOutput() .setName(PropertyNames.OUTPUT) .setOriginalName(DEFAULT_OUTPUT_ORIGINAL_NAME) @@ -371,7 +410,7 @@ private ParallelInstruction makeDoFnInstruction( CloudObjects.asCloudObject( WindowedValue.getFullCoder( outputCoder, windowingStrategy.getWindowFn().windowCoder()), - /*sdkComponents=*/ null)))); + /* sdkComponents= */ null)))); } private ParallelInstruction makeDoFnInstruction( @@ -407,7 +446,7 @@ private ParallelInstruction makeSinkInstruction( .setCodec( CloudObjects.asCloudObject( WindowedValue.getFullCoder(coder, windowCoder), - /*sdkComponents=*/ null)))); + /* sdkComponents= */ null)))); } private ParallelInstruction makeSinkInstruction( @@ -493,26 +532,6 @@ private Windmill.GetWorkResponse buildSessionInput( .build(); } - private static final Function EMPTY_DATA_RESPONDER = - (GetDataRequest request) -> { - GetDataResponse.Builder builder = GetDataResponse.newBuilder(); - for (ComputationGetDataRequest compRequest : request.getRequestsList()) { - ComputationGetDataResponse.Builder compBuilder = - builder.addDataBuilder().setComputationId(compRequest.getComputationId()); - for (KeyedGetDataRequest keyRequest : compRequest.getRequestsList()) { - KeyedGetDataResponse.Builder keyBuilder = - compBuilder - .addDataBuilder() - .setKey(keyRequest.getKey()) - .setShardingKey(keyRequest.getShardingKey()); - keyBuilder.addAllValues(keyRequest.getValuesToFetchList()); - keyBuilder.addAllBags(keyRequest.getBagsToFetchList()); - keyBuilder.addAllWatermarkHolds(keyRequest.getWatermarkHoldsToFetchList()); - } - } - return builder.build(); - }; - private Windmill.GetWorkResponse makeInput(int index, long timestamp) throws Exception { return makeInput(index, timestamp, keyStringForIndex(index), DEFAULT_SHARDING_KEY); } @@ -552,7 +571,8 @@ private Windmill.GetWorkResponse makeInput( + " }" + "}", CoderUtils.encodeToByteArray( - CollectionCoder.of(IntervalWindow.getCoder()), Arrays.asList(DEFAULT_WINDOW))); + CollectionCoder.of(IntervalWindow.getCoder()), + Collections.singletonList(DEFAULT_WINDOW))); } /** @@ -684,8 +704,7 @@ private StreamingComputationConfig makeDefaultStreamingComputationConfig( return config; } - private ByteString addPaneTag(PaneInfo pane, byte[] windowBytes) - throws CoderException, IOException { + private ByteString addPaneTag(PaneInfo pane, byte[] windowBytes) throws IOException { ByteStringOutputStream output = new ByteStringOutputStream(); PaneInfo.PaneInfoCoder.INSTANCE.encode(pane, output, Context.OUTER); output.write(windowBytes); @@ -718,7 +737,7 @@ private StreamingDataflowWorker makeWorker( throws Exception { StreamingDataflowWorker worker = new StreamingDataflowWorker( - Arrays.asList(defaultMapTask(instructions)), + Collections.singletonList(defaultMapTask(instructions)), IntrinsicMapTaskExecutorFactory.defaultFactory(), mockWorkUnitClient, options, @@ -888,34 +907,6 @@ public void testHotKeyLoggingNotEnabled() throws Exception { verify(hotKeyLogger, atLeastOnce()).logHotKeyDetection(nullable(String.class), any()); } - static class BlockingFn extends DoFn implements TestRule { - - public static CountDownLatch blocker = new CountDownLatch(1); - public static Semaphore counter = new Semaphore(0); - public static AtomicInteger callCounter = new AtomicInteger(0); - - @ProcessElement - public void processElement(ProcessContext c) throws InterruptedException { - callCounter.incrementAndGet(); - counter.release(); - blocker.await(); - c.output(c.element()); - } - - @Override - public Statement apply(final Statement base, final Description description) { - return new Statement() { - @Override - public void evaluate() throws Throwable { - blocker = new CountDownLatch(1); - counter = new Semaphore(0); - callCounter = new AtomicInteger(); - base.evaluate(); - } - }; - } - } - @Test public void testIgnoreRetriedKeys() throws Exception { final int numIters = 4; @@ -1079,21 +1070,6 @@ public void testNumberOfWorkerHarnessThreadsIsHonored() throws Exception { BlockingFn.blocker.countDown(); } - static class KeyTokenInvalidFn extends DoFn, KV> { - - static boolean thrown = false; - - @ProcessElement - public void processElement(ProcessContext c) { - if (!thrown) { - thrown = true; - throw new KeyTokenInvalidException("key"); - } else { - c.output(c.element()); - } - } - } - @Test public void testKeyTokenInvalidException() throws Exception { if (streamingEngine) { @@ -1132,22 +1108,6 @@ public void testKeyTokenInvalidException() throws Exception { assertEquals(1, result.size()); } - static class LargeCommitFn extends DoFn, KV> { - - @ProcessElement - public void processElement(ProcessContext c) { - if (c.element().getKey().equals("large_key")) { - StringBuilder s = new StringBuilder(); - for (int i = 0; i < 100; ++i) { - s.append("large_commit"); - } - c.output(KV.of(c.element().getKey(), s.toString())); - } else { - c.output(c.element()); - } - } - } - @Test public void testKeyCommitTooLargeException() throws Exception { KvCoder kvCoder = KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()); @@ -1216,15 +1176,6 @@ public void testKeyCommitTooLargeException() throws Exception { assertTrue(foundErrors); } - static class ChangeKeysFn extends DoFn, KV> { - - @ProcessElement - public void processElement(ProcessContext c) { - KV elem = c.element(); - c.output(KV.of(elem.getKey() + "_" + elem.getValue(), elem.getValue())); - } - } - @Test public void testKeyChange() throws Exception { KvCoder kvCoder = KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()); @@ -1280,23 +1231,6 @@ public void testKeyChange() throws Exception { } } - static class TestExceptionFn extends DoFn { - - @ProcessElement - public void processElement(ProcessContext c) throws Exception { - if (firstTime) { - firstTime = false; - try { - throw new Exception("Exception!"); - } catch (Exception e) { - throw new Exception("Another exception!", e); - } - } - } - - boolean firstTime = true; - } - @Test(timeout = 30000) public void testExceptions() throws Exception { if (streamingEngine) { @@ -1340,7 +1274,8 @@ public void testExceptions() throws Exception { + " }" + "}", CoderUtils.encodeToByteArray( - CollectionCoder.of(IntervalWindow.getCoder()), Arrays.asList(DEFAULT_WINDOW)))); + CollectionCoder.of(IntervalWindow.getCoder()), + Collections.singletonList(DEFAULT_WINDOW)))); StreamingDataflowWorker worker = makeWorker(instructions, createTestingPipelineOptions(server), true /* publishCounters */); @@ -1422,7 +1357,7 @@ public void testAssignWindows() throws Exception { .setNumOutputs(1) .setUserFn(spec)) .setOutputs( - Arrays.asList( + Collections.singletonList( new InstructionOutput() .setOriginalName(DEFAULT_OUTPUT_ORIGINAL_NAME) .setSystemName(DEFAULT_OUTPUT_SYSTEM_NAME) @@ -1431,7 +1366,7 @@ public void testAssignWindows() throws Exception { CloudObjects.asCloudObject( WindowedValue.getFullCoder( StringUtf8Coder.of(), IntervalWindow.getCoder()), - /*sdkComponents=*/ null)))); + /* sdkComponents= */ null)))); List instructions = Arrays.asList( @@ -1539,7 +1474,7 @@ public void testMergeWindows() throws Exception { addObject( spec, WorkerPropertyNames.INPUT_CODER, - CloudObjects.asCloudObject(windowedKvCoder, /*sdkComponents=*/ null)); + CloudObjects.asCloudObject(windowedKvCoder, /* sdkComponents= */ null)); ParallelInstruction mergeWindowsInstruction = new ParallelInstruction() @@ -1552,14 +1487,14 @@ public void testMergeWindows() throws Exception { .setNumOutputs(1) .setUserFn(spec)) .setOutputs( - Arrays.asList( + Collections.singletonList( new InstructionOutput() .setOriginalName(DEFAULT_OUTPUT_ORIGINAL_NAME) .setSystemName(DEFAULT_OUTPUT_SYSTEM_NAME) .setName("output") .setCodec( CloudObjects.asCloudObject( - windowedGroupedCoder, /*sdkComponents=*/ null)))); + windowedGroupedCoder, /* sdkComponents= */ null)))); List instructions = Arrays.asList( @@ -1749,7 +1684,7 @@ public void testMergeWindows() throws Exception { assertEquals( PaneInfo.createPane(true, true, Timing.ON_TIME), PaneInfoCoder.INSTANCE.decode(inStream)); assertEquals( - Arrays.asList(WINDOW_AT_ZERO), + Collections.singletonList(WINDOW_AT_ZERO), DEFAULT_WINDOW_COLLECTION_CODER.decode(inStream, Coder.Context.OUTER)); // Data was deleted @@ -1799,15 +1734,6 @@ public void testMergeWindows() throws Exception { assertEquals(0L, splitIntToLong(getCounter(counters, "WindmillShuffleBytesRead").getInteger())); } - static class PassthroughDoFn - extends DoFn>, KV>> { - - @ProcessElement - public void processElement(ProcessContext c) { - c.output(c.element()); - } - } - @Test // Runs a merging windows test verifying stored state, holds and timers with caching due to // the first processing having is_new_key set. @@ -1835,7 +1761,7 @@ public void testMergeWindowsCaching() throws Exception { addObject( spec, WorkerPropertyNames.INPUT_CODER, - CloudObjects.asCloudObject(windowedKvCoder, /*sdkComponents=*/ null)); + CloudObjects.asCloudObject(windowedKvCoder, /* sdkComponents= */ null)); ParallelInstruction mergeWindowsInstruction = new ParallelInstruction() @@ -1848,14 +1774,14 @@ public void testMergeWindowsCaching() throws Exception { .setNumOutputs(1) .setUserFn(spec)) .setOutputs( - Arrays.asList( + Collections.singletonList( new InstructionOutput() .setOriginalName(DEFAULT_OUTPUT_ORIGINAL_NAME) .setSystemName(DEFAULT_OUTPUT_SYSTEM_NAME) .setName("output") .setCodec( CloudObjects.asCloudObject( - windowedGroupedCoder, /*sdkComponents=*/ null)))); + windowedGroupedCoder, /* sdkComponents= */ null)))); List instructions = Arrays.asList( @@ -2048,7 +1974,7 @@ public void testMergeWindowsCaching() throws Exception { assertEquals( PaneInfo.createPane(true, true, Timing.ON_TIME), PaneInfoCoder.INSTANCE.decode(inStream)); assertEquals( - Arrays.asList(WINDOW_AT_ZERO), + Collections.singletonList(WINDOW_AT_ZERO), DEFAULT_WINDOW_COLLECTION_CODER.decode(inStream, Coder.Context.OUTER)); // Data was deleted @@ -2103,27 +2029,6 @@ public void testMergeWindowsCaching() throws Exception { assertEquals(4, stats.missCount()); } - static class Action { - - public Action(GetWorkResponse response) { - this.response = response; - } - - Action withHolds(WatermarkHold... holds) { - this.expectedHolds = holds; - return this; - } - - Action withTimers(Timer... timers) { - this.expectedTimers = timers; - return this; - } - - GetWorkResponse response; - Timer[] expectedTimers = new Timer[] {}; - WatermarkHold[] expectedHolds = new WatermarkHold[] {}; - } - // Helper for running tests for merging sessions based upon Actions consisting of GetWorkResponse // and expected timers and holds in the corresponding commit. All GetData requests are responded // to with empty state, relying on user worker caching to keep data written. @@ -2156,7 +2061,7 @@ private void runMergeSessionsActions(List actions) throws Exception { addObject( spec, WorkerPropertyNames.INPUT_CODER, - CloudObjects.asCloudObject(windowedKvCoder, /*sdkComponents=*/ null)); + CloudObjects.asCloudObject(windowedKvCoder, /* sdkComponents= */ null)); ParallelInstruction mergeWindowsInstruction = new ParallelInstruction() @@ -2169,14 +2074,14 @@ private void runMergeSessionsActions(List actions) throws Exception { .setNumOutputs(1) .setUserFn(spec)) .setOutputs( - Arrays.asList( + Collections.singletonList( new InstructionOutput() .setOriginalName(DEFAULT_OUTPUT_ORIGINAL_NAME) .setSystemName(DEFAULT_OUTPUT_SYSTEM_NAME) .setName("output") .setCodec( CloudObjects.asCloudObject( - windowedGroupedCoder, /*sdkComponents=*/ null)))); + windowedGroupedCoder, /* sdkComponents= */ null)))); List instructions = Arrays.asList( @@ -2211,8 +2116,10 @@ private void runMergeSessionsActions(List actions) throws Exception { public void testMergeSessionWindows() throws Exception { // Test a single late window. runMergeSessionsActions( - Arrays.asList( - new Action(buildSessionInput(1, 40, 0, Arrays.asList(1L), Collections.EMPTY_LIST)) + Collections.singletonList( + new Action( + buildSessionInput( + 1, 40, 0, Collections.singletonList(1L), Collections.EMPTY_LIST)) .withHolds( buildHold("/gAAAAAAAAAsK/+uhold", -1, true), buildHold("/gAAAAAAAAAsK/+uextra", -1, true)) @@ -2226,7 +2133,9 @@ public void testMergeSessionWindows() throws Exception { // elements runMergeSessionsActions( Arrays.asList( - new Action(buildSessionInput(1, 0, 0, Arrays.asList(1L), Collections.EMPTY_LIST)) + new Action( + buildSessionInput( + 1, 0, 0, Collections.singletonList(1L), Collections.EMPTY_LIST)) .withHolds(buildHold("/gAAAAAAAAAsK/+uhold", 10, false)) .withTimers( buildWatermarkTimer("/s/gAAAAAAAAAsK/+0", 10), @@ -2237,12 +2146,14 @@ public void testMergeSessionWindows() throws Exception { 30, 0, Collections.EMPTY_LIST, - Arrays.asList(buildWatermarkTimer("/s/gAAAAAAAAAsK/+0", 10)))) + Collections.singletonList(buildWatermarkTimer("/s/gAAAAAAAAAsK/+0", 10)))) .withTimers(buildWatermarkTimer("/s/gAAAAAAAAAsK/+0", 3600010)) .withHolds( buildHold("/gAAAAAAAAAsK/+uhold", -1, true), buildHold("/gAAAAAAAAAsK/+uextra", -1, true)), - new Action(buildSessionInput(3, 30, 0, Arrays.asList(8L), Collections.EMPTY_LIST)) + new Action( + buildSessionInput( + 3, 30, 0, Collections.singletonList(8L), Collections.EMPTY_LIST)) .withTimers( buildWatermarkTimer("/s/gAAAAAAAABIR/+0", 3600017), buildWatermarkTimer("/s/gAAAAAAAAAsK/+0", 10, true), @@ -2250,7 +2161,9 @@ public void testMergeSessionWindows() throws Exception { .withHolds( buildHold("/gAAAAAAAAAsK/+uhold", -1, true), buildHold("/gAAAAAAAAAsK/+uextra", -1, true)), - new Action(buildSessionInput(4, 30, 0, Arrays.asList(31L), Collections.EMPTY_LIST)) + new Action( + buildSessionInput( + 4, 30, 0, Collections.singletonList(31L), Collections.EMPTY_LIST)) .withTimers( buildWatermarkTimer("/s/gAAAAAAAACkK/+0", 3600040), buildWatermarkTimer("/s/gAAAAAAAACkK/+0", 40)) @@ -2274,31 +2187,13 @@ public void testMergeSessionWindows() throws Exception { 50, 0, Collections.EMPTY_LIST, - Arrays.asList(buildWatermarkTimer("/s/gAAAAAAAACko/+0", 40)))) + Collections.singletonList(buildWatermarkTimer("/s/gAAAAAAAACko/+0", 40)))) .withTimers(buildWatermarkTimer("/s/gAAAAAAAACko/+0", 3600040)) .withHolds( buildHold("/gAAAAAAAAAsK/+uhold", -1, true), buildHold("/gAAAAAAAAAsK/+uextra", -1, true)))); } - private static CounterUpdate getCounter(Iterable counters, String name) { - for (CounterUpdate counter : counters) { - if (counter.getNameAndKind().getName().equals(name)) { - return counter; - } - } - return null; - } - - static class PrintFn extends DoFn>, String> { - - @ProcessElement - public void processElement(ProcessContext c) { - KV elem = c.element().getValue(); - c.output(elem.getKey() + ":" + elem.getValue()); - } - } - private List makeUnboundedSourcePipeline() throws Exception { return makeUnboundedSourcePipeline(1, new PrintFn()); } @@ -2316,7 +2211,7 @@ private List makeUnboundedSourcePipeline( ValueWithRecordId.ValueWithRecordIdCoder.of( KvCoder.of(VarIntCoder.of(), VarIntCoder.of())), GlobalWindow.Coder.INSTANCE), - /*sdkComponents=*/ null); + /* sdkComponents= */ null); return Arrays.asList( new ParallelInstruction() @@ -2329,7 +2224,7 @@ private List makeUnboundedSourcePipeline( new TestCountingSource(numMessagesPerShard), options) .setCodec(codec))) .setOutputs( - Arrays.asList( + Collections.singletonList( new InstructionOutput() .setName("read_output") .setOriginalName(DEFAULT_OUTPUT_ORIGINAL_NAME) @@ -2382,7 +2277,7 @@ public void testUnboundedSources() throws Exception { PaneInfo.NO_FIRING, CoderUtils.encodeToByteArray( CollectionCoder.of(GlobalWindow.Coder.INSTANCE), - Arrays.asList(GlobalWindow.INSTANCE)), + Collections.singletonList(GlobalWindow.INSTANCE)), parseCommitRequest( "key: \"0000000000000001\" " + "sharding_key: 1 " @@ -2457,7 +2352,7 @@ public void testUnboundedSources() throws Exception { assertThat(finalizeTracker, contains(0)); - assertEquals(null, getCounter(counters, "dataflow_input_size-computation")); + assertNull(getCounter(counters, "dataflow_input_size-computation")); // Test recovery (on a new key so fresh reader state). Counter is done. server @@ -2503,7 +2398,7 @@ public void testUnboundedSources() throws Exception { + "source_watermark: 1000") .build())); - assertEquals(null, getCounter(counters, "dataflow_input_size-computation")); + assertNull(getCounter(counters, "dataflow_input_size-computation")); } @Test @@ -2549,7 +2444,7 @@ public void testUnboundedSourcesDrain() throws Exception { PaneInfo.NO_FIRING, CoderUtils.encodeToByteArray( CollectionCoder.of(GlobalWindow.Coder.INSTANCE), - Arrays.asList(GlobalWindow.INSTANCE)), + Collections.singletonList(GlobalWindow.INSTANCE)), parseCommitRequest( "key: \"0000000000000001\" " + "sharding_key: 1 " @@ -2659,7 +2554,7 @@ public void testUnboundedSourceWorkRetry() throws Exception { PaneInfo.NO_FIRING, CoderUtils.encodeToByteArray( CollectionCoder.of(GlobalWindow.Coder.INSTANCE), - Arrays.asList(GlobalWindow.INSTANCE)), + Collections.singletonList(GlobalWindow.INSTANCE)), parseCommitRequest( "key: \"0000000000000001\" " + "sharding_key: 1 " @@ -2750,26 +2645,13 @@ public void testUnboundedSourceWorkRetry() throws Exception { assertThat(finalizeTracker, contains(0)); } - private static class MockWork extends StreamingDataflowWorker.Work { - - public MockWork(long workToken) { - super( - Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(workToken).build(), - Instant::now, - Collections.emptyList()); - } - - @Override - public void run() {} - } - @Test public void testActiveWork() throws Exception { BoundedQueueExecutor mockExecutor = Mockito.mock(BoundedQueueExecutor.class); - StreamingDataflowWorker.ComputationState computationState = - new StreamingDataflowWorker.ComputationState( + ComputationState computationState = + new ComputationState( "computation", - defaultMapTask(Arrays.asList(makeSourceInstruction(StringUtf8Coder.of()))), + defaultMapTask(Collections.singletonList(makeSourceInstruction(StringUtf8Coder.of()))), mockExecutor, ImmutableMap.of(), null); @@ -2777,49 +2659,49 @@ public void testActiveWork() throws Exception { ShardedKey key1 = ShardedKey.create(ByteString.copyFromUtf8("key1"), 1); ShardedKey key2 = ShardedKey.create(ByteString.copyFromUtf8("key2"), 2); - MockWork m1 = new MockWork(1); + Work m1 = createMockWork(1); assertTrue(computationState.activateWork(key1, m1)); Mockito.verify(mockExecutor).execute(m1, m1.getWorkItem().getSerializedSize()); - computationState.completeWork(key1, 1); + computationState.completeWorkAndScheduleNextWorkForKey(key1, 1); Mockito.verifyNoMoreInteractions(mockExecutor); // Verify work queues. - MockWork m2 = new MockWork(2); + Work m2 = createMockWork(2); assertTrue(computationState.activateWork(key1, m2)); Mockito.verify(mockExecutor).execute(m2, m2.getWorkItem().getSerializedSize()); - MockWork m3 = new MockWork(3); + Work m3 = createMockWork(3); assertTrue(computationState.activateWork(key1, m3)); Mockito.verifyNoMoreInteractions(mockExecutor); // Verify another key is a separate queue. - MockWork m4 = new MockWork(4); + Work m4 = createMockWork(4); assertTrue(computationState.activateWork(key2, m4)); Mockito.verify(mockExecutor).execute(m4, m4.getWorkItem().getSerializedSize()); - computationState.completeWork(key2, 4); + computationState.completeWorkAndScheduleNextWorkForKey(key2, 4); Mockito.verifyNoMoreInteractions(mockExecutor); - computationState.completeWork(key1, 2); + computationState.completeWorkAndScheduleNextWorkForKey(key1, 2); Mockito.verify(mockExecutor).forceExecute(m3, m3.getWorkItem().getSerializedSize()); - computationState.completeWork(key1, 3); + computationState.completeWorkAndScheduleNextWorkForKey(key1, 3); Mockito.verifyNoMoreInteractions(mockExecutor); // Verify duplicate work dropped. - MockWork m5 = new MockWork(5); + Work m5 = createMockWork(5); computationState.activateWork(key1, m5); Mockito.verify(mockExecutor).execute(m5, m5.getWorkItem().getSerializedSize()); assertFalse(computationState.activateWork(key1, m5)); Mockito.verifyNoMoreInteractions(mockExecutor); - computationState.completeWork(key1, 5); + computationState.completeWorkAndScheduleNextWorkForKey(key1, 5); Mockito.verifyNoMoreInteractions(mockExecutor); } @Test public void testActiveWorkForShardedKeys() throws Exception { BoundedQueueExecutor mockExecutor = Mockito.mock(BoundedQueueExecutor.class); - StreamingDataflowWorker.ComputationState computationState = - new StreamingDataflowWorker.ComputationState( + ComputationState computationState = + new ComputationState( "computation", - defaultMapTask(Arrays.asList(makeSourceInstruction(StringUtf8Coder.of()))), + defaultMapTask(Collections.singletonList(makeSourceInstruction(StringUtf8Coder.of()))), mockExecutor, ImmutableMap.of(), null); @@ -2827,22 +2709,22 @@ public void testActiveWorkForShardedKeys() throws Exception { ShardedKey key1Shard1 = ShardedKey.create(ByteString.copyFromUtf8("key1"), 1); ShardedKey key1Shard2 = ShardedKey.create(ByteString.copyFromUtf8("key1"), 2); - MockWork m1 = new MockWork(1); + Work m1 = createMockWork(1); assertTrue(computationState.activateWork(key1Shard1, m1)); Mockito.verify(mockExecutor).execute(m1, m1.getWorkItem().getSerializedSize()); - computationState.completeWork(key1Shard1, 1); + computationState.completeWorkAndScheduleNextWorkForKey(key1Shard1, 1); Mockito.verifyNoMoreInteractions(mockExecutor); // Verify work queues. - MockWork m2 = new MockWork(2); + Work m2 = createMockWork(2); assertTrue(computationState.activateWork(key1Shard1, m2)); Mockito.verify(mockExecutor).execute(m2, m2.getWorkItem().getSerializedSize()); - MockWork m3 = new MockWork(3); + Work m3 = createMockWork(3); assertTrue(computationState.activateWork(key1Shard1, m3)); Mockito.verifyNoMoreInteractions(mockExecutor); // Verify a different shard of key is a separate queue. - MockWork m4 = new MockWork(3); + Work m4 = createMockWork(3); assertFalse(computationState.activateWork(key1Shard1, m4)); Mockito.verifyNoMoreInteractions(mockExecutor); assertTrue(computationState.activateWork(key1Shard2, m4)); @@ -2850,7 +2732,7 @@ public void testActiveWorkForShardedKeys() throws Exception { // Verify duplicate work dropped assertFalse(computationState.activateWork(key1Shard2, m4)); - computationState.completeWork(key1Shard2, 3); + computationState.completeWorkAndScheduleNextWorkForKey(key1Shard2, 3); Mockito.verifyNoMoreInteractions(mockExecutor); } @@ -2873,10 +2755,10 @@ public void testMaxThreadMetric() throws Exception { .setDaemon(true) .build()); - StreamingDataflowWorker.ComputationState computationState = - new StreamingDataflowWorker.ComputationState( + ComputationState computationState = + new ComputationState( "computation", - defaultMapTask(Arrays.asList(makeSourceInstruction(StringUtf8Coder.of()))), + defaultMapTask(Collections.singletonList(makeSourceInstruction(StringUtf8Coder.of()))), executor, ImmutableMap.of(), null); @@ -2886,29 +2768,17 @@ public void testMaxThreadMetric() throws Exception { // overriding definition of MockWork to add sleep, which will help us keep track of how // long each work item takes to process and therefore let us manipulate how long the time // at which we're at max threads is. - MockWork m2 = - new MockWork(2) { - @Override - public void run() { - try { - Thread.sleep(1000); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } + Consumer sleepProcessWorkFn = + unused -> { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); } }; - MockWork m3 = - new MockWork(3) { - @Override - public void run() { - try { - Thread.sleep(1000); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - } - }; + Work m2 = createMockWork(2, sleepProcessWorkFn); + Work m3 = createMockWork(3, sleepProcessWorkFn); assertTrue(computationState.activateWork(key1Shard1, m2)); assertTrue(computationState.activateWork(key1Shard1, m3)); @@ -2923,41 +2793,6 @@ public void run() { executor.shutdown(); } - static class TestExceptionInvalidatesCacheFn - extends DoFn>, String> { - - static boolean thrown = false; - - @StateId("int") - private final StateSpec> counter = StateSpecs.value(VarIntCoder.of()); - - @ProcessElement - public void processElement(ProcessContext c, @StateId("int") ValueState state) - throws Exception { - KV elem = c.element().getValue(); - if (elem.getValue() == 0) { - LOG.error("**** COUNTER 0 ****"); - assertEquals(null, state.read()); - state.write(42); - assertEquals((Integer) 42, state.read()); - } else if (elem.getValue() == 1) { - LOG.error("**** COUNTER 1 ****"); - assertEquals((Integer) 42, state.read()); - } else if (elem.getValue() == 2) { - if (!thrown) { - LOG.error("**** COUNTER 2 (will throw) ****"); - thrown = true; - throw new Exception("Exception!"); - } - LOG.error("**** COUNTER 2 (retry) ****"); - assertEquals((Integer) 42, state.read()); - } else { - throw new RuntimeException("only expecting values [0,2]"); - } - c.output(elem.getKey() + ":" + elem.getValue()); - } - } - @Test public void testExceptionInvalidatesCache() throws Exception { // We'll need to force the system to limit bundles to one message at a time. @@ -3003,7 +2838,7 @@ public void testExceptionInvalidatesCache() throws Exception { ValueWithRecordId.ValueWithRecordIdCoder.of( KvCoder.of(VarIntCoder.of(), VarIntCoder.of())), GlobalWindow.Coder.INSTANCE), - /*sdkComponents=*/ null); + /* sdkComponents= */ null); TestCountingSource counter = new TestCountingSource(3).withThrowOnFirstSnapshot(true); @@ -3018,7 +2853,7 @@ public void testExceptionInvalidatesCache() throws Exception { .setSource( CustomSources.serializeToCloudSource(counter, options).setCodec(codec))) .setOutputs( - Arrays.asList( + Collections.singletonList( new InstructionOutput() .setName("read_output") .setOriginalName(DEFAULT_OUTPUT_ORIGINAL_NAME) @@ -3168,21 +3003,6 @@ public void testExceptionInvalidatesCache() throws Exception { } } - private static class FanoutFn extends DoFn { - - @ProcessElement - public void processElement(ProcessContext c) { - StringBuilder builder = new StringBuilder(1000000); - for (int i = 0; i < 1000000; i++) { - builder.append(' '); - } - String largeString = builder.toString(); - for (int i = 0; i < 3000; i++) { - c.output(largeString); - } - } - } - @Test public void testHugeCommits() throws Exception { List instructions = @@ -3202,15 +3022,6 @@ public void testHugeCommits() throws Exception { worker.stop(); } - private static class SlowDoFn extends DoFn { - - @ProcessElement - public void processElement(ProcessContext c) throws Exception { - Thread.sleep(1000); - c.output(c.element()); - } - } - @Test public void testActiveWorkRefresh() throws Exception { List instructions = @@ -3235,290 +3046,54 @@ public void testActiveWorkRefresh() throws Exception { assertThat(server.numGetDataRequests(), greaterThan(0)); } - static class FakeClock implements Supplier { - private class FakeScheduledExecutor implements ScheduledExecutorService { - @Override - public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { - return true; - } - - @Override - public void execute(Runnable command) { - throw new UnsupportedOperationException("Not implemented yet"); - } - - @Override - public List> invokeAll(Collection> tasks) - throws InterruptedException { - throw new UnsupportedOperationException("Not implemented yet"); - } - - @Override - public List> invokeAll( - Collection> tasks, long timeout, TimeUnit unit) - throws InterruptedException { - throw new UnsupportedOperationException("Not implemented yet"); - } - - @Override - public T invokeAny(Collection> tasks) - throws ExecutionException, InterruptedException { - throw new UnsupportedOperationException("Not implemented yet"); - } - - @Override - public T invokeAny(Collection> tasks, long timeout, TimeUnit unit) - throws ExecutionException, InterruptedException, TimeoutException { - throw new UnsupportedOperationException("Not implemented yet"); - } - - @Override - public boolean isShutdown() { - throw new UnsupportedOperationException("Not implemented yet"); - } - - @Override - public boolean isTerminated() { - throw new UnsupportedOperationException("Not implemented yet"); - } - - @Override - public void shutdown() {} - - @Override - public List shutdownNow() { - throw new UnsupportedOperationException("Not implemented yet"); - } - - @Override - public Future submit(Callable task) { - throw new UnsupportedOperationException("Not implemented yet"); - } - - @Override - public Future submit(Runnable task) { - throw new UnsupportedOperationException("Not implemented yet"); - } - - @Override - public Future submit(Runnable task, T result) { - throw new UnsupportedOperationException("Not implemented yet"); - } - - @Override - public ScheduledFuture schedule(Callable callable, long delay, TimeUnit unit) { - throw new UnsupportedOperationException("Not implemented yet"); - } - - @Override - public ScheduledFuture schedule(Runnable command, long delay, TimeUnit unit) { - throw new UnsupportedOperationException("Not implemented yet"); - } - - @Override - public ScheduledFuture scheduleAtFixedRate( - Runnable command, long initialDelay, long period, TimeUnit unit) { - throw new UnsupportedOperationException("Not implemented yet"); - } - - @Override - public ScheduledFuture scheduleWithFixedDelay( - Runnable command, long initialDelay, long delay, TimeUnit unit) { - if (delay <= 0) { - throw new UnsupportedOperationException( - "Please supply a delay > 0 to scheduleWithFixedDelay"); - } - FakeClock.this.schedule( - Duration.millis(unit.toMillis(initialDelay)), - new Runnable() { - @Override - public void run() { - command.run(); - FakeClock.this.schedule(Duration.millis(unit.toMillis(delay)), this); - } - }); - FakeClock.this.sleep(Duration.ZERO); // Execute work that has an intial delay of zero. - return null; - } - } - - private static class Job implements Comparable { - final Instant when; - final Runnable work; - - Job(Instant when, Runnable work) { - this.when = when; - this.work = work; - } - - @Override - public int compareTo(Job job) { - return when.compareTo(job.when); - } - } - - private final PriorityQueue jobs = new PriorityQueue<>(); - private Instant now = Instant.now(); - - public ScheduledExecutorService newFakeScheduledExecutor(String unused) { - return new FakeScheduledExecutor(); - } - - @Override - public synchronized Instant get() { - return now; - } - - public synchronized void clear() { - jobs.clear(); - } - - public synchronized void sleep(Duration duration) { - if (duration.isShorterThan(Duration.ZERO)) { - throw new UnsupportedOperationException("Cannot sleep backwards in time"); - } - Instant endOfSleep = now.plus(duration); - while (true) { - Job job = jobs.peek(); - if (job == null || job.when.isAfter(endOfSleep)) { - break; - } - jobs.remove(); - now = job.when; - job.work.run(); - } - now = endOfSleep; - } - - private synchronized void schedule(Duration fromNow, Runnable work) { - jobs.add(new Job(now.plus(fromNow), work)); - } - } - - private static class FakeSlowDoFn extends DoFn { - private static FakeClock clock; // A static variable keeps this DoFn serializable. - private final Duration sleep; - - FakeSlowDoFn(FakeClock clock, Duration sleep) { - FakeSlowDoFn.clock = clock; - this.sleep = sleep; - } - - @ProcessElement - public void processElement(ProcessContext c) throws Exception { - clock.sleep(sleep); - c.output(c.element()); - } - } - @Test - public void testLatencyAttributionProtobufsPopulated() throws Exception { + public void testLatencyAttributionProtobufsPopulated() { FakeClock clock = new FakeClock(); - StreamingDataflowWorker.Work work = - new StreamingDataflowWorker.Work(null, clock, Collections.emptyList()) { - @Override - public void run() {} - }; + Work work = Work.create(null, clock, Collections.emptyList(), unused -> {}); clock.sleep(Duration.millis(10)); - work.setState(StreamingDataflowWorker.Work.State.PROCESSING); + work.setState(Work.State.PROCESSING); clock.sleep(Duration.millis(20)); - work.setState(StreamingDataflowWorker.Work.State.READING); + work.setState(Work.State.READING); clock.sleep(Duration.millis(30)); - work.setState(StreamingDataflowWorker.Work.State.PROCESSING); + work.setState(Work.State.PROCESSING); clock.sleep(Duration.millis(40)); - work.setState(StreamingDataflowWorker.Work.State.COMMIT_QUEUED); + work.setState(Work.State.COMMIT_QUEUED); clock.sleep(Duration.millis(50)); - work.setState(StreamingDataflowWorker.Work.State.COMMITTING); + work.setState(Work.State.COMMITTING); clock.sleep(Duration.millis(60)); Iterator it = work.getLatencyAttributions().iterator(); assertTrue(it.hasNext()); LatencyAttribution lat = it.next(); - assertTrue(lat.getState() == LatencyAttribution.State.QUEUED); - assertTrue(lat.getTotalDurationMillis() == 10); + assertSame(State.QUEUED, lat.getState()); + assertEquals(10, lat.getTotalDurationMillis()); assertTrue(it.hasNext()); lat = it.next(); - assertTrue(lat.getState() == LatencyAttribution.State.ACTIVE); - assertTrue(lat.getTotalDurationMillis() == 60); + assertSame(State.ACTIVE, lat.getState()); + assertEquals(60, lat.getTotalDurationMillis()); assertTrue(it.hasNext()); lat = it.next(); - assertTrue(lat.getState() == LatencyAttribution.State.READING); - assertTrue(lat.getTotalDurationMillis() == 30); + assertSame(State.READING, lat.getState()); + assertEquals(30, lat.getTotalDurationMillis()); assertTrue(it.hasNext()); lat = it.next(); - assertTrue(lat.getState() == LatencyAttribution.State.COMMITTING); - assertTrue(lat.getTotalDurationMillis() == 110); - assertTrue(!it.hasNext()); + assertSame(State.COMMITTING, lat.getState()); + assertEquals(110, lat.getTotalDurationMillis()); + assertFalse(it.hasNext()); } - // Aggregates LatencyAttribution data from active work refresh requests. - static class ActiveWorkRefreshSink { - private final Function responder; - private final Map> totalDurations = - new HashMap<>(); - - ActiveWorkRefreshSink(Function responder) { - this.responder = responder; - } + @Test + public void testLatencyAttributionToQueuedState() throws Exception { + final int workToken = 3232; // A unique id makes it easier to search logs. - Duration getLatencyAttributionDuration(long workToken, LatencyAttribution.State state) { - EnumMap durations = totalDurations.get(workToken); - return durations == null ? Duration.ZERO : durations.getOrDefault(state, Duration.ZERO); - } - - boolean isActiveWorkRefresh(GetDataRequest request) { - for (ComputationGetDataRequest computationRequest : request.getRequestsList()) { - if (!computationRequest.getComputationId().equals(DEFAULT_COMPUTATION_ID)) { - return false; - } - for (KeyedGetDataRequest keyedRequest : computationRequest.getRequestsList()) { - if (keyedRequest.getWorkToken() == 0 - || keyedRequest.getShardingKey() != DEFAULT_SHARDING_KEY - || keyedRequest.getValuesToFetchCount() != 0 - || keyedRequest.getBagsToFetchCount() != 0 - || keyedRequest.getTagValuePrefixesToFetchCount() != 0 - || keyedRequest.getWatermarkHoldsToFetchCount() != 0) { - return false; - } - } - } - return true; - } - - GetDataResponse getData(GetDataRequest request) { - if (!isActiveWorkRefresh(request)) { - return responder.apply(request); - } - for (ComputationGetDataRequest computationRequest : request.getRequestsList()) { - for (KeyedGetDataRequest keyedRequest : computationRequest.getRequestsList()) { - for (LatencyAttribution la : keyedRequest.getLatencyAttributionList()) { - EnumMap durations = - totalDurations.computeIfAbsent( - keyedRequest.getWorkToken(), - (Long workToken) -> - new EnumMap( - LatencyAttribution.State.class)); - Duration cur = Duration.millis(la.getTotalDurationMillis()); - durations.compute(la.getState(), (s, d) -> d == null || d.isShorterThan(cur) ? cur : d); - } - } - } - return EMPTY_DATA_RESPONDER.apply(request); - } - } - - @Test - public void testLatencyAttributionToQueuedState() throws Exception { - final int workToken = 3232; // A unique id makes it easier to search logs. - - FakeClock clock = new FakeClock(); - List instructions = - Arrays.asList( - makeSourceInstruction(StringUtf8Coder.of()), - makeDoFnInstruction( - new FakeSlowDoFn(clock, Duration.millis(1000)), 0, StringUtf8Coder.of()), - makeSinkInstruction(StringUtf8Coder.of(), 0)); + FakeClock clock = new FakeClock(); + List instructions = + Arrays.asList( + makeSourceInstruction(StringUtf8Coder.of()), + makeDoFnInstruction( + new FakeSlowDoFn(clock, Duration.millis(1000)), 0, StringUtf8Coder.of()), + makeSinkInstruction(StringUtf8Coder.of(), 0)); FakeWindmillServer server = new FakeWindmillServer(errorCollector); StreamingDataflowWorkerOptions options = createTestingPipelineOptions(server); @@ -3545,14 +3120,9 @@ public void testLatencyAttributionToQueuedState() throws Exception { worker.stop(); - assertTrue( - awrSink - .getLatencyAttributionDuration(workToken, LatencyAttribution.State.QUEUED) - .equals(Duration.millis(1000))); - assertTrue( - awrSink - .getLatencyAttributionDuration(workToken + 1, LatencyAttribution.State.QUEUED) - .equals(Duration.ZERO)); + assertEquals( + awrSink.getLatencyAttributionDuration(workToken, State.QUEUED), Duration.millis(1000)); + assertEquals(awrSink.getLatencyAttributionDuration(workToken + 1, State.QUEUED), Duration.ZERO); } @Test @@ -3587,22 +3157,8 @@ public void testLatencyAttributionToActiveState() throws Exception { worker.stop(); - assertTrue( - awrSink - .getLatencyAttributionDuration(workToken, LatencyAttribution.State.ACTIVE) - .equals(Duration.millis(1000))); - } - - // A DoFn that triggers a GetData request. - static class ReadingDoFn extends DoFn { - @StateId("int") - private final StateSpec> counter = StateSpecs.value(VarIntCoder.of()); - - @ProcessElement - public void processElement(ProcessContext c, @StateId("int") ValueState state) { - state.read(); - c.output(c.element()); - } + assertEquals( + awrSink.getLatencyAttributionDuration(workToken, State.ACTIVE), Duration.millis(1000)); } @Test @@ -3642,10 +3198,8 @@ public void testLatencyAttributionToReadingState() throws Exception { worker.stop(); - assertTrue( - awrSink - .getLatencyAttributionDuration(workToken, LatencyAttribution.State.READING) - .equals(Duration.millis(1000))); + assertEquals( + awrSink.getLatencyAttributionDuration(workToken, State.READING), Duration.millis(1000)); } @Test @@ -3685,10 +3239,8 @@ public void testLatencyAttributionToCommittingState() throws Exception { worker.stop(); - assertTrue( - awrSink - .getLatencyAttributionDuration(workToken, LatencyAttribution.State.COMMITTING) - .equals(Duration.millis(1000))); + assertEquals( + awrSink.getLatencyAttributionDuration(workToken, State.COMMITTING), Duration.millis(1000)); } @Test @@ -3742,24 +3294,6 @@ public void testLatencyAttributionPopulatedInCommitRequest() throws Exception { } } - /** For each input element, emits a large string. */ - private static class InflateDoFn extends DoFn>, String> { - - final int inflatedSize; - - /** For each input elements, outputs a string of this length */ - InflateDoFn(int inflatedSize) { - this.inflatedSize = inflatedSize; - } - - @ProcessElement - public void processElement(ProcessContext c) { - char[] chars = new char[inflatedSize]; - Arrays.fill(chars, ' '); - c.output(new String(chars)); - } - } - @Test public void testLimitOnOutputBundleSize() throws Exception { // This verifies that ReadOperation, StreamingModeExecutionContext, and windmill sinks @@ -3958,4 +3492,459 @@ public void testStuckCommit() throws Exception { .build(), removeDynamicFields(result.get(1L))); } + + static class BlockingFn extends DoFn implements TestRule { + + public static CountDownLatch blocker = new CountDownLatch(1); + public static Semaphore counter = new Semaphore(0); + public static AtomicInteger callCounter = new AtomicInteger(0); + + @ProcessElement + public void processElement(ProcessContext c) throws InterruptedException { + callCounter.incrementAndGet(); + counter.release(); + blocker.await(); + c.output(c.element()); + } + + @Override + public Statement apply(final Statement base, final Description description) { + return new Statement() { + @Override + public void evaluate() throws Throwable { + blocker = new CountDownLatch(1); + counter = new Semaphore(0); + callCounter = new AtomicInteger(); + base.evaluate(); + } + }; + } + } + + static class KeyTokenInvalidFn extends DoFn, KV> { + + static boolean thrown = false; + + @ProcessElement + public void processElement(ProcessContext c) { + if (!thrown) { + thrown = true; + throw new KeyTokenInvalidException("key"); + } else { + c.output(c.element()); + } + } + } + + static class LargeCommitFn extends DoFn, KV> { + + @ProcessElement + public void processElement(ProcessContext c) { + if (c.element().getKey().equals("large_key")) { + StringBuilder s = new StringBuilder(); + for (int i = 0; i < 100; ++i) { + s.append("large_commit"); + } + c.output(KV.of(c.element().getKey(), s.toString())); + } else { + c.output(c.element()); + } + } + } + + static class ChangeKeysFn extends DoFn, KV> { + + @ProcessElement + public void processElement(ProcessContext c) { + KV elem = c.element(); + c.output(KV.of(elem.getKey() + "_" + elem.getValue(), elem.getValue())); + } + } + + static class TestExceptionFn extends DoFn { + + boolean firstTime = true; + + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + if (firstTime) { + firstTime = false; + try { + throw new Exception("Exception!"); + } catch (Exception e) { + throw new Exception("Another exception!", e); + } + } + } + } + + static class PassthroughDoFn + extends DoFn>, KV>> { + + @ProcessElement + public void processElement(ProcessContext c) { + c.output(c.element()); + } + } + + static class Action { + + GetWorkResponse response; + Timer[] expectedTimers = new Timer[] {}; + WatermarkHold[] expectedHolds = new WatermarkHold[] {}; + + public Action(GetWorkResponse response) { + this.response = response; + } + + Action withHolds(WatermarkHold... holds) { + this.expectedHolds = holds; + return this; + } + + Action withTimers(Timer... timers) { + this.expectedTimers = timers; + return this; + } + } + + static class PrintFn extends DoFn>, String> { + + @ProcessElement + public void processElement(ProcessContext c) { + KV elem = c.element().getValue(); + c.output(elem.getKey() + ":" + elem.getValue()); + } + } + + private static class MockWork { + Work create(long workToken) { + return Work.create( + Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(workToken).build(), + Instant::now, + Collections.emptyList(), + work -> {}); + } + } + + static class TestExceptionInvalidatesCacheFn + extends DoFn>, String> { + + static boolean thrown = false; + + @StateId("int") + private final StateSpec> counter = StateSpecs.value(VarIntCoder.of()); + + @ProcessElement + public void processElement(ProcessContext c, @StateId("int") ValueState state) + throws Exception { + KV elem = c.element().getValue(); + if (elem.getValue() == 0) { + LOG.error("**** COUNTER 0 ****"); + assertNull(state.read()); + state.write(42); + assertEquals((Integer) 42, state.read()); + } else if (elem.getValue() == 1) { + LOG.error("**** COUNTER 1 ****"); + assertEquals((Integer) 42, state.read()); + } else if (elem.getValue() == 2) { + if (!thrown) { + LOG.error("**** COUNTER 2 (will throw) ****"); + thrown = true; + throw new Exception("Exception!"); + } + LOG.error("**** COUNTER 2 (retry) ****"); + assertEquals((Integer) 42, state.read()); + } else { + throw new RuntimeException("only expecting values [0,2]"); + } + c.output(elem.getKey() + ":" + elem.getValue()); + } + } + + private static class FanoutFn extends DoFn { + + @ProcessElement + public void processElement(ProcessContext c) { + StringBuilder builder = new StringBuilder(1000000); + for (int i = 0; i < 1000000; i++) { + builder.append(' '); + } + String largeString = builder.toString(); + for (int i = 0; i < 3000; i++) { + c.output(largeString); + } + } + } + + private static class SlowDoFn extends DoFn { + + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + Thread.sleep(1000); + c.output(c.element()); + } + } + + static class FakeClock implements Supplier { + private final PriorityQueue jobs = new PriorityQueue<>(); + private Instant now = Instant.now(); + + public ScheduledExecutorService newFakeScheduledExecutor(String unused) { + return new FakeScheduledExecutor(); + } + + @Override + public synchronized Instant get() { + return now; + } + + public synchronized void clear() { + jobs.clear(); + } + + public synchronized void sleep(Duration duration) { + if (duration.isShorterThan(Duration.ZERO)) { + throw new UnsupportedOperationException("Cannot sleep backwards in time"); + } + Instant endOfSleep = now.plus(duration); + while (true) { + Job job = jobs.peek(); + if (job == null || job.when.isAfter(endOfSleep)) { + break; + } + jobs.remove(); + now = job.when; + job.work.run(); + } + now = endOfSleep; + } + + private synchronized void schedule(Duration fromNow, Runnable work) { + jobs.add(new Job(now.plus(fromNow), work)); + } + + private static class Job implements Comparable { + final Instant when; + final Runnable work; + + Job(Instant when, Runnable work) { + this.when = when; + this.work = work; + } + + @Override + public int compareTo(Job job) { + return when.compareTo(job.when); + } + } + + private class FakeScheduledExecutor implements ScheduledExecutorService { + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return true; + } + + @Override + public void execute(Runnable command) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public List> invokeAll(Collection> tasks) + throws InterruptedException { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public List> invokeAll( + Collection> tasks, long timeout, TimeUnit unit) + throws InterruptedException { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public T invokeAny(Collection> tasks) + throws ExecutionException, InterruptedException { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public T invokeAny(Collection> tasks, long timeout, TimeUnit unit) + throws ExecutionException, InterruptedException, TimeoutException { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public boolean isShutdown() { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public boolean isTerminated() { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public void shutdown() {} + + @Override + public List shutdownNow() { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public Future submit(Callable task) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public Future submit(Runnable task) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public Future submit(Runnable task, T result) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public ScheduledFuture schedule(Callable callable, long delay, TimeUnit unit) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public ScheduledFuture schedule(Runnable command, long delay, TimeUnit unit) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public ScheduledFuture scheduleAtFixedRate( + Runnable command, long initialDelay, long period, TimeUnit unit) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public ScheduledFuture scheduleWithFixedDelay( + Runnable command, long initialDelay, long delay, TimeUnit unit) { + if (delay <= 0) { + throw new UnsupportedOperationException( + "Please supply a delay > 0 to scheduleWithFixedDelay"); + } + FakeClock.this.schedule( + Duration.millis(unit.toMillis(initialDelay)), + new Runnable() { + @Override + public void run() { + command.run(); + FakeClock.this.schedule(Duration.millis(unit.toMillis(delay)), this); + } + }); + FakeClock.this.sleep(Duration.ZERO); // Execute work that has an intial delay of zero. + return null; + } + } + } + + private static class FakeSlowDoFn extends DoFn { + private static FakeClock clock; // A static variable keeps this DoFn serializable. + private final Duration sleep; + + FakeSlowDoFn(FakeClock clock, Duration sleep) { + FakeSlowDoFn.clock = clock; + this.sleep = sleep; + } + + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + clock.sleep(sleep); + c.output(c.element()); + } + } + + // Aggregates LatencyAttribution data from active work refresh requests. + static class ActiveWorkRefreshSink { + private final Function responder; + private final Map> totalDurations = + new HashMap<>(); + + ActiveWorkRefreshSink(Function responder) { + this.responder = responder; + } + + Duration getLatencyAttributionDuration(long workToken, LatencyAttribution.State state) { + EnumMap durations = totalDurations.get(workToken); + return durations == null ? Duration.ZERO : durations.getOrDefault(state, Duration.ZERO); + } + + boolean isActiveWorkRefresh(GetDataRequest request) { + for (ComputationGetDataRequest computationRequest : request.getRequestsList()) { + if (!computationRequest.getComputationId().equals(DEFAULT_COMPUTATION_ID)) { + return false; + } + for (KeyedGetDataRequest keyedRequest : computationRequest.getRequestsList()) { + if (keyedRequest.getWorkToken() == 0 + || keyedRequest.getShardingKey() != DEFAULT_SHARDING_KEY + || keyedRequest.getValuesToFetchCount() != 0 + || keyedRequest.getBagsToFetchCount() != 0 + || keyedRequest.getTagValuePrefixesToFetchCount() != 0 + || keyedRequest.getWatermarkHoldsToFetchCount() != 0) { + return false; + } + } + } + return true; + } + + GetDataResponse getData(GetDataRequest request) { + if (!isActiveWorkRefresh(request)) { + return responder.apply(request); + } + for (ComputationGetDataRequest computationRequest : request.getRequestsList()) { + for (KeyedGetDataRequest keyedRequest : computationRequest.getRequestsList()) { + for (LatencyAttribution la : keyedRequest.getLatencyAttributionList()) { + EnumMap durations = + totalDurations.computeIfAbsent( + keyedRequest.getWorkToken(), + (Long workToken) -> + new EnumMap( + LatencyAttribution.State.class)); + Duration cur = Duration.millis(la.getTotalDurationMillis()); + durations.compute(la.getState(), (s, d) -> d == null || d.isShorterThan(cur) ? cur : d); + } + } + } + return EMPTY_DATA_RESPONDER.apply(request); + } + } + + // A DoFn that triggers a GetData request. + static class ReadingDoFn extends DoFn { + @StateId("int") + private final StateSpec> counter = StateSpecs.value(VarIntCoder.of()); + + @ProcessElement + public void processElement(ProcessContext c, @StateId("int") ValueState state) { + state.read(); + c.output(c.element()); + } + } + + /** For each input element, emits a large string. */ + private static class InflateDoFn extends DoFn>, String> { + + final int inflatedSize; + + /** For each input elements, outputs a string of this length */ + InflateDoFn(int inflatedSize) { + this.inflatedSize = inflatedSize; + } + + @ProcessElement + public void processElement(ProcessContext c) { + char[] chars = new char[inflatedSize]; + Arrays.fill(chars, ' '); + c.output(new String(chars)); + } + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java new file mode 100644 index 000000000000..1f3dee4b76ba --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java @@ -0,0 +1,296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import com.google.auto.value.AutoValue; +import java.util.Collections; +import java.util.Deque; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.WindmillStateCache; +import org.apache.beam.runners.dataflow.worker.streaming.ActiveWorkState.ActivateWorkResult; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; +import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ActiveWorkStateTest { + + private final WindmillStateCache.ForComputation computationStateCache = + mock(WindmillStateCache.ForComputation.class); + private Map> readOnlyActiveWork; + + private ActiveWorkState activeWorkState; + + private static ShardedKey shardedKey(String str, long shardKey) { + return ShardedKey.create(ByteString.copyFromUtf8(str), shardKey); + } + + private static Work emptyWork() { + return createWork(null); + } + + private static Work createWork(@Nullable Windmill.WorkItem workItem) { + return Work.create(workItem, Instant::now, Collections.emptyList(), unused -> {}); + } + + private static Work expiredWork(Windmill.WorkItem workItem) { + return Work.create(workItem, () -> Instant.EPOCH, Collections.emptyList(), unused -> {}); + } + + private static Windmill.WorkItem createWorkItem(long workToken) { + return Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("")) + .setShardingKey(1) + .setWorkToken(workToken) + .build(); + } + + @Before + public void setup() { + Map> readWriteActiveWorkMap = new HashMap<>(); + // Only use readOnlyActiveWork to verify internal behavior in reaction to exposed API calls. + readOnlyActiveWork = Collections.unmodifiableMap(readWriteActiveWorkMap); + activeWorkState = ActiveWorkState.forTesting(readWriteActiveWorkMap, computationStateCache); + } + + @Test + public void testActivateWorkForKey_EXECUTE_unknownKey() { + ActivateWorkResult activateWorkResult = + activeWorkState.activateWorkForKey(shardedKey("someKey", 1L), emptyWork()); + + assertEquals(ActivateWorkResult.EXECUTE, activateWorkResult); + } + + @Test + public void testActivateWorkForKey_EXECUTE_emptyWorkQueueForKey() { + ShardedKey shardedKey = shardedKey("someKey", 1L); + long workToken = 1L; + + ActivateWorkResult activateWorkResult = + activeWorkState.activateWorkForKey(shardedKey, createWork(createWorkItem(workToken))); + + Optional nextWorkForKey = + activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, workToken); + + assertEquals(ActivateWorkResult.EXECUTE, activateWorkResult); + assertEquals(Optional.empty(), nextWorkForKey); + assertThat(readOnlyActiveWork).doesNotContainKey(shardedKey); + } + + @Test + public void testActivateWorkForKey_DUPLICATE() { + long workToken = 10L; + ShardedKey shardedKey = shardedKey("someKey", 1L); + + // ActivateWork with the same shardedKey, and the same workTokens. + activeWorkState.activateWorkForKey(shardedKey, createWork(createWorkItem(workToken))); + ActivateWorkResult activateWorkResult = + activeWorkState.activateWorkForKey(shardedKey, createWork(createWorkItem(workToken))); + + assertEquals(ActivateWorkResult.DUPLICATE, activateWorkResult); + } + + @Test + public void testActivateWorkForKey_QUEUED() { + ShardedKey shardedKey = shardedKey("someKey", 1L); + + // ActivateWork with the same shardedKey, but different workTokens. + activeWorkState.activateWorkForKey(shardedKey, createWork(createWorkItem(1L))); + ActivateWorkResult activateWorkResult = + activeWorkState.activateWorkForKey(shardedKey, createWork(createWorkItem(2L))); + + assertEquals(ActivateWorkResult.QUEUED, activateWorkResult); + } + + @Test + public void testCompleteWorkAndGetNextWorkForKey_noWorkQueueForKey() { + assertEquals( + Optional.empty(), + activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey("someKey", 1L), 10L)); + } + + @Test + public void testCompleteWorkAndGetNextWorkForKey_currentWorkInQueueDoesNotMatchWorkToComplete() { + long workTokenToComplete = 1L; + + Work workInQueue = createWork(createWorkItem(2L)); + ShardedKey shardedKey = shardedKey("someKey", 1L); + + activeWorkState.activateWorkForKey(shardedKey, workInQueue); + activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, workTokenToComplete); + + assertEquals(1, readOnlyActiveWork.get(shardedKey).size()); + assertEquals(workInQueue, readOnlyActiveWork.get(shardedKey).peek()); + } + + @Test + public void testCompleteWorkAndGetNextWorkForKey_removesWorkFromQueueWhenComplete() { + long workTokenToComplete = 1L; + + Work activeWork = createWork(createWorkItem(workTokenToComplete)); + Work nextWork = createWork(createWorkItem(2L)); + ShardedKey shardedKey = shardedKey("someKey", 1L); + + activeWorkState.activateWorkForKey(shardedKey, activeWork); + activeWorkState.activateWorkForKey(shardedKey, nextWork); + activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, workTokenToComplete); + + assertEquals(nextWork, readOnlyActiveWork.get(shardedKey).peek()); + assertEquals(1, readOnlyActiveWork.get(shardedKey).size()); + assertFalse(readOnlyActiveWork.get(shardedKey).contains(activeWork)); + } + + @Test + public void testCompleteWorkAndGetNextWorkForKey_removesQueueIfNoWorkPresent() { + Work workInQueue = createWork(createWorkItem(1L)); + ShardedKey shardedKey = shardedKey("someKey", 1L); + + activeWorkState.activateWorkForKey(shardedKey, workInQueue); + activeWorkState.completeWorkAndGetNextWorkForKey( + shardedKey, workInQueue.getWorkItem().getWorkToken()); + + assertFalse(readOnlyActiveWork.containsKey(shardedKey)); + } + + @Test + public void testCompleteWorkAndGetNextWorkForKey_returnsWorkIfPresent() { + Work workToBeCompleted = createWork(createWorkItem(1L)); + Work nextWork = createWork(createWorkItem(2L)); + ShardedKey shardedKey = shardedKey("someKey", 1L); + + activeWorkState.activateWorkForKey(shardedKey, workToBeCompleted); + activeWorkState.activateWorkForKey(shardedKey, nextWork); + activeWorkState.completeWorkAndGetNextWorkForKey( + shardedKey, workToBeCompleted.getWorkItem().getWorkToken()); + + Optional nextWorkOpt = + activeWorkState.completeWorkAndGetNextWorkForKey( + shardedKey, workToBeCompleted.getWorkItem().getWorkToken()); + + assertTrue(nextWorkOpt.isPresent()); + assertSame(nextWork, nextWorkOpt.get()); + + Optional endOfWorkQueue = + activeWorkState.completeWorkAndGetNextWorkForKey( + shardedKey, nextWork.getWorkItem().getWorkToken()); + + assertFalse(endOfWorkQueue.isPresent()); + assertFalse(readOnlyActiveWork.containsKey(shardedKey)); + } + + @Test + public void testInvalidateStuckCommits() { + Map invalidatedCommits = new HashMap<>(); + + Work stuckWork1 = expiredWork(createWorkItem(1L)); + stuckWork1.setState(Work.State.COMMITTING); + Work stuckWork2 = expiredWork(createWorkItem(2L)); + stuckWork2.setState(Work.State.COMMITTING); + ShardedKey shardedKey1 = shardedKey("someKey", 1L); + ShardedKey shardedKey2 = shardedKey("anotherKey", 2L); + + activeWorkState.activateWorkForKey(shardedKey1, stuckWork1); + activeWorkState.activateWorkForKey(shardedKey2, stuckWork2); + + activeWorkState.invalidateStuckCommits(Instant.now(), invalidatedCommits::put); + + assertThat(invalidatedCommits) + .containsEntry(shardedKey1, stuckWork1.getWorkItem().getWorkToken()); + assertThat(invalidatedCommits) + .containsEntry(shardedKey2, stuckWork2.getWorkItem().getWorkToken()); + verify(computationStateCache).invalidate(shardedKey1.key(), shardedKey1.shardingKey()); + verify(computationStateCache).invalidate(shardedKey2.key(), shardedKey2.shardingKey()); + } + + @Test + public void testGetKeysToRefresh() { + Instant refreshDeadline = Instant.now(); + + Work freshWork = createWork(createWorkItem(3L)); + Work refreshableWork1 = expiredWork(createWorkItem(1L)); + refreshableWork1.setState(Work.State.COMMITTING); + Work refreshableWork2 = expiredWork(createWorkItem(2L)); + refreshableWork2.setState(Work.State.COMMITTING); + ShardedKey shardedKey1 = shardedKey("someKey", 1L); + ShardedKey shardedKey2 = shardedKey("anotherKey", 2L); + + activeWorkState.activateWorkForKey(shardedKey1, refreshableWork1); + activeWorkState.activateWorkForKey(shardedKey1, freshWork); + activeWorkState.activateWorkForKey(shardedKey2, refreshableWork2); + + ImmutableList requests = activeWorkState.getKeysToRefresh(refreshDeadline); + + ImmutableList expected = + ImmutableList.of( + GetDataRequestKeyShardingKeyAndWorkToken.from(shardedKey1, refreshableWork1), + GetDataRequestKeyShardingKeyAndWorkToken.from(shardedKey2, refreshableWork2)); + + ImmutableList actual = + requests.stream() + .map(GetDataRequestKeyShardingKeyAndWorkToken::from) + .collect(toImmutableList()); + + assertThat(actual).containsExactlyElementsIn(expected); + } + + @AutoValue + abstract static class GetDataRequestKeyShardingKeyAndWorkToken { + + private static GetDataRequestKeyShardingKeyAndWorkToken create( + ByteString key, long shardingKey, long workToken) { + return new AutoValue_ActiveWorkStateTest_GetDataRequestKeyShardingKeyAndWorkToken( + key, shardingKey, workToken); + } + + private static GetDataRequestKeyShardingKeyAndWorkToken from( + KeyedGetDataRequest keyedGetDataRequest) { + return create( + keyedGetDataRequest.getKey(), + keyedGetDataRequest.getShardingKey(), + keyedGetDataRequest.getWorkToken()); + } + + private static GetDataRequestKeyShardingKeyAndWorkToken from(ShardedKey shardedKey, Work work) { + return create(shardedKey.key(), shardedKey.shardingKey(), work.getWorkItem().getWorkToken()); + } + + abstract ByteString key(); + + abstract long shardingKey(); + + abstract long workToken(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java new file mode 100644 index 000000000000..b2d98fb0e954 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class WeightBoundedQueueTest { + private static final int MAX_WEIGHT = 10; + + @Test + public void testPut_hasCapacity() { + WeightedBoundedQueue queue = + WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + + int insertedValue = 1; + + queue.put(insertedValue); + + assertEquals(insertedValue, queue.queuedElementsWeight()); + assertEquals(1, queue.size()); + assertEquals(insertedValue, (int) queue.poll()); + } + + @Test + public void testPut_noCapacity() throws InterruptedException { + WeightedBoundedQueue queue = + WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + + // Insert value that takes all the capacity into the queue. + queue.put(MAX_WEIGHT); + + // Try to insert another value into the queue. This will block since there is no capacity in the + // queue. + Thread putThread = + new Thread( + () -> { + try { + Thread.sleep(100); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + queue.put(MAX_WEIGHT); + }); + putThread.start(); + + // Should only see the first value in the queue, since the queue is at capacity. thread2 + // should be blocked. + assertEquals(MAX_WEIGHT, queue.queuedElementsWeight()); + assertEquals(1, queue.size()); + + // Poll the queue, pulling off the only value inside and freeing up the capacity in the queue. + queue.poll(); + + // Wait for the putThread which was previously blocked due to the queue being at capacity. + putThread.join(); + + assertEquals(MAX_WEIGHT, queue.queuedElementsWeight()); + assertEquals(1, queue.size()); + } + + @Test + public void testPoll() { + WeightedBoundedQueue queue = + WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + + int insertedValue1 = 1; + int insertedValue2 = 2; + + queue.put(insertedValue1); + queue.put(insertedValue2); + + assertEquals(insertedValue1 + insertedValue2, queue.queuedElementsWeight()); + assertEquals(2, queue.size()); + assertEquals(insertedValue1, (int) queue.poll()); + assertEquals(1, queue.size()); + } + + @Test + public void testPoll_withTimeout() throws InterruptedException { + WeightedBoundedQueue queue = + WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + int pollWaitTimeMillis = 10000; + int insertedValue1 = 1; + + AtomicInteger pollResult = new AtomicInteger(); + Thread pollThread = + new Thread( + () -> { + int polled; + try { + polled = queue.poll(pollWaitTimeMillis, TimeUnit.MILLISECONDS); + pollResult.set(polled); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + + pollThread.start(); + Thread.sleep(pollWaitTimeMillis / 100); + queue.put(insertedValue1); + pollThread.join(); + + assertEquals(insertedValue1, pollResult.get()); + } + + @Test + public void testPoll_withTimeout_timesOut() throws InterruptedException { + WeightedBoundedQueue queue = + WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + int defaultPollResult = -10; + int pollWaitTimeMillis = 100; + int insertedValue1 = 1; + + // AtomicInteger default isn't null, so set it to a negative value and verify that it doesn't + // change. + AtomicInteger pollResult = new AtomicInteger(defaultPollResult); + + Thread pollThread = + new Thread( + () -> { + int polled; + try { + polled = queue.poll(pollWaitTimeMillis, TimeUnit.MILLISECONDS); + pollResult.set(polled); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + + pollThread.start(); + Thread.sleep(pollWaitTimeMillis * 100); + queue.put(insertedValue1); + pollThread.join(); + + assertEquals(defaultPollResult, pollResult.get()); + } + + @Test + public void testPoll_emptyQueue() { + WeightedBoundedQueue queue = + WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + + assertNull(queue.poll()); + } + + @Test + public void testTake() throws InterruptedException { + WeightedBoundedQueue queue = + WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + + AtomicInteger value = new AtomicInteger(); + // Should block until value is available + Thread takeThread = + new Thread( + () -> { + try { + value.set(queue.take()); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + takeThread.start(); + + Thread.sleep(100); + queue.put(MAX_WEIGHT); + + takeThread.join(); + + assertEquals(MAX_WEIGHT, value.get()); + } +}