Skip to content

Commit

Permalink
[YUNIKORN-2465] Remove Task objects from the shim upon pod completion (
Browse files Browse the repository at this point in the history
…apache#796)

Closes: apache#796

Signed-off-by: Craig Condit <[email protected]>
  • Loading branch information
pbacsko authored and craigcondit committed Mar 5, 2024
1 parent 288661c commit 5fff0bc
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 3 deletions.
18 changes: 17 additions & 1 deletion pkg/cache/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,13 @@ func (app *Application) addTask(task *Task) {
app.taskMap[task.taskID] = task
}

func (app *Application) removeTask(taskID string) {
func (app *Application) RemoveTask(taskID string) {
app.lock.Lock()
defer app.lock.Unlock()
app.removeTask(taskID)
}

func (app *Application) removeTask(taskID string) {
if _, ok := app.taskMap[taskID]; !ok {
log.Log(log.ShimCacheApplication).Debug("Attempted to remove non-existent task", zap.String("taskID", taskID))
return
Expand Down Expand Up @@ -363,6 +367,7 @@ func (app *Application) Schedule() bool {
app.scheduleTasks(func(t *Task) bool {
return t.placeholder
})
app.removeCompletedTasks()
if len(app.GetNewTasks()) == 0 {
return false
}
Expand All @@ -372,6 +377,7 @@ func (app *Application) Schedule() bool {
app.scheduleTasks(func(t *Task) bool {
return !t.placeholder
})
app.removeCompletedTasks()
if len(app.GetNewTasks()) == 0 {
return false
}
Expand Down Expand Up @@ -663,3 +669,13 @@ func (app *Application) SetPlaceholderTimeout(timeout int64) {
defer app.lock.Unlock()
app.placeholderTimeoutInSec = timeout
}

func (app *Application) removeCompletedTasks() {
app.lock.Lock()
defer app.lock.Unlock()
for _, task := range app.taskMap {
if task.isTerminated() {
app.removeTask(task.taskID)
}
}
}
23 changes: 23 additions & 0 deletions pkg/cache/application_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,29 @@ func TestApplication_onReservationStateChange(t *testing.T) {
assertAppState(t, app, ApplicationStates().Running, 1*time.Second)
}

func TestTaskRemoval(t *testing.T) {
app := NewApplication(appID, "root.a", "testuser", testGroups, map[string]string{}, newMockSchedulerAPI())
context := initContextForTest()

// "Reserving" state
app.SetState(ApplicationStates().Reserving)
task := NewTask("task0001", app, context, &v1.Pod{})
phTask := NewTaskPlaceholder("ph-task0001", app, context, &v1.Pod{})
app.addTask(task)
task.sm.SetState(TaskStates().Completed)
phTask.sm.SetState(TaskStates().Completed)
app.Schedule()
assert.Equal(t, 0, len(app.getTasks(TaskStates().Completed)))

// "Running" state
app.SetState(ApplicationStates().Running)
task = NewTask("task0002", app, context, &v1.Pod{})
app.addTask(task)
task.sm.SetState(TaskStates().Completed)
app.Schedule()
assert.Equal(t, 0, len(app.getTasks(TaskStates().Completed)))
}

func (ctx *Context) addApplicationToContext(app *Application) {
ctx.lock.Lock()
defer ctx.lock.Unlock()
Expand Down
3 changes: 1 addition & 2 deletions pkg/cache/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,6 @@ func (ctx *Context) DeletePod(obj interface{}) {
func (ctx *Context) deleteYuniKornPod(pod *v1.Pod) {
ctx.lock.Lock()
defer ctx.lock.Unlock()

if taskMeta, ok := getTaskMetadata(pod); ok {
if app := ctx.getApplication(taskMeta.ApplicationID); app != nil {
ctx.notifyTaskComplete(taskMeta.ApplicationID, taskMeta.TaskID)
Expand Down Expand Up @@ -1150,7 +1149,7 @@ func (ctx *Context) RemoveTask(appID, taskID string) {
log.Log(log.ShimContext).Debug("Attempted to remove task from non-existent application", zap.String("appID", appID))
return
}
app.removeTask(taskID)
app.RemoveTask(taskID)
}

func (ctx *Context) getTask(appID string, taskID string) *Task {
Expand Down
54 changes: 54 additions & 0 deletions pkg/cache/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2084,6 +2084,60 @@ func TestInitializeState(t *testing.T) {
assert.Assert(t, task3 == nil, "pod3 was found")
}

func TestTaskRemoveOnCompletion(t *testing.T) {
context := initContextForTest()
dispatcher.Start()
dispatcher.RegisterEventHandler("TestAppHandler", dispatcher.EventTypeApp, context.ApplicationEventHandler())
dispatcher.RegisterEventHandler("TestTaskHandler", dispatcher.EventTypeTask, context.TaskEventHandler())
defer dispatcher.UnregisterAllEventHandlers()
defer dispatcher.Stop()

const (
pod1UID = "task00001"
taskUID1 = "task00001"
pod1Name = "my-pod-1"
fakeNodeName = "fake-node"
)

app := context.AddApplication(&AddApplicationRequest{
Metadata: ApplicationMetadata{
ApplicationID: appID,
QueueName: queue,
User: "test-user",
Tags: nil,
},
})

task := context.AddTask(&AddTaskRequest{
Metadata: TaskMetadata{
ApplicationID: appID,
TaskID: pod1UID,
Pod: newPodHelper(pod1Name, namespace, pod1UID, fakeNodeName, appID, v1.PodRunning),
},
})

// task gets scheduled
app.SetState("Running")
app.Schedule()
err := utils.WaitForCondition(func() bool {
return task.GetTaskState() == TaskStates().Scheduling
}, 100*time.Millisecond, time.Second)
assert.NilError(t, err)

// mark completion
context.NotifyTaskComplete(appID, taskUID1)
err = utils.WaitForCondition(func() bool {
return task.GetTaskState() == TaskStates().Completed
}, 100*time.Millisecond, time.Second)
assert.NilError(t, err)

// check removal
app.Schedule()
appTask, err := app.GetTask(taskUID1)
assert.Assert(t, appTask == nil)
assert.Error(t, err, "task task00001 doesn't exist in application app01")
}

func waitForNodeAcceptedEvent(recorder *k8sEvents.FakeRecorder) error {
// fetch the "node accepted" event
err := utils.WaitForCondition(func() bool {
Expand Down
18 changes: 18 additions & 0 deletions pkg/shim/scheduler_mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,24 @@ func (fc *MockScheduler) GetActiveNodeCountInCore(partition string) int {
return len(coreNodes)
}

func (fc *MockScheduler) waitForApplicationStateInCore(appID, partition, expectedState string) error {
return utils.WaitForCondition(func() bool {
app := fc.coreContext.Scheduler.GetClusterContext().GetApplication(appID, partition)
if app == nil {
log.Log(log.Test).Info("Application not found in the scheduler core", zap.String("appID", appID))
return false
}
current := app.CurrentState()
if current != expectedState {
log.Log(log.Test).Info("waiting for app state in core",
zap.String("expected", expectedState),
zap.String("actual", current))
return false
}
return true
}, time.Second, 5*time.Second)
}

func (fc *MockScheduler) GetPodBindStats() client.BindStats {
return fc.apiProvider.GetPodBindStats()
}
Expand Down
13 changes: 13 additions & 0 deletions pkg/shim/scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,19 @@ partitions:
cluster.waitAndAssertApplicationState(t, "app0001", cache.ApplicationStates().Running)
cluster.waitAndAssertTaskState(t, "app0001", "task0001", cache.TaskStates().Bound)
cluster.waitAndAssertTaskState(t, "app0001", "task0002", cache.TaskStates().Bound)

// complete pods
task1Upd := task1.DeepCopy()
task1Upd.Status.Phase = v1.PodSucceeded
cluster.UpdatePod(task1, task1Upd)
cluster.waitAndAssertTaskState(t, "app0001", "task0001", cache.TaskStates().Completed)
cluster.waitAndAssertApplicationState(t, "app0001", cache.ApplicationStates().Running)
task2Upd := task2.DeepCopy()
task2Upd.Status.Phase = v1.PodSucceeded
cluster.UpdatePod(task2, task2Upd)
cluster.waitAndAssertTaskState(t, "app0001", "task0002", cache.TaskStates().Completed)
err = cluster.waitForApplicationStateInCore("app0001", partitionName, "Completing")
assert.NilError(t, err)
}

func TestRejectApplications(t *testing.T) {
Expand Down

0 comments on commit 5fff0bc

Please sign in to comment.