Skip to content

Commit

Permalink
Replace StateTag.StateBinder to top level StateBinder in SparkStateIn…
Browse files Browse the repository at this point in the history
…ternals (#31798)
  • Loading branch information
twosom authored Aug 12, 2024
1 parent 2f93d8b commit 780eef9
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 60 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run",
"https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test"
"https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test",
"https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test"
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run",
"https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test"
"https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test",
"https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test"
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run",
"https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test"
"https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test",
"https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test"
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ public interface StateTag<StateT extends State> extends Serializable {
/**
* Visitor for binding a {@link StateSpec} and to the associated {@link State}.
*
* @deprecated for migration only; runners should reference the top level {@link StateBinder} and
* move towards {@link StateSpec} rather than {@link StateTag}.
* @deprecated for migration only; runners should reference the top level {@link
* org.apache.beam.sdk.state.StateBinder} and move towards {@link StateSpec} rather than
* {@link StateTag}.
*/
@Deprecated
public interface StateBinder {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.apache.beam.runners.core.StateInternals;
import org.apache.beam.runners.core.StateNamespace;
import org.apache.beam.runners.core.StateTag;
import org.apache.beam.runners.core.StateTag.StateBinder;
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.InstantCoder;
Expand All @@ -42,11 +41,13 @@
import org.apache.beam.sdk.state.ReadableStates;
import org.apache.beam.sdk.state.SetState;
import org.apache.beam.sdk.state.State;
import org.apache.beam.sdk.state.StateBinder;
import org.apache.beam.sdk.state.StateContext;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.state.WatermarkHoldState;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
import org.apache.beam.sdk.transforms.CombineWithContext;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.util.CombineFnUtil;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashBasedTable;
Expand Down Expand Up @@ -96,45 +97,47 @@ public K getKey() {
@Override
public <T extends State> T state(
StateNamespace namespace, StateTag<T> address, StateContext<?> c) {
return address.bind(new SparkStateBinder(namespace, c));
return address.getSpec().bind(address.getId(), new SparkStateBinder(namespace, c));
}

private class SparkStateBinder implements StateBinder {
private final StateNamespace namespace;
private final StateContext<?> c;
private final StateContext<?> stateContext;

private SparkStateBinder(StateNamespace namespace, StateContext<?> c) {
private SparkStateBinder(StateNamespace namespace, StateContext<?> stateContext) {
this.namespace = namespace;
this.c = c;
this.stateContext = stateContext;
}

@Override
public <T> ValueState<T> bindValue(StateTag<ValueState<T>> address, Coder<T> coder) {
return new SparkValueState<>(namespace, address, coder);
public <T> ValueState<T> bindValue(String id, StateSpec<ValueState<T>> spec, Coder<T> coder) {
return new SparkValueState<>(namespace, id, coder);
}

@Override
public <T> BagState<T> bindBag(StateTag<BagState<T>> address, Coder<T> elemCoder) {
return new SparkBagState<>(namespace, address, elemCoder);
public <T> BagState<T> bindBag(String id, StateSpec<BagState<T>> spec, Coder<T> elemCoder) {
return new SparkBagState<>(namespace, id, elemCoder);
}

@Override
public <T> SetState<T> bindSet(StateTag<SetState<T>> spec, Coder<T> elemCoder) {
public <T> SetState<T> bindSet(String id, StateSpec<SetState<T>> spec, Coder<T> elemCoder) {
throw new UnsupportedOperationException(
String.format("%s is not supported", SetState.class.getSimpleName()));
}

@Override
public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
StateTag<MapState<KeyT, ValueT>> address,
String id,
StateSpec<MapState<KeyT, ValueT>> spec,
Coder<KeyT> mapKeyCoder,
Coder<ValueT> mapValueCoder) {
return new SparkMapState<>(namespace, address, MapCoder.of(mapKeyCoder, mapValueCoder));
return new SparkMapState<>(namespace, id, MapCoder.of(mapKeyCoder, mapValueCoder));
}

@Override
public <KeyT, ValueT> MultimapState<KeyT, ValueT> bindMultimap(
StateTag<MultimapState<KeyT, ValueT>> spec,
String id,
StateSpec<MultimapState<KeyT, ValueT>> spec,
Coder<KeyT> keyCoder,
Coder<ValueT> valueCoder) {
throw new UnsupportedOperationException(
Expand All @@ -143,63 +146,63 @@ public <KeyT, ValueT> MultimapState<KeyT, ValueT> bindMultimap(

@Override
public <T> OrderedListState<T> bindOrderedList(
StateTag<OrderedListState<T>> spec, Coder<T> elemCoder) {
String id, StateSpec<OrderedListState<T>> spec, Coder<T> elemCoder) {
throw new UnsupportedOperationException(
String.format("%s is not supported", OrderedListState.class.getSimpleName()));
}

@Override
public <InputT, AccumT, OutputT> CombiningState<InputT, AccumT, OutputT> bindCombiningValue(
StateTag<CombiningState<InputT, AccumT, OutputT>> address,
public <InputT, AccumT, OutputT> CombiningState<InputT, AccumT, OutputT> bindCombining(
String id,
StateSpec<CombiningState<InputT, AccumT, OutputT>> spec,
Coder<AccumT> accumCoder,
CombineFn<InputT, AccumT, OutputT> combineFn) {
return new SparkCombiningState<>(namespace, address, accumCoder, combineFn);
return new SparkCombiningState<>(namespace, id, accumCoder, combineFn);
}

@Override
public <InputT, AccumT, OutputT>
CombiningState<InputT, AccumT, OutputT> bindCombiningValueWithContext(
StateTag<CombiningState<InputT, AccumT, OutputT>> address,
CombiningState<InputT, AccumT, OutputT> bindCombiningWithContext(
String id,
StateSpec<CombiningState<InputT, AccumT, OutputT>> spec,
Coder<AccumT> accumCoder,
CombineFnWithContext<InputT, AccumT, OutputT> combineFn) {
CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn) {
return new SparkCombiningState<>(
namespace, address, accumCoder, CombineFnUtil.bindContext(combineFn, c));
namespace, id, accumCoder, CombineFnUtil.bindContext(combineFn, stateContext));
}

@Override
public WatermarkHoldState bindWatermark(
StateTag<WatermarkHoldState> address, TimestampCombiner timestampCombiner) {
return new SparkWatermarkHoldState(namespace, address, timestampCombiner);
String id, StateSpec<WatermarkHoldState> spec, TimestampCombiner timestampCombiner) {
return new SparkWatermarkHoldState(namespace, id, timestampCombiner);
}
}

private class AbstractState<T> {
final StateNamespace namespace;
final StateTag<? extends State> address;
final String id;
final Coder<T> coder;

private AbstractState(
StateNamespace namespace, StateTag<? extends State> address, Coder<T> coder) {
private AbstractState(StateNamespace namespace, String id, Coder<T> coder) {
this.namespace = namespace;
this.address = address;
this.id = id;
this.coder = coder;
}

T readValue() {
byte[] buf = stateTable.get(namespace.stringKey(), address.getId());
byte[] buf = stateTable.get(namespace.stringKey(), id);
if (buf != null) {
return CoderHelpers.fromByteArray(buf, coder);
}
return null;
}

void writeValue(T input) {
stateTable.put(
namespace.stringKey(), address.getId(), CoderHelpers.toByteArray(input, coder));
stateTable.put(namespace.stringKey(), id, CoderHelpers.toByteArray(input, coder));
}

public void clear() {
stateTable.remove(namespace.stringKey(), address.getId());
stateTable.remove(namespace.stringKey(), id);
}

@Override
Expand All @@ -212,22 +215,21 @@ public boolean equals(@Nullable Object o) {
}
@SuppressWarnings("unchecked")
AbstractState<?> that = (AbstractState<?>) o;
return namespace.equals(that.namespace) && address.equals(that.address);
return namespace.equals(that.namespace) && id.equals(that.id);
}

@Override
public int hashCode() {
int result = namespace.hashCode();
result = 31 * result + address.hashCode();
result = 31 * result + id.hashCode();
return result;
}
}

private class SparkValueState<T> extends AbstractState<T> implements ValueState<T> {

private SparkValueState(
StateNamespace namespace, StateTag<ValueState<T>> address, Coder<T> coder) {
super(namespace, address, coder);
private SparkValueState(StateNamespace namespace, String id, Coder<T> coder) {
super(namespace, id, coder);
}

@Override
Expand All @@ -252,10 +254,8 @@ private class SparkWatermarkHoldState extends AbstractState<Instant>
private final TimestampCombiner timestampCombiner;

SparkWatermarkHoldState(
StateNamespace namespace,
StateTag<WatermarkHoldState> address,
TimestampCombiner timestampCombiner) {
super(namespace, address, InstantCoder.of());
StateNamespace namespace, String id, TimestampCombiner timestampCombiner) {
super(namespace, id, InstantCoder.of());
this.timestampCombiner = timestampCombiner;
}

Expand Down Expand Up @@ -287,7 +287,7 @@ public ReadableState<Boolean> readLater() {

@Override
public Boolean read() {
return stateTable.get(namespace.stringKey(), address.getId()) == null;
return stateTable.get(namespace.stringKey(), id) == null;
}
};
}
Expand All @@ -299,22 +299,22 @@ public TimestampCombiner getTimestampCombiner() {
}

@SuppressWarnings("TypeParameterShadowing")
private class SparkCombiningState<K, InputT, AccumT, OutputT> extends AbstractState<AccumT>
private class SparkCombiningState<KeyT, InputT, AccumT, OutputT> extends AbstractState<AccumT>
implements CombiningState<InputT, AccumT, OutputT> {

private final CombineFn<InputT, AccumT, OutputT> combineFn;

private SparkCombiningState(
StateNamespace namespace,
StateTag<CombiningState<InputT, AccumT, OutputT>> address,
String id,
Coder<AccumT> coder,
CombineFn<InputT, AccumT, OutputT> combineFn) {
super(namespace, address, coder);
super(namespace, id, coder);
this.combineFn = combineFn;
}

@Override
public SparkCombiningState<K, InputT, AccumT, OutputT> readLater() {
public SparkCombiningState<KeyT, InputT, AccumT, OutputT> readLater() {
return this;
}

Expand Down Expand Up @@ -348,7 +348,7 @@ public ReadableState<Boolean> readLater() {

@Override
public Boolean read() {
return stateTable.get(namespace.stringKey(), address.getId()) == null;
return stateTable.get(namespace.stringKey(), id) == null;
}
};
}
Expand All @@ -369,10 +369,8 @@ private final class SparkMapState<MapKeyT, MapValueT>
extends AbstractState<Map<MapKeyT, MapValueT>> implements MapState<MapKeyT, MapValueT> {

private SparkMapState(
StateNamespace namespace,
StateTag<? extends State> address,
Coder<Map<MapKeyT, MapValueT>> coder) {
super(namespace, address, coder);
StateNamespace namespace, String id, Coder<Map<MapKeyT, MapValueT>> coder) {
super(namespace, id, coder);
}

@Override
Expand Down Expand Up @@ -490,7 +488,7 @@ public ReadableState<Boolean> isEmpty() {
return new ReadableState<Boolean>() {
@Override
public Boolean read() {
return stateTable.get(namespace.stringKey(), address.getId()) == null;
return stateTable.get(namespace.stringKey(), id) == null;
}

@Override
Expand All @@ -502,8 +500,8 @@ public ReadableState<Boolean> readLater() {
}

private final class SparkBagState<T> extends AbstractState<List<T>> implements BagState<T> {
private SparkBagState(StateNamespace namespace, StateTag<BagState<T>> address, Coder<T> coder) {
super(namespace, address, ListCoder.of(coder));
private SparkBagState(StateNamespace namespace, String id, Coder<T> coder) {
super(namespace, id, ListCoder.of(coder));
}

@Override
Expand Down Expand Up @@ -537,7 +535,7 @@ public ReadableState<Boolean> readLater() {

@Override
public Boolean read() {
return stateTable.get(namespace.stringKey(), address.getId()) == null;
return stateTable.get(namespace.stringKey(), id) == null;
}
};
}
Expand Down

0 comments on commit 780eef9

Please sign in to comment.