Skip to content

Commit

Permalink
Introduce filter in VariableSizedArray type
Browse files Browse the repository at this point in the history
  • Loading branch information
darkdrag00nv2 committed Jul 24, 2023
1 parent 181924e commit 19dbd0e
Show file tree
Hide file tree
Showing 4 changed files with 413 additions and 0 deletions.
97 changes: 97 additions & 0 deletions runtime/interpreter/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -2443,6 +2443,28 @@ func (v *ArrayValue) GetMember(interpreter *Interpreter, locationRange LocationR
)
},
)

case sema.ArrayTypeFilterFunctionName:
return NewHostFunctionValue(
interpreter,
sema.ArrayFilterFunctionType(
v.SemaType(interpreter).ElementType(false),
),
func(invocation Invocation) Value {
interpreter := invocation.Interpreter

funcArgument, ok := invocation.Arguments[0].(FunctionValue)
if !ok {
panic(errors.NewUnreachableError())
}

return v.Filter(
interpreter,
invocation.LocationRange,
funcArgument,
)
},
)
}

return nil
Expand Down Expand Up @@ -2946,6 +2968,81 @@ func (v *ArrayValue) Reverse(
)
}

func (v *ArrayValue) Filter(
interpreter *Interpreter,
locationRange LocationRange,
procedure FunctionValue,
) Value {
filteredValues := make([]Value, 0)
filteredValuesCount := 0
i := 0

iterationInvocation := func(arrayElement Value) Invocation {
return NewInvocation(
interpreter,
nil,
nil,
[]Value{arrayElement},
[]sema.Type{sema.BoolType},
nil,
locationRange,
)
}

iterate := func() {
err := v.array.Iterate(
func(item atree.Value) (bool, error) {
arrayElement := MustConvertStoredValue(interpreter, item)

shouldInclude, ok := procedure.invoke(iterationInvocation(arrayElement)).(BoolValue)
if !ok {
panic(errors.NewUnreachableError())
}

if shouldInclude {
filteredValues = append(filteredValues, arrayElement)
filteredValuesCount++
}

i++
return true, nil
},
)

if err != nil {
panic(errors.NewExternalError(err))
}
}

iterate()

iterationIndex := 0
return NewArrayValueWithIterator(
interpreter,
NewVariableSizedStaticType(interpreter, v.Type.ElementType()),
common.ZeroAddress,
uint64(filteredValuesCount),
func() Value {

if iterationIndex == filteredValuesCount {
return nil
}

value := filteredValues[iterationIndex]
iterationIndex++

return value.Transfer(
interpreter,
locationRange,
atree.Address{},
false,
nil,
nil,
)
},
)
}

// NumberValue
type NumberValue interface {
ComparableValue
Expand Down
60 changes: 60 additions & 0 deletions runtime/sema/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -1795,6 +1795,13 @@ Returns a new array with contents in the reversed order.
Available if the array element type is not resource-kinded.
`

const ArrayTypeFilterFunctionName = "filter"

const arrayTypeFilterFunctionDocString = `
Returns a new array whose elements are filtered by applying the filter function on each element of the original array.
Available if the array element type is not resource-kinded.
`

func getArrayMembers(arrayType ArrayType) map[string]MemberResolver {

members := map[string]MemberResolver{
Expand Down Expand Up @@ -2083,6 +2090,32 @@ func getArrayMembers(arrayType ArrayType) map[string]MemberResolver {
)
},
}

members[ArrayTypeFilterFunctionName] = MemberResolver{
Kind: common.DeclarationKindFunction,
Resolve: func(memoryGauge common.MemoryGauge, identifier string, targetRange ast.Range, report func(error)) *Member {

elementType := arrayType.ElementType(false)

if elementType.IsResourceType() {
report(
&InvalidResourceArrayMemberError{
Name: identifier,
DeclarationKind: common.DeclarationKindFunction,
Range: targetRange,
},
)
}

return NewPublicFunctionMember(
memoryGauge,
arrayType,
identifier,
ArrayFilterFunctionType(elementType),
arrayTypeFilterFunctionDocString,
)
},
}
}

return withBuiltinMembers(arrayType, members)
Expand Down Expand Up @@ -2232,6 +2265,33 @@ func ArrayReverseFunctionType(arrayType ArrayType) *FunctionType {
}
}

func ArrayFilterFunctionType(elementType Type) *FunctionType {
// fun filter(_ function: ((T): Bool)): [T]
// funcType: elementType -> Bool
funcType := &FunctionType{
Parameters: []Parameter{
{
Identifier: "element",
TypeAnnotation: NewTypeAnnotation(elementType),
},
},
ReturnTypeAnnotation: NewTypeAnnotation(BoolType),
}

return &FunctionType{
Parameters: []Parameter{
{
Label: ArgumentLabelNotRequired,
Identifier: "f",
TypeAnnotation: NewTypeAnnotation(funcType),
},
},
ReturnTypeAnnotation: NewTypeAnnotation(&VariableSizedType{
Type: elementType,
}),
}
}

// VariableSizedType is a variable sized array type
type VariableSizedType struct {
Type Type
Expand Down
103 changes: 103 additions & 0 deletions runtime/tests/checker/arrays_dictionaries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,109 @@ func TestCheckResourceArrayReverseInvalid(t *testing.T) {
assert.IsType(t, &sema.InvalidResourceArrayMemberError{}, errs[0])
}

func TestCheckArrayFilter(t *testing.T) {

t.Parallel()

_, err := ParseAndCheck(t, `
fun test() {
let x = [1, 2, 3]
let onlyEven =
fun (_ x: Int): Bool {
return x % 2 == 0
}
let y = x.filter(onlyEven)
}
`)

require.NoError(t, err)
}

func TestCheckArrayFilterInvalidArgs(t *testing.T) {

t.Parallel()

testInvalidArgs := func(code string, expectedErrors []sema.SemanticError) {
_, err := ParseAndCheck(t, code)

errs := RequireCheckerErrors(t, err, len(expectedErrors))

for i, e := range expectedErrors {
assert.IsType(t, e, errs[i])
}
}

testInvalidArgs(`
fun test() {
let x = [1, 2, 3]
let y = x.filter(100)
}
`,
[]sema.SemanticError{
&sema.TypeMismatchError{},
},
)

testInvalidArgs(`
fun test() {
let x = [1, 2, 3]
let onlyEvenInt16 =
fun (_ x: Int16): Bool {
return x % 2 == 0
}
let y = x.filter(onlyEvenInt16)
}
`,
[]sema.SemanticError{
&sema.TypeMismatchError{},
},
)

testInvalidArgs(`
fun test() {
let x : [Int; 5] = [1, 2, 3, 21, 30]
let onlyEvenInt =
fun (_ x: Int): Bool {
return x % 2 == 0
}
let y = x.filter(onlyEvenInt)
}
`,
[]sema.SemanticError{
&sema.NotDeclaredMemberError{},
},
)
}

func TestCheckResourceArrayFilterInvalid(t *testing.T) {

t.Parallel()

_, err := ParseAndCheck(t, `
resource X {}
fun test(): @[X] {
let xs <- [<-create X()]
let allResources =
fun (_ x: @X): Bool {
destroy x
return true
}
let filteredXs <-xs.filter(allResources)
destroy xs
return <- filteredXs
}
`)

errs := RequireCheckerErrors(t, err, 1)

assert.IsType(t, &sema.InvalidResourceArrayMemberError{}, errs[0])
}

func TestCheckArrayContains(t *testing.T) {

t.Parallel()
Expand Down
Loading

0 comments on commit 19dbd0e

Please sign in to comment.