From 780eef98083fe56f81cc5c62dc8ff193993584f0 Mon Sep 17 00:00:00 2001 From: twosom <72733442+twosom@users.noreply.github.com> Date: Mon, 12 Aug 2024 22:35:35 +0900 Subject: [PATCH] Replace StateTag.StateBinder to top level StateBinder in SparkStateInternals (#31798) --- ...PostCommit_Java_ValidatesRunner_Spark.json | 3 +- ...idatesRunner_SparkStructuredStreaming.json | 3 +- ...mit_Java_ValidatesRunner_Spark_Java11.json | 3 +- .../apache/beam/runners/core/StateTag.java | 5 +- .../spark/stateful/SparkStateInternals.java | 108 +++++++++--------- 5 files changed, 62 insertions(+), 60 deletions(-) diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json index b970762c8397..d59e273949da 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json @@ -1,4 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test" + "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test", + "https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test" } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.json index b970762c8397..d59e273949da 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.json @@ -1,4 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test" + "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test", + "https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test" } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json index b970762c8397..d59e273949da 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json @@ -1,4 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test" + "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test", + "https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test" } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTag.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTag.java index 8c699ac31117..0106f95ed748 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTag.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTag.java @@ -69,8 +69,9 @@ public interface StateTag extends Serializable { /** * Visitor for binding a {@link StateSpec} and to the associated {@link State}. * - * @deprecated for migration only; runners should reference the top level {@link StateBinder} and - * move towards {@link StateSpec} rather than {@link StateTag}. + * @deprecated for migration only; runners should reference the top level {@link + * org.apache.beam.sdk.state.StateBinder} and move towards {@link StateSpec} rather than + * {@link StateTag}. */ @Deprecated public interface StateBinder { diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java index 731cadb89f0c..7ca0dc29e615 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java @@ -27,7 +27,6 @@ import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.StateTag; -import org.apache.beam.runners.core.StateTag.StateBinder; import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.InstantCoder; @@ -42,11 +41,13 @@ import org.apache.beam.sdk.state.ReadableStates; import org.apache.beam.sdk.state.SetState; import org.apache.beam.sdk.state.State; +import org.apache.beam.sdk.state.StateBinder; import org.apache.beam.sdk.state.StateContext; +import org.apache.beam.sdk.state.StateSpec; import org.apache.beam.sdk.state.ValueState; import org.apache.beam.sdk.state.WatermarkHoldState; import org.apache.beam.sdk.transforms.Combine.CombineFn; -import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext; +import org.apache.beam.sdk.transforms.CombineWithContext; import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; import org.apache.beam.sdk.util.CombineFnUtil; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashBasedTable; @@ -96,45 +97,47 @@ public K getKey() { @Override public T state( StateNamespace namespace, StateTag address, StateContext c) { - return address.bind(new SparkStateBinder(namespace, c)); + return address.getSpec().bind(address.getId(), new SparkStateBinder(namespace, c)); } private class SparkStateBinder implements StateBinder { private final StateNamespace namespace; - private final StateContext c; + private final StateContext stateContext; - private SparkStateBinder(StateNamespace namespace, StateContext c) { + private SparkStateBinder(StateNamespace namespace, StateContext stateContext) { this.namespace = namespace; - this.c = c; + this.stateContext = stateContext; } @Override - public ValueState bindValue(StateTag> address, Coder coder) { - return new SparkValueState<>(namespace, address, coder); + public ValueState bindValue(String id, StateSpec> spec, Coder coder) { + return new SparkValueState<>(namespace, id, coder); } @Override - public BagState bindBag(StateTag> address, Coder elemCoder) { - return new SparkBagState<>(namespace, address, elemCoder); + public BagState bindBag(String id, StateSpec> spec, Coder elemCoder) { + return new SparkBagState<>(namespace, id, elemCoder); } @Override - public SetState bindSet(StateTag> spec, Coder elemCoder) { + public SetState bindSet(String id, StateSpec> spec, Coder elemCoder) { throw new UnsupportedOperationException( String.format("%s is not supported", SetState.class.getSimpleName())); } @Override public MapState bindMap( - StateTag> address, + String id, + StateSpec> spec, Coder mapKeyCoder, Coder mapValueCoder) { - return new SparkMapState<>(namespace, address, MapCoder.of(mapKeyCoder, mapValueCoder)); + return new SparkMapState<>(namespace, id, MapCoder.of(mapKeyCoder, mapValueCoder)); } @Override public MultimapState bindMultimap( - StateTag> spec, + String id, + StateSpec> spec, Coder keyCoder, Coder valueCoder) { throw new UnsupportedOperationException( @@ -143,50 +146,51 @@ public MultimapState bindMultimap( @Override public OrderedListState bindOrderedList( - StateTag> spec, Coder elemCoder) { + String id, StateSpec> spec, Coder elemCoder) { throw new UnsupportedOperationException( String.format("%s is not supported", OrderedListState.class.getSimpleName())); } @Override - public CombiningState bindCombiningValue( - StateTag> address, + public CombiningState bindCombining( + String id, + StateSpec> spec, Coder accumCoder, CombineFn combineFn) { - return new SparkCombiningState<>(namespace, address, accumCoder, combineFn); + return new SparkCombiningState<>(namespace, id, accumCoder, combineFn); } @Override public - CombiningState bindCombiningValueWithContext( - StateTag> address, + CombiningState bindCombiningWithContext( + String id, + StateSpec> spec, Coder accumCoder, - CombineFnWithContext combineFn) { + CombineWithContext.CombineFnWithContext combineFn) { return new SparkCombiningState<>( - namespace, address, accumCoder, CombineFnUtil.bindContext(combineFn, c)); + namespace, id, accumCoder, CombineFnUtil.bindContext(combineFn, stateContext)); } @Override public WatermarkHoldState bindWatermark( - StateTag address, TimestampCombiner timestampCombiner) { - return new SparkWatermarkHoldState(namespace, address, timestampCombiner); + String id, StateSpec spec, TimestampCombiner timestampCombiner) { + return new SparkWatermarkHoldState(namespace, id, timestampCombiner); } } private class AbstractState { final StateNamespace namespace; - final StateTag address; + final String id; final Coder coder; - private AbstractState( - StateNamespace namespace, StateTag address, Coder coder) { + private AbstractState(StateNamespace namespace, String id, Coder coder) { this.namespace = namespace; - this.address = address; + this.id = id; this.coder = coder; } T readValue() { - byte[] buf = stateTable.get(namespace.stringKey(), address.getId()); + byte[] buf = stateTable.get(namespace.stringKey(), id); if (buf != null) { return CoderHelpers.fromByteArray(buf, coder); } @@ -194,12 +198,11 @@ T readValue() { } void writeValue(T input) { - stateTable.put( - namespace.stringKey(), address.getId(), CoderHelpers.toByteArray(input, coder)); + stateTable.put(namespace.stringKey(), id, CoderHelpers.toByteArray(input, coder)); } public void clear() { - stateTable.remove(namespace.stringKey(), address.getId()); + stateTable.remove(namespace.stringKey(), id); } @Override @@ -212,22 +215,21 @@ public boolean equals(@Nullable Object o) { } @SuppressWarnings("unchecked") AbstractState that = (AbstractState) o; - return namespace.equals(that.namespace) && address.equals(that.address); + return namespace.equals(that.namespace) && id.equals(that.id); } @Override public int hashCode() { int result = namespace.hashCode(); - result = 31 * result + address.hashCode(); + result = 31 * result + id.hashCode(); return result; } } private class SparkValueState extends AbstractState implements ValueState { - private SparkValueState( - StateNamespace namespace, StateTag> address, Coder coder) { - super(namespace, address, coder); + private SparkValueState(StateNamespace namespace, String id, Coder coder) { + super(namespace, id, coder); } @Override @@ -252,10 +254,8 @@ private class SparkWatermarkHoldState extends AbstractState private final TimestampCombiner timestampCombiner; SparkWatermarkHoldState( - StateNamespace namespace, - StateTag address, - TimestampCombiner timestampCombiner) { - super(namespace, address, InstantCoder.of()); + StateNamespace namespace, String id, TimestampCombiner timestampCombiner) { + super(namespace, id, InstantCoder.of()); this.timestampCombiner = timestampCombiner; } @@ -287,7 +287,7 @@ public ReadableState readLater() { @Override public Boolean read() { - return stateTable.get(namespace.stringKey(), address.getId()) == null; + return stateTable.get(namespace.stringKey(), id) == null; } }; } @@ -299,22 +299,22 @@ public TimestampCombiner getTimestampCombiner() { } @SuppressWarnings("TypeParameterShadowing") - private class SparkCombiningState extends AbstractState + private class SparkCombiningState extends AbstractState implements CombiningState { private final CombineFn combineFn; private SparkCombiningState( StateNamespace namespace, - StateTag> address, + String id, Coder coder, CombineFn combineFn) { - super(namespace, address, coder); + super(namespace, id, coder); this.combineFn = combineFn; } @Override - public SparkCombiningState readLater() { + public SparkCombiningState readLater() { return this; } @@ -348,7 +348,7 @@ public ReadableState readLater() { @Override public Boolean read() { - return stateTable.get(namespace.stringKey(), address.getId()) == null; + return stateTable.get(namespace.stringKey(), id) == null; } }; } @@ -369,10 +369,8 @@ private final class SparkMapState extends AbstractState> implements MapState { private SparkMapState( - StateNamespace namespace, - StateTag address, - Coder> coder) { - super(namespace, address, coder); + StateNamespace namespace, String id, Coder> coder) { + super(namespace, id, coder); } @Override @@ -490,7 +488,7 @@ public ReadableState isEmpty() { return new ReadableState() { @Override public Boolean read() { - return stateTable.get(namespace.stringKey(), address.getId()) == null; + return stateTable.get(namespace.stringKey(), id) == null; } @Override @@ -502,8 +500,8 @@ public ReadableState readLater() { } private final class SparkBagState extends AbstractState> implements BagState { - private SparkBagState(StateNamespace namespace, StateTag> address, Coder coder) { - super(namespace, address, ListCoder.of(coder)); + private SparkBagState(StateNamespace namespace, String id, Coder coder) { + super(namespace, id, ListCoder.of(coder)); } @Override @@ -537,7 +535,7 @@ public ReadableState readLater() { @Override public Boolean read() { - return stateTable.get(namespace.stringKey(), address.getId()) == null; + return stateTable.get(namespace.stringKey(), id) == null; } }; }