Skip to content

Commit

Permalink
Enable artifact staging during Prism Runner lifecycle (#32084)
Browse files Browse the repository at this point in the history
  • Loading branch information
damondouglas authored Aug 6, 2024
1 parent e9b5dc6 commit 99a2383
Show file tree
Hide file tree
Showing 3 changed files with 317 additions and 0 deletions.
1 change: 1 addition & 0 deletions runners/prism/java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies {
implementation project(path: ":model:pipeline", configuration: "shadow")
implementation project(path: ":sdks:java:core", configuration: "shadow")
implementation project(path: ":sdks:java:harness", configuration: "shadow")
implementation project(":runners:java-fn-execution")
implementation project(":runners:portability:java")

implementation library.java.joda_time
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.prism;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;

import com.google.auto.value.AutoValue;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import org.apache.beam.model.jobmanagement.v1.ArtifactStagingServiceGrpc;
import org.apache.beam.model.jobmanagement.v1.JobApi;
import org.apache.beam.model.jobmanagement.v1.JobServiceGrpc;
import org.apache.beam.model.pipeline.v1.Endpoints;
import org.apache.beam.runners.fnexecution.artifact.ArtifactRetrievalService;
import org.apache.beam.runners.fnexecution.artifact.ArtifactStagingService;
import org.apache.beam.sdk.fn.channel.ManagedChannelFactory;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Stages {@link org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline} artifacts of prepared jobs.
*/
@AutoValue
abstract class PrismArtifactStager implements AutoCloseable {

private static final Logger LOG = LoggerFactory.getLogger(PrismArtifactStager.class);

/**
* Instantiate a {@link PrismArtifactStager} via call to {@link #of(String, String)}, assigning
* {@link Builder#setStagingEndpoint} using {@param prepareJobResponse} {@link
* JobApi.PrepareJobResponse#getArtifactStagingEndpoint} and {@link
* JobApi.PrepareJobResponse#getStagingSessionToken}.
*/
static PrismArtifactStager of(JobApi.PrepareJobResponse prepareJobResponse) {
return of(
prepareJobResponse.getArtifactStagingEndpoint().getUrl(),
prepareJobResponse.getStagingSessionToken());
}

/**
* Instantiates a {@link PrismArtifactStager} from the {@param stagingEndpoint} URL and {@param
* stagingSessionToken} to instantiate the {@link #getRetrievalService}, {@link
* #getManagedChannel}, and {@link #getStagingServiceStub} defaults. See the referenced getters
* for more details.
*/
static PrismArtifactStager of(String stagingEndpoint, String stagingSessionToken) {
return PrismArtifactStager.builder()
.setStagingEndpoint(stagingEndpoint)
.setStagingSessionToken(stagingSessionToken)
.build();
}

static Builder builder() {
return new AutoValue_PrismArtifactStager.Builder();
}

/**
* Stage the {@link org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline} artifacts via {@link
* ArtifactStagingService#offer} supplying {@link #getRetrievalService}, {@link
* #getStagingServiceStub}, and {@link #getStagingSessionToken}.
*/
void stage() throws ExecutionException, InterruptedException {
LOG.info("staging artifacts at {}", getStagingEndpoint());
ArtifactStagingService.offer(
getRetrievalService(), getStagingServiceStub(), getStagingSessionToken());
}

/** The URL of the {@link ArtifactStagingService}. */
abstract String getStagingEndpoint();

/**
* Token associated with a staging session and acquired from a {@link
* JobServiceGrpc.JobServiceStub#prepare}'s {@link JobApi.PrepareJobResponse}.
*/
abstract String getStagingSessionToken();

/**
* The service that retrieves artifacts; defaults to instantiating from the default {@link
* ArtifactRetrievalService#ArtifactRetrievalService()} constructor.
*/
abstract ArtifactRetrievalService getRetrievalService();

/**
* Used to instantiate the {@link #getStagingServiceStub}. By default, instantiates using {@link
* ManagedChannelFactory#forDescriptor(Endpoints.ApiServiceDescriptor)}, where {@link
* Endpoints.ApiServiceDescriptor} is instantiated via {@link
* Endpoints.ApiServiceDescriptor.Builder#setUrl(String)} and the URL provided by {@link
* #getStagingEndpoint}.
*/
abstract ManagedChannel getManagedChannel();

/**
* Required by {@link ArtifactStagingService#offer}. By default, instantiates using {@link
* ArtifactStagingServiceGrpc#newStub} and {@link #getManagedChannel}.
*/
abstract ArtifactStagingServiceGrpc.ArtifactStagingServiceStub getStagingServiceStub();

@Override
public void close() {
LOG.info("shutting down {}", PrismArtifactStager.class);
getRetrievalService().close();
getManagedChannel().shutdown();
try {
getManagedChannel().awaitTermination(3000L, TimeUnit.MILLISECONDS);
} catch (InterruptedException ignored) {
}
}

@AutoValue.Builder
abstract static class Builder {

abstract Builder setStagingEndpoint(String stagingEndpoint);

abstract Optional<String> getStagingEndpoint();

abstract Builder setStagingSessionToken(String stagingSessionToken);

abstract Builder setRetrievalService(ArtifactRetrievalService retrievalService);

abstract Optional<ArtifactRetrievalService> getRetrievalService();

abstract Builder setManagedChannel(ManagedChannel managedChannel);

abstract Optional<ManagedChannel> getManagedChannel();

abstract Builder setStagingServiceStub(
ArtifactStagingServiceGrpc.ArtifactStagingServiceStub stub);

abstract Optional<ArtifactStagingServiceGrpc.ArtifactStagingServiceStub>
getStagingServiceStub();

abstract PrismArtifactStager autoBuild();

final PrismArtifactStager build() {

checkState(getStagingEndpoint().isPresent(), "missing staging endpoint");
ManagedChannelFactory channelFactory = ManagedChannelFactory.createDefault();

if (!getManagedChannel().isPresent()) {
Endpoints.ApiServiceDescriptor descriptor =
Endpoints.ApiServiceDescriptor.newBuilder().setUrl(getStagingEndpoint().get()).build();
setManagedChannel(channelFactory.forDescriptor(descriptor));
}

if (!getStagingServiceStub().isPresent()) {
setStagingServiceStub(ArtifactStagingServiceGrpc.newStub(getManagedChannel().get()));
}

if (!getRetrievalService().isPresent()) {
setRetrievalService(new ArtifactRetrievalService());
}

return autoBuild();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.prism;

import static com.google.common.truth.Truth.assertThat;
import static org.apache.beam.runners.fnexecution.artifact.ArtifactRetrievalService.EMBEDDED_ARTIFACT_URN;
import static org.junit.Assert.assertThrows;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.fnexecution.artifact.ArtifactRetrievalService;
import org.apache.beam.runners.fnexecution.artifact.ArtifactStagingService;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.commons.io.output.ByteArrayOutputStream;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Tests for {@link PrismArtifactStager}. */
@RunWith(JUnit4.class)
public class PrismArtifactStagerTest {

@Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();

final ArtifactStagingService stagingService =
new ArtifactStagingService(new TestDestinationProvider());

@Test
public void givenValidArtifacts_stages()
throws IOException, ExecutionException, InterruptedException {
PrismArtifactStager underTest = prismArtifactStager(validArtifacts());
assertThat(underTest.getManagedChannel().isShutdown()).isFalse();
underTest.stage();
assertThat(stagingService.getStagedArtifacts(underTest.getStagingSessionToken())).isNotEmpty();
underTest.close();
assertThat(underTest.getManagedChannel().isShutdown()).isTrue();
}

@Test
public void givenErrors_performsGracefulCleanup() throws IOException {
PrismArtifactStager underTest = prismArtifactStager(invalidArtifacts());
assertThat(underTest.getManagedChannel().isShutdown()).isFalse();
ExecutionException error = assertThrows(ExecutionException.class, underTest::stage);
assertThat(error.getMessage()).contains("Unexpected artifact type: invalid-type-urn");
assertThat(underTest.getManagedChannel().isShutdown()).isFalse();
underTest.close();
assertThat(underTest.getManagedChannel().isShutdown()).isTrue();
}

private PrismArtifactStager prismArtifactStager(
Map<String, List<RunnerApi.ArtifactInformation>> artifacts) throws IOException {
String serverName = InProcessServerBuilder.generateName();
ArtifactRetrievalService retrievalService = new ArtifactRetrievalService();
String stagingToken = "staging-token";
stagingService.registerJob(stagingToken, artifacts);

grpcCleanup.register(
InProcessServerBuilder.forName(serverName)
.directExecutor()
.addService(stagingService)
.addService(retrievalService)
.build()
.start());

ManagedChannel channel =
grpcCleanup.register(InProcessChannelBuilder.forName(serverName).build());

return PrismArtifactStager.builder()
.setStagingEndpoint("ignore")
.setStagingSessionToken(stagingToken)
.setManagedChannel(channel)
.build();
}

private Map<String, List<RunnerApi.ArtifactInformation>> validArtifacts() {
return ImmutableMap.of(
"env1",
Collections.singletonList(
RunnerApi.ArtifactInformation.newBuilder()
.setTypeUrn(EMBEDDED_ARTIFACT_URN)
.setTypePayload(
RunnerApi.EmbeddedFilePayload.newBuilder()
.setData(ByteString.copyFromUtf8("type-payload"))
.build()
.toByteString())
.setRoleUrn("role-urn")
.build()));
}

private Map<String, List<RunnerApi.ArtifactInformation>> invalidArtifacts() {
return ImmutableMap.of(
"env1",
Collections.singletonList(
RunnerApi.ArtifactInformation.newBuilder()
.setTypeUrn("invalid-type-urn")
.setTypePayload(
RunnerApi.EmbeddedFilePayload.newBuilder()
.setData(ByteString.copyFromUtf8("type-payload"))
.build()
.toByteString())
.setRoleUrn("role-urn")
.build()));
}

private static class TestDestinationProvider
implements ArtifactStagingService.ArtifactDestinationProvider {

@Override
public ArtifactStagingService.ArtifactDestination getDestination(
String stagingToken, String name) throws IOException {
return ArtifactStagingService.ArtifactDestination.create(
EMBEDDED_ARTIFACT_URN, ByteString.EMPTY, new ByteArrayOutputStream());
}

@Override
public void removeStagedArtifacts(String stagingToken) throws IOException {}
}
}

0 comments on commit 99a2383

Please sign in to comment.