diff --git a/modes/defaultmode/defaultmode.go b/modes/defaultmode/defaultmode.go index 6b4be5b..6c2a1e2 100644 --- a/modes/defaultmode/defaultmode.go +++ b/modes/defaultmode/defaultmode.go @@ -125,7 +125,7 @@ func (d *DefaultMode) generateCode() (err error) { continue } - options := overrideRunOptions(&model.DefaultParams, &d.commandLineOptions.Options) + options := model.DefaultParams.Override(&d.commandLineOptions.Options) fileName, err := f.WriteModel(d.settings.TempPath, options) if err != nil { @@ -143,7 +143,7 @@ func (d *DefaultMode) runCode(frameworks framework.List) { for _, f := range frameworks { model := f.Model() - options := overrideRunOptions(&model.DefaultParams, &d.commandLineOptions.Options) + options := model.DefaultParams.Override(&d.commandLineOptions.Options) result, err := f.Run(options) if err != nil { @@ -156,22 +156,3 @@ func (d *DefaultMode) runCode(frameworks framework.List) { fmt.Println() } } - -// overrideRunOptions overrides options set in the model with any set on the command line. -func overrideRunOptions(modelOptions, cliOptions *runoptions.Options) *runoptions.Options { - options := *modelOptions - - if cliOptions.LogLevel != nil { - options.LogLevel = cliOptions.LogLevel - } - - if cliOptions.TraceActivations != nil { - options.TraceActivations = cliOptions.TraceActivations - } - - if cliOptions.RandomSeed != nil { - options.RandomSeed = cliOptions.RandomSeed - } - - return &options -} diff --git a/util/runoptions/runoptions.go b/util/runoptions/runoptions.go index 058d773..17baa73 100644 --- a/util/runoptions/runoptions.go +++ b/util/runoptions/runoptions.go @@ -69,6 +69,25 @@ func New() Options { } } +// Override applies any overrides set in another Options struct. It returns a new struct. +func (o Options) Override(cliOptions *Options) *Options { + options := o + + if cliOptions.LogLevel != nil { + options.LogLevel = cliOptions.LogLevel + } + + if cliOptions.TraceActivations != nil { + options.TraceActivations = cliOptions.TraceActivations + } + + if cliOptions.RandomSeed != nil { + options.RandomSeed = cliOptions.RandomSeed + } + + return &options +} + // IsValidFramework returns if the framework name is in our list of valid ones or not. func IsValidFramework(frameworkName string) bool { return slices.Contains(ValidFrameworks, frameworkName)