diff --git a/pkg/repository/prisma/dbsqlc/step_runs.sql b/pkg/repository/prisma/dbsqlc/step_runs.sql index 79963e266..4a44c750c 100644 --- a/pkg/repository/prisma/dbsqlc/step_runs.sql +++ b/pkg/repository/prisma/dbsqlc/step_runs.sql @@ -664,6 +664,29 @@ WHERE sr."id" = input."id" RETURNING sr."id"; +-- name: GetDesiredLabels :many +SELECT + "key", + "strValue", + "intValue", + "required", + "weight", + "comparator" +FROM + "StepDesiredWorkerLabel" +WHERE + "stepId" = @stepId::uuid; + +-- name: GetWorkerLabels :many +SELECT + "key", + "strValue", + "intValue" +FROM + "WorkerLabel" +WHERE + "workerId" = @workerId::uuid; + -- name: AcquireWorkerSemaphoreSlotAndAssign :one WITH valid_workers AS ( SELECT diff --git a/pkg/repository/prisma/dbsqlc/step_runs.sql.go b/pkg/repository/prisma/dbsqlc/step_runs.sql.go index 7c5023b55..359061e62 100644 --- a/pkg/repository/prisma/dbsqlc/step_runs.sql.go +++ b/pkg/repository/prisma/dbsqlc/step_runs.sql.go @@ -763,6 +763,56 @@ func (q *Queries) CreateStepRunEvent(ctx context.Context, db DBTX, arg CreateSte return err } +const getDesiredLabels = `-- name: GetDesiredLabels :many +SELECT + "key", + "strValue", + "intValue", + "required", + "weight", + "comparator" +FROM + "StepDesiredWorkerLabel" +WHERE + "stepId" = $1::uuid +` + +type GetDesiredLabelsRow struct { + Key string `json:"key"` + StrValue pgtype.Text `json:"strValue"` + IntValue pgtype.Int4 `json:"intValue"` + Required bool `json:"required"` + Weight int32 `json:"weight"` + Comparator WorkerLabelComparator `json:"comparator"` +} + +func (q *Queries) GetDesiredLabels(ctx context.Context, db DBTX, stepid pgtype.UUID) ([]*GetDesiredLabelsRow, error) { + rows, err := db.Query(ctx, getDesiredLabels, stepid) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*GetDesiredLabelsRow + for rows.Next() { + var i GetDesiredLabelsRow + if err := rows.Scan( + &i.Key, + &i.StrValue, + &i.IntValue, + &i.Required, + &i.Weight, + &i.Comparator, + ); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getLaterStepRunsForReplay = `-- name: GetLaterStepRunsForReplay :many WITH RECURSIVE currStepRun AS ( SELECT id, "createdAt", "updatedAt", "deletedAt", "tenantId", "jobRunId", "stepId", "order", "workerId", "tickerId", status, input, output, "requeueAfter", "scheduleTimeoutAt", error, "startedAt", "finishedAt", "timeoutAt", "cancelledAt", "cancelledReason", "cancelledError", "inputSchema", "callerFiles", "gitRepoBranch", "retryCount", "semaphoreReleased", queue, "queueOrder" @@ -1250,6 +1300,43 @@ func (q *Queries) GetStepRunQueueOrder(ctx context.Context, db DBTX, arg GetStep return queueOrder, err } +const getWorkerLabels = `-- name: GetWorkerLabels :many +SELECT + "key", + "strValue", + "intValue" +FROM + "WorkerLabel" +WHERE + "workerId" = $1::uuid +` + +type GetWorkerLabelsRow struct { + Key string `json:"key"` + StrValue pgtype.Text `json:"strValue"` + IntValue pgtype.Int4 `json:"intValue"` +} + +func (q *Queries) GetWorkerLabels(ctx context.Context, db DBTX, workerid pgtype.UUID) ([]*GetWorkerLabelsRow, error) { + rows, err := db.Query(ctx, getWorkerLabels, workerid) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*GetWorkerLabelsRow + for rows.Next() { + var i GetWorkerLabelsRow + if err := rows.Scan(&i.Key, &i.StrValue, &i.IntValue); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listNonFinalChildStepRuns = `-- name: ListNonFinalChildStepRuns :many WITH RECURSIVE currStepRun AS ( SELECT id, "createdAt", "updatedAt", "deletedAt", "tenantId", "jobRunId", "stepId", "order", "workerId", "tickerId", status, input, output, "requeueAfter", "scheduleTimeoutAt", error, "startedAt", "finishedAt", "timeoutAt", "cancelledAt", "cancelledReason", "cancelledError", "inputSchema", "callerFiles", "gitRepoBranch", "retryCount", "semaphoreReleased", queue, "queueOrder" diff --git a/pkg/repository/prisma/step_run.go b/pkg/repository/prisma/step_run.go index d1dba75dc..9edea0cd2 100644 --- a/pkg/repository/prisma/step_run.go +++ b/pkg/repository/prisma/step_run.go @@ -974,6 +974,17 @@ func (s *stepRunEngineRepository) UnassignStepRunFromWorker(ctx context.Context, }) } +func UniqueSet[T any](i []T, keyFunc func(T) string) map[string]struct{} { + set := make(map[string]struct{}) + + for _, item := range i { + key := keyFunc(item) + set[key] = struct{}{} + } + + return set +} + type debugInfo struct { UniqueActions []string `json:"unique_actions"` TotalStepRuns int `json:"total_step_runs"` @@ -1061,7 +1072,7 @@ func (s *stepRunEngineRepository) QueueStepRuns(ctx context.Context, tenantId st durationsOfQueueListResults := make([]string, 0) - queueItems := make([]scheduling.QueueItemWithOrder, 0) + queueItems := make([]*scheduling.QueueItemWithOrder, 0) // TODO: verify whether this is multithreaded and if it is, whether thread safe results.Query(func(i int, qi []*dbsqlc.QueueItem, err error) { @@ -1072,7 +1083,7 @@ func (s *stepRunEngineRepository) QueueStepRuns(ctx context.Context, tenantId st queueName := "" for i := range qi { - queueItems = append(queueItems, scheduling.QueueItemWithOrder{ + queueItems = append(queueItems, &scheduling.QueueItemWithOrder{ QueueItem: qi[i], Order: i, }) @@ -1174,12 +1185,50 @@ func (s *stepRunEngineRepository) QueueStepRuns(ctx context.Context, tenantId st return emptyRes, fmt.Errorf("could not list semaphore slots to assign: %w", err) } + // GET UNIQUE STEP IDS + stepIdSet := UniqueSet(queueItems, func(x *scheduling.QueueItemWithOrder) string { + return sqlchelpers.UUIDToStr(x.StepId) + }) + + desiredLabels := make(map[string][]*dbsqlc.GetDesiredLabelsRow) + hasDesired := false + + // GET DESIRED LABELS + // OPTIMIZATION: CACHEABLE + for stepId := range stepIdSet { + labels, err := s.queries.GetDesiredLabels(ctx, tx, sqlchelpers.UUIDFromStr(stepId)) + if err != nil { + return emptyRes, fmt.Errorf("could not get desired labels: %w", err) + } + desiredLabels[stepId] = labels + hasDesired = true + } + + var workerLabels = make(map[string][]*dbsqlc.GetWorkerLabelsRow) + + if hasDesired { + // GET UNIQUE WORKER LABELS + workerIdSet := UniqueSet(slots, func(x *dbsqlc.ListSemaphoreSlotsToAssignRow) string { + return sqlchelpers.UUIDToStr(x.WorkerId) + }) + + for workerId := range workerIdSet { + labels, err := s.queries.GetWorkerLabels(ctx, tx, sqlchelpers.UUIDFromStr(workerId)) + if err != nil { + return emptyRes, fmt.Errorf("could not get worker labels: %w", err) + } + workerLabels[workerId] = labels + } + } + plan, err := scheduling.GeneratePlan( slots, uniqueActionsArr, queueItems, stepRateUnits, currRateLimitValues, + workerLabels, + desiredLabels, ) if err != nil { diff --git a/pkg/scheduling/affinity.go b/pkg/scheduling/affinity.go new file mode 100644 index 000000000..1e5017831 --- /dev/null +++ b/pkg/scheduling/affinity.go @@ -0,0 +1,52 @@ +package scheduling + +import "github.com/hatchet-dev/hatchet/pkg/repository/prisma/dbsqlc" + +func ComputeWeight(s []*dbsqlc.GetDesiredLabelsRow, l []*dbsqlc.GetWorkerLabelsRow) int { + totalWeight := 0 + + for _, desiredLabel := range s { + labelFound := false + for _, workerLabel := range l { + if desiredLabel.Key == workerLabel.Key { + labelFound = true + switch desiredLabel.Comparator { + case dbsqlc.WorkerLabelComparatorEQUAL: + if (desiredLabel.StrValue.Valid && workerLabel.StrValue.Valid && desiredLabel.StrValue.String == workerLabel.StrValue.String) || + (desiredLabel.IntValue.Valid && workerLabel.IntValue.Valid && desiredLabel.IntValue.Int32 == workerLabel.IntValue.Int32) { + totalWeight += int(desiredLabel.Weight) + } + case dbsqlc.WorkerLabelComparatorNOTEQUAL: + if (desiredLabel.StrValue.Valid && workerLabel.StrValue.Valid && desiredLabel.StrValue.String != workerLabel.StrValue.String) || + (desiredLabel.IntValue.Valid && workerLabel.IntValue.Valid && desiredLabel.IntValue.Int32 != workerLabel.IntValue.Int32) { + totalWeight += int(desiredLabel.Weight) + } + case dbsqlc.WorkerLabelComparatorGREATERTHAN: + if desiredLabel.IntValue.Valid && workerLabel.IntValue.Valid && workerLabel.IntValue.Int32 > desiredLabel.IntValue.Int32 { + totalWeight += int(desiredLabel.Weight) + } + case dbsqlc.WorkerLabelComparatorLESSTHAN: + if desiredLabel.IntValue.Valid && workerLabel.IntValue.Valid && workerLabel.IntValue.Int32 < desiredLabel.IntValue.Int32 { + totalWeight += int(desiredLabel.Weight) + } + case dbsqlc.WorkerLabelComparatorGREATERTHANOREQUAL: + if desiredLabel.IntValue.Valid && workerLabel.IntValue.Valid && workerLabel.IntValue.Int32 >= desiredLabel.IntValue.Int32 { + totalWeight += int(desiredLabel.Weight) + } + case dbsqlc.WorkerLabelComparatorLESSTHANOREQUAL: + if desiredLabel.IntValue.Valid && workerLabel.IntValue.Valid && workerLabel.IntValue.Int32 <= desiredLabel.IntValue.Int32 { + totalWeight += int(desiredLabel.Weight) + } + } + break // Move to the next desired label + } + } + + // If the label is required but not found, return -1 to indicate an invalid match + if desiredLabel.Required && !labelFound { + return -1 + } + } + + return totalWeight +} diff --git a/pkg/scheduling/affinity_output.json b/pkg/scheduling/affinity_output.json new file mode 100644 index 000000000..9d9b18593 --- /dev/null +++ b/pkg/scheduling/affinity_output.json @@ -0,0 +1,40 @@ +{ + "StepRunIds": [ + "16fa711e-c03e-435c-88d9-62adf1591d98", + "064e787a-6cfd-4f82-8b6d-d8031459fdee" + ], + "StepRunTimeouts": [ + "60s", + "60s" + ], + "SlotIds": [ + "8cf68f09-b914-4f31-9777-8082b751a2d4", + "e2c744b8-b914-4f31-9777-8082b751a2d4" + ], + "WorkerIds": [ + "aaaaaaaa-43db-4b5f-876a-c6c71f4d52aa", + "aaaaaaaa-43db-4b5f-876a-c6c71f4d52aa" + ], + "UnassignedStepRunIds": [], + "QueuedStepRuns": [ + { + "StepRunId": "16fa711e-c03e-435c-88d9-62adf1591d98", + "WorkerId": "aaaaaaaa-43db-4b5f-876a-c6c71f4d52aa", + "DispatcherId": "9994a9eb-430d-46da-934d-d9dd953cfd21" + }, + { + "StepRunId": "064e787a-6cfd-4f82-8b6d-d8031459fdee", + "WorkerId": "aaaaaaaa-43db-4b5f-876a-c6c71f4d52aa", + "DispatcherId": "9994a9eb-430d-46da-934d-d9dd953cfd21" + } + ], + "TimedOutStepRuns": [], + "QueuedItems": [ + 137295, + 152259 + ], + "ShouldContinue": false, + "MinQueuedIds": { + "child:process2": 137295 + } +} diff --git a/pkg/scheduling/fixtures/affinity.json b/pkg/scheduling/fixtures/affinity.json new file mode 100644 index 000000000..44f833344 --- /dev/null +++ b/pkg/scheduling/fixtures/affinity.json @@ -0,0 +1,89 @@ +{ + "Slots": [ + { + "id": "8cf68f09-b914-4f31-9777-8082b751a2d4", + "workerId": "aaaaaaaa-43db-4b5f-876a-c6c71f4d52aa", + "dispatcherId": "9994a9eb-430d-46da-934d-d9dd953cfd21", + "actionId": "child:process2" + }, + { + "id": "e2c744b8-b914-4f31-9777-8082b751a2d4", + "workerId": "aaaaaaaa-43db-4b5f-876a-c6c71f4d52aa", + "dispatcherId": "9994a9eb-430d-46da-934d-d9dd953cfd21", + "actionId": "child:process2" + }, + { + "id": "aed946be-43db-4b5f-876a-c6c71f4d52aa", + "workerId": "bbbbbbbb-43db-4b5f-876a-c6c71f4d52aa", + "dispatcherId": "9994a9eb-430d-46da-934d-d9dd953cfd21", + "actionId": "child:process2" + }, + { + "id": "e2c744b8-43db-4b5f-876a-c6c71f4d52aa", + "workerId": "bbbbbbbb-43db-4b5f-876a-c6c71f4d52aa", + "dispatcherId": "9994a9eb-430d-46da-934d-d9dd953cfd21", + "actionId": "child:process2" + } + ], + "UniqueActions": [ + "child:process2" + ], + "QueueItems": [ + { + "id": 137295, + "stepRunId": "16fa711e-c03e-435c-88d9-62adf1591d98", + "stepId": "29fae13a-6672-48cb-9aed-3b0bf481f570", + "actionId": "child:process2", + "scheduleTimeoutAt": "2024-08-07T19:43:30.692Z", + "stepTimeout": "60s", + "isQueued": true, + "tenantId": "707d0855-80ab-4e1f-a156-f1c4546cbf52", + "queue": "child:process2", + "Order": 0 + }, + { + "id": 152259, + "stepRunId": "064e787a-6cfd-4f82-8b6d-d8031459fdee", + "stepId": "e2c744b8-bebb-4b05-8292-d2ce36f27928", + "actionId": "child:process2", + "scheduleTimeoutAt": "2024-08-07T21:12:49.56Z", + "stepTimeout": "60s", + "isQueued": true, + "tenantId": "707d0855-80ab-4e1f-a156-f1c4546cbf52", + "queue": "child:process2", + "Order": 0 + } + ], + "WorkerLabels": { + "aaaaaaaa-43db-4b5f-876a-c6c71f4d52aa": [ + { + "Key": "MODEL", + "StrVal": "A" + } + ], + "bbbbbbbb-43db-4b5f-876a-c6c71f4d52aa": [ + { + "Key": "MODEL", + "StrVal": "B" + } + ] + }, + "StepDesiredLabels": { + "29fae13a-6672-48cb-9aed-3b0bf481f570": [ + { + "Key": "MODEL", + "StrValue": "A", + "Required": true, + "Weight": 1 + } + ], + "e2c744b8-bebb-4b05-8292-d2ce36f27928": [ + { + "Key": "MODEL", + "StrValue": "B", + "Required": true, + "Weight": 1 + } + ] + } +} diff --git a/pkg/scheduling/fixtures/affinity_output.json b/pkg/scheduling/fixtures/affinity_output.json new file mode 100644 index 000000000..10680c3e3 --- /dev/null +++ b/pkg/scheduling/fixtures/affinity_output.json @@ -0,0 +1,40 @@ +{ + "StepRunIds": [ + "16fa711e-c03e-435c-88d9-62adf1591d98", + "064e787a-6cfd-4f82-8b6d-d8031459fdee" + ], + "StepRunTimeouts": [ + "60s", + "60s" + ], + "SlotIds": [ + "8cf68f09-b914-4f31-9777-8082b751a2d4", + "aed946be-43db-4b5f-876a-c6c71f4d52aa" + ], + "WorkerIds": [ + "aaaaaaaa-43db-4b5f-876a-c6c71f4d52aa", + "bbbbbbbb-43db-4b5f-876a-c6c71f4d52aa" + ], + "UnassignedStepRunIds": [], + "QueuedStepRuns": [ + { + "StepRunId": "16fa711e-c03e-435c-88d9-62adf1591d98", + "WorkerId": "aaaaaaaa-43db-4b5f-876a-c6c71f4d52aa", + "DispatcherId": "9994a9eb-430d-46da-934d-d9dd953cfd21" + }, + { + "StepRunId": "064e787a-6cfd-4f82-8b6d-d8031459fdee", + "WorkerId": "bbbbbbbb-43db-4b5f-876a-c6c71f4d52aa", + "DispatcherId": "9994a9eb-430d-46da-934d-d9dd953cfd21" + } + ], + "TimedOutStepRuns": [], + "QueuedItems": [ + 137295, + 152259 + ], + "ShouldContinue": false, + "MinQueuedIds": { + "child:process2": 137295 + } +} diff --git a/pkg/scheduling/schedule_plan.go b/pkg/scheduling/schedule_plan.go index f288108c8..8ca9c8829 100644 --- a/pkg/scheduling/schedule_plan.go +++ b/pkg/scheduling/schedule_plan.go @@ -26,7 +26,7 @@ type SchedulePlan struct { RateLimitUnitsConsumed map[string]int32 } -func (sp *SchedulePlan) UpdateMinQueuedIds(qi QueueItemWithOrder) []repository.QueuedStepRun { +func (sp *SchedulePlan) UpdateMinQueuedIds(qi *QueueItemWithOrder) []repository.QueuedStepRun { if qi.Priority == 1 { if currMinQueued, ok := sp.MinQueuedIds[qi.Queue]; !ok { sp.MinQueuedIds[qi.Queue] = qi.ID @@ -38,25 +38,25 @@ func (sp *SchedulePlan) UpdateMinQueuedIds(qi QueueItemWithOrder) []repository.Q return sp.QueuedStepRuns } -func (plan *SchedulePlan) HandleTimedOut(qi QueueItemWithOrder) { +func (plan *SchedulePlan) HandleTimedOut(qi *QueueItemWithOrder) { plan.TimedOutStepRuns = append(plan.TimedOutStepRuns, qi.StepRunId) // mark as queued so that we don't requeue plan.QueuedItems = append(plan.QueuedItems, qi.ID) } -func (plan *SchedulePlan) HandleNoSlots(qi QueueItemWithOrder) { +func (plan *SchedulePlan) HandleNoSlots(qi *QueueItemWithOrder) { plan.UnassignedStepRunIds = append(plan.UnassignedStepRunIds, qi.StepRunId) } -func (plan *SchedulePlan) HandleUnassigned(qi QueueItemWithOrder) { +func (plan *SchedulePlan) HandleUnassigned(qi *QueueItemWithOrder) { plan.UnassignedStepRunIds = append(plan.UnassignedStepRunIds, qi.StepRunId) } -func (plan *SchedulePlan) HandleRateLimited(qi QueueItemWithOrder) { +func (plan *SchedulePlan) HandleRateLimited(qi *QueueItemWithOrder) { plan.RateLimitedStepRuns = append(plan.RateLimitedStepRuns, qi.StepRunId) } -func (plan *SchedulePlan) AssignQiToSlot(qi QueueItemWithOrder, slot *dbsqlc.ListSemaphoreSlotsToAssignRow) { +func (plan *SchedulePlan) AssignQiToSlot(qi *QueueItemWithOrder, slot *dbsqlc.ListSemaphoreSlotsToAssignRow) { plan.StepRunIds = append(plan.StepRunIds, qi.StepRunId) plan.StepRunTimeouts = append(plan.StepRunTimeouts, qi.StepTimeout.String) plan.SlotIds = append(plan.SlotIds, slot.ID) diff --git a/pkg/scheduling/scheduling.go b/pkg/scheduling/scheduling.go index b6af4f71a..dfe0ce605 100644 --- a/pkg/scheduling/scheduling.go +++ b/pkg/scheduling/scheduling.go @@ -16,13 +16,14 @@ type QueueItemWithOrder struct { Order int } -// Generate generates a random string of n bytes. func GeneratePlan( slots []*dbsqlc.ListSemaphoreSlotsToAssignRow, uniqueActionsArr []string, - queueItems []QueueItemWithOrder, + queueItems []*QueueItemWithOrder, stepRateUnits map[string]map[string]int32, currRateLimits map[string]*dbsqlc.ListRateLimitsForTenantRow, + workerLabels map[string][]*dbsqlc.GetWorkerLabelsRow, + stepDesiredLabels map[string][]*dbsqlc.GetDesiredLabelsRow, ) (SchedulePlan, error) { plan := SchedulePlan{ @@ -45,9 +46,12 @@ func GeneratePlan( // initialize worker states for _, slot := range slots { - if _, ok := workers[sqlchelpers.UUIDToStr(slot.WorkerId)]; !ok { - workers[sqlchelpers.UUIDToStr(slot.WorkerId)] = NewWorkerState( - sqlchelpers.UUIDToStr(slot.WorkerId), + workerId := sqlchelpers.UUIDToStr(slot.WorkerId) + + if _, ok := workers[workerId]; !ok { + workers[workerId] = NewWorkerState( + workerId, + workerLabels[workerId], ) } workers[sqlchelpers.UUIDToStr(slot.WorkerId)].AddSlot(slot) @@ -105,7 +109,7 @@ func GeneratePlan( if !isRateLimited { for _, worker := range workers { - slot, isEmpty := worker.AssignSlot(qi) + slot, isEmpty := worker.AssignSlot(qi, stepDesiredLabels[sqlchelpers.UUIDToStr(qi.StepId)]) if slot == nil { // if we can't assign the slot to the worker then continue @@ -159,7 +163,7 @@ func GeneratePlan( if len(workers) > 0 && len(plan.UnassignedStepRunIds) > 0 { for _, qi := range queueItems { for _, worker := range workers { - if worker.CanAssign(qi) { + if worker.CanAssign(qi, stepDesiredLabels[sqlchelpers.UUIDToStr(qi.StepId)]) { plan.ShouldContinue = true break } diff --git a/pkg/scheduling/scheduling_test.go b/pkg/scheduling/scheduling_test.go index 8955b4e1c..a6a303824 100644 --- a/pkg/scheduling/scheduling_test.go +++ b/pkg/scheduling/scheduling_test.go @@ -14,9 +14,11 @@ import ( ) type args struct { - Slots []*dbsqlc.ListSemaphoreSlotsToAssignRow - UniqueActionsArr []string - QueueItems []QueueItemWithOrder + Slots []*dbsqlc.ListSemaphoreSlotsToAssignRow + UniqueActionsArr []string + QueueItems []*QueueItemWithOrder + WorkerLabels map[string][]*dbsqlc.GetWorkerLabelsRow + StepDesiredLabels map[string][]*dbsqlc.GetDesiredLabelsRow } func loadFixture(filename string, noTimeout bool) (*args, error) { @@ -45,7 +47,7 @@ func loadFixture(filename string, noTimeout bool) (*args, error) { return args, nil } -func assertResult(res SchedulePlan, filename string) (bool, error) { +func assertResult(actual SchedulePlan, filename string) (bool, error) { data, err := os.ReadFile(filename) if err != nil { return false, err @@ -58,39 +60,45 @@ func assertResult(res SchedulePlan, filename string) (bool, error) { } // Compare the results - if res.ShouldContinue != expected.ShouldContinue { + if actual.ShouldContinue != expected.ShouldContinue { return false, fmt.Errorf("ShouldContinue does not match") } - if len(res.StepRunIds) != len(expected.StepRunIds) { + if len(actual.StepRunIds) != len(expected.StepRunIds) { return false, fmt.Errorf("StepRunIds length does not match") } - if len(res.StepRunTimeouts) != len(expected.StepRunTimeouts) { + if len(actual.StepRunTimeouts) != len(expected.StepRunTimeouts) { return false, fmt.Errorf("StepRunTimeouts length does not match") } - if len(res.SlotIds) != len(expected.SlotIds) { + if len(actual.SlotIds) != len(expected.SlotIds) { return false, fmt.Errorf("SlotIds length does not match") } - if len(res.WorkerIds) != len(expected.WorkerIds) { + for i := range actual.QueuedStepRuns { + if actual.QueuedStepRuns[i].WorkerId != expected.QueuedStepRuns[i].WorkerId { + return false, fmt.Errorf("Expected worker mismatch") + } + } + + if len(actual.WorkerIds) != len(expected.WorkerIds) { return false, fmt.Errorf("WorkerIds length does not match") } - if len(res.UnassignedStepRunIds) != len(expected.UnassignedStepRunIds) { + if len(actual.UnassignedStepRunIds) != len(expected.UnassignedStepRunIds) { return false, fmt.Errorf("UnassignedStepRunIds length does not match") } - if len(res.QueuedStepRuns) != len(expected.QueuedStepRuns) { + if len(actual.QueuedStepRuns) != len(expected.QueuedStepRuns) { return false, fmt.Errorf("QueuedStepRuns length does not match") } - if len(res.TimedOutStepRuns) != len(expected.TimedOutStepRuns) { + if len(actual.TimedOutStepRuns) != len(expected.TimedOutStepRuns) { return false, fmt.Errorf("TimedOutStepRuns length does not match") } - if len(res.QueuedItems) != len(expected.QueuedItems) { + if len(actual.QueuedItems) != len(expected.QueuedItems) { return false, fmt.Errorf("QueuedItems length does not match") } @@ -144,6 +152,25 @@ func TestGeneratePlan(t *testing.T) { }, wantErr: assert.NoError, }, + { + name: "GeneratePlan_Affinity", + args: args{ + fixtureArgs: "./fixtures/affinity.json", + fixtureResult: "./fixtures/affinity_output.json", + noTimeout: true, + }, + want: func(s SchedulePlan, fixtureResult string) bool { + // DumpResults(s, "affinity_output.json") + + assert, err := assertResult(s, fixtureResult) + if err != nil { + fmt.Println(err) + } + + return assert + }, + wantErr: assert.NoError, + }, { name: "GeneratePlan_TimedOut", args: args{ @@ -173,7 +200,15 @@ func TestGeneratePlan(t *testing.T) { t.Fatalf("Failed to load fixture: %v", err) } - got, err := GeneratePlan(fixtureData.Slots, fixtureData.UniqueActionsArr, fixtureData.QueueItems, nil, nil) + got, err := GeneratePlan( + fixtureData.Slots, + fixtureData.UniqueActionsArr, + fixtureData.QueueItems, + nil, + nil, + fixtureData.WorkerLabels, + fixtureData.StepDesiredLabels, + ) if !tt.wantErr(t, err, "GeneratePlan_Simple") { return diff --git a/pkg/scheduling/timeout.go b/pkg/scheduling/timeout.go index 457241c64..0eb34d9a9 100644 --- a/pkg/scheduling/timeout.go +++ b/pkg/scheduling/timeout.go @@ -2,7 +2,7 @@ package scheduling import "time" -func IsTimedout(qi QueueItemWithOrder) bool { +func IsTimedout(qi *QueueItemWithOrder) bool { // if the current time is after the scheduleTimeoutAt, then mark this as timed out now := time.Now().UTC().UTC() scheduleTimeoutAt := qi.ScheduleTimeoutAt.Time diff --git a/pkg/scheduling/worker_state.go b/pkg/scheduling/worker_state.go index d4f3b7e4f..6106e9032 100644 --- a/pkg/scheduling/worker_state.go +++ b/pkg/scheduling/worker_state.go @@ -1,6 +1,8 @@ package scheduling import ( + "fmt" + "github.com/hatchet-dev/hatchet/pkg/repository/prisma/dbsqlc" ) @@ -8,13 +10,15 @@ type WorkerState struct { workerId string slots []*dbsqlc.ListSemaphoreSlotsToAssignRow actionIds map[string]struct{} + labels []*dbsqlc.GetWorkerLabelsRow } -func NewWorkerState(workerId string) *WorkerState { +func NewWorkerState(workerId string, labels []*dbsqlc.GetWorkerLabelsRow) *WorkerState { return &WorkerState{ workerId: workerId, slots: make([]*dbsqlc.ListSemaphoreSlotsToAssignRow, 0), actionIds: make(map[string]struct{}), + labels: labels, } } @@ -26,18 +30,27 @@ func (w *WorkerState) AddSlot(slot *dbsqlc.ListSemaphoreSlotsToAssignRow) { } } -func (w *WorkerState) CanAssign(qi QueueItemWithOrder) bool { +func (w *WorkerState) CanAssign(qi *QueueItemWithOrder, desiredLabels []*dbsqlc.GetDesiredLabelsRow) bool { if _, ok := w.actionIds[qi.ActionId.String]; !ok { return false } + if len(desiredLabels) > 0 { + + // TODO cache + weight := ComputeWeight(desiredLabels, w.labels) + + fmt.Println(weight) + return weight >= 0 + } + return true } -func (w *WorkerState) AssignSlot(qi QueueItemWithOrder) (*dbsqlc.ListSemaphoreSlotsToAssignRow, bool) { +func (w *WorkerState) AssignSlot(qi *QueueItemWithOrder, desiredLabels []*dbsqlc.GetDesiredLabelsRow) (*dbsqlc.ListSemaphoreSlotsToAssignRow, bool) { // if the actionId is not in the worker's actionIds, then we can't assign this slot - if !w.CanAssign(qi) { + if !w.CanAssign(qi, desiredLabels) { return nil, false } diff --git a/scheduling.json b/scheduling.json new file mode 100644 index 000000000..0967ef424 --- /dev/null +++ b/scheduling.json @@ -0,0 +1 @@ +{}