Skip to content

Commit

Permalink
Use param.ParametersInterface to validate model parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
asmaloney committed Aug 6, 2023
1 parent c32faa5 commit 52d0a7d
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 31 deletions.
59 changes: 35 additions & 24 deletions actr/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
package actr

import (
"fmt"
"strings"

"github.com/asmaloney/gactar/actr/buffer"
"github.com/asmaloney/gactar/actr/modules"
"github.com/asmaloney/gactar/actr/param"
Expand Down Expand Up @@ -60,6 +57,9 @@ type Model struct {
Productions []*Production

Options

// Used to validate our parameters
parameters param.ParametersInterface
}

type Initializer struct {
Expand Down Expand Up @@ -93,6 +93,32 @@ func (model *Model) Initialize() {
model.Modules = append(model.Modules, model.Procedural)

model.LogLevel = "info"

// Declare our parameters
loggingParam := param.NewStr(
"log_level",
"Level of logging output",
ACTRLoggingLevels,
)

traceParam := param.NewBool(
"trace_activations",
"output detailed info about activations",
)

seedParam := param.NewInt(
"random_seed",
"the seed to use for generating pseudo-random numbers",
param.Ptr(0), nil,
)

parameters := param.NewParameters(param.InfoMap{
"log_level": loggingParam,
"trace_activations": traceParam,
"random_seed": seedParam,
})

model.parameters = parameters
}

func (m *Model) SetRunOptions(options *Options) {
Expand Down Expand Up @@ -256,37 +282,22 @@ func (model Model) LookupBuffer(bufferName string) buffer.Interface {
}

func (model *Model) SetParam(kv *keyvalue.KeyValue) (err error) {
err = model.parameters.ValidateParam(kv)
if err != nil {
return
}

value := kv.Value

switch kv.Key {
case "log_level":
if (value.Str == nil) || !ValidLogLevel(*value.Str) {
context := fmt.Sprintf("(expected one of: %s)", strings.Join(ACTRLoggingLevels, ", "))

val := value.String()

return param.ErrInvalidValue{
ParameterName: "log_level",
Value: val,
Context: &context,
}
}

model.LogLevel = ACTRLogLevel(*value.Str)

case "trace_activations":
boolVal, err := value.AsBool()
if err != nil {
return err
}

boolVal, _ := value.AsBool() // already validated
model.TraceActivations = boolVal

case "random_seed":
if value.Number == nil {
return keyvalue.ErrInvalidType{ExpectedType: keyvalue.Number}
}

seed := uint32(*value.Number)

model.RandomSeed = &seed
Expand Down
37 changes: 36 additions & 1 deletion actr/param/param.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ func (e ErrValueOutOfRange) Error() string {
}

type ErrInvalidType struct {
FoundType string
ExpectedType string
}

func (e ErrInvalidType) Error() string {
return fmt.Sprintf("invalid type (expected %s)", e.ExpectedType)
return fmt.Sprintf("invalid type (found %s; expected %s)", e.FoundType, e.ExpectedType)
}

type ErrInvalidValue struct {
Expand Down Expand Up @@ -78,6 +79,11 @@ type info struct {
description string
}

// Bool is a boolean parameter
type Bool struct {
info
}

// Str is a string parameter
type Str struct {
info
Expand Down Expand Up @@ -156,11 +162,13 @@ func (p parameters) ValidateParam(param *keyvalue.KeyValue) (err error) {
}

value := param.Value
valueType := value.Type()

switch pInfo := paramInfo.(type) {
case Str:
if value.Str == nil {
return ErrInvalidType{
FoundType: valueType,
ExpectedType: "string",
}
}
Expand All @@ -179,6 +187,7 @@ func (p parameters) ValidateParam(param *keyvalue.KeyValue) (err error) {
case Int:
if value.Number == nil {
return ErrInvalidType{
FoundType: valueType,
ExpectedType: "number",
}
}
Expand All @@ -193,6 +202,7 @@ func (p parameters) ValidateParam(param *keyvalue.KeyValue) (err error) {
case Float:
if value.Number == nil {
return ErrInvalidType{
FoundType: valueType,
ExpectedType: "number",
}
}
Expand All @@ -203,6 +213,24 @@ func (p parameters) ValidateParam(param *keyvalue.KeyValue) (err error) {
if err != nil {
return
}

case Bool:
if value.ID == nil {
return ErrInvalidType{
FoundType: valueType,
ExpectedType: "true or false",
}
}

if !slices.Contains(keyvalue.BooleanValues, *value.ID) {
context := fmt.Sprintf("(expected one of: %s)", strings.Join(keyvalue.BooleanValues, ", "))

return ErrInvalidValue{
ParameterName: param.Key,
Value: *value.ID,
Context: &context,
}
}
}

if !value.IsSet() {
Expand Down Expand Up @@ -283,3 +311,10 @@ func NewFloat(name, description string, min, max *float64) Float {
min, max,
}
}

// NewBool creates a new boolean param
func NewBool(name, description string) Bool {
return Bool{
info: info{name, description},
}
}
8 changes: 4 additions & 4 deletions amod/amod_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func Example_gactarUnrecognizedLogLevel() {
~~ productions ~~`)

// Output:
// ERROR: 'log_level' invalid value "bar" for option "log_level" (expected one of: min, info, detail) (line 5, col 21)
// ERROR: 'log_level' invalid type (found id; expected string) (line 5, col 21)
}

func Example_gactarUnrecognizedNestedValue() {
Expand All @@ -64,7 +64,7 @@ func Example_gactarFieldNotANestedValue() {
~~ productions ~~`)

// Output:
// ERROR: 'log_level' invalid value "" for option "log_level" (expected one of: min, info, detail) (line 5, col 21)
// ERROR: 'log_level' invalid type (found <none>; expected string) (line 5, col 21)
}

func Example_gactarSpaceSeparator() {
Expand Down Expand Up @@ -101,7 +101,7 @@ func Example_gactarTraceActivationsNonBool() {
~~ productions ~~`)

// Output:
// ERROR: 'trace_activations' must be 'true' or 'false' (line 5, col 29)
// ERROR: 'trace_activations' invalid type (found number; expected true or false) (line 5, col 29)
}

func Example_chunkInternalType() {
Expand Down Expand Up @@ -248,7 +248,7 @@ func Example_imaginalFieldType() {
~~ productions ~~`)

// Output:
// ERROR: imaginal "delay" invalid type (expected number) (line 6, col 20)
// ERROR: imaginal "delay" invalid type (found string; expected number) (line 6, col 20)
}

func Example_imaginalFieldRange() {
Expand Down
22 changes: 20 additions & 2 deletions util/keyvalue/keyvalue.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"golang.org/x/exp/slices"
)

var boolean = []string{
var BooleanValues = []string{
"true",
"false",
}
Expand Down Expand Up @@ -48,8 +48,26 @@ func (v Value) String() string {
return ""
}

func (v Value) Type() string {
switch {
case v.ID != nil:
return "id"

case v.Str != nil:
return "string"

case v.Number != nil:
return "number"

case v.Field != nil:
return "field"
}

return "<none>"
}

func (v Value) AsBool() (bool, error) {
if (v.ID == nil) || !slices.Contains(boolean, *v.ID) {
if (v.ID == nil) || !slices.Contains(BooleanValues, *v.ID) {
return false, ErrInvalidType{ExpectedType: Boolean}
}

Expand Down

0 comments on commit 52d0a7d

Please sign in to comment.