From 5fff0bc8e9ec76f62f797c91c997c7c6bea8a6f3 Mon Sep 17 00:00:00 2001 From: Peter Bacsko Date: Tue, 5 Mar 2024 09:52:19 -0600 Subject: [PATCH] [YUNIKORN-2465] Remove Task objects from the shim upon pod completion (#796) Closes: #796 Signed-off-by: Craig Condit --- pkg/cache/application.go | 18 ++++++++++- pkg/cache/application_test.go | 23 ++++++++++++++ pkg/cache/context.go | 3 +- pkg/cache/context_test.go | 54 +++++++++++++++++++++++++++++++++ pkg/shim/scheduler_mock_test.go | 18 +++++++++++ pkg/shim/scheduler_test.go | 13 ++++++++ 6 files changed, 126 insertions(+), 3 deletions(-) diff --git a/pkg/cache/application.go b/pkg/cache/application.go index bb5bba2b2..ffead77ad 100644 --- a/pkg/cache/application.go +++ b/pkg/cache/application.go @@ -228,9 +228,13 @@ func (app *Application) addTask(task *Task) { app.taskMap[task.taskID] = task } -func (app *Application) removeTask(taskID string) { +func (app *Application) RemoveTask(taskID string) { app.lock.Lock() defer app.lock.Unlock() + app.removeTask(taskID) +} + +func (app *Application) removeTask(taskID string) { if _, ok := app.taskMap[taskID]; !ok { log.Log(log.ShimCacheApplication).Debug("Attempted to remove non-existent task", zap.String("taskID", taskID)) return @@ -363,6 +367,7 @@ func (app *Application) Schedule() bool { app.scheduleTasks(func(t *Task) bool { return t.placeholder }) + app.removeCompletedTasks() if len(app.GetNewTasks()) == 0 { return false } @@ -372,6 +377,7 @@ func (app *Application) Schedule() bool { app.scheduleTasks(func(t *Task) bool { return !t.placeholder }) + app.removeCompletedTasks() if len(app.GetNewTasks()) == 0 { return false } @@ -663,3 +669,13 @@ func (app *Application) SetPlaceholderTimeout(timeout int64) { defer app.lock.Unlock() app.placeholderTimeoutInSec = timeout } + +func (app *Application) removeCompletedTasks() { + app.lock.Lock() + defer app.lock.Unlock() + for _, task := range app.taskMap { + if task.isTerminated() { + app.removeTask(task.taskID) + } + } +} diff --git a/pkg/cache/application_test.go b/pkg/cache/application_test.go index 03a572fce..edd3846f4 100644 --- a/pkg/cache/application_test.go +++ b/pkg/cache/application_test.go @@ -1288,6 +1288,29 @@ func TestApplication_onReservationStateChange(t *testing.T) { assertAppState(t, app, ApplicationStates().Running, 1*time.Second) } +func TestTaskRemoval(t *testing.T) { + app := NewApplication(appID, "root.a", "testuser", testGroups, map[string]string{}, newMockSchedulerAPI()) + context := initContextForTest() + + // "Reserving" state + app.SetState(ApplicationStates().Reserving) + task := NewTask("task0001", app, context, &v1.Pod{}) + phTask := NewTaskPlaceholder("ph-task0001", app, context, &v1.Pod{}) + app.addTask(task) + task.sm.SetState(TaskStates().Completed) + phTask.sm.SetState(TaskStates().Completed) + app.Schedule() + assert.Equal(t, 0, len(app.getTasks(TaskStates().Completed))) + + // "Running" state + app.SetState(ApplicationStates().Running) + task = NewTask("task0002", app, context, &v1.Pod{}) + app.addTask(task) + task.sm.SetState(TaskStates().Completed) + app.Schedule() + assert.Equal(t, 0, len(app.getTasks(TaskStates().Completed))) +} + func (ctx *Context) addApplicationToContext(app *Application) { ctx.lock.Lock() defer ctx.lock.Unlock() diff --git a/pkg/cache/context.go b/pkg/cache/context.go index c9f33a605..46a230016 100644 --- a/pkg/cache/context.go +++ b/pkg/cache/context.go @@ -428,7 +428,6 @@ func (ctx *Context) DeletePod(obj interface{}) { func (ctx *Context) deleteYuniKornPod(pod *v1.Pod) { ctx.lock.Lock() defer ctx.lock.Unlock() - if taskMeta, ok := getTaskMetadata(pod); ok { if app := ctx.getApplication(taskMeta.ApplicationID); app != nil { ctx.notifyTaskComplete(taskMeta.ApplicationID, taskMeta.TaskID) @@ -1150,7 +1149,7 @@ func (ctx *Context) RemoveTask(appID, taskID string) { log.Log(log.ShimContext).Debug("Attempted to remove task from non-existent application", zap.String("appID", appID)) return } - app.removeTask(taskID) + app.RemoveTask(taskID) } func (ctx *Context) getTask(appID string, taskID string) *Task { diff --git a/pkg/cache/context_test.go b/pkg/cache/context_test.go index dab395e3d..cd3a23fb6 100644 --- a/pkg/cache/context_test.go +++ b/pkg/cache/context_test.go @@ -2084,6 +2084,60 @@ func TestInitializeState(t *testing.T) { assert.Assert(t, task3 == nil, "pod3 was found") } +func TestTaskRemoveOnCompletion(t *testing.T) { + context := initContextForTest() + dispatcher.Start() + dispatcher.RegisterEventHandler("TestAppHandler", dispatcher.EventTypeApp, context.ApplicationEventHandler()) + dispatcher.RegisterEventHandler("TestTaskHandler", dispatcher.EventTypeTask, context.TaskEventHandler()) + defer dispatcher.UnregisterAllEventHandlers() + defer dispatcher.Stop() + + const ( + pod1UID = "task00001" + taskUID1 = "task00001" + pod1Name = "my-pod-1" + fakeNodeName = "fake-node" + ) + + app := context.AddApplication(&AddApplicationRequest{ + Metadata: ApplicationMetadata{ + ApplicationID: appID, + QueueName: queue, + User: "test-user", + Tags: nil, + }, + }) + + task := context.AddTask(&AddTaskRequest{ + Metadata: TaskMetadata{ + ApplicationID: appID, + TaskID: pod1UID, + Pod: newPodHelper(pod1Name, namespace, pod1UID, fakeNodeName, appID, v1.PodRunning), + }, + }) + + // task gets scheduled + app.SetState("Running") + app.Schedule() + err := utils.WaitForCondition(func() bool { + return task.GetTaskState() == TaskStates().Scheduling + }, 100*time.Millisecond, time.Second) + assert.NilError(t, err) + + // mark completion + context.NotifyTaskComplete(appID, taskUID1) + err = utils.WaitForCondition(func() bool { + return task.GetTaskState() == TaskStates().Completed + }, 100*time.Millisecond, time.Second) + assert.NilError(t, err) + + // check removal + app.Schedule() + appTask, err := app.GetTask(taskUID1) + assert.Assert(t, appTask == nil) + assert.Error(t, err, "task task00001 doesn't exist in application app01") +} + func waitForNodeAcceptedEvent(recorder *k8sEvents.FakeRecorder) error { // fetch the "node accepted" event err := utils.WaitForCondition(func() bool { diff --git a/pkg/shim/scheduler_mock_test.go b/pkg/shim/scheduler_mock_test.go index 51253b1dc..dd12a93de 100644 --- a/pkg/shim/scheduler_mock_test.go +++ b/pkg/shim/scheduler_mock_test.go @@ -304,6 +304,24 @@ func (fc *MockScheduler) GetActiveNodeCountInCore(partition string) int { return len(coreNodes) } +func (fc *MockScheduler) waitForApplicationStateInCore(appID, partition, expectedState string) error { + return utils.WaitForCondition(func() bool { + app := fc.coreContext.Scheduler.GetClusterContext().GetApplication(appID, partition) + if app == nil { + log.Log(log.Test).Info("Application not found in the scheduler core", zap.String("appID", appID)) + return false + } + current := app.CurrentState() + if current != expectedState { + log.Log(log.Test).Info("waiting for app state in core", + zap.String("expected", expectedState), + zap.String("actual", current)) + return false + } + return true + }, time.Second, 5*time.Second) +} + func (fc *MockScheduler) GetPodBindStats() client.BindStats { return fc.apiProvider.GetPodBindStats() } diff --git a/pkg/shim/scheduler_test.go b/pkg/shim/scheduler_test.go index 9651cd0c2..38fe981fc 100644 --- a/pkg/shim/scheduler_test.go +++ b/pkg/shim/scheduler_test.go @@ -91,6 +91,19 @@ partitions: cluster.waitAndAssertApplicationState(t, "app0001", cache.ApplicationStates().Running) cluster.waitAndAssertTaskState(t, "app0001", "task0001", cache.TaskStates().Bound) cluster.waitAndAssertTaskState(t, "app0001", "task0002", cache.TaskStates().Bound) + + // complete pods + task1Upd := task1.DeepCopy() + task1Upd.Status.Phase = v1.PodSucceeded + cluster.UpdatePod(task1, task1Upd) + cluster.waitAndAssertTaskState(t, "app0001", "task0001", cache.TaskStates().Completed) + cluster.waitAndAssertApplicationState(t, "app0001", cache.ApplicationStates().Running) + task2Upd := task2.DeepCopy() + task2Upd.Status.Phase = v1.PodSucceeded + cluster.UpdatePod(task2, task2Upd) + cluster.waitAndAssertTaskState(t, "app0001", "task0002", cache.TaskStates().Completed) + err = cluster.waitForApplicationStateInCore("app0001", partitionName, "Completing") + assert.NilError(t, err) } func TestRejectApplications(t *testing.T) {