Skip to content

Commit

Permalink
StateWatcher watches and reports changed Pipeline State (#32040)
Browse files Browse the repository at this point in the history
* StateWatcher watches for changed Pipeline State

* Add Javadoc
  • Loading branch information
damondouglas authored Aug 1, 2024
1 parent 56aa17b commit 21009e6
Show file tree
Hide file tree
Showing 4 changed files with 311 additions and 0 deletions.
2 changes: 2 additions & 0 deletions runners/prism/java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ description = "Apache Beam :: Runners :: Prism :: Java"
ext.summary = "Support for executing a pipeline on Prism."

dependencies {
implementation project(path: ":model:job-management", configuration: "shadow")
implementation project(path: ":sdks:java:core", configuration: "shadow")
implementation project(":runners:portability:java")

implementation library.java.joda_time
implementation library.java.slf4j_api
implementation library.java.vendored_grpc_1_60_1
implementation library.java.vendored_guava_32_1_2_jre

testImplementation library.java.junit
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.prism;

import org.apache.beam.sdk.PipelineResult;

/** Listens for {@link PipelineResult.State} changes reported by the {@link StateWatcher}. */
interface StateListener {

/** Callback invoked when {@link StateWatcher} discovers a {@link PipelineResult.State} change. */
void onStateChanged(PipelineResult.State state);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.prism;

import com.google.auto.value.AutoValue;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import org.apache.beam.model.jobmanagement.v1.JobApi;
import org.apache.beam.model.jobmanagement.v1.JobServiceGrpc;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ChannelCredentials;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.InsecureChannelCredentials;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.netty.NettyChannelBuilder;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;

/**
* {@link StateWatcher} {@link #watch}es for and reports {@link PipelineResult.State} changes to
* {@link StateListener}s.
*/
@AutoValue
abstract class StateWatcher implements AutoCloseable {

private Optional<PipelineResult.State> latestState = Optional.empty();

/**
* Instantiates a {@link StateWatcher} with {@link InsecureChannelCredentials}. {@link
* StateWatcher} will report to each {@link StateListener} of {@param listeners} of any changed
* {@link PipelineResult.State}.
*/
static StateWatcher insecure(String endpoint, StateListener... listeners) {
return StateWatcher.builder()
.setEndpoint(HostAndPort.fromString(endpoint))
.setCredentials(InsecureChannelCredentials.create())
.setListeners(Arrays.asList(listeners))
.build();
}

/**
* Watch for a Job's {@link PipelineResult.State} change. A {@link
* org.apache.beam.model.jobmanagement.v1.JobApi.GetJobStateRequest} identifies a Job to watch via
* its {@link JobApi.GetJobStateRequest#getJobId()}. The method is blocking until the {@link
* JobApi.JobStateEvent} {@link StreamObserver#onCompleted()}.
*/
void watch(String jobId) {
JobApi.GetJobStateRequest request =
JobApi.GetJobStateRequest.newBuilder().setJobId(jobId).build();
Iterator<JobApi.JobStateEvent> iterator = getJobServiceBlockingStub().getStateStream(request);
while (iterator.hasNext()) {
JobApi.JobStateEvent event = iterator.next();
PipelineResult.State state = PipelineResult.State.valueOf(event.getState().name());
publish(state);
}
}

private void publish(PipelineResult.State state) {
if (latestState.isPresent() && latestState.get().equals(state)) {
return;
}
latestState = Optional.of(state);
for (StateListener listener : getListeners()) {
listener.onStateChanged(state);
}
}

static Builder builder() {
return new AutoValue_StateWatcher.Builder();
}

abstract HostAndPort getEndpoint();

abstract ChannelCredentials getCredentials();

abstract List<StateListener> getListeners();

abstract ManagedChannel getManagedChannel();

abstract JobServiceGrpc.JobServiceBlockingStub getJobServiceBlockingStub();

@Override
public void close() {
getManagedChannel().shutdown();
try {
getManagedChannel().awaitTermination(3000L, TimeUnit.MILLISECONDS);
} catch (InterruptedException ignored) {
}
}

@AutoValue.Builder
abstract static class Builder {

abstract Builder setEndpoint(HostAndPort endpoint);

abstract Optional<HostAndPort> getEndpoint();

abstract Builder setCredentials(ChannelCredentials credentials);

abstract Optional<ChannelCredentials> getCredentials();

abstract Builder setListeners(List<StateListener> listeners);

abstract Builder setManagedChannel(ManagedChannel managedChannel);

abstract Builder setJobServiceBlockingStub(
JobServiceGrpc.JobServiceBlockingStub jobServiceBlockingStub);

abstract StateWatcher autoBuild();

final StateWatcher build() {
if (!getEndpoint().isPresent()) {
throw new IllegalStateException("missing endpoint");
}
if (!getCredentials().isPresent()) {
throw new IllegalStateException("missing credentials");
}
HostAndPort endpoint = getEndpoint().get();
ManagedChannel channel =
NettyChannelBuilder.forAddress(
endpoint.getHost(), endpoint.getPort(), getCredentials().get())
.build();
setManagedChannel(channel);
setJobServiceBlockingStub(JobServiceGrpc.newBlockingStub(channel));

return autoBuild();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.prism;

import static com.google.common.truth.Truth.assertThat;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.beam.model.jobmanagement.v1.JobApi;
import org.apache.beam.model.jobmanagement.v1.JobServiceGrpc;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Grpc;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.InsecureServerCredentials;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public class StateWatcherTest {

@Test
public void givenSingleListener_watches() {
Server server = serverOf(PipelineResult.State.RUNNING, PipelineResult.State.DONE);
TestStateListener listener = new TestStateListener();
try (StateWatcher underTest = StateWatcher.insecure("0.0.0.0:" + server.getPort(), listener)) {
underTest.watch("job-001");
assertThat(listener.states)
.containsExactly(PipelineResult.State.RUNNING, PipelineResult.State.DONE);
shutdown(server);
}
}

@Test
public void givenMultipleListeners_watches() {
Server server = serverOf(PipelineResult.State.RUNNING, PipelineResult.State.DONE);
TestStateListener listenerA = new TestStateListener();
TestStateListener listenerB = new TestStateListener();
try (StateWatcher underTest =
StateWatcher.insecure("0.0.0.0:" + server.getPort(), listenerA, listenerB)) {
underTest.watch("job-001");
assertThat(listenerA.states)
.containsExactly(PipelineResult.State.RUNNING, PipelineResult.State.DONE);
assertThat(listenerB.states)
.containsExactly(PipelineResult.State.RUNNING, PipelineResult.State.DONE);
shutdown(server);
}
}

@Test
public void publishesOnlyChangedState() {
Server server =
serverOf(
PipelineResult.State.RUNNING,
PipelineResult.State.RUNNING,
PipelineResult.State.RUNNING,
PipelineResult.State.RUNNING,
PipelineResult.State.RUNNING,
PipelineResult.State.RUNNING,
PipelineResult.State.RUNNING,
PipelineResult.State.DONE);
TestStateListener listener = new TestStateListener();
try (StateWatcher underTest = StateWatcher.insecure("0.0.0.0:" + server.getPort(), listener)) {
underTest.watch("job-001");
assertThat(listener.states)
.containsExactly(PipelineResult.State.RUNNING, PipelineResult.State.DONE);
shutdown(server);
}
}

private static class TestStateListener implements StateListener {
private final List<PipelineResult.State> states = new ArrayList<>();

@Override
public void onStateChanged(PipelineResult.State state) {
states.add(state);
}
}

private static class TestJobServiceStateStream extends JobServiceGrpc.JobServiceImplBase {
private final List<PipelineResult.State> states;

TestJobServiceStateStream(PipelineResult.State... states) {
this.states = Arrays.asList(states);
}

@Override
public void getStateStream(
JobApi.GetJobStateRequest request, StreamObserver<JobApi.JobStateEvent> responseObserver) {
for (PipelineResult.State state : states) {
responseObserver.onNext(
JobApi.JobStateEvent.newBuilder()
.setState(JobApi.JobState.Enum.valueOf(state.name()))
.build());
}
responseObserver.onCompleted();
}
}

private static Server serverOf(PipelineResult.State... states) {
try {
return Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create())
.addService(new TestJobServiceStateStream(states))
.build()
.start();
} catch (IOException e) {
throw new RuntimeException(e);
}
}

private static void shutdown(Server server) {
server.shutdownNow();
try {
server.awaitTermination();
} catch (InterruptedException ignored) {
}
}
}

0 comments on commit 21009e6

Please sign in to comment.