Skip to content

Commit

Permalink
Refactor lmes reconcile optoins
Browse files Browse the repository at this point in the history
Signed-off-by: ted chang <[email protected]>
  • Loading branch information
tedhtchang committed Oct 16, 2024
1 parent ab6bc98 commit cb4f1bb
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 207 deletions.
108 changes: 108 additions & 0 deletions controllers/lmes/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
Copyright 2024.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package lmes

import (
"fmt"
"reflect"
"strconv"
"strings"
"time"

"github.com/go-logr/logr"
corev1 "k8s.io/api/core/v1"
)

var options *serviceOptions = &serviceOptions{
DriverImage: DefaultDriverImage,
PodImage: DefaultPodImage,
PodCheckingInterval: DefaultPodCheckingInterval,
ImagePullPolicy: DefaultImagePullPolicy,
MaxBatchSize: DefaultMaxBatchSize,
DetectDevice: DefaultDetectDevice,
DefaultBatchSize: DefaultBatchSize,
}

type serviceOptions struct {
PodImage string
DriverImage string
PodCheckingInterval time.Duration
ImagePullPolicy corev1.PullPolicy
MaxBatchSize int
DefaultBatchSize int
DetectDevice bool
}

func constructOptionsFromConfigMap(log *logr.Logger, configmap *corev1.ConfigMap) error {

rv := reflect.ValueOf(options).Elem()
var msgs []string

for idx, cap := 0, rv.NumField(); idx < cap; idx++ {
frv := rv.Field(idx)
fname := rv.Type().Field(idx).Name
configKey, ok := optionKeys[fname]
if !ok {
continue
}

if v, found := configmap.Data[configKey]; found {
var err error
switch frv.Type().Name() {
case "string":
frv.SetString(v)
case "bool":
val, err := strconv.ParseBool(v)
if err != nil {
val = DefaultDetectDevice
msgs = append(msgs, fmt.Sprintf("invalid setting for %v: %v, use default setting instead", optionKeys[fname], val))
}
frv.SetBool(val)
case "int":
var intVal int
intVal, err = strconv.Atoi(v)
if err == nil {
frv.SetInt(int64(intVal))
}
case "Duration":
var d time.Duration
d, err = time.ParseDuration(v)
if err == nil {
frv.Set(reflect.ValueOf(d))
}
case "PullPolicy":
if p, found := pullPolicyMap[corev1.PullPolicy(v)]; found {
frv.Set(reflect.ValueOf(p))
} else {
err = fmt.Errorf("invalid PullPolicy")
}
default:
return fmt.Errorf("can not handle the config %v, type: %v", optionKeys[fname], frv.Type().Name())
}

if err != nil {
msgs = append(msgs, fmt.Sprintf("invalid setting for %v: %v, use default setting instead", optionKeys[fname], v))
}
}
}

if len(msgs) > 0 && log != nil {
log.Error(fmt.Errorf("some settings in the configmap are invalid"), strings.Join(msgs, "\n"))
}

return nil
}
131 changes: 24 additions & 107 deletions controllers/lmes/lmevaljob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ import (
"context"
"fmt"
"maps"
"reflect"
"slices"
"strconv"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -92,22 +90,11 @@ type LMEvalJobReconciler struct {
Recorder record.EventRecorder
ConfigMap string
Namespace string
options *ServiceOptions
restConfig *rest.Config
restClient rest.Interface
pullingJobs *syncedMap4Reconciler
}

type ServiceOptions struct {
PodImage string
DriverImage string
PodCheckingInterval time.Duration
ImagePullPolicy corev1.PullPolicy
MaxBatchSize int
DefaultBatchSize int
DetectDevice bool
}

// The registered function to set up LMES controller
func ControllerSetUp(mgr manager.Manager, ns, configmap string, recorder record.EventRecorder) error {
clientset, err := kubernetes.NewForConfig(mgr.GetConfig())
Expand Down Expand Up @@ -209,6 +196,7 @@ func (r *LMEvalJobReconciler) SetupWithManager(mgr ctrl.Manager) error {
// Add a runnable to retrieve the settings from the specified configmap
if err := mgr.Add(manager.RunnableFunc(func(ctx context.Context) error {
var cm corev1.ConfigMap
log := log.FromContext(ctx)
if err := r.Get(
ctx,
types.NamespacedName{Namespace: r.Namespace, Name: r.ConfigMap},
Expand All @@ -221,8 +209,8 @@ func (r *LMEvalJobReconciler) SetupWithManager(mgr ctrl.Manager) error {

return err
}

if err := r.constructOptionsFromConfigMap(ctx, &cm); err != nil {
var err error
if err = constructOptionsFromConfigMap(&log, &cm); err != nil {
return err
}

Expand Down Expand Up @@ -324,77 +312,6 @@ func (r *LMEvalJobReconciler) remoteCommand(ctx context.Context, job *lmesv1alph
return outBuff.Bytes(), errBuf.Bytes(), nil
}

func (r *LMEvalJobReconciler) constructOptionsFromConfigMap(
ctx context.Context, configmap *corev1.ConfigMap) error {
r.options = &ServiceOptions{
DriverImage: DefaultDriverImage,
PodImage: DefaultPodImage,
PodCheckingInterval: DefaultPodCheckingInterval,
ImagePullPolicy: DefaultImagePullPolicy,
MaxBatchSize: DefaultMaxBatchSize,
DetectDevice: DefaultDetectDevice,
DefaultBatchSize: DefaultBatchSize,
}

log := log.FromContext(ctx)
rv := reflect.ValueOf(r.options).Elem()
var msgs []string

for idx, cap := 0, rv.NumField(); idx < cap; idx++ {
frv := rv.Field(idx)
fname := rv.Type().Field(idx).Name
configKey, ok := optionKeys[fname]
if !ok {
continue
}

if v, found := configmap.Data[configKey]; found {
var err error
switch frv.Type().Name() {
case "string":
frv.SetString(v)
case "bool":
val, err := strconv.ParseBool(v)
if err != nil {
val = DefaultDetectDevice
msgs = append(msgs, fmt.Sprintf("invalid setting for %v: %v, use default setting instead", optionKeys[fname], val))
}
frv.SetBool(val)
case "int":
var intVal int
intVal, err = strconv.Atoi(v)
if err == nil {
frv.SetInt(int64(intVal))
}
case "Duration":
var d time.Duration
d, err = time.ParseDuration(v)
if err == nil {
frv.Set(reflect.ValueOf(d))
}
case "PullPolicy":
if p, found := pullPolicyMap[corev1.PullPolicy(v)]; found {
frv.Set(reflect.ValueOf(p))
} else {
err = fmt.Errorf("invalid PullPolicy")
}
default:
return fmt.Errorf("can not handle the config %v, type: %v", optionKeys[fname], frv.Type().Name())
}

if err != nil {
msgs = append(msgs, fmt.Sprintf("invalid setting for %v: %v, use default setting instead", optionKeys[fname], v))
}
}
}

if len(msgs) > 0 {
log.Error(fmt.Errorf("some settings in the configmap are invalid"), strings.Join(msgs, "\n"))
}

return nil
}

func (r *LMEvalJobReconciler) handleDeletion(ctx context.Context, job *lmesv1alpha1.LMEvalJob, log logr.Logger) (reconcile.Result, error) {
defer r.pullingJobs.remove(string(job.GetUID()))

Expand Down Expand Up @@ -456,7 +373,7 @@ func (r *LMEvalJobReconciler) handleNewCR(ctx context.Context, log logr.Logger,

// construct a new pod and create a pod for the job
currentTime := v1.Now()
pod := r.createPod(job, log)
pod := createPod(options, job, log)
if err := r.Create(ctx, pod, &client.CreateOptions{}); err != nil {
// Failed to create the pod. Mark the status as complete with failed
job.Status.State = lmesv1alpha1.CompleteJobState
Expand All @@ -483,7 +400,7 @@ func (r *LMEvalJobReconciler) handleNewCR(ctx context.Context, log logr.Logger,
job.Namespace))
log.Info("Successfully create a Pod for the Job")
// Check the pod after the config interval
return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), nil
return r.pullingJobs.addOrUpdate(string(job.GetUID()), options.PodCheckingInterval), nil
}

func (r *LMEvalJobReconciler) checkScheduledPod(ctx context.Context, log logr.Logger, job *lmesv1alpha1.LMEvalJob) (ctrl.Result, error) {
Expand All @@ -508,7 +425,7 @@ func (r *LMEvalJobReconciler) checkScheduledPod(ctx context.Context, log logr.Lo

if mainIdx := getContainerByName(&pod.Status, "main"); mainIdx == -1 {
// waiting for the main container to be up
return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), nil
return r.pullingJobs.addOrUpdate(string(job.GetUID()), options.PodCheckingInterval), nil
} else if podFailed, msg := isContainerFailed(&pod.Status.ContainerStatuses[mainIdx]); podFailed {
job.Status.State = lmesv1alpha1.CompleteJobState
job.Status.Reason = lmesv1alpha1.FailedReason
Expand All @@ -519,7 +436,7 @@ func (r *LMEvalJobReconciler) checkScheduledPod(ctx context.Context, log logr.Lo
log.Info("detect an error on the job's pod. marked the job as done", "name", job.Name)
return ctrl.Result{}, err
} else if pod.Status.ContainerStatuses[mainIdx].State.Running == nil {
return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), nil
return r.pullingJobs.addOrUpdate(string(job.GetUID()), options.PodCheckingInterval), nil
}

// pull status from the driver
Expand All @@ -530,7 +447,7 @@ func (r *LMEvalJobReconciler) checkScheduledPod(ctx context.Context, log logr.Lo
if err != nil {
log.Error(err, "unable to retrieve the status from the job's pod. retry after the pulling interval")
}
return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), nil
return r.pullingJobs.addOrUpdate(string(job.GetUID()), options.PodCheckingInterval), nil
}

func (r *LMEvalJobReconciler) getPod(ctx context.Context, job *lmesv1alpha1.LMEvalJob) (*corev1.Pod, error) {
Expand Down Expand Up @@ -579,7 +496,7 @@ func (r *LMEvalJobReconciler) handleComplete(ctx context.Context, log logr.Logge
// send shutdown command if the main container is running
if err := r.shutdownDriver(ctx, job); err != nil {
log.Error(err, "failed to shutdown the job pod. retry after the pulling interval")
return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), nil
return r.pullingJobs.addOrUpdate(string(job.GetUID()), options.PodCheckingInterval), nil
}
}
} else {
Expand Down Expand Up @@ -617,8 +534,8 @@ func (r *LMEvalJobReconciler) handleCancel(ctx context.Context, log logr.Logger,
job.Status.Reason = lmesv1alpha1.CancelledReason
if err := r.deleteJobPod(ctx, job); err != nil {
// leave the state as is and retry again
log.Error(err, "failed to delete pod. scheduled a retry", "interval", r.options.PodCheckingInterval.String())
return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), err
log.Error(err, "failed to delete pod. scheduled a retry", "interval", options.PodCheckingInterval.String())
return r.pullingJobs.addOrUpdate(string(job.GetUID()), options.PodCheckingInterval), err
}
}

Expand Down Expand Up @@ -658,7 +575,7 @@ func (r *LMEvalJobReconciler) validateCustomCard(job *lmesv1alpha1.LMEvalJob, lo
return nil
}

func (r *LMEvalJobReconciler) createPod(job *lmesv1alpha1.LMEvalJob, log logr.Logger) *corev1.Pod {
func createPod(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr.Logger) *corev1.Pod {
var allowPrivilegeEscalation = false
var runAsNonRootUser = true
var ownerRefController = true
Expand Down Expand Up @@ -712,8 +629,8 @@ func (r *LMEvalJobReconciler) createPod(job *lmesv1alpha1.LMEvalJob, log logr.Lo
InitContainers: []corev1.Container{
{
Name: "driver",
Image: r.options.DriverImage,
ImagePullPolicy: r.options.ImagePullPolicy,
Image: svcOpts.DriverImage,
ImagePullPolicy: svcOpts.ImagePullPolicy,
Command: []string{DriverPath, "--copy", DestDriverPath},
SecurityContext: &corev1.SecurityContext{
AllowPrivilegeEscalation: &allowPrivilegeEscalation,
Expand All @@ -735,11 +652,11 @@ func (r *LMEvalJobReconciler) createPod(job *lmesv1alpha1.LMEvalJob, log logr.Lo
Containers: []corev1.Container{
{
Name: "main",
Image: r.options.PodImage,
ImagePullPolicy: r.options.ImagePullPolicy,
Image: svcOpts.PodImage,
ImagePullPolicy: svcOpts.ImagePullPolicy,
Env: envVars,
Command: r.generateCmd(job),
Args: r.generateArgs(job, log),
Command: generateCmd(svcOpts, job),
Args: generateArgs(svcOpts, job, log),
SecurityContext: &corev1.SecurityContext{
AllowPrivilegeEscalation: &allowPrivilegeEscalation,
RunAsUser: &runAsUser,
Expand Down Expand Up @@ -810,7 +727,7 @@ func mergeMapWithFilters(dest, src map[string]string, prefixFitlers []string, lo
}
}

func (r *LMEvalJobReconciler) generateArgs(job *lmesv1alpha1.LMEvalJob, log logr.Logger) []string {
func generateArgs(options *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr.Logger) []string {
if job == nil {
return nil
}
Expand Down Expand Up @@ -844,13 +761,13 @@ func (r *LMEvalJobReconciler) generateArgs(job *lmesv1alpha1.LMEvalJob, log logr
cmds = append(cmds, "--log_samples")
}
// --batch_size
var batchSize = r.options.DefaultBatchSize
var batchSize = options.DefaultBatchSize
if job.Spec.BatchSize != nil && *job.Spec.BatchSize > 0 {
batchSize = *job.Spec.BatchSize
}
// This could be done in the webhook if it's enabled.
if batchSize > r.options.MaxBatchSize {
batchSize = r.options.MaxBatchSize
if batchSize > options.MaxBatchSize {
batchSize = options.MaxBatchSize
log.Info("batchSize is greater than max-batch-size of the controller's configuration, use the max-batch-size instead")
}
cmds = append(cmds, "--batch_size", fmt.Sprintf("%d", batchSize))
Expand All @@ -870,7 +787,7 @@ func concatTasks(tasks lmesv1alpha1.TaskList) []string {
return append(tasks.TaskNames, recipesName...)
}

func (r *LMEvalJobReconciler) generateCmd(job *lmesv1alpha1.LMEvalJob) []string {
func generateCmd(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob) []string {
if job == nil {
return nil
}
Expand All @@ -879,7 +796,7 @@ func (r *LMEvalJobReconciler) generateCmd(job *lmesv1alpha1.LMEvalJob) []string
"--output-path", "/opt/app-root/src/output",
}

if r.options.DetectDevice {
if svcOpts.DetectDevice {
cmds = append(cmds, "--detect-device")
}

Expand Down
Loading

0 comments on commit cb4f1bb

Please sign in to comment.