Skip to content

Commit

Permalink
Max Parallelism
Browse files Browse the repository at this point in the history
Signed-off-by: Rafael Raposo <[email protected]>
  • Loading branch information
RRap0so committed Jun 26, 2024
1 parent de6417e commit a959945
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 2 deletions.
8 changes: 8 additions & 0 deletions flytekit-api/src/main/java/org/flyte/api/v1/LaunchPlan.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.google.auto.value.AutoValue;
import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import javax.annotation.Nullable;

/** User-provided launch plan definition and configuration values. */
Expand All @@ -40,6 +41,11 @@ public abstract class LaunchPlan {
*/
public abstract Map<String, Parameter> defaultInputs();

/**
* Controls the maximum number of tasknodes that can be run in parallel for the entire workflow
*/
public abstract Optional<Integer> maxParallelism();

@Nullable
public abstract CronSchedule cronSchedule();

Expand All @@ -64,6 +70,8 @@ public abstract static class Builder {

public abstract Builder cronSchedule(CronSchedule cronSchedule);

public abstract Builder maxParallelism(Optional<Integer> maxParallelism);

public abstract LaunchPlan build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import javax.annotation.Nullable;
import org.flyte.api.v1.Literal;
Expand Down Expand Up @@ -81,6 +82,10 @@ public abstract class SdkLaunchPlan {
@Nullable
public abstract SdkCronSchedule cronSchedule();

/** Returns the max parallelism of the launch plan. */
@Nullable
public abstract Optional<Integer> maxParallelism();

/**
* Creates a launch plan for specified {@link SdkLaunchPlan} with default naming, no inputs and no
* schedule. The default launch plan name is {@link SdkWorkflow#getName()}. New name, inputs and
Expand Down Expand Up @@ -322,6 +327,19 @@ public <T> SdkLaunchPlan withDefaultInput(SdkType<T> type, T value) {
v -> createParameter(v.getValue(), literalMap.get(v.getKey())))));
}

/**
*
* @param maxParallelism Optional Integer for the max parallelism (cannot be negative).
* Default Value: Empty, it will default to what's set in the Flyte Platform.
* 0: It will try to use as much as allowed.
* @return the new launch plan
*/
public SdkLaunchPlan withMaxParallelism(Optional<Integer> maxParallelism) {
return withMaxParallelism0(maxParallelism);
}



private SdkLaunchPlan withDefaultInputs0(Map<String, Parameter> newDefaultInputs) {

verifyNonEmptyWorkflowInput(newDefaultInputs, "default");
Expand All @@ -336,6 +354,17 @@ private SdkLaunchPlan withDefaultInputs0(Map<String, Parameter> newDefaultInputs
return toBuilder().defaultInputs(newCompleteDefaultInputs).build();
}

private SdkLaunchPlan withMaxParallelism0(Optional<Integer> maxParallelism) {
if (maxParallelism.isPresent() && maxParallelism.get() < 0) {
String message =
String.format(
"invalid max parallelism %s, expected a positive integer", maxParallelism.get());
throw new IllegalArgumentException(message);
}

return toBuilder().maxParallelism(maxParallelism).build();
}

private <T> Map<String, T> mergeInputs(
Map<String, T> oldInputs, Map<String, T> newInputs, String inputType) {
Map<String, T> newCompleteInputs = new LinkedHashMap<>(oldInputs);
Expand Down Expand Up @@ -384,11 +413,14 @@ private void verifyMatchedInput(Map<String, LiteralType> newInputTypes, String i
}
}



static Builder builder() {
return new AutoValue_SdkLaunchPlan.Builder()
.fixedInputs(Collections.emptyMap())
.defaultInputs(Collections.emptyMap())
.workflowInputTypeMap(Collections.emptyMap());
.workflowInputTypeMap(Collections.emptyMap())
.maxParallelism(Optional.empty());
}

abstract Builder toBuilder();
Expand All @@ -414,6 +446,8 @@ abstract static class Builder {

abstract Builder workflowInputTypeMap(Map<String, LiteralType> workflowInputTypeMap);

abstract Builder maxParallelism(Optional<Integer> maxParallelism);

abstract SdkLaunchPlan build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ Map<LaunchPlanIdentifier, LaunchPlan> load(
.name(sdkLaunchPlan.name())
.workflowId(getWorkflowIdentifier(sdkLaunchPlan))
.fixedInputs(sdkLaunchPlan.fixedInputs())
.defaultInputs(sdkLaunchPlan.defaultInputs());
.defaultInputs(sdkLaunchPlan.defaultInputs())
.maxParallelism(sdkLaunchPlan.maxParallelism());

if (sdkLaunchPlan.cronSchedule() != null) {
builder.cronSchedule(getCronSchedule(sdkLaunchPlan.cronSchedule()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.flyte.api.v1.CronSchedule;
import org.flyte.api.v1.LaunchPlan;
import org.flyte.api.v1.LaunchPlanIdentifier;
Expand Down Expand Up @@ -137,6 +138,34 @@ void shouldTestLaunchPlansWithCronSchedule() {
hasEntry(expectedIdentifierWithOffset, planWithOffset)));
}

@Test
void shouldTestLaunchPlansWithMaxParallelism() {
Map<LaunchPlanIdentifier, LaunchPlan> launchPlans =
registrar.load(ENV, singletonList(new TestRegistryWithMaxParallelism()));

LaunchPlanIdentifier expectedIdentifierWithOffset =
LaunchPlanIdentifier.builder()
.project("project")
.domain("domain")
.name("TestPlanScheduleWithMaxParallelism")
.version("version")
.build();

LaunchPlan planWithOffset =
LaunchPlan.builder()
.name("TestPlanScheduleWithMaxParallelism")
.workflowId(
PartialWorkflowIdentifier.builder()
.name("org.flyte.flytekit.SdkLaunchPlanRegistrarTest$TestWorkflow")
.build())
.fixedInputs(Collections.emptyMap())
.defaultInputs(Collections.emptyMap())
.maxParallelism(Optional.of(10))
.build();

assertThat(launchPlans, allOf(hasEntry(expectedIdentifierWithOffset, planWithOffset)));
}

@Test
void shouldRejectLoadingLaunchPlanDuplicatesInSameRegistry() {
IllegalArgumentException exception =
Expand Down Expand Up @@ -208,6 +237,17 @@ public List<SdkLaunchPlan> getLaunchPlans() {
}
}

public static class TestRegistryWithMaxParallelism implements SdkLaunchPlanRegistry {

@Override
public List<SdkLaunchPlan> getLaunchPlans() {
return Arrays.asList(
SdkLaunchPlan.of(new TestWorkflow())
.withName("TestPlanScheduleWithMaxParallelism")
.withMaxParallelism(Optional.of(10)));
}
}

public static class TestWorkflow extends SdkWorkflow<Void, Void> {

public TestWorkflow() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.time.Duration;
import java.time.Instant;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.stream.Stream;
import org.flyte.api.v1.Literal;
Expand Down Expand Up @@ -91,6 +92,14 @@ void shouldCreateLaunchPlanWithCronSchedule() {
assertThat(plan.cronSchedule().offset(), equalTo(Duration.ofHours(1)));
}

@Test
void shouldCreateLaunchPlanWithMaxParallelism() {
SdkLaunchPlan plan = SdkLaunchPlan.of(new TestWorkflow()).withMaxParallelism(Optional.of(123));

assertThat(plan.maxParallelism(), notNullValue());
assertThat(plan.maxParallelism().get(), equalTo(123));
}

@Test
void shouldAddFixedInputs() {
Instant now = Instant.now();
Expand Down

0 comments on commit a959945

Please sign in to comment.