diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkNativePipelineVisitor.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkNativePipelineVisitor.java index 54b3b8f46770..143eab8be85f 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkNativePipelineVisitor.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkNativePipelineVisitor.java @@ -183,7 +183,7 @@ private String replaceFnString( String doFnName; Class enclosingClass = fnClass.getEnclosingClass(); if (enclosingClass != null && enclosingClass.equals(MapElements.class)) { - Field parent = fnClass.getDeclaredField("this$0"); + Field parent = fnClass.getSuperclass().getDeclaredField("outer"); parent.setAccessible(true); Field fnField = enclosingClass.getDeclaredField(fnFieldName); fnField.setAccessible(true); diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java index ef3e9589ba2e..6e4acc6fb930 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java @@ -17,13 +17,17 @@ */ package org.apache.beam.runners.spark; +import static org.apache.beam.sdk.transforms.Contextful.fn; +import static org.apache.beam.sdk.transforms.Requirements.requiresSideInputs; import static org.hamcrest.MatcherAssert.assertThat; +import java.util.Arrays; import java.util.Collections; import org.apache.beam.runners.spark.examples.WordCount; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.io.TextIO; import org.apache.beam.sdk.io.kafka.KafkaIO; import org.apache.beam.sdk.options.PipelineOptions; @@ -39,12 +43,15 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.SimpleFunction; import org.apache.beam.sdk.transforms.Sum; +import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.WithKeys; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionList; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.kafka.common.serialization.StringDeserializer; import org.apache.kafka.common.serialization.StringSerializer; import org.hamcrest.Matchers; @@ -160,6 +167,40 @@ public void debugStreamingPipeline() { Matchers.equalTo(expectedPipeline)); } + @Test + public void debugBatchPipelineWithContextfulTransform() { + PipelineOptions options = contextRule.configure(PipelineOptionsFactory.create()); + options.setRunner(SparkRunnerDebugger.class); + Pipeline pipeline = Pipeline.create(options); + + final PCollectionView view = + pipeline.apply("Dummy", Create.of(0)).apply(View.asSingleton()); + + pipeline + .apply(Create.of(Arrays.asList(0))) + .setCoder(VarIntCoder.of()) + .apply( + MapElements.into(new TypeDescriptor() {}) + .via(fn((element, c) -> element, requiresSideInputs(view)))); + + SparkRunnerDebugger.DebugSparkPipelineResult result = + (SparkRunnerDebugger.DebugSparkPipelineResult) pipeline.run(); + + final String expectedPipeline = + "sparkContext.()\n" + + "_.mapPartitions(new org.apache.beam.sdk.transforms.Create$Values$2())\n" + + "_.aggregate(..., new org.apache.beam.sdk.transforms.View$SingletonCombineFn(), ...)\n" + + "_.\n" + + "sparkContext.()\n" + + "_.mapPartitions(new org.apache.beam.sdk.transforms.Create$Values$2())\n" + + "_.mapPartitions(new org.apache.beam.sdk.transforms.Contextful())"; + + assertThat( + "Debug pipeline did not equal expected", + result.getDebugString(), + Matchers.equalTo(expectedPipeline)); + } + private static class FormatKVFn extends DoFn, String> { @SuppressWarnings("unused") @ProcessElement diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java index 7210e82e0ca1..6b123d3bd106 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java @@ -159,6 +159,10 @@ public void processElement( /** A DoFn implementation that handles a trivial map call. */ private abstract class MapDoFn extends DoFn { + + /** Holds {@link MapDoFn#outer instance} of enclosing class, used by runner implementations. */ + final MapElements outer = MapElements.this; + @Override public void populateDisplayData(DisplayData.Builder builder) { builder.delegate(MapElements.this);