diff --git a/runners/prism/java/build.gradle b/runners/prism/java/build.gradle index 2b0635ca6125..96ab4e70a579 100644 --- a/runners/prism/java/build.gradle +++ b/runners/prism/java/build.gradle @@ -30,6 +30,7 @@ dependencies { implementation project(path: ":model:pipeline", configuration: "shadow") implementation project(path: ":sdks:java:core", configuration: "shadow") implementation project(path: ":sdks:java:harness", configuration: "shadow") + implementation project(":runners:java-fn-execution") implementation project(":runners:portability:java") implementation library.java.joda_time diff --git a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismArtifactStager.java b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismArtifactStager.java new file mode 100644 index 000000000000..f1d99a213eea --- /dev/null +++ b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismArtifactStager.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.prism; + +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; + +import com.google.auto.value.AutoValue; +import java.util.Optional; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import org.apache.beam.model.jobmanagement.v1.ArtifactStagingServiceGrpc; +import org.apache.beam.model.jobmanagement.v1.JobApi; +import org.apache.beam.model.jobmanagement.v1.JobServiceGrpc; +import org.apache.beam.model.pipeline.v1.Endpoints; +import org.apache.beam.runners.fnexecution.artifact.ArtifactRetrievalService; +import org.apache.beam.runners.fnexecution.artifact.ArtifactStagingService; +import org.apache.beam.sdk.fn.channel.ManagedChannelFactory; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Stages {@link org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline} artifacts of prepared jobs. + */ +@AutoValue +abstract class PrismArtifactStager implements AutoCloseable { + + private static final Logger LOG = LoggerFactory.getLogger(PrismArtifactStager.class); + + /** + * Instantiate a {@link PrismArtifactStager} via call to {@link #of(String, String)}, assigning + * {@link Builder#setStagingEndpoint} using {@param prepareJobResponse} {@link + * JobApi.PrepareJobResponse#getArtifactStagingEndpoint} and {@link + * JobApi.PrepareJobResponse#getStagingSessionToken}. + */ + static PrismArtifactStager of(JobApi.PrepareJobResponse prepareJobResponse) { + return of( + prepareJobResponse.getArtifactStagingEndpoint().getUrl(), + prepareJobResponse.getStagingSessionToken()); + } + + /** + * Instantiates a {@link PrismArtifactStager} from the {@param stagingEndpoint} URL and {@param + * stagingSessionToken} to instantiate the {@link #getRetrievalService}, {@link + * #getManagedChannel}, and {@link #getStagingServiceStub} defaults. See the referenced getters + * for more details. + */ + static PrismArtifactStager of(String stagingEndpoint, String stagingSessionToken) { + return PrismArtifactStager.builder() + .setStagingEndpoint(stagingEndpoint) + .setStagingSessionToken(stagingSessionToken) + .build(); + } + + static Builder builder() { + return new AutoValue_PrismArtifactStager.Builder(); + } + + /** + * Stage the {@link org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline} artifacts via {@link + * ArtifactStagingService#offer} supplying {@link #getRetrievalService}, {@link + * #getStagingServiceStub}, and {@link #getStagingSessionToken}. + */ + void stage() throws ExecutionException, InterruptedException { + LOG.info("staging artifacts at {}", getStagingEndpoint()); + ArtifactStagingService.offer( + getRetrievalService(), getStagingServiceStub(), getStagingSessionToken()); + } + + /** The URL of the {@link ArtifactStagingService}. */ + abstract String getStagingEndpoint(); + + /** + * Token associated with a staging session and acquired from a {@link + * JobServiceGrpc.JobServiceStub#prepare}'s {@link JobApi.PrepareJobResponse}. + */ + abstract String getStagingSessionToken(); + + /** + * The service that retrieves artifacts; defaults to instantiating from the default {@link + * ArtifactRetrievalService#ArtifactRetrievalService()} constructor. + */ + abstract ArtifactRetrievalService getRetrievalService(); + + /** + * Used to instantiate the {@link #getStagingServiceStub}. By default, instantiates using {@link + * ManagedChannelFactory#forDescriptor(Endpoints.ApiServiceDescriptor)}, where {@link + * Endpoints.ApiServiceDescriptor} is instantiated via {@link + * Endpoints.ApiServiceDescriptor.Builder#setUrl(String)} and the URL provided by {@link + * #getStagingEndpoint}. + */ + abstract ManagedChannel getManagedChannel(); + + /** + * Required by {@link ArtifactStagingService#offer}. By default, instantiates using {@link + * ArtifactStagingServiceGrpc#newStub} and {@link #getManagedChannel}. + */ + abstract ArtifactStagingServiceGrpc.ArtifactStagingServiceStub getStagingServiceStub(); + + @Override + public void close() { + LOG.info("shutting down {}", PrismArtifactStager.class); + getRetrievalService().close(); + getManagedChannel().shutdown(); + try { + getManagedChannel().awaitTermination(3000L, TimeUnit.MILLISECONDS); + } catch (InterruptedException ignored) { + } + } + + @AutoValue.Builder + abstract static class Builder { + + abstract Builder setStagingEndpoint(String stagingEndpoint); + + abstract Optional getStagingEndpoint(); + + abstract Builder setStagingSessionToken(String stagingSessionToken); + + abstract Builder setRetrievalService(ArtifactRetrievalService retrievalService); + + abstract Optional getRetrievalService(); + + abstract Builder setManagedChannel(ManagedChannel managedChannel); + + abstract Optional getManagedChannel(); + + abstract Builder setStagingServiceStub( + ArtifactStagingServiceGrpc.ArtifactStagingServiceStub stub); + + abstract Optional + getStagingServiceStub(); + + abstract PrismArtifactStager autoBuild(); + + final PrismArtifactStager build() { + + checkState(getStagingEndpoint().isPresent(), "missing staging endpoint"); + ManagedChannelFactory channelFactory = ManagedChannelFactory.createDefault(); + + if (!getManagedChannel().isPresent()) { + Endpoints.ApiServiceDescriptor descriptor = + Endpoints.ApiServiceDescriptor.newBuilder().setUrl(getStagingEndpoint().get()).build(); + setManagedChannel(channelFactory.forDescriptor(descriptor)); + } + + if (!getStagingServiceStub().isPresent()) { + setStagingServiceStub(ArtifactStagingServiceGrpc.newStub(getManagedChannel().get())); + } + + if (!getRetrievalService().isPresent()) { + setRetrievalService(new ArtifactRetrievalService()); + } + + return autoBuild(); + } + } +} diff --git a/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismArtifactStagerTest.java b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismArtifactStagerTest.java new file mode 100644 index 000000000000..d3ac8a72eafb --- /dev/null +++ b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismArtifactStagerTest.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.prism; + +import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.runners.fnexecution.artifact.ArtifactRetrievalService.EMBEDDED_ARTIFACT_URN; +import static org.junit.Assert.assertThrows; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.runners.fnexecution.artifact.ArtifactRetrievalService; +import org.apache.beam.runners.fnexecution.artifact.ArtifactStagingService; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.commons.io.output.ByteArrayOutputStream; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link PrismArtifactStager}. */ +@RunWith(JUnit4.class) +public class PrismArtifactStagerTest { + + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + + final ArtifactStagingService stagingService = + new ArtifactStagingService(new TestDestinationProvider()); + + @Test + public void givenValidArtifacts_stages() + throws IOException, ExecutionException, InterruptedException { + PrismArtifactStager underTest = prismArtifactStager(validArtifacts()); + assertThat(underTest.getManagedChannel().isShutdown()).isFalse(); + underTest.stage(); + assertThat(stagingService.getStagedArtifacts(underTest.getStagingSessionToken())).isNotEmpty(); + underTest.close(); + assertThat(underTest.getManagedChannel().isShutdown()).isTrue(); + } + + @Test + public void givenErrors_performsGracefulCleanup() throws IOException { + PrismArtifactStager underTest = prismArtifactStager(invalidArtifacts()); + assertThat(underTest.getManagedChannel().isShutdown()).isFalse(); + ExecutionException error = assertThrows(ExecutionException.class, underTest::stage); + assertThat(error.getMessage()).contains("Unexpected artifact type: invalid-type-urn"); + assertThat(underTest.getManagedChannel().isShutdown()).isFalse(); + underTest.close(); + assertThat(underTest.getManagedChannel().isShutdown()).isTrue(); + } + + private PrismArtifactStager prismArtifactStager( + Map> artifacts) throws IOException { + String serverName = InProcessServerBuilder.generateName(); + ArtifactRetrievalService retrievalService = new ArtifactRetrievalService(); + String stagingToken = "staging-token"; + stagingService.registerJob(stagingToken, artifacts); + + grpcCleanup.register( + InProcessServerBuilder.forName(serverName) + .directExecutor() + .addService(stagingService) + .addService(retrievalService) + .build() + .start()); + + ManagedChannel channel = + grpcCleanup.register(InProcessChannelBuilder.forName(serverName).build()); + + return PrismArtifactStager.builder() + .setStagingEndpoint("ignore") + .setStagingSessionToken(stagingToken) + .setManagedChannel(channel) + .build(); + } + + private Map> validArtifacts() { + return ImmutableMap.of( + "env1", + Collections.singletonList( + RunnerApi.ArtifactInformation.newBuilder() + .setTypeUrn(EMBEDDED_ARTIFACT_URN) + .setTypePayload( + RunnerApi.EmbeddedFilePayload.newBuilder() + .setData(ByteString.copyFromUtf8("type-payload")) + .build() + .toByteString()) + .setRoleUrn("role-urn") + .build())); + } + + private Map> invalidArtifacts() { + return ImmutableMap.of( + "env1", + Collections.singletonList( + RunnerApi.ArtifactInformation.newBuilder() + .setTypeUrn("invalid-type-urn") + .setTypePayload( + RunnerApi.EmbeddedFilePayload.newBuilder() + .setData(ByteString.copyFromUtf8("type-payload")) + .build() + .toByteString()) + .setRoleUrn("role-urn") + .build())); + } + + private static class TestDestinationProvider + implements ArtifactStagingService.ArtifactDestinationProvider { + + @Override + public ArtifactStagingService.ArtifactDestination getDestination( + String stagingToken, String name) throws IOException { + return ArtifactStagingService.ArtifactDestination.create( + EMBEDDED_ARTIFACT_URN, ByteString.EMPTY, new ByteArrayOutputStream()); + } + + @Override + public void removeStagedArtifacts(String stagingToken) throws IOException {} + } +}