From bb964015e62b73262915346b0af7ecd92420a053 Mon Sep 17 00:00:00 2001 From: gabriel ruttner Date: Thu, 8 Aug 2024 15:30:07 -0400 Subject: [PATCH] wip: affinity --- pkg/repository/prisma/dbsqlc/step_runs.sql | 23 +++++ pkg/repository/prisma/dbsqlc/step_runs.sql.go | 87 ++++++++++++++++++ pkg/repository/prisma/step_run.go | 53 ++++++++++- pkg/scheduling/affinity.go | 52 +++++++++++ pkg/scheduling/affinity_output.json | 40 +++++++++ pkg/scheduling/fixtures/affinity.json | 89 +++++++++++++++++++ pkg/scheduling/fixtures/affinity_output.json | 40 +++++++++ pkg/scheduling/schedule_plan.go | 10 +-- pkg/scheduling/scheduling.go | 18 ++-- pkg/scheduling/scheduling_test.go | 61 ++++++++++--- pkg/scheduling/timeout.go | 2 +- pkg/scheduling/worker_state.go | 21 ++++- scheduling.json | 1 + 13 files changed, 464 insertions(+), 33 deletions(-) create mode 100644 pkg/scheduling/affinity.go create mode 100644 pkg/scheduling/affinity_output.json create mode 100644 pkg/scheduling/fixtures/affinity.json create mode 100644 pkg/scheduling/fixtures/affinity_output.json create mode 100644 scheduling.json diff --git a/pkg/repository/prisma/dbsqlc/step_runs.sql b/pkg/repository/prisma/dbsqlc/step_runs.sql index 79963e266..4a44c750c 100644 --- a/pkg/repository/prisma/dbsqlc/step_runs.sql +++ b/pkg/repository/prisma/dbsqlc/step_runs.sql @@ -664,6 +664,29 @@ WHERE sr."id" = input."id" RETURNING sr."id"; +-- name: GetDesiredLabels :many +SELECT + "key", + "strValue", + "intValue", + "required", + "weight", + "comparator" +FROM + "StepDesiredWorkerLabel" +WHERE + "stepId" = @stepId::uuid; + +-- name: GetWorkerLabels :many +SELECT + "key", + "strValue", + "intValue" +FROM + "WorkerLabel" +WHERE + "workerId" = @workerId::uuid; + -- name: AcquireWorkerSemaphoreSlotAndAssign :one WITH valid_workers AS ( SELECT diff --git a/pkg/repository/prisma/dbsqlc/step_runs.sql.go b/pkg/repository/prisma/dbsqlc/step_runs.sql.go index 7c5023b55..359061e62 100644 --- a/pkg/repository/prisma/dbsqlc/step_runs.sql.go +++ b/pkg/repository/prisma/dbsqlc/step_runs.sql.go @@ -763,6 +763,56 @@ func (q *Queries) CreateStepRunEvent(ctx context.Context, db DBTX, arg CreateSte return err } +const getDesiredLabels = `-- name: GetDesiredLabels :many +SELECT + "key", + "strValue", + "intValue", + "required", + "weight", + "comparator" +FROM + "StepDesiredWorkerLabel" +WHERE + "stepId" = $1::uuid +` + +type GetDesiredLabelsRow struct { + Key string `json:"key"` + StrValue pgtype.Text `json:"strValue"` + IntValue pgtype.Int4 `json:"intValue"` + Required bool `json:"required"` + Weight int32 `json:"weight"` + Comparator WorkerLabelComparator `json:"comparator"` +} + +func (q *Queries) GetDesiredLabels(ctx context.Context, db DBTX, stepid pgtype.UUID) ([]*GetDesiredLabelsRow, error) { + rows, err := db.Query(ctx, getDesiredLabels, stepid) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*GetDesiredLabelsRow + for rows.Next() { + var i GetDesiredLabelsRow + if err := rows.Scan( + &i.Key, + &i.StrValue, + &i.IntValue, + &i.Required, + &i.Weight, + &i.Comparator, + ); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getLaterStepRunsForReplay = `-- name: GetLaterStepRunsForReplay :many WITH RECURSIVE currStepRun AS ( SELECT id, "createdAt", "updatedAt", "deletedAt", "tenantId", "jobRunId", "stepId", "order", "workerId", "tickerId", status, input, output, "requeueAfter", "scheduleTimeoutAt", error, "startedAt", "finishedAt", "timeoutAt", "cancelledAt", "cancelledReason", "cancelledError", "inputSchema", "callerFiles", "gitRepoBranch", "retryCount", "semaphoreReleased", queue, "queueOrder" @@ -1250,6 +1300,43 @@ func (q *Queries) GetStepRunQueueOrder(ctx context.Context, db DBTX, arg GetStep return queueOrder, err } +const getWorkerLabels = `-- name: GetWorkerLabels :many +SELECT + "key", + "strValue", + "intValue" +FROM + "WorkerLabel" +WHERE + "workerId" = $1::uuid +` + +type GetWorkerLabelsRow struct { + Key string `json:"key"` + StrValue pgtype.Text `json:"strValue"` + IntValue pgtype.Int4 `json:"intValue"` +} + +func (q *Queries) GetWorkerLabels(ctx context.Context, db DBTX, workerid pgtype.UUID) ([]*GetWorkerLabelsRow, error) { + rows, err := db.Query(ctx, getWorkerLabels, workerid) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*GetWorkerLabelsRow + for rows.Next() { + var i GetWorkerLabelsRow + if err := rows.Scan(&i.Key, &i.StrValue, &i.IntValue); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listNonFinalChildStepRuns = `-- name: ListNonFinalChildStepRuns :many WITH RECURSIVE currStepRun AS ( SELECT id, "createdAt", "updatedAt", "deletedAt", "tenantId", "jobRunId", "stepId", "order", "workerId", "tickerId", status, input, output, "requeueAfter", "scheduleTimeoutAt", error, "startedAt", "finishedAt", "timeoutAt", "cancelledAt", "cancelledReason", "cancelledError", "inputSchema", "callerFiles", "gitRepoBranch", "retryCount", "semaphoreReleased", queue, "queueOrder" diff --git a/pkg/repository/prisma/step_run.go b/pkg/repository/prisma/step_run.go index 69bc49784..58a68e679 100644 --- a/pkg/repository/prisma/step_run.go +++ b/pkg/repository/prisma/step_run.go @@ -940,6 +940,17 @@ func (s *stepRunEngineRepository) UnassignStepRunFromWorker(ctx context.Context, }) } +func UniqueSet[T any](i []T, keyFunc func(T) string) map[string]struct{} { + set := make(map[string]struct{}) + + for _, item := range i { + key := keyFunc(item) + set[key] = struct{}{} + } + + return set +} + type debugInfo struct { UniqueActions []string `json:"unique_actions"` TotalStepRuns int `json:"total_step_runs"` @@ -1022,7 +1033,7 @@ func (s *stepRunEngineRepository) QueueStepRuns(ctx context.Context, tenantId st durationsOfQueueListResults := make([]string, 0) - queueItems := make([]scheduling.QueueItemWithOrder, 0) + queueItems := make([]*scheduling.QueueItemWithOrder, 0) // TODO: verify whether this is multithreaded and if it is, whether thread safe results.Query(func(i int, qi []*dbsqlc.QueueItem, err error) { @@ -1033,7 +1044,7 @@ func (s *stepRunEngineRepository) QueueStepRuns(ctx context.Context, tenantId st queueName := "" for i := range qi { - queueItems = append(queueItems, scheduling.QueueItemWithOrder{ + queueItems = append(queueItems, &scheduling.QueueItemWithOrder{ QueueItem: qi[i], Order: i, }) @@ -1085,10 +1096,48 @@ func (s *stepRunEngineRepository) QueueStepRuns(ctx context.Context, tenantId st return emptyRes, fmt.Errorf("could not list semaphore slots to assign: %w", err) } + // GET UNIQUE STEP IDS + stepIdSet := UniqueSet(queueItems, func(x *scheduling.QueueItemWithOrder) string { + return sqlchelpers.UUIDToStr(x.StepId) + }) + + desiredLabels := make(map[string][]*dbsqlc.GetDesiredLabelsRow) + hasDesired := false + + // GET DESIRED LABELS + // OPTIMIZATION: CACHEABLE + for stepId := range stepIdSet { + labels, err := s.queries.GetDesiredLabels(ctx, tx, sqlchelpers.UUIDFromStr(stepId)) + if err != nil { + return emptyRes, fmt.Errorf("could not get desired labels: %w", err) + } + desiredLabels[stepId] = labels + hasDesired = true + } + + var workerLabels = make(map[string][]*dbsqlc.GetWorkerLabelsRow) + + if hasDesired { + // GET UNIQUE WORKER LABELS + workerIdSet := UniqueSet(slots, func(x *dbsqlc.ListSemaphoreSlotsToAssignRow) string { + return sqlchelpers.UUIDToStr(x.WorkerId) + }) + + for workerId := range workerIdSet { + labels, err := s.queries.GetWorkerLabels(ctx, tx, sqlchelpers.UUIDFromStr(workerId)) + if err != nil { + return emptyRes, fmt.Errorf("could not get worker labels: %w", err) + } + workerLabels[workerId] = labels + } + } + plan, err := scheduling.GeneratePlan( slots, uniqueActionsArr, queueItems, + workerLabels, + desiredLabels, ) if err != nil { diff --git a/pkg/scheduling/affinity.go b/pkg/scheduling/affinity.go new file mode 100644 index 000000000..1e5017831 --- /dev/null +++ b/pkg/scheduling/affinity.go @@ -0,0 +1,52 @@ +package scheduling + +import "github.com/hatchet-dev/hatchet/pkg/repository/prisma/dbsqlc" + +func ComputeWeight(s []*dbsqlc.GetDesiredLabelsRow, l []*dbsqlc.GetWorkerLabelsRow) int { + totalWeight := 0 + + for _, desiredLabel := range s { + labelFound := false + for _, workerLabel := range l { + if desiredLabel.Key == workerLabel.Key { + labelFound = true + switch desiredLabel.Comparator { + case dbsqlc.WorkerLabelComparatorEQUAL: + if (desiredLabel.StrValue.Valid && workerLabel.StrValue.Valid && desiredLabel.StrValue.String == workerLabel.StrValue.String) || + (desiredLabel.IntValue.Valid && workerLabel.IntValue.Valid && desiredLabel.IntValue.Int32 == workerLabel.IntValue.Int32) { + totalWeight += int(desiredLabel.Weight) + } + case dbsqlc.WorkerLabelComparatorNOTEQUAL: + if (desiredLabel.StrValue.Valid && workerLabel.StrValue.Valid && desiredLabel.StrValue.String != workerLabel.StrValue.String) || + (desiredLabel.IntValue.Valid && workerLabel.IntValue.Valid && desiredLabel.IntValue.Int32 != workerLabel.IntValue.Int32) { + totalWeight += int(desiredLabel.Weight) + } + case dbsqlc.WorkerLabelComparatorGREATERTHAN: + if desiredLabel.IntValue.Valid && workerLabel.IntValue.Valid && workerLabel.IntValue.Int32 > desiredLabel.IntValue.Int32 { + totalWeight += int(desiredLabel.Weight) + } + case dbsqlc.WorkerLabelComparatorLESSTHAN: + if desiredLabel.IntValue.Valid && workerLabel.IntValue.Valid && workerLabel.IntValue.Int32 < desiredLabel.IntValue.Int32 { + totalWeight += int(desiredLabel.Weight) + } + case dbsqlc.WorkerLabelComparatorGREATERTHANOREQUAL: + if desiredLabel.IntValue.Valid && workerLabel.IntValue.Valid && workerLabel.IntValue.Int32 >= desiredLabel.IntValue.Int32 { + totalWeight += int(desiredLabel.Weight) + } + case dbsqlc.WorkerLabelComparatorLESSTHANOREQUAL: + if desiredLabel.IntValue.Valid && workerLabel.IntValue.Valid && workerLabel.IntValue.Int32 <= desiredLabel.IntValue.Int32 { + totalWeight += int(desiredLabel.Weight) + } + } + break // Move to the next desired label + } + } + + // If the label is required but not found, return -1 to indicate an invalid match + if desiredLabel.Required && !labelFound { + return -1 + } + } + + return totalWeight +} diff --git a/pkg/scheduling/affinity_output.json b/pkg/scheduling/affinity_output.json new file mode 100644 index 000000000..9d9b18593 --- /dev/null +++ b/pkg/scheduling/affinity_output.json @@ -0,0 +1,40 @@ +{ + "StepRunIds": [ + "16fa711e-c03e-435c-88d9-62adf1591d98", + "064e787a-6cfd-4f82-8b6d-d8031459fdee" + ], + "StepRunTimeouts": [ + "60s", + "60s" + ], + "SlotIds": [ + "8cf68f09-b914-4f31-9777-8082b751a2d4", + "e2c744b8-b914-4f31-9777-8082b751a2d4" + ], + "WorkerIds": [ + "aaaaaaaa-43db-4b5f-876a-c6c71f4d52aa", + "aaaaaaaa-43db-4b5f-876a-c6c71f4d52aa" + ], + "UnassignedStepRunIds": [], + "QueuedStepRuns": [ + { + "StepRunId": "16fa711e-c03e-435c-88d9-62adf1591d98", + "WorkerId": "aaaaaaaa-43db-4b5f-876a-c6c71f4d52aa", + "DispatcherId": "9994a9eb-430d-46da-934d-d9dd953cfd21" + }, + { + "StepRunId": "064e787a-6cfd-4f82-8b6d-d8031459fdee", + "WorkerId": "aaaaaaaa-43db-4b5f-876a-c6c71f4d52aa", + "DispatcherId": "9994a9eb-430d-46da-934d-d9dd953cfd21" + } + ], + "TimedOutStepRuns": [], + "QueuedItems": [ + 137295, + 152259 + ], + "ShouldContinue": false, + "MinQueuedIds": { + "child:process2": 137295 + } +} diff --git a/pkg/scheduling/fixtures/affinity.json b/pkg/scheduling/fixtures/affinity.json new file mode 100644 index 000000000..44f833344 --- /dev/null +++ b/pkg/scheduling/fixtures/affinity.json @@ -0,0 +1,89 @@ +{ + "Slots": [ + { + "id": "8cf68f09-b914-4f31-9777-8082b751a2d4", + "workerId": "aaaaaaaa-43db-4b5f-876a-c6c71f4d52aa", + "dispatcherId": "9994a9eb-430d-46da-934d-d9dd953cfd21", + "actionId": "child:process2" + }, + { + "id": "e2c744b8-b914-4f31-9777-8082b751a2d4", + "workerId": "aaaaaaaa-43db-4b5f-876a-c6c71f4d52aa", + "dispatcherId": "9994a9eb-430d-46da-934d-d9dd953cfd21", + "actionId": "child:process2" + }, + { + "id": "aed946be-43db-4b5f-876a-c6c71f4d52aa", + "workerId": "bbbbbbbb-43db-4b5f-876a-c6c71f4d52aa", + "dispatcherId": "9994a9eb-430d-46da-934d-d9dd953cfd21", + "actionId": "child:process2" + }, + { + "id": "e2c744b8-43db-4b5f-876a-c6c71f4d52aa", + "workerId": "bbbbbbbb-43db-4b5f-876a-c6c71f4d52aa", + "dispatcherId": "9994a9eb-430d-46da-934d-d9dd953cfd21", + "actionId": "child:process2" + } + ], + "UniqueActions": [ + "child:process2" + ], + "QueueItems": [ + { + "id": 137295, + "stepRunId": "16fa711e-c03e-435c-88d9-62adf1591d98", + "stepId": "29fae13a-6672-48cb-9aed-3b0bf481f570", + "actionId": "child:process2", + "scheduleTimeoutAt": "2024-08-07T19:43:30.692Z", + "stepTimeout": "60s", + "isQueued": true, + "tenantId": "707d0855-80ab-4e1f-a156-f1c4546cbf52", + "queue": "child:process2", + "Order": 0 + }, + { + "id": 152259, + "stepRunId": "064e787a-6cfd-4f82-8b6d-d8031459fdee", + "stepId": "e2c744b8-bebb-4b05-8292-d2ce36f27928", + "actionId": "child:process2", + "scheduleTimeoutAt": "2024-08-07T21:12:49.56Z", + "stepTimeout": "60s", + "isQueued": true, + "tenantId": "707d0855-80ab-4e1f-a156-f1c4546cbf52", + "queue": "child:process2", + "Order": 0 + } + ], + "WorkerLabels": { + "aaaaaaaa-43db-4b5f-876a-c6c71f4d52aa": [ + { + "Key": "MODEL", + "StrVal": "A" + } + ], + "bbbbbbbb-43db-4b5f-876a-c6c71f4d52aa": [ + { + "Key": "MODEL", + "StrVal": "B" + } + ] + }, + "StepDesiredLabels": { + "29fae13a-6672-48cb-9aed-3b0bf481f570": [ + { + "Key": "MODEL", + "StrValue": "A", + "Required": true, + "Weight": 1 + } + ], + "e2c744b8-bebb-4b05-8292-d2ce36f27928": [ + { + "Key": "MODEL", + "StrValue": "B", + "Required": true, + "Weight": 1 + } + ] + } +} diff --git a/pkg/scheduling/fixtures/affinity_output.json b/pkg/scheduling/fixtures/affinity_output.json new file mode 100644 index 000000000..10680c3e3 --- /dev/null +++ b/pkg/scheduling/fixtures/affinity_output.json @@ -0,0 +1,40 @@ +{ + "StepRunIds": [ + "16fa711e-c03e-435c-88d9-62adf1591d98", + "064e787a-6cfd-4f82-8b6d-d8031459fdee" + ], + "StepRunTimeouts": [ + "60s", + "60s" + ], + "SlotIds": [ + "8cf68f09-b914-4f31-9777-8082b751a2d4", + "aed946be-43db-4b5f-876a-c6c71f4d52aa" + ], + "WorkerIds": [ + "aaaaaaaa-43db-4b5f-876a-c6c71f4d52aa", + "bbbbbbbb-43db-4b5f-876a-c6c71f4d52aa" + ], + "UnassignedStepRunIds": [], + "QueuedStepRuns": [ + { + "StepRunId": "16fa711e-c03e-435c-88d9-62adf1591d98", + "WorkerId": "aaaaaaaa-43db-4b5f-876a-c6c71f4d52aa", + "DispatcherId": "9994a9eb-430d-46da-934d-d9dd953cfd21" + }, + { + "StepRunId": "064e787a-6cfd-4f82-8b6d-d8031459fdee", + "WorkerId": "bbbbbbbb-43db-4b5f-876a-c6c71f4d52aa", + "DispatcherId": "9994a9eb-430d-46da-934d-d9dd953cfd21" + } + ], + "TimedOutStepRuns": [], + "QueuedItems": [ + 137295, + 152259 + ], + "ShouldContinue": false, + "MinQueuedIds": { + "child:process2": 137295 + } +} diff --git a/pkg/scheduling/schedule_plan.go b/pkg/scheduling/schedule_plan.go index 7b2b4f30c..14c41f289 100644 --- a/pkg/scheduling/schedule_plan.go +++ b/pkg/scheduling/schedule_plan.go @@ -21,7 +21,7 @@ type SchedulePlan struct { MinQueuedIds map[string]int64 } -func (sp *SchedulePlan) UpdateMinQueuedIds(qi QueueItemWithOrder) []repository.QueuedStepRun { +func (sp *SchedulePlan) UpdateMinQueuedIds(qi *QueueItemWithOrder) []repository.QueuedStepRun { if qi.Priority == 1 { if currMinQueued, ok := sp.MinQueuedIds[qi.Queue]; !ok { sp.MinQueuedIds[qi.Queue] = qi.ID @@ -33,21 +33,21 @@ func (sp *SchedulePlan) UpdateMinQueuedIds(qi QueueItemWithOrder) []repository.Q return sp.QueuedStepRuns } -func (plan *SchedulePlan) HandleTimedOut(qi QueueItemWithOrder) { +func (plan *SchedulePlan) HandleTimedOut(qi *QueueItemWithOrder) { plan.TimedOutStepRuns = append(plan.TimedOutStepRuns, qi.StepRunId) // mark as queued so that we don't requeue plan.QueuedItems = append(plan.QueuedItems, qi.ID) } -func (plan *SchedulePlan) HandleNoSlots(qi QueueItemWithOrder) { +func (plan *SchedulePlan) HandleNoSlots(qi *QueueItemWithOrder) { plan.UnassignedStepRunIds = append(plan.UnassignedStepRunIds, qi.StepRunId) } -func (plan *SchedulePlan) HandleUnassigned(qi QueueItemWithOrder) { +func (plan *SchedulePlan) HandleUnassigned(qi *QueueItemWithOrder) { plan.UnassignedStepRunIds = append(plan.UnassignedStepRunIds, qi.StepRunId) } -func (plan *SchedulePlan) AssignQiToSlot(qi QueueItemWithOrder, slot *dbsqlc.ListSemaphoreSlotsToAssignRow) { +func (plan *SchedulePlan) AssignQiToSlot(qi *QueueItemWithOrder, slot *dbsqlc.ListSemaphoreSlotsToAssignRow) { plan.StepRunIds = append(plan.StepRunIds, qi.StepRunId) plan.StepRunTimeouts = append(plan.StepRunTimeouts, qi.StepTimeout.String) plan.SlotIds = append(plan.SlotIds, slot.ID) diff --git a/pkg/scheduling/scheduling.go b/pkg/scheduling/scheduling.go index fdc84ce9d..073b45ab5 100644 --- a/pkg/scheduling/scheduling.go +++ b/pkg/scheduling/scheduling.go @@ -14,11 +14,12 @@ type QueueItemWithOrder struct { Order int } -// Generate generates a random string of n bytes. func GeneratePlan( slots []*dbsqlc.ListSemaphoreSlotsToAssignRow, uniqueActionsArr []string, - queueItems []QueueItemWithOrder, + queueItems []*QueueItemWithOrder, + workerLabels map[string][]*dbsqlc.GetWorkerLabelsRow, + stepDesiredLabels map[string][]*dbsqlc.GetDesiredLabelsRow, ) (SchedulePlan, error) { plan := SchedulePlan{ @@ -38,9 +39,12 @@ func GeneratePlan( // initialize worker states for _, slot := range slots { - if _, ok := workers[sqlchelpers.UUIDToStr(slot.WorkerId)]; !ok { - workers[sqlchelpers.UUIDToStr(slot.WorkerId)] = NewWorkerState( - sqlchelpers.UUIDToStr(slot.WorkerId), + workerId := sqlchelpers.UUIDToStr(slot.WorkerId) + + if _, ok := workers[workerId]; !ok { + workers[workerId] = NewWorkerState( + workerId, + workerLabels[workerId], ) } workers[sqlchelpers.UUIDToStr(slot.WorkerId)].AddSlot(slot) @@ -69,7 +73,7 @@ func GeneratePlan( // pick a worker to assign the slot to assigned := false for _, worker := range workers { - slot, isEmpty := worker.AssignSlot(qi) + slot, isEmpty := worker.AssignSlot(qi, stepDesiredLabels[sqlchelpers.UUIDToStr(qi.StepId)]) if slot == nil { // if we can't assign the slot to the worker then continue @@ -101,7 +105,7 @@ func GeneratePlan( if len(workers) > 0 && len(plan.UnassignedStepRunIds) > 0 { for _, qi := range queueItems { for _, worker := range workers { - if worker.CanAssign(qi) { + if worker.CanAssign(qi, stepDesiredLabels[sqlchelpers.UUIDToStr(qi.StepId)]) { plan.ShouldContinue = true break } diff --git a/pkg/scheduling/scheduling_test.go b/pkg/scheduling/scheduling_test.go index 9ac5faf56..f34c91c81 100644 --- a/pkg/scheduling/scheduling_test.go +++ b/pkg/scheduling/scheduling_test.go @@ -14,9 +14,11 @@ import ( ) type args struct { - Slots []*dbsqlc.ListSemaphoreSlotsToAssignRow - UniqueActionsArr []string - QueueItems []QueueItemWithOrder + Slots []*dbsqlc.ListSemaphoreSlotsToAssignRow + UniqueActionsArr []string + QueueItems []*QueueItemWithOrder + WorkerLabels map[string][]*dbsqlc.GetWorkerLabelsRow + StepDesiredLabels map[string][]*dbsqlc.GetDesiredLabelsRow } func loadFixture(filename string, noTimeout bool) (*args, error) { @@ -45,7 +47,7 @@ func loadFixture(filename string, noTimeout bool) (*args, error) { return args, nil } -func assertResult(res SchedulePlan, filename string) (bool, error) { +func assertResult(actual SchedulePlan, filename string) (bool, error) { data, err := os.ReadFile(filename) if err != nil { return false, err @@ -58,39 +60,45 @@ func assertResult(res SchedulePlan, filename string) (bool, error) { } // Compare the results - if res.ShouldContinue != expected.ShouldContinue { + if actual.ShouldContinue != expected.ShouldContinue { return false, fmt.Errorf("ShouldContinue does not match") } - if len(res.StepRunIds) != len(expected.StepRunIds) { + if len(actual.StepRunIds) != len(expected.StepRunIds) { return false, fmt.Errorf("StepRunIds length does not match") } - if len(res.StepRunTimeouts) != len(expected.StepRunTimeouts) { + if len(actual.StepRunTimeouts) != len(expected.StepRunTimeouts) { return false, fmt.Errorf("StepRunTimeouts length does not match") } - if len(res.SlotIds) != len(expected.SlotIds) { + if len(actual.SlotIds) != len(expected.SlotIds) { return false, fmt.Errorf("SlotIds length does not match") } - if len(res.WorkerIds) != len(expected.WorkerIds) { + for i := range actual.QueuedStepRuns { + if actual.QueuedStepRuns[i].WorkerId != expected.QueuedStepRuns[i].WorkerId { + return false, fmt.Errorf("Expected worker mismatch") + } + } + + if len(actual.WorkerIds) != len(expected.WorkerIds) { return false, fmt.Errorf("WorkerIds length does not match") } - if len(res.UnassignedStepRunIds) != len(expected.UnassignedStepRunIds) { + if len(actual.UnassignedStepRunIds) != len(expected.UnassignedStepRunIds) { return false, fmt.Errorf("UnassignedStepRunIds length does not match") } - if len(res.QueuedStepRuns) != len(expected.QueuedStepRuns) { + if len(actual.QueuedStepRuns) != len(expected.QueuedStepRuns) { return false, fmt.Errorf("QueuedStepRuns length does not match") } - if len(res.TimedOutStepRuns) != len(expected.TimedOutStepRuns) { + if len(actual.TimedOutStepRuns) != len(expected.TimedOutStepRuns) { return false, fmt.Errorf("TimedOutStepRuns length does not match") } - if len(res.QueuedItems) != len(expected.QueuedItems) { + if len(actual.QueuedItems) != len(expected.QueuedItems) { return false, fmt.Errorf("QueuedItems length does not match") } @@ -144,6 +152,25 @@ func TestGeneratePlan(t *testing.T) { }, wantErr: assert.NoError, }, + { + name: "GeneratePlan_Affinity", + args: args{ + fixtureArgs: "./fixtures/affinity.json", + fixtureResult: "./fixtures/affinity_output.json", + noTimeout: true, + }, + want: func(s SchedulePlan, fixtureResult string) bool { + // DumpResults(s, "affinity_output.json") + + assert, err := assertResult(s, fixtureResult) + if err != nil { + fmt.Println(err) + } + + return assert + }, + wantErr: assert.NoError, + }, { name: "GeneratePlan_TimedOut", args: args{ @@ -173,7 +200,13 @@ func TestGeneratePlan(t *testing.T) { t.Fatalf("Failed to load fixture: %v", err) } - got, err := GeneratePlan(fixtureData.Slots, fixtureData.UniqueActionsArr, fixtureData.QueueItems) + got, err := GeneratePlan( + fixtureData.Slots, + fixtureData.UniqueActionsArr, + fixtureData.QueueItems, + fixtureData.WorkerLabels, + fixtureData.StepDesiredLabels, + ) if !tt.wantErr(t, err, "GeneratePlan_Simple") { return diff --git a/pkg/scheduling/timeout.go b/pkg/scheduling/timeout.go index 457241c64..0eb34d9a9 100644 --- a/pkg/scheduling/timeout.go +++ b/pkg/scheduling/timeout.go @@ -2,7 +2,7 @@ package scheduling import "time" -func IsTimedout(qi QueueItemWithOrder) bool { +func IsTimedout(qi *QueueItemWithOrder) bool { // if the current time is after the scheduleTimeoutAt, then mark this as timed out now := time.Now().UTC().UTC() scheduleTimeoutAt := qi.ScheduleTimeoutAt.Time diff --git a/pkg/scheduling/worker_state.go b/pkg/scheduling/worker_state.go index d4f3b7e4f..6106e9032 100644 --- a/pkg/scheduling/worker_state.go +++ b/pkg/scheduling/worker_state.go @@ -1,6 +1,8 @@ package scheduling import ( + "fmt" + "github.com/hatchet-dev/hatchet/pkg/repository/prisma/dbsqlc" ) @@ -8,13 +10,15 @@ type WorkerState struct { workerId string slots []*dbsqlc.ListSemaphoreSlotsToAssignRow actionIds map[string]struct{} + labels []*dbsqlc.GetWorkerLabelsRow } -func NewWorkerState(workerId string) *WorkerState { +func NewWorkerState(workerId string, labels []*dbsqlc.GetWorkerLabelsRow) *WorkerState { return &WorkerState{ workerId: workerId, slots: make([]*dbsqlc.ListSemaphoreSlotsToAssignRow, 0), actionIds: make(map[string]struct{}), + labels: labels, } } @@ -26,18 +30,27 @@ func (w *WorkerState) AddSlot(slot *dbsqlc.ListSemaphoreSlotsToAssignRow) { } } -func (w *WorkerState) CanAssign(qi QueueItemWithOrder) bool { +func (w *WorkerState) CanAssign(qi *QueueItemWithOrder, desiredLabels []*dbsqlc.GetDesiredLabelsRow) bool { if _, ok := w.actionIds[qi.ActionId.String]; !ok { return false } + if len(desiredLabels) > 0 { + + // TODO cache + weight := ComputeWeight(desiredLabels, w.labels) + + fmt.Println(weight) + return weight >= 0 + } + return true } -func (w *WorkerState) AssignSlot(qi QueueItemWithOrder) (*dbsqlc.ListSemaphoreSlotsToAssignRow, bool) { +func (w *WorkerState) AssignSlot(qi *QueueItemWithOrder, desiredLabels []*dbsqlc.GetDesiredLabelsRow) (*dbsqlc.ListSemaphoreSlotsToAssignRow, bool) { // if the actionId is not in the worker's actionIds, then we can't assign this slot - if !w.CanAssign(qi) { + if !w.CanAssign(qi, desiredLabels) { return nil, false } diff --git a/scheduling.json b/scheduling.json new file mode 100644 index 000000000..0967ef424 --- /dev/null +++ b/scheduling.json @@ -0,0 +1 @@ +{}