Skip to content

Commit

Permalink
Use pointers in run options so we can tell what's been set
Browse files Browse the repository at this point in the history
  • Loading branch information
asmaloney committed Mar 4, 2024
1 parent 2f32495 commit e5ef0db
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 24 deletions.
6 changes: 3 additions & 3 deletions actr/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,15 +295,15 @@ func (model *Model) SetParam(kv *keyvalue.KeyValue) (err error) {

switch kv.Key {
case "log_level":
model.DefaultParams.LogLevel = runoptions.ACTRLogLevel(*value.Str)
logLevel := runoptions.ACTRLogLevel(*value.Str) // already validated
model.DefaultParams.LogLevel = &logLevel

case "trace_activations":
boolVal, _ := value.AsBool() // already validated
model.DefaultParams.TraceActivations = boolVal
model.DefaultParams.TraceActivations = &boolVal

case "random_seed":
seed := uint32(*value.Number)

model.DefaultParams.RandomSeed = &seed

default:
Expand Down
12 changes: 6 additions & 6 deletions framework/ccm_pyactr/ccm_pyactr.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func (c *CCMPyACTR) WriteModel(path string, options *runoptions.Options) (output
}

// If our model is tracing activations, then write out our support file
if options.TraceActivations {
if *options.TraceActivations {
err = framework.WriteSupportFile(path, gactarActivateTraceFileName, gactarActivateTraceFile)
if err != nil {
return
Expand Down Expand Up @@ -238,7 +238,7 @@ func (c *CCMPyACTR) GenerateCode(options *runoptions.Options) (code []byte, err
c.Writeln(" %s = Memory(%s)", memory.ModuleName(), memory.BufferName())
}

if options.TraceActivations {
if *options.TraceActivations {
c.Writeln(" trace = ActivateTrace(%s)", memory.ModuleName())
}

Expand Down Expand Up @@ -287,7 +287,7 @@ func (c *CCMPyACTR) GenerateCode(options *runoptions.Options) (code []byte, err
c.Writeln("")
}

if options.LogLevel == "info" {
if *options.LogLevel == "info" {
// this turns on some logging at the high level
c.Writeln(" def __init__(self):")
c.Writeln(" super().__init__(log=True)")
Expand Down Expand Up @@ -380,7 +380,7 @@ func (c CCMPyACTR) writeImports(runOptions *runoptions.Options) {
c.Write("from python_actr import %s\n", strings.Join(additionalImports, ", "))
}

if runOptions.LogLevel == "detail" {
if *runOptions.LogLevel == "detail" {
c.Writeln("from python_actr import log, log_everything")
}

Expand All @@ -389,7 +389,7 @@ func (c CCMPyACTR) writeImports(runOptions *runoptions.Options) {
c.Writeln(fmt.Sprintf("from %s import CCMPrint", ccmPrintImportName))
}

if runOptions.TraceActivations {
if *runOptions.TraceActivations {
c.Writeln("")
c.Writeln(fmt.Sprintf("from %s import ActivateTrace", gactarActivateTraceImportName))
}
Expand Down Expand Up @@ -494,7 +494,7 @@ func (c CCMPyACTR) writeMain(runOptions *runoptions.Options) {
c.Writeln("if __name__ == \"__main__\":")
c.Writeln(fmt.Sprintf(" model = %s()", c.className))

if runOptions.LogLevel == "detail" {
if *runOptions.LogLevel == "detail" {
c.Writeln(" log(summary=1)")
c.Writeln(" log_everything(model)")
}
Expand Down
6 changes: 3 additions & 3 deletions framework/pyactr/pyactr.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ func (p *PyACTR) GenerateCode(options *runoptions.Options) (code []byte, err err
p.Writeln(" rule_firing=%s,", numbers.Float64Str(*procedural.DefaultActionTime))
}

if options.TraceActivations {
if *options.TraceActivations {
p.Writeln(" activation_trace=True,")
}

Expand Down Expand Up @@ -498,14 +498,14 @@ func (p PyACTR) writeMain(runOptions *runoptions.Options) {

options := []string{"gui=False"}

if runOptions.LogLevel == "min" {
if *runOptions.LogLevel == "min" {
options = append(options, "trace=False")
}

p.Writeln(" sim = %s.simulation( %s )", p.className, strings.Join(options, ", "))
p.Writeln(" sim.run()")

if runOptions.LogLevel != "min" {
if *runOptions.LogLevel != "min" {
p.Writeln(" if goal.test_buffer('full'):")
p.Writeln(" print('chunk left in goal: ' + str(goal.pop()))")
p.Writeln(" if %s.retrieval.test_buffer('full'):", p.className)
Expand Down
4 changes: 2 additions & 2 deletions framework/vanilla_actr/vanilla_actr.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ func (v *VanillaACTR) GenerateCode(options *runoptions.Options) (code []byte, er
v.Writeln("\t:dat %s", numbers.Float64Str(*procedural.DefaultActionTime))
}

switch options.LogLevel {
switch *options.LogLevel {
case "min":
v.Writeln("\t:trace-detail low")
case "info":
Expand All @@ -252,7 +252,7 @@ func (v *VanillaACTR) GenerateCode(options *runoptions.Options) (code []byte, er
v.Writeln("\t:trace-detail high")
}

if options.TraceActivations {
if *options.TraceActivations {
v.Writeln("\t:act t")
}

Expand Down
15 changes: 11 additions & 4 deletions modes/web/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,23 @@ func (w Web) actrOptionsFromJSON(defaults *runoptions.Options, options *runOptio
opts.Frameworks = options.Frameworks

if options.LogLevel != nil {
opts.LogLevel = runoptions.ACTRLogLevel(*options.LogLevel)
if !runoptions.ValidLogLevel(*options.LogLevel) {
err = runoptions.ErrInvalidLogLevel{Level: *options.LogLevel}
} else {
logLevel := runoptions.ACTRLogLevel(*options.LogLevel)
opts.LogLevel = &logLevel
}
}

if options.TraceActivations != nil {
opts.TraceActivations = *options.TraceActivations
opts.TraceActivations = options.TraceActivations
}

opts.RandomSeed = options.RandomSeed
if options.RandomSeed != nil {
opts.RandomSeed = options.RandomSeed
}

return &opts, nil
return &opts, err
}

func generateModel(amodFile string) (model *actr.Model, err error) {
Expand Down
15 changes: 13 additions & 2 deletions util/runoptions/errors.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package runoptions

import "fmt"
import (
"fmt"
"strings"
)

type ErrFrameworkNotActive struct {
Name string
Expand All @@ -15,5 +18,13 @@ type ErrInvalidFrameworkName struct {
}

func (e ErrInvalidFrameworkName) Error() string {
return fmt.Sprintf("invalid framework name: %q", e.Name)
return fmt.Sprintf("invalid framework name: %q; expected one of %q", e.Name, strings.Join(ValidFrameworks, ", "))
}

type ErrInvalidLogLevel struct {
Level string
}

func (e ErrInvalidLogLevel) Error() string {
return fmt.Sprintf("invalid log level: %q; expected one of %q", e.Level, strings.Join(ACTRLoggingLevels, ", "))
}
11 changes: 7 additions & 4 deletions util/runoptions/runoptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ type Options struct {
InitialBuffers InitialBuffers

// One of 'min', 'info', or 'detail'
LogLevel ACTRLogLevel
LogLevel *ACTRLogLevel

// If true, output detailed info about activations
TraceActivations bool
TraceActivations *bool

// The seed to use for generating pseudo-random numbers (allows for reproducible runs)
// For all frameworks, if it is not set it uses current system time.
Expand All @@ -58,10 +58,13 @@ type Options struct {

// New returns a default-initialized Options struct.
func New() Options {
logLevel := ACTRLogLevel("info")
activations := false

return Options{
Frameworks: FrameworkNameList{"all"},
LogLevel: ACTRLogLevel("info"),
TraceActivations: false,
LogLevel: &logLevel,
TraceActivations: &activations,
RandomSeed: nil,
}
}
Expand Down

0 comments on commit e5ef0db

Please sign in to comment.