diff --git a/flytekit-api/src/main/java/org/flyte/api/v1/LaunchPlan.java b/flytekit-api/src/main/java/org/flyte/api/v1/LaunchPlan.java index 6a8118c6..3e6b7cfa 100644 --- a/flytekit-api/src/main/java/org/flyte/api/v1/LaunchPlan.java +++ b/flytekit-api/src/main/java/org/flyte/api/v1/LaunchPlan.java @@ -19,6 +19,7 @@ import com.google.auto.value.AutoValue; import java.util.Collections; import java.util.Map; +import java.util.Optional; import javax.annotation.Nullable; /** User-provided launch plan definition and configuration values. */ @@ -40,6 +41,11 @@ public abstract class LaunchPlan { */ public abstract Map defaultInputs(); + /** + * Controls the maximum number of tasknodes that can be run in parallel for the entire workflow + */ + public abstract Optional maxParallelism(); + @Nullable public abstract CronSchedule cronSchedule(); @@ -64,6 +70,8 @@ public abstract static class Builder { public abstract Builder cronSchedule(CronSchedule cronSchedule); + public abstract Builder maxParallelism(Optional maxParallelism); + public abstract LaunchPlan build(); } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlan.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlan.java index 7e234ffa..eedde0c2 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlan.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlan.java @@ -26,6 +26,7 @@ import java.util.Collections; import java.util.LinkedHashMap; import java.util.Map; +import java.util.Optional; import java.util.function.Function; import javax.annotation.Nullable; import org.flyte.api.v1.Literal; @@ -81,6 +82,10 @@ public abstract class SdkLaunchPlan { @Nullable public abstract SdkCronSchedule cronSchedule(); + /** Returns the max parallelism of the launch plan. */ + @Nullable + public abstract Optional maxParallelism(); + /** * Creates a launch plan for specified {@link SdkLaunchPlan} with default naming, no inputs and no * schedule. The default launch plan name is {@link SdkWorkflow#getName()}. New name, inputs and @@ -322,6 +327,19 @@ public SdkLaunchPlan withDefaultInput(SdkType type, T value) { v -> createParameter(v.getValue(), literalMap.get(v.getKey()))))); } + /** + * + * @param maxParallelism Optional Integer for the max parallelism (cannot be negative). + * Default Value: Empty, it will default to what's set in the Flyte Platform. + * 0: It will try to use as much as allowed. + * @return the new launch plan + */ + public SdkLaunchPlan withMaxParallelism(Optional maxParallelism) { + return withMaxParallelism0(maxParallelism); + } + + + private SdkLaunchPlan withDefaultInputs0(Map newDefaultInputs) { verifyNonEmptyWorkflowInput(newDefaultInputs, "default"); @@ -336,6 +354,17 @@ private SdkLaunchPlan withDefaultInputs0(Map newDefaultInputs return toBuilder().defaultInputs(newCompleteDefaultInputs).build(); } + private SdkLaunchPlan withMaxParallelism0(Optional maxParallelism) { + if (maxParallelism.isPresent() && maxParallelism.get() < 0) { + String message = + String.format( + "invalid max parallelism %s, expected a positive integer", maxParallelism.get()); + throw new IllegalArgumentException(message); + } + + return toBuilder().maxParallelism(maxParallelism).build(); + } + private Map mergeInputs( Map oldInputs, Map newInputs, String inputType) { Map newCompleteInputs = new LinkedHashMap<>(oldInputs); @@ -384,11 +413,14 @@ private void verifyMatchedInput(Map newInputTypes, String i } } + + static Builder builder() { return new AutoValue_SdkLaunchPlan.Builder() .fixedInputs(Collections.emptyMap()) .defaultInputs(Collections.emptyMap()) - .workflowInputTypeMap(Collections.emptyMap()); + .workflowInputTypeMap(Collections.emptyMap()) + .maxParallelism(Optional.empty()); } abstract Builder toBuilder(); @@ -414,6 +446,8 @@ abstract static class Builder { abstract Builder workflowInputTypeMap(Map workflowInputTypeMap); + abstract Builder maxParallelism(Optional maxParallelism); + abstract SdkLaunchPlan build(); } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlanRegistrar.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlanRegistrar.java index 9913b18d..ac0f121c 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlanRegistrar.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlanRegistrar.java @@ -96,7 +96,8 @@ Map load( .name(sdkLaunchPlan.name()) .workflowId(getWorkflowIdentifier(sdkLaunchPlan)) .fixedInputs(sdkLaunchPlan.fixedInputs()) - .defaultInputs(sdkLaunchPlan.defaultInputs()); + .defaultInputs(sdkLaunchPlan.defaultInputs()) + .maxParallelism(sdkLaunchPlan.maxParallelism()); if (sdkLaunchPlan.cronSchedule() != null) { builder.cronSchedule(getCronSchedule(sdkLaunchPlan.cronSchedule())); diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanRegistrarTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanRegistrarTest.java index 769b5d53..ef1498db 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanRegistrarTest.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanRegistrarTest.java @@ -35,6 +35,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Optional; import org.flyte.api.v1.CronSchedule; import org.flyte.api.v1.LaunchPlan; import org.flyte.api.v1.LaunchPlanIdentifier; @@ -137,6 +138,34 @@ void shouldTestLaunchPlansWithCronSchedule() { hasEntry(expectedIdentifierWithOffset, planWithOffset))); } + @Test + void shouldTestLaunchPlansWithMaxParallelism() { + Map launchPlans = + registrar.load(ENV, singletonList(new TestRegistryWithMaxParallelism())); + + LaunchPlanIdentifier expectedIdentifierWithOffset = + LaunchPlanIdentifier.builder() + .project("project") + .domain("domain") + .name("TestPlanScheduleWithMaxParallelism") + .version("version") + .build(); + + LaunchPlan planWithOffset = + LaunchPlan.builder() + .name("TestPlanScheduleWithMaxParallelism") + .workflowId( + PartialWorkflowIdentifier.builder() + .name("org.flyte.flytekit.SdkLaunchPlanRegistrarTest$TestWorkflow") + .build()) + .fixedInputs(Collections.emptyMap()) + .defaultInputs(Collections.emptyMap()) + .maxParallelism(Optional.of(10)) + .build(); + + assertThat(launchPlans, allOf(hasEntry(expectedIdentifierWithOffset, planWithOffset))); + } + @Test void shouldRejectLoadingLaunchPlanDuplicatesInSameRegistry() { IllegalArgumentException exception = @@ -208,6 +237,17 @@ public List getLaunchPlans() { } } + public static class TestRegistryWithMaxParallelism implements SdkLaunchPlanRegistry { + + @Override + public List getLaunchPlans() { + return Arrays.asList( + SdkLaunchPlan.of(new TestWorkflow()) + .withName("TestPlanScheduleWithMaxParallelism") + .withMaxParallelism(Optional.of(10))); + } + } + public static class TestWorkflow extends SdkWorkflow { public TestWorkflow() { diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanTest.java index bace2a7f..815c36bc 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanTest.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanTest.java @@ -31,6 +31,7 @@ import java.time.Duration; import java.time.Instant; import java.util.Map; +import java.util.Optional; import java.util.function.Consumer; import java.util.stream.Stream; import org.flyte.api.v1.Literal; @@ -91,6 +92,14 @@ void shouldCreateLaunchPlanWithCronSchedule() { assertThat(plan.cronSchedule().offset(), equalTo(Duration.ofHours(1))); } + @Test + void shouldCreateLaunchPlanWithMaxParallelism() { + SdkLaunchPlan plan = SdkLaunchPlan.of(new TestWorkflow()).withMaxParallelism(Optional.of(123)); + + assertThat(plan.maxParallelism(), notNullValue()); + assertThat(plan.maxParallelism().get(), equalTo(123)); + } + @Test void shouldAddFixedInputs() { Instant now = Instant.now();