From cb4f1bbca7b60c8f0f53aaedd99f7390b80ac910 Mon Sep 17 00:00:00 2001 From: ted chang Date: Tue, 15 Oct 2024 14:58:17 -0700 Subject: [PATCH] Refactor lmes reconcile optoins Signed-off-by: ted chang --- controllers/lmes/config.go | 108 +++++++++++ controllers/lmes/lmevaljob_controller.go | 131 +++---------- controllers/lmes/lmevaljob_controller_test.go | 176 ++++++++---------- 3 files changed, 208 insertions(+), 207 deletions(-) create mode 100644 controllers/lmes/config.go diff --git a/controllers/lmes/config.go b/controllers/lmes/config.go new file mode 100644 index 0000000..53b4700 --- /dev/null +++ b/controllers/lmes/config.go @@ -0,0 +1,108 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package lmes + +import ( + "fmt" + "reflect" + "strconv" + "strings" + "time" + + "github.com/go-logr/logr" + corev1 "k8s.io/api/core/v1" +) + +var options *serviceOptions = &serviceOptions{ + DriverImage: DefaultDriverImage, + PodImage: DefaultPodImage, + PodCheckingInterval: DefaultPodCheckingInterval, + ImagePullPolicy: DefaultImagePullPolicy, + MaxBatchSize: DefaultMaxBatchSize, + DetectDevice: DefaultDetectDevice, + DefaultBatchSize: DefaultBatchSize, +} + +type serviceOptions struct { + PodImage string + DriverImage string + PodCheckingInterval time.Duration + ImagePullPolicy corev1.PullPolicy + MaxBatchSize int + DefaultBatchSize int + DetectDevice bool +} + +func constructOptionsFromConfigMap(log *logr.Logger, configmap *corev1.ConfigMap) error { + + rv := reflect.ValueOf(options).Elem() + var msgs []string + + for idx, cap := 0, rv.NumField(); idx < cap; idx++ { + frv := rv.Field(idx) + fname := rv.Type().Field(idx).Name + configKey, ok := optionKeys[fname] + if !ok { + continue + } + + if v, found := configmap.Data[configKey]; found { + var err error + switch frv.Type().Name() { + case "string": + frv.SetString(v) + case "bool": + val, err := strconv.ParseBool(v) + if err != nil { + val = DefaultDetectDevice + msgs = append(msgs, fmt.Sprintf("invalid setting for %v: %v, use default setting instead", optionKeys[fname], val)) + } + frv.SetBool(val) + case "int": + var intVal int + intVal, err = strconv.Atoi(v) + if err == nil { + frv.SetInt(int64(intVal)) + } + case "Duration": + var d time.Duration + d, err = time.ParseDuration(v) + if err == nil { + frv.Set(reflect.ValueOf(d)) + } + case "PullPolicy": + if p, found := pullPolicyMap[corev1.PullPolicy(v)]; found { + frv.Set(reflect.ValueOf(p)) + } else { + err = fmt.Errorf("invalid PullPolicy") + } + default: + return fmt.Errorf("can not handle the config %v, type: %v", optionKeys[fname], frv.Type().Name()) + } + + if err != nil { + msgs = append(msgs, fmt.Sprintf("invalid setting for %v: %v, use default setting instead", optionKeys[fname], v)) + } + } + } + + if len(msgs) > 0 && log != nil { + log.Error(fmt.Errorf("some settings in the configmap are invalid"), strings.Join(msgs, "\n")) + } + + return nil +} diff --git a/controllers/lmes/lmevaljob_controller.go b/controllers/lmes/lmevaljob_controller.go index 3240036..f1985ab 100644 --- a/controllers/lmes/lmevaljob_controller.go +++ b/controllers/lmes/lmevaljob_controller.go @@ -21,9 +21,7 @@ import ( "context" "fmt" "maps" - "reflect" "slices" - "strconv" "strings" "sync" "time" @@ -92,22 +90,11 @@ type LMEvalJobReconciler struct { Recorder record.EventRecorder ConfigMap string Namespace string - options *ServiceOptions restConfig *rest.Config restClient rest.Interface pullingJobs *syncedMap4Reconciler } -type ServiceOptions struct { - PodImage string - DriverImage string - PodCheckingInterval time.Duration - ImagePullPolicy corev1.PullPolicy - MaxBatchSize int - DefaultBatchSize int - DetectDevice bool -} - // The registered function to set up LMES controller func ControllerSetUp(mgr manager.Manager, ns, configmap string, recorder record.EventRecorder) error { clientset, err := kubernetes.NewForConfig(mgr.GetConfig()) @@ -209,6 +196,7 @@ func (r *LMEvalJobReconciler) SetupWithManager(mgr ctrl.Manager) error { // Add a runnable to retrieve the settings from the specified configmap if err := mgr.Add(manager.RunnableFunc(func(ctx context.Context) error { var cm corev1.ConfigMap + log := log.FromContext(ctx) if err := r.Get( ctx, types.NamespacedName{Namespace: r.Namespace, Name: r.ConfigMap}, @@ -221,8 +209,8 @@ func (r *LMEvalJobReconciler) SetupWithManager(mgr ctrl.Manager) error { return err } - - if err := r.constructOptionsFromConfigMap(ctx, &cm); err != nil { + var err error + if err = constructOptionsFromConfigMap(&log, &cm); err != nil { return err } @@ -324,77 +312,6 @@ func (r *LMEvalJobReconciler) remoteCommand(ctx context.Context, job *lmesv1alph return outBuff.Bytes(), errBuf.Bytes(), nil } -func (r *LMEvalJobReconciler) constructOptionsFromConfigMap( - ctx context.Context, configmap *corev1.ConfigMap) error { - r.options = &ServiceOptions{ - DriverImage: DefaultDriverImage, - PodImage: DefaultPodImage, - PodCheckingInterval: DefaultPodCheckingInterval, - ImagePullPolicy: DefaultImagePullPolicy, - MaxBatchSize: DefaultMaxBatchSize, - DetectDevice: DefaultDetectDevice, - DefaultBatchSize: DefaultBatchSize, - } - - log := log.FromContext(ctx) - rv := reflect.ValueOf(r.options).Elem() - var msgs []string - - for idx, cap := 0, rv.NumField(); idx < cap; idx++ { - frv := rv.Field(idx) - fname := rv.Type().Field(idx).Name - configKey, ok := optionKeys[fname] - if !ok { - continue - } - - if v, found := configmap.Data[configKey]; found { - var err error - switch frv.Type().Name() { - case "string": - frv.SetString(v) - case "bool": - val, err := strconv.ParseBool(v) - if err != nil { - val = DefaultDetectDevice - msgs = append(msgs, fmt.Sprintf("invalid setting for %v: %v, use default setting instead", optionKeys[fname], val)) - } - frv.SetBool(val) - case "int": - var intVal int - intVal, err = strconv.Atoi(v) - if err == nil { - frv.SetInt(int64(intVal)) - } - case "Duration": - var d time.Duration - d, err = time.ParseDuration(v) - if err == nil { - frv.Set(reflect.ValueOf(d)) - } - case "PullPolicy": - if p, found := pullPolicyMap[corev1.PullPolicy(v)]; found { - frv.Set(reflect.ValueOf(p)) - } else { - err = fmt.Errorf("invalid PullPolicy") - } - default: - return fmt.Errorf("can not handle the config %v, type: %v", optionKeys[fname], frv.Type().Name()) - } - - if err != nil { - msgs = append(msgs, fmt.Sprintf("invalid setting for %v: %v, use default setting instead", optionKeys[fname], v)) - } - } - } - - if len(msgs) > 0 { - log.Error(fmt.Errorf("some settings in the configmap are invalid"), strings.Join(msgs, "\n")) - } - - return nil -} - func (r *LMEvalJobReconciler) handleDeletion(ctx context.Context, job *lmesv1alpha1.LMEvalJob, log logr.Logger) (reconcile.Result, error) { defer r.pullingJobs.remove(string(job.GetUID())) @@ -456,7 +373,7 @@ func (r *LMEvalJobReconciler) handleNewCR(ctx context.Context, log logr.Logger, // construct a new pod and create a pod for the job currentTime := v1.Now() - pod := r.createPod(job, log) + pod := createPod(options, job, log) if err := r.Create(ctx, pod, &client.CreateOptions{}); err != nil { // Failed to create the pod. Mark the status as complete with failed job.Status.State = lmesv1alpha1.CompleteJobState @@ -483,7 +400,7 @@ func (r *LMEvalJobReconciler) handleNewCR(ctx context.Context, log logr.Logger, job.Namespace)) log.Info("Successfully create a Pod for the Job") // Check the pod after the config interval - return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), nil + return r.pullingJobs.addOrUpdate(string(job.GetUID()), options.PodCheckingInterval), nil } func (r *LMEvalJobReconciler) checkScheduledPod(ctx context.Context, log logr.Logger, job *lmesv1alpha1.LMEvalJob) (ctrl.Result, error) { @@ -508,7 +425,7 @@ func (r *LMEvalJobReconciler) checkScheduledPod(ctx context.Context, log logr.Lo if mainIdx := getContainerByName(&pod.Status, "main"); mainIdx == -1 { // waiting for the main container to be up - return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), nil + return r.pullingJobs.addOrUpdate(string(job.GetUID()), options.PodCheckingInterval), nil } else if podFailed, msg := isContainerFailed(&pod.Status.ContainerStatuses[mainIdx]); podFailed { job.Status.State = lmesv1alpha1.CompleteJobState job.Status.Reason = lmesv1alpha1.FailedReason @@ -519,7 +436,7 @@ func (r *LMEvalJobReconciler) checkScheduledPod(ctx context.Context, log logr.Lo log.Info("detect an error on the job's pod. marked the job as done", "name", job.Name) return ctrl.Result{}, err } else if pod.Status.ContainerStatuses[mainIdx].State.Running == nil { - return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), nil + return r.pullingJobs.addOrUpdate(string(job.GetUID()), options.PodCheckingInterval), nil } // pull status from the driver @@ -530,7 +447,7 @@ func (r *LMEvalJobReconciler) checkScheduledPod(ctx context.Context, log logr.Lo if err != nil { log.Error(err, "unable to retrieve the status from the job's pod. retry after the pulling interval") } - return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), nil + return r.pullingJobs.addOrUpdate(string(job.GetUID()), options.PodCheckingInterval), nil } func (r *LMEvalJobReconciler) getPod(ctx context.Context, job *lmesv1alpha1.LMEvalJob) (*corev1.Pod, error) { @@ -579,7 +496,7 @@ func (r *LMEvalJobReconciler) handleComplete(ctx context.Context, log logr.Logge // send shutdown command if the main container is running if err := r.shutdownDriver(ctx, job); err != nil { log.Error(err, "failed to shutdown the job pod. retry after the pulling interval") - return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), nil + return r.pullingJobs.addOrUpdate(string(job.GetUID()), options.PodCheckingInterval), nil } } } else { @@ -617,8 +534,8 @@ func (r *LMEvalJobReconciler) handleCancel(ctx context.Context, log logr.Logger, job.Status.Reason = lmesv1alpha1.CancelledReason if err := r.deleteJobPod(ctx, job); err != nil { // leave the state as is and retry again - log.Error(err, "failed to delete pod. scheduled a retry", "interval", r.options.PodCheckingInterval.String()) - return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), err + log.Error(err, "failed to delete pod. scheduled a retry", "interval", options.PodCheckingInterval.String()) + return r.pullingJobs.addOrUpdate(string(job.GetUID()), options.PodCheckingInterval), err } } @@ -658,7 +575,7 @@ func (r *LMEvalJobReconciler) validateCustomCard(job *lmesv1alpha1.LMEvalJob, lo return nil } -func (r *LMEvalJobReconciler) createPod(job *lmesv1alpha1.LMEvalJob, log logr.Logger) *corev1.Pod { +func createPod(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr.Logger) *corev1.Pod { var allowPrivilegeEscalation = false var runAsNonRootUser = true var ownerRefController = true @@ -712,8 +629,8 @@ func (r *LMEvalJobReconciler) createPod(job *lmesv1alpha1.LMEvalJob, log logr.Lo InitContainers: []corev1.Container{ { Name: "driver", - Image: r.options.DriverImage, - ImagePullPolicy: r.options.ImagePullPolicy, + Image: svcOpts.DriverImage, + ImagePullPolicy: svcOpts.ImagePullPolicy, Command: []string{DriverPath, "--copy", DestDriverPath}, SecurityContext: &corev1.SecurityContext{ AllowPrivilegeEscalation: &allowPrivilegeEscalation, @@ -735,11 +652,11 @@ func (r *LMEvalJobReconciler) createPod(job *lmesv1alpha1.LMEvalJob, log logr.Lo Containers: []corev1.Container{ { Name: "main", - Image: r.options.PodImage, - ImagePullPolicy: r.options.ImagePullPolicy, + Image: svcOpts.PodImage, + ImagePullPolicy: svcOpts.ImagePullPolicy, Env: envVars, - Command: r.generateCmd(job), - Args: r.generateArgs(job, log), + Command: generateCmd(svcOpts, job), + Args: generateArgs(svcOpts, job, log), SecurityContext: &corev1.SecurityContext{ AllowPrivilegeEscalation: &allowPrivilegeEscalation, RunAsUser: &runAsUser, @@ -810,7 +727,7 @@ func mergeMapWithFilters(dest, src map[string]string, prefixFitlers []string, lo } } -func (r *LMEvalJobReconciler) generateArgs(job *lmesv1alpha1.LMEvalJob, log logr.Logger) []string { +func generateArgs(options *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr.Logger) []string { if job == nil { return nil } @@ -844,13 +761,13 @@ func (r *LMEvalJobReconciler) generateArgs(job *lmesv1alpha1.LMEvalJob, log logr cmds = append(cmds, "--log_samples") } // --batch_size - var batchSize = r.options.DefaultBatchSize + var batchSize = options.DefaultBatchSize if job.Spec.BatchSize != nil && *job.Spec.BatchSize > 0 { batchSize = *job.Spec.BatchSize } // This could be done in the webhook if it's enabled. - if batchSize > r.options.MaxBatchSize { - batchSize = r.options.MaxBatchSize + if batchSize > options.MaxBatchSize { + batchSize = options.MaxBatchSize log.Info("batchSize is greater than max-batch-size of the controller's configuration, use the max-batch-size instead") } cmds = append(cmds, "--batch_size", fmt.Sprintf("%d", batchSize)) @@ -870,7 +787,7 @@ func concatTasks(tasks lmesv1alpha1.TaskList) []string { return append(tasks.TaskNames, recipesName...) } -func (r *LMEvalJobReconciler) generateCmd(job *lmesv1alpha1.LMEvalJob) []string { +func generateCmd(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob) []string { if job == nil { return nil } @@ -879,7 +796,7 @@ func (r *LMEvalJobReconciler) generateCmd(job *lmesv1alpha1.LMEvalJob) []string "--output-path", "/opt/app-root/src/output", } - if r.options.DetectDevice { + if svcOpts.DetectDevice { cmds = append(cmds, "--detect-device") } diff --git a/controllers/lmes/lmevaljob_controller_test.go b/controllers/lmes/lmevaljob_controller_test.go index b04b21d..2e08759 100644 --- a/controllers/lmes/lmevaljob_controller_test.go +++ b/controllers/lmes/lmevaljob_controller_test.go @@ -18,6 +18,7 @@ package lmes import ( "context" + "strconv" "testing" "github.com/stretchr/testify/assert" @@ -38,14 +39,12 @@ var ( func Test_SimplePod(t *testing.T) { log := log.FromContext(context.Background()) - lmevalRec := LMEvalJobReconciler{ - Namespace: "test", - options: &ServiceOptions{ - PodImage: "podimage:latest", - DriverImage: "driver:latest", - ImagePullPolicy: corev1.PullAlways, - }, + svcOpts := &serviceOptions{ + PodImage: "podimage:latest", + DriverImage: "driver:latest", + ImagePullPolicy: corev1.PullAlways, } + var job = &lmesv1alpha1.LMEvalJob{ ObjectMeta: metav1.ObjectMeta{ Name: "test", @@ -92,8 +91,8 @@ func Test_SimplePod(t *testing.T) { InitContainers: []corev1.Container{ { Name: "driver", - Image: lmevalRec.options.DriverImage, - ImagePullPolicy: lmevalRec.options.ImagePullPolicy, + Image: svcOpts.DriverImage, + ImagePullPolicy: svcOpts.ImagePullPolicy, Command: []string{DriverPath, "--copy", DestDriverPath}, SecurityContext: &corev1.SecurityContext{ AllowPrivilegeEscalation: &allowPrivilegeEscalation, @@ -115,10 +114,10 @@ func Test_SimplePod(t *testing.T) { Containers: []corev1.Container{ { Name: "main", - Image: lmevalRec.options.PodImage, - ImagePullPolicy: lmevalRec.options.ImagePullPolicy, - Command: lmevalRec.generateCmd(job), - Args: lmevalRec.generateArgs(job, log), + Image: svcOpts.PodImage, + ImagePullPolicy: svcOpts.ImagePullPolicy, + Command: generateCmd(svcOpts, job), + Args: generateArgs(svcOpts, job, log), SecurityContext: &corev1.SecurityContext{ AllowPrivilegeEscalation: &allowPrivilegeEscalation, RunAsUser: &runAsUser, @@ -153,20 +152,17 @@ func Test_SimplePod(t *testing.T) { }, } - newPod := lmevalRec.createPod(job, log) + newPod := createPod(svcOpts, job, log) assert.Equal(t, expect, newPod) } func Test_WithLabelsAnnotationsResourcesVolumes(t *testing.T) { log := log.FromContext(context.Background()) - lmevalRec := LMEvalJobReconciler{ - Namespace: "test", - options: &ServiceOptions{ - PodImage: "podimage:latest", - DriverImage: "driver:latest", - ImagePullPolicy: corev1.PullAlways, - }, + svcOpts := &serviceOptions{ + PodImage: "podimage:latest", + DriverImage: "driver:latest", + ImagePullPolicy: corev1.PullAlways, } var job = &lmesv1alpha1.LMEvalJob{ ObjectMeta: metav1.ObjectMeta{ @@ -254,8 +250,8 @@ func Test_WithLabelsAnnotationsResourcesVolumes(t *testing.T) { InitContainers: []corev1.Container{ { Name: "driver", - Image: lmevalRec.options.DriverImage, - ImagePullPolicy: lmevalRec.options.ImagePullPolicy, + Image: svcOpts.DriverImage, + ImagePullPolicy: svcOpts.ImagePullPolicy, Command: []string{DriverPath, "--copy", DestDriverPath}, SecurityContext: &corev1.SecurityContext{ AllowPrivilegeEscalation: &allowPrivilegeEscalation, @@ -277,10 +273,10 @@ func Test_WithLabelsAnnotationsResourcesVolumes(t *testing.T) { Containers: []corev1.Container{ { Name: "main", - Image: lmevalRec.options.PodImage, - ImagePullPolicy: lmevalRec.options.ImagePullPolicy, - Command: lmevalRec.generateCmd(job), - Args: lmevalRec.generateArgs(job, log), + Image: svcOpts.PodImage, + ImagePullPolicy: svcOpts.ImagePullPolicy, + Command: generateCmd(svcOpts, job), + Args: generateArgs(svcOpts, job, log), SecurityContext: &corev1.SecurityContext{ AllowPrivilegeEscalation: &allowPrivilegeEscalation, RunAsUser: &runAsUser, @@ -333,7 +329,7 @@ func Test_WithLabelsAnnotationsResourcesVolumes(t *testing.T) { }, } - newPod := lmevalRec.createPod(job, log) + newPod := createPod(svcOpts, job, log) assert.Equal(t, expect, newPod) @@ -348,19 +344,16 @@ func Test_WithLabelsAnnotationsResourcesVolumes(t *testing.T) { "custom/annotation1": "annotation1", } - newPod = lmevalRec.createPod(job, log) + newPod = createPod(svcOpts, job, log) assert.Equal(t, expect, newPod) } func Test_EnvSecretsPod(t *testing.T) { log := log.FromContext(context.Background()) - lmevalRec := LMEvalJobReconciler{ - Namespace: "test", - options: &ServiceOptions{ - PodImage: "podimage:latest", - DriverImage: "driver:latest", - ImagePullPolicy: corev1.PullAlways, - }, + svcOpts := &serviceOptions{ + PodImage: "podimage:latest", + DriverImage: "driver:latest", + ImagePullPolicy: corev1.PullAlways, } var job = &lmesv1alpha1.LMEvalJob{ ObjectMeta: metav1.ObjectMeta{ @@ -425,8 +418,8 @@ func Test_EnvSecretsPod(t *testing.T) { InitContainers: []corev1.Container{ { Name: "driver", - Image: lmevalRec.options.DriverImage, - ImagePullPolicy: lmevalRec.options.ImagePullPolicy, + Image: svcOpts.DriverImage, + ImagePullPolicy: svcOpts.ImagePullPolicy, Command: []string{DriverPath, "--copy", DestDriverPath}, SecurityContext: &corev1.SecurityContext{ AllowPrivilegeEscalation: &allowPrivilegeEscalation, @@ -448,8 +441,8 @@ func Test_EnvSecretsPod(t *testing.T) { Containers: []corev1.Container{ { Name: "main", - Image: lmevalRec.options.PodImage, - ImagePullPolicy: lmevalRec.options.ImagePullPolicy, + Image: svcOpts.PodImage, + ImagePullPolicy: svcOpts.ImagePullPolicy, Env: []corev1.EnvVar{ { Name: "my_env", @@ -463,8 +456,8 @@ func Test_EnvSecretsPod(t *testing.T) { }, }, }, - Command: lmevalRec.generateCmd(job), - Args: lmevalRec.generateArgs(job, log), + Command: generateCmd(svcOpts, job), + Args: generateArgs(svcOpts, job, log), SecurityContext: &corev1.SecurityContext{ AllowPrivilegeEscalation: &allowPrivilegeEscalation, RunAsUser: &runAsUser, @@ -499,20 +492,17 @@ func Test_EnvSecretsPod(t *testing.T) { }, } - newPod := lmevalRec.createPod(job, log) + newPod := createPod(svcOpts, job, log) // maybe only verify the envs: Containers[0].Env assert.Equal(t, expect, newPod) } func Test_FileSecretsPod(t *testing.T) { log := log.FromContext(context.Background()) - lmevalRec := LMEvalJobReconciler{ - Namespace: "test", - options: &ServiceOptions{ - PodImage: "podimage:latest", - DriverImage: "driver:latest", - ImagePullPolicy: corev1.PullAlways, - }, + svcOpts := &serviceOptions{ + PodImage: "podimage:latest", + DriverImage: "driver:latest", + ImagePullPolicy: corev1.PullAlways, } var job = &lmesv1alpha1.LMEvalJob{ ObjectMeta: metav1.ObjectMeta{ @@ -587,8 +577,8 @@ func Test_FileSecretsPod(t *testing.T) { InitContainers: []corev1.Container{ { Name: "driver", - Image: lmevalRec.options.DriverImage, - ImagePullPolicy: lmevalRec.options.ImagePullPolicy, + Image: svcOpts.DriverImage, + ImagePullPolicy: svcOpts.ImagePullPolicy, Command: []string{DriverPath, "--copy", DestDriverPath}, SecurityContext: &corev1.SecurityContext{ AllowPrivilegeEscalation: &allowPrivilegeEscalation, @@ -610,10 +600,10 @@ func Test_FileSecretsPod(t *testing.T) { Containers: []corev1.Container{ { Name: "main", - Image: lmevalRec.options.PodImage, - ImagePullPolicy: lmevalRec.options.ImagePullPolicy, - Command: lmevalRec.generateCmd(job), - Args: lmevalRec.generateArgs(job, log), + Image: svcOpts.PodImage, + ImagePullPolicy: svcOpts.ImagePullPolicy, + Command: generateCmd(svcOpts, job), + Args: generateArgs(svcOpts, job, log), SecurityContext: &corev1.SecurityContext{ AllowPrivilegeEscalation: &allowPrivilegeEscalation, RunAsUser: &runAsUser, @@ -667,22 +657,19 @@ func Test_FileSecretsPod(t *testing.T) { }, } - newPod := lmevalRec.createPod(job, log) + newPod := createPod(svcOpts, job, log) // maybe only verify the envs: Containers[0].Env assert.Equal(t, expect, newPod) } func Test_GenerateArgBatchSize(t *testing.T) { log := log.FromContext(context.Background()) - lmevalRec := LMEvalJobReconciler{ - Namespace: "test", - options: &ServiceOptions{ - PodImage: "podimage:latest", - DriverImage: "driver:latest", - ImagePullPolicy: corev1.PullAlways, - MaxBatchSize: 24, - DefaultBatchSize: 8, - }, + svcOpts := &serviceOptions{ + PodImage: "podimage:latest", + DriverImage: "driver:latest", + ImagePullPolicy: corev1.PullAlways, + MaxBatchSize: 20, + DefaultBatchSize: 4, } var job = &lmesv1alpha1.LMEvalJob{ ObjectMeta: metav1.ObjectMeta{ @@ -708,16 +695,16 @@ func Test_GenerateArgBatchSize(t *testing.T) { // no batchSize in the job, use default batchSize assert.Equal(t, []string{ "sh", "-ec", - "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --include_path /opt/app-root/src/my_tasks --batch_size 8", - }, lmevalRec.generateArgs(job, log)) + "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --include_path /opt/app-root/src/my_tasks --batch_size " + strconv.Itoa(svcOpts.DefaultBatchSize), + }, generateArgs(svcOpts, job, log)) // exceed the max-batch-size, use max-batch-size var biggerBatchSize = 30 job.Spec.BatchSize = &biggerBatchSize assert.Equal(t, []string{ "sh", "-ec", - "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --include_path /opt/app-root/src/my_tasks --batch_size 24", - }, lmevalRec.generateArgs(job, log)) + "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --include_path /opt/app-root/src/my_tasks --batch_size " + strconv.Itoa(svcOpts.MaxBatchSize), + }, generateArgs(svcOpts, job, log)) // normal batchSize var normalBatchSize = 16 @@ -725,20 +712,17 @@ func Test_GenerateArgBatchSize(t *testing.T) { assert.Equal(t, []string{ "sh", "-ec", "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --include_path /opt/app-root/src/my_tasks --batch_size 16", - }, lmevalRec.generateArgs(job, log)) + }, generateArgs(svcOpts, job, log)) } func Test_GenerateArgCmdTaskRecipes(t *testing.T) { log := log.FromContext(context.Background()) - lmevalRec := LMEvalJobReconciler{ - Namespace: "test", - options: &ServiceOptions{ - PodImage: "podimage:latest", - DriverImage: "driver:latest", - ImagePullPolicy: corev1.PullAlways, - DefaultBatchSize: DefaultBatchSize, - MaxBatchSize: DefaultMaxBatchSize, - }, + svcOpts := &serviceOptions{ + PodImage: "podimage:latest", + DriverImage: "driver:latest", + ImagePullPolicy: corev1.PullAlways, + MaxBatchSize: options.MaxBatchSize, + DefaultBatchSize: options.DefaultBatchSize, } var format = "unitxt.format" var numDemos = 5 @@ -778,14 +762,14 @@ func Test_GenerateArgCmdTaskRecipes(t *testing.T) { assert.Equal(t, []string{ "sh", "-ec", "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2,tr_0 --include_path /opt/app-root/src/my_tasks --batch_size 8", - }, lmevalRec.generateArgs(job, log)) + }, generateArgs(svcOpts, job, log)) assert.Equal(t, []string{ "/opt/app-root/src/bin/driver", "--output-path", "/opt/app-root/src/output", "--task-recipe", "card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", "--", - }, lmevalRec.generateCmd(job)) + }, generateCmd(svcOpts, job)) job.Spec.TaskList.TaskRecipes = append(job.Spec.TaskList.TaskRecipes, lmesv1alpha1.TaskRecipe{ @@ -803,7 +787,7 @@ func Test_GenerateArgCmdTaskRecipes(t *testing.T) { assert.Equal(t, []string{ "sh", "-ec", "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2,tr_0,tr_1 --include_path /opt/app-root/src/my_tasks --batch_size 8", - }, lmevalRec.generateArgs(job, log)) + }, generateArgs(svcOpts, job, log)) assert.Equal(t, []string{ "/opt/app-root/src/bin/driver", @@ -811,20 +795,17 @@ func Test_GenerateArgCmdTaskRecipes(t *testing.T) { "--task-recipe", "card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", "--task-recipe", "card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10", "--", - }, lmevalRec.generateCmd(job)) + }, generateCmd(svcOpts, job)) } func Test_GenerateArgCmdCustomCard(t *testing.T) { log := log.FromContext(context.Background()) - lmevalRec := LMEvalJobReconciler{ - Namespace: "test", - options: &ServiceOptions{ - PodImage: "podimage:latest", - DriverImage: "driver:latest", - ImagePullPolicy: corev1.PullAlways, - DefaultBatchSize: DefaultBatchSize, - MaxBatchSize: DefaultMaxBatchSize, - }, + svcOpts := &serviceOptions{ + PodImage: "podimage:latest", + DriverImage: "driver:latest", + ImagePullPolicy: corev1.PullAlways, + MaxBatchSize: options.MaxBatchSize, + DefaultBatchSize: options.DefaultBatchSize, } var format = "unitxt.format" var numDemos = 5 @@ -865,7 +846,7 @@ func Test_GenerateArgCmdCustomCard(t *testing.T) { assert.Equal(t, []string{ "sh", "-ec", "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2,tr_0 --include_path /opt/app-root/src/my_tasks --batch_size 8", - }, lmevalRec.generateArgs(job, log)) + }, generateArgs(svcOpts, job, log)) assert.Equal(t, []string{ "/opt/app-root/src/bin/driver", @@ -873,18 +854,13 @@ func Test_GenerateArgCmdCustomCard(t *testing.T) { "--task-recipe", "card=cards.custom_0,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", "--custom-card", `{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "deutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, "--", - }, lmevalRec.generateCmd(job)) + }, generateCmd(svcOpts, job)) } func Test_CustomCardValidation(t *testing.T) { log := log.FromContext(context.Background()) lmevalRec := LMEvalJobReconciler{ Namespace: "test", - options: &ServiceOptions{ - PodImage: "podimage:latest", - DriverImage: "driver:latest", - ImagePullPolicy: corev1.PullAlways, - }, } var job = &lmesv1alpha1.LMEvalJob{ ObjectMeta: metav1.ObjectMeta{