diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java index db918aa680c6..bd52e43ec71c 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.aws2.sqs; import static java.util.Collections.EMPTY_LIST; +import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.apache.beam.sdk.io.aws2.common.ClientBuilderFactory.buildClient; import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; @@ -27,10 +28,19 @@ import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; +import java.util.ConcurrentModificationException; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Consumer; @@ -61,6 +71,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import org.checkerframework.dataflow.qual.Pure; @@ -152,6 +163,7 @@ public static WriteBatches writeBatches() { .concurrentRequests(WriteBatches.DEFAULT_CONCURRENCY) .batchSize(WriteBatches.MAX_BATCH_SIZE) .batchTimeout(WriteBatches.DEFAULT_BATCH_TIMEOUT) + .strictTimeouts(false) .build(); } @@ -289,6 +301,8 @@ public abstract static class WriteBatches abstract @Pure Duration batchTimeout(); + abstract @Pure boolean strictTimeouts(); + abstract @Pure int batchSize(); abstract @Pure ClientConfiguration clientConfiguration(); @@ -311,6 +325,8 @@ abstract static class Builder { abstract Builder batchTimeout(Duration duration); + abstract Builder strictTimeouts(boolean strict); + abstract Builder batchSize(int batchSize); abstract Builder clientConfiguration(ClientConfiguration config); @@ -363,10 +379,20 @@ public WriteBatches withBatchSize(int batchSize) { /** * The duration to accumulate records before timing out, default is 3 secs. * - *

Timeouts will be checked upon arrival of new messages. + *

By default timeouts will be checked upon arrival of records. */ public WriteBatches withBatchTimeout(Duration timeout) { - return builder().batchTimeout(timeout).build(); + return withBatchTimeout(timeout, false); + } + + /** + * The duration to accumulate records before timing out, default is 3 secs. + * + *

By default timeouts will be checked upon arrival of records. If using {@code strict} + * enforcement, timeouts will be check by a separate thread. + */ + public WriteBatches withBatchTimeout(Duration timeout, boolean strict) { + return builder().batchTimeout(timeout).strictTimeouts(strict).build(); } /** Dynamic record based destination to write to. */ @@ -546,12 +572,17 @@ public void finishSpecifyingOutput( } private static class BatchHandler implements AutoCloseable { + private static final int CHECKS_PER_TIMEOUT_PERIOD = 5; + private final WriteBatches spec; private final SqsAsyncClient sqs; private final Batches batches; private final EntryMapperFn entryMapper; private final AsyncBatchWriteHandler handler; + private final @Nullable ScheduledExecutorService scheduler; + + private @MonotonicNonNull ScheduledFuture expirationCheck = null; BatchHandler(WriteBatches spec, EntryMapperFn entryMapper, AwsOptions options) { this.spec = spec; @@ -567,8 +598,10 @@ private static class BatchHandler implements AutoCloseable { error -> error.code(), record -> record.id(), error -> error.id()); + this.scheduler = + spec.strictTimeouts() ? Executors.newSingleThreadScheduledExecutor() : null; if (spec.queueUrl() != null) { - this.batches = new Single(spec.queueUrl()); + this.batches = new Single(); } else if (spec.dynamicDestination() != null) { this.batches = new Dynamic(spec.dynamicDestination()); } else { @@ -585,6 +618,13 @@ private static CompletableFuture> sendMessageBatch( public void startBundle() { handler.reset(); + if (scheduler != null && spec.strictTimeouts()) { + long timeout = spec.batchTimeout().getMillis(); + long period = timeout / CHECKS_PER_TIMEOUT_PERIOD; + expirationCheck = + scheduler.scheduleWithFixedDelay( + () -> batches.submitExpired(false), timeout, period, MILLISECONDS); + } } public void process(T msg) { @@ -592,18 +632,19 @@ public void process(T msg) { Batch batch = batches.getLocked(msg); batch.add(entry); if (batch.size() >= spec.batchSize() || batch.isExpired()) { - writeEntries(batch, true); + submitEntries(batch, true); } else { checkState(batch.lock(false)); // unlock to continue writing to batch } - // check timeouts synchronously on arrival of new messages - batches.writeExpired(true); + // check for expired batches synchronously + batches.submitExpired(true); } - private void writeEntries(Batch batch, boolean throwPendingFailures) { + /** Submit entries of a {@link Batch} to the async write handler. */ + private void submitEntries(Batch batch, boolean throwFailures) { try { - handler.batchWrite(batch.queue, batch.getAndClear(), throwPendingFailures); + handler.batchWrite(batch.queue, batch.getAndClose(), throwFailures); } catch (RuntimeException e) { throw e; } catch (Throwable e) { @@ -612,32 +653,54 @@ private void writeEntries(Batch batch, boolean throwPendingFailures) { } public void finishBundle() throws Throwable { - batches.writeAll(); + if (expirationCheck != null) { + expirationCheck.cancel(false); + while (true) { + try { + expirationCheck.get(3, TimeUnit.SECONDS); + } catch (TimeoutException e) { + LOG.warn("Waiting for timeout check to complete"); + } catch (CancellationException e) { + break; // scheduled checks completed after cancellation + } + } + } + // safe to write remaining batches without risking to encounter locked ones + checkState(batches.submitAll()); handler.waitForCompletion(); } @Override public void close() throws Exception { sqs.close(); + if (scheduler != null) { + scheduler.shutdown(); + } } /** * Batch(es) of a single fixed or several dynamic queues. * - *

{@link #getLocked} is meant to support atomic writes from multiple threads if using an - * appropriate thread-safe implementation. This is necessary to later support strict timeouts - * (see below). + *

A {@link Batch} can only ever be modified from the single runner thread. * - *

For simplicity, check for expired messages after appending to a batch. For strict - * enforcement of timeouts, {@link #writeExpired} would have to be periodically called using a - * scheduler and requires also a thread-safe impl of {@link Batch#lock(boolean)}. + *

In case of strict timeouts, a batch may be submitted to the write handler by periodic + * expiration checks using a scheduler. Otherwise, and by default, this is done after + * appending to a batch. {@link Batch#lock(boolean)} prevents concurrent access to a batch + * between threads. Once a batch was locked by an expiration check, it must always be + * submitted to the write handler. */ + @NotThreadSafe private abstract class Batches { private int nextId = 0; // only ever used from one "runner" thread abstract int maxBatches(); - /** Next batch entry id is guaranteed to be unique for all open batches. */ + /** + * Next batch entry id is guaranteed to be unique for all open batches. + * + *

This method is not thread-safe and may only ever be called from the single runner + * thread. + */ String nextId() { if (nextId >= (spec.batchSize() * maxBatches())) { nextId = 0; @@ -645,24 +708,40 @@ String nextId() { return Integer.toString(nextId++); } - /** Get existing or new locked batch that can be written to. */ + /** + * Get an existing or new locked batch to append new messages. + * + *

This method is not thread-safe and may only ever be called from a single runner + * thread. If this encounters a locked batch, it assumes the {@link Batch} is currently + * written to SQS and creates a new one. + */ abstract Batch getLocked(T record); - /** Write all remaining batches (that can be locked). */ - abstract void writeAll(); - - /** Write all expired batches (that can be locked). */ - abstract void writeExpired(boolean throwPendingFailures); - - /** Create a new locked batch that is ready for writing. */ - Batch createLocked(String queue) { - return new Batch(queue, spec.batchSize(), spec.batchTimeout()); - } - - /** Write a batch if it can be locked. */ - protected boolean writeLocked(Batch batch, boolean throwPendingFailures) { - if (batch.lock(true)) { - writeEntries(batch, throwPendingFailures); + /** + * Submit all remaining batches (that can be locked) to the write handler. + * + * @return {@code true} if successful for all batches. + */ + abstract boolean submitAll(); + + /** + * Submit all expired batches (that can be locked) to the write handler. + * + *

This is the only method that may be invoked from a thread other than the runner + * thread. + */ + abstract void submitExpired(boolean throwFailures); + + /** + * Submit a batch to the write handler if it can be locked. + * + * @return {@code true} if successful (or closed). + */ + protected boolean lockAndSubmit(Batch batch, boolean throwFailures) { + if (batch.isClosed()) { + return true; // nothing to submit + } else if (batch.lock(true)) { + submitEntries(batch, throwFailures); return true; } return false; @@ -672,11 +751,7 @@ protected boolean writeLocked(Batch batch, boolean throwPendingFailures) { /** Batch of a single, fixed queue. */ @NotThreadSafe private class Single extends Batches { - private Batch batch; - - Single(String queue) { - this.batch = new Batch(queue, EMPTY_LIST, Batch.NEVER); // locked - } + private @Nullable Batch batch; @Override int maxBatches() { @@ -685,18 +760,21 @@ int maxBatches() { @Override Batch getLocked(T record) { - return batch.lock(true) ? batch : (batch = createLocked(batch.queue)); + if (batch == null || !batch.lock(true)) { + batch = Batch.createLocked(checkStateNotNull(spec.queueUrl()), spec); + } + return batch; } @Override - void writeAll() { - writeLocked(batch, true); + boolean submitAll() { + return batch == null || lockAndSubmit(batch, true); } @Override - void writeExpired(boolean throwPendingFailures) { - if (batch.isExpired()) { - writeLocked(batch, throwPendingFailures); + void submitExpired(boolean throwFailures) { + if (batch != null && batch.isExpired()) { + lockAndSubmit(batch, throwFailures); } } } @@ -709,8 +787,9 @@ private class Dynamic extends Batches { (queue, batch) -> batch != null && batch.lock(true) ? batch : createLocked(queue); private final Map<@NonNull String, Batch> batches = new HashMap<>(); + private final AtomicBoolean submitExpiredRunning = new AtomicBoolean(false); + private final AtomicReference nextTimeout = new AtomicReference<>(Batch.NEVER); private final DynamicDestination destination; - private Instant nextTimeout = Batch.NEVER; Dynamic(DynamicDestination destination) { this.destination = destination; @@ -727,77 +806,118 @@ Batch getLocked(T record) { } @Override - void writeAll() { - batches.values().forEach(batch -> writeLocked(batch, true)); + boolean submitAll() { + AtomicBoolean res = new AtomicBoolean(true); + batches.values().forEach(batch -> res.compareAndSet(true, lockAndSubmit(batch, true))); batches.clear(); - nextTimeout = Batch.NEVER; + nextTimeout.set(Batch.NEVER); + return res.get(); } - private void writeExpired(Batch batch) { - if (!batch.isExpired() || !writeLocked(batch, true)) { - // find next timeout for remaining, unwritten batches - if (batch.timeout.isBefore(nextTimeout)) { - nextTimeout = batch.timeout; - } + private void updateNextTimeout(Batch batch) { + Instant prev; + do { + prev = nextTimeout.get(); + } while (batch.expirationTime.isBefore(prev) + && !nextTimeout.compareAndSet(prev, batch.expirationTime)); + } + + private void submitExpired(Batch batch, boolean throwFailures) { + if (!batch.isClosed() && (!batch.isExpired() || !lockAndSubmit(batch, throwFailures))) { + updateNextTimeout(batch); } } @Override - void writeExpired(boolean throwPendingFailures) { - if (nextTimeout.isBeforeNow()) { - nextTimeout = Batch.NEVER; - batches.values().forEach(this::writeExpired); + void submitExpired(boolean throwFailures) { + Instant timeout = nextTimeout.get(); + if (timeout.isBeforeNow()) { + // prevent concurrent checks for expired batches + if (submitExpiredRunning.compareAndSet(false, true)) { + try { + nextTimeout.set(Batch.NEVER); + batches.values().forEach(b -> submitExpired(b, throwFailures)); + } catch (ConcurrentModificationException e) { + // Can happen rarely when adding a new dynamic destination and is expected. + // Reset old timeout to repeat check asap. + nextTimeout.set(timeout); + } finally { + submitExpiredRunning.set(false); + } + } } } - @Override Batch createLocked(String queue) { - Batch batch = super.createLocked(queue); - if (batch.timeout.isBefore(nextTimeout)) { - nextTimeout = batch.timeout; - } + Batch batch = Batch.createLocked(queue, spec); + updateNextTimeout(batch); return batch; } } } - /** - * Batch of entries of a queue. - * - *

Overwrite {@link #lock} with a thread-safe implementation to support concurrent usage. - */ + /** Batch of entries of a queue. */ @NotThreadSafe - private static final class Batch { + private abstract static class Batch { private static final Instant NEVER = Instant.ofEpochMilli(Long.MAX_VALUE); + private final String queue; - private Instant timeout; + private final Instant expirationTime; private List entries; - Batch(String queue, int size, Duration bufferedTime) { - this(queue, new ArrayList<>(size), Instant.now().plus(bufferedTime)); + static Batch createLocked(String queue, SqsIO.WriteBatches spec) { + return spec.strictTimeouts() + ? new BatchWithAtomicLock(queue, spec.batchSize(), spec.batchTimeout()) + : new BatchWithNoopLock(queue, spec.batchSize(), spec.batchTimeout()); } - Batch(String queue, List entries, Instant timeout) { - this.queue = queue; - this.entries = entries; - this.timeout = timeout; + /** A {@link Batch} with a noop lock that just rejects un/locking if closed. */ + private static class BatchWithNoopLock extends Batch { + BatchWithNoopLock(String queue, int size, Duration timeout) { + super(queue, size, timeout); + } + + @Override + boolean lock(boolean lock) { + return !isClosed(); // always un/lock unless closed + } } - /** Attempt to un/lock this batch and return if successful. */ - boolean lock(boolean lock) { - // thread unsafe dummy impl that rejects locking batches after getAndClear - return !NEVER.equals(timeout) || !lock; + /** A {@link Batch} supporting atomic locking for concurrent usage. */ + private static class BatchWithAtomicLock extends Batch { + private final AtomicBoolean locked = new AtomicBoolean(true); // always lock on creation + + BatchWithAtomicLock(String queue, int size, Duration timeout) { + super(queue, size, timeout); + } + + @Override + boolean lock(boolean lock) { + return !isClosed() && locked.compareAndSet(!lock, lock); + } + } + + private Batch(String queue, int size, Duration timeout) { + this.queue = queue; + this.entries = new ArrayList<>(size); + this.expirationTime = Instant.now().plus(timeout); } - /** Get and clear entries for writing. */ - List getAndClear() { + /** Attempt to un/lock this batch, if closed this always fails. */ + abstract boolean lock(boolean lock); + + /** + * Get and clear entries for submission to the write handler. + * + *

The batch must be locked and kept locked, it can't be modified anymore. + */ + List getAndClose() { List res = entries; entries = EMPTY_LIST; - timeout = NEVER; return res; } - /** Add entry to this batch. */ + /** Append entry (only use if locked!). */ void add(SendMessageBatchRequestEntry entry) { entries.add(entry); } @@ -807,7 +927,11 @@ int size() { } boolean isExpired() { - return timeout.isBeforeNow(); + return expirationTime.isBeforeNow(); + } + + boolean isClosed() { + return entries == EMPTY_LIST; } } } diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java index 0dc0719cc470..cccd42c47eac 100644 --- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java +++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java @@ -17,23 +17,28 @@ */ package org.apache.beam.sdk.io.aws2.sqs; +import static java.lang.Math.sqrt; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.concurrent.CompletableFuture.completedFuture; import static java.util.concurrent.CompletableFuture.supplyAsync; import static java.util.stream.Collectors.toList; import static java.util.stream.IntStream.range; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.apache.commons.lang3.RandomUtils.nextInt; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.joda.time.Duration.millis; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; import java.util.Arrays; +import java.util.HashSet; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -263,6 +268,29 @@ public void testWriteBatchesWithTimeout() { verify(sqs).sendMessageBatch(request("queue", entries[3], entries[4])); } + @Test + public void testWriteBatchesWithStrictTimeout() { + when(sqs.sendMessageBatch(any(SendMessageBatchRequest.class))) + .thenReturn(completedFuture(SendMessageBatchResponse.builder().build())); + + p.apply(Create.of(5)) + .apply(ParDo.of(new CreateMessages())) + .apply( + // simulate delay between messages > batch timeout + SqsIO.writeBatches() + .withEntryMapper(withDelay(millis(100), SET_MESSAGE_BODY)) + .withBatchTimeout(millis(150), true) + .to("queue")); + + p.run().waitUntilFinish(); + + SendMessageBatchRequestEntry[] entries = entries(range(0, 5)); + // using strict timeouts batches, batches are timed out by a separate thread + verify(sqs).sendMessageBatch(request("queue", entries[0], entries[1])); + verify(sqs).sendMessageBatch(request("queue", entries[2], entries[3])); + verify(sqs).sendMessageBatch(request("queue", entries[4])); + } + @Test public void testWriteBatchesToDynamic() { when(sqs.sendMessageBatch(anyRequest())).thenReturn(completedFuture(SUCCESS)); @@ -315,6 +343,36 @@ public void testWriteBatchesToDynamicWithTimeout() { verify(sqs).sendMessageBatch(request("even", entries[4])); } + @Test + public void testWriteBatchesToDynamicWithStrictTimeout() { + when(sqs.sendMessageBatch(any(SendMessageBatchRequest.class))) + .thenReturn(completedFuture(SendMessageBatchResponse.builder().build())); + + p.apply(Create.of(10000)) + .apply(ParDo.of(new CreateMessages())) + .apply( + // simulate delay between messages > batch timeout + SqsIO.writeBatches() + .withEntryMapper(withDelay(millis(1), SET_MESSAGE_BODY)) + .withBatchTimeout(millis(10), true) + // Use sqrt to change the rate of newly created dynamic destinations over time + .to(msg -> String.valueOf(nextInt(0, (int) (1 + sqrt(Integer.valueOf(msg))))))); + + p.run().waitUntilFinish(); + + ArgumentCaptor reqCaptor = + ArgumentCaptor.forClass(SendMessageBatchRequest.class); + verify(sqs, atLeastOnce()).sendMessageBatch(reqCaptor.capture()); + + Set capturedMessages = new HashSet<>(); + for (SendMessageBatchRequest req : reqCaptor.getAllValues()) { + for (SendMessageBatchRequestEntry entry : req.entries()) { + assertTrue("duplicate message", capturedMessages.add(entry.messageBody())); + } + } + assertEquals("Invalid message count", 10000, capturedMessages.size()); + } + private SendMessageBatchRequest anyRequest() { return any(); }