diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/StickyQueueBalancer.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/StickyQueueBalancer.java index 087a7b614..ec5b8129f 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/StickyQueueBalancer.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/StickyQueueBalancer.java @@ -21,19 +21,16 @@ package io.temporal.internal.worker; import io.temporal.api.enums.v1.TaskQueueKind; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import javax.annotation.concurrent.ThreadSafe; @ThreadSafe public class StickyQueueBalancer { private final int pollersCount; private final boolean stickyQueueEnabled; - private final AtomicInteger stickyPollers = new AtomicInteger(0); - private final AtomicInteger normalPollers = new AtomicInteger(0); - private final AtomicBoolean disableNormalPoll = new AtomicBoolean(false); - - private volatile long stickyBacklogSize = 0; + private int stickyPollers = 0; + private int normalPollers = 0; + private boolean disableNormalPoll = false; + private long stickyBacklogSize = 0; public StickyQueueBalancer(int pollersCount, boolean stickyQueueEnabled) { this.pollersCount = pollersCount; @@ -43,35 +40,35 @@ public StickyQueueBalancer(int pollersCount, boolean stickyQueueEnabled) { /** * @return task queue kind that should be used for the next poll */ - public TaskQueueKind makePoll() { + public synchronized TaskQueueKind makePoll() { if (stickyQueueEnabled) { - if (disableNormalPoll.get()) { - stickyPollers.incrementAndGet(); + if (disableNormalPoll) { + stickyPollers++; return TaskQueueKind.TASK_QUEUE_KIND_STICKY; } // If pollersCount >= stickyBacklogSize > 0 we want to go back to a normal ratio to avoid a // situation that too many pollers (all of them in the worst case) will open only sticky queue // polls observing a stickyBacklogSize == 1 for example (which actually can be 0 already at // that moment) and get stuck causing dip in worker load. - if (stickyBacklogSize > pollersCount || stickyPollers.get() <= normalPollers.get()) { - stickyPollers.incrementAndGet(); + if (stickyBacklogSize > pollersCount || stickyPollers <= normalPollers) { + stickyPollers++; return TaskQueueKind.TASK_QUEUE_KIND_STICKY; } } - normalPollers.incrementAndGet(); + normalPollers++; return TaskQueueKind.TASK_QUEUE_KIND_NORMAL; } /** * @param taskQueueKind what kind of task queue poll was just finished */ - public void finishPoll(TaskQueueKind taskQueueKind) { + public synchronized void finishPoll(TaskQueueKind taskQueueKind) { switch (taskQueueKind) { case TASK_QUEUE_KIND_NORMAL: - normalPollers.decrementAndGet(); + normalPollers--; break; case TASK_QUEUE_KIND_STICKY: - stickyPollers.decrementAndGet(); + stickyPollers--; break; default: throw new IllegalArgumentException("Invalid task queue kind: " + taskQueueKind); @@ -83,18 +80,14 @@ public void finishPoll(TaskQueueKind taskQueueKind) { * @param backlogSize backlog size from the poll response, helps to determine if the sticky queue * is backlogged */ - public void finishPoll(TaskQueueKind taskQueueKind, long backlogSize) { + public synchronized void finishPoll(TaskQueueKind taskQueueKind, long backlogSize) { finishPoll(taskQueueKind); if (TaskQueueKind.TASK_QUEUE_KIND_STICKY.equals(taskQueueKind)) { stickyBacklogSize = backlogSize; } } - public void disableNormalPoll() { - disableNormalPoll.set(true); - } - - public int getNormalPollerCount() { - return normalPollers.get(); + public synchronized void disableNormalPoll() { + disableNormalPoll = true; } }