diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index 6b11e3737a..7767571b24 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -2443,6 +2443,28 @@ func (v *ArrayValue) GetMember(interpreter *Interpreter, locationRange LocationR ) }, ) + + case sema.ArrayTypeFilterFunctionName: + return NewHostFunctionValue( + interpreter, + sema.ArrayFilterFunctionType( + v.SemaType(interpreter).ElementType(false), + ), + func(invocation Invocation) Value { + interpreter := invocation.Interpreter + + funcArgument, ok := invocation.Arguments[0].(FunctionValue) + if !ok { + panic(errors.NewUnreachableError()) + } + + return v.Filter( + interpreter, + invocation.LocationRange, + funcArgument, + ) + }, + ) } return nil @@ -2946,6 +2968,81 @@ func (v *ArrayValue) Reverse( ) } +func (v *ArrayValue) Filter( + interpreter *Interpreter, + locationRange LocationRange, + procedure FunctionValue, +) Value { + filteredValues := make([]Value, 0) + filteredValuesCount := 0 + i := 0 + + iterationInvocation := func(arrayElement Value) Invocation { + return NewInvocation( + interpreter, + nil, + nil, + []Value{arrayElement}, + []sema.Type{sema.BoolType}, + nil, + locationRange, + ) + } + + iterate := func() { + err := v.array.Iterate( + func(item atree.Value) (bool, error) { + arrayElement := MustConvertStoredValue(interpreter, item) + + shouldInclude, ok := procedure.invoke(iterationInvocation(arrayElement)).(BoolValue) + if !ok { + panic(errors.NewUnreachableError()) + } + + if shouldInclude { + filteredValues = append(filteredValues, arrayElement) + filteredValuesCount++ + } + + i++ + return true, nil + }, + ) + + if err != nil { + panic(errors.NewExternalError(err)) + } + } + + iterate() + + iterationIndex := 0 + return NewArrayValueWithIterator( + interpreter, + NewVariableSizedStaticType(interpreter, v.Type.ElementType()), + common.ZeroAddress, + uint64(filteredValuesCount), + func() Value { + + if iterationIndex == filteredValuesCount { + return nil + } + + value := filteredValues[iterationIndex] + iterationIndex++ + + return value.Transfer( + interpreter, + locationRange, + atree.Address{}, + false, + nil, + nil, + ) + }, + ) +} + // NumberValue type NumberValue interface { ComparableValue diff --git a/runtime/sema/type.go b/runtime/sema/type.go index e57342f501..80d3e4e1d6 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -1795,6 +1795,13 @@ Returns a new array with contents in the reversed order. Available if the array element type is not resource-kinded. ` +const ArrayTypeFilterFunctionName = "filter" + +const arrayTypeFilterFunctionDocString = ` +Returns a new array whose elements are filtered by applying the filter function on each element of the original array. +Available if the array element type is not resource-kinded. +` + func getArrayMembers(arrayType ArrayType) map[string]MemberResolver { members := map[string]MemberResolver{ @@ -2083,6 +2090,32 @@ func getArrayMembers(arrayType ArrayType) map[string]MemberResolver { ) }, } + + members[ArrayTypeFilterFunctionName] = MemberResolver{ + Kind: common.DeclarationKindFunction, + Resolve: func(memoryGauge common.MemoryGauge, identifier string, targetRange ast.Range, report func(error)) *Member { + + elementType := arrayType.ElementType(false) + + if elementType.IsResourceType() { + report( + &InvalidResourceArrayMemberError{ + Name: identifier, + DeclarationKind: common.DeclarationKindFunction, + Range: targetRange, + }, + ) + } + + return NewPublicFunctionMember( + memoryGauge, + arrayType, + identifier, + ArrayFilterFunctionType(elementType), + arrayTypeFilterFunctionDocString, + ) + }, + } } return withBuiltinMembers(arrayType, members) @@ -2232,6 +2265,33 @@ func ArrayReverseFunctionType(arrayType ArrayType) *FunctionType { } } +func ArrayFilterFunctionType(elementType Type) *FunctionType { + // fun filter(_ function: ((T): Bool)): [T] + // funcType: elementType -> Bool + funcType := &FunctionType{ + Parameters: []Parameter{ + { + Identifier: "element", + TypeAnnotation: NewTypeAnnotation(elementType), + }, + }, + ReturnTypeAnnotation: NewTypeAnnotation(BoolType), + } + + return &FunctionType{ + Parameters: []Parameter{ + { + Label: ArgumentLabelNotRequired, + Identifier: "f", + TypeAnnotation: NewTypeAnnotation(funcType), + }, + }, + ReturnTypeAnnotation: NewTypeAnnotation(&VariableSizedType{ + Type: elementType, + }), + } +} + // VariableSizedType is a variable sized array type type VariableSizedType struct { Type Type diff --git a/runtime/tests/checker/arrays_dictionaries_test.go b/runtime/tests/checker/arrays_dictionaries_test.go index a7d47a673e..cfea7f1c34 100644 --- a/runtime/tests/checker/arrays_dictionaries_test.go +++ b/runtime/tests/checker/arrays_dictionaries_test.go @@ -1128,6 +1128,109 @@ func TestCheckResourceArrayReverseInvalid(t *testing.T) { assert.IsType(t, &sema.InvalidResourceArrayMemberError{}, errs[0]) } +func TestCheckArrayFilter(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test() { + let x = [1, 2, 3] + let onlyEven = + fun (_ x: Int): Bool { + return x % 2 == 0 + } + + let y = x.filter(onlyEven) + } + `) + + require.NoError(t, err) +} + +func TestCheckArrayFilterInvalidArgs(t *testing.T) { + + t.Parallel() + + testInvalidArgs := func(code string, expectedErrors []sema.SemanticError) { + _, err := ParseAndCheck(t, code) + + errs := RequireCheckerErrors(t, err, len(expectedErrors)) + + for i, e := range expectedErrors { + assert.IsType(t, e, errs[i]) + } + } + + testInvalidArgs(` + fun test() { + let x = [1, 2, 3] + let y = x.filter(100) + } + `, + []sema.SemanticError{ + &sema.TypeMismatchError{}, + }, + ) + + testInvalidArgs(` + fun test() { + let x = [1, 2, 3] + let onlyEvenInt16 = + fun (_ x: Int16): Bool { + return x % 2 == 0 + } + + let y = x.filter(onlyEvenInt16) + } + `, + []sema.SemanticError{ + &sema.TypeMismatchError{}, + }, + ) + + testInvalidArgs(` + fun test() { + let x : [Int; 5] = [1, 2, 3, 21, 30] + let onlyEvenInt = + fun (_ x: Int): Bool { + return x % 2 == 0 + } + + let y = x.filter(onlyEvenInt) + } + `, + []sema.SemanticError{ + &sema.NotDeclaredMemberError{}, + }, + ) +} + +func TestCheckResourceArrayFilterInvalid(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + resource X {} + + fun test(): @[X] { + let xs <- [<-create X()] + let allResources = + fun (_ x: @X): Bool { + destroy x + return true + } + + let filteredXs <-xs.filter(allResources) + destroy xs + return <- filteredXs + } + `) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.InvalidResourceArrayMemberError{}, errs[0]) +} + func TestCheckArrayContains(t *testing.T) { t.Parallel() diff --git a/runtime/tests/interpreter/interpreter_test.go b/runtime/tests/interpreter/interpreter_test.go index 2ef3540596..a876b86d2e 100644 --- a/runtime/tests/interpreter/interpreter_test.go +++ b/runtime/tests/interpreter/interpreter_test.go @@ -10741,6 +10741,159 @@ func TestInterpretArrayReverse(t *testing.T) { } } +func TestInterpretArrayFilter(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + let xs = [1, 2, 3, 100, 200] + let emptyVals: [Int] = [] + + fun filterxs(): [Int] { + let onlyEven = + fun (_ x: Int): Bool { + return x % 2 == 0 + } + + return xs.filter(onlyEven) + } + fun originalxs(): [Int] { + return xs + } + + fun filterempty(): [Int] { + let onlyEven = + fun (_ x: Int): Bool { + return x % 2 == 0 + } + + return emptyVals.filter(onlyEven) + } + fun originalempty(): [Int] { + return emptyVals + } + + pub struct TestStruct { + pub var test: Int + + init(_ t: Int) { + self.test = t + } + } + let sa = [TestStruct(1), TestStruct(2), TestStruct(3)] + + fun filtersa(): [Int] { + let onlyOdd = + fun (_ x: TestStruct): Bool { + return x.test % 2 == 1 + } + + let sa_filtered = sa.filter(onlyOdd) + + let res: [Int] = []; + for s in sa_filtered { + res.append(s.test) + } + + return res + } + fun originalsa(): [Int] { + let res: [Int] = []; + for s in sa { + res.append(s.test) + } + + return res + } + `) + + runValidCase := func(t *testing.T, filterFuncName, originalFuncName string, filteredArray, originalArray *interpreter.ArrayValue) { + val, err := inter.Invoke(filterFuncName) + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + filteredArray, + val, + ) + + origVal, err := inter.Invoke(originalFuncName) + require.NoError(t, err) + + // Original array remains unchanged + AssertValuesEqual( + t, + inter, + originalArray, + origVal, + ) + } + + runValidCase(t, "filterempty", "originalempty", + interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + interpreter.VariableSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeInt, + }, + common.ZeroAddress, + ), interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + interpreter.VariableSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeInt, + }, + common.ZeroAddress, + )) + + runValidCase(t, "filterxs", "originalxs", + interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + interpreter.VariableSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeInt, + }, + common.ZeroAddress, + interpreter.NewUnmeteredIntValueFromInt64(2), + interpreter.NewUnmeteredIntValueFromInt64(100), + interpreter.NewUnmeteredIntValueFromInt64(200), + ), interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + interpreter.VariableSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeInt, + }, + common.ZeroAddress, + interpreter.NewUnmeteredIntValueFromInt64(1), + interpreter.NewUnmeteredIntValueFromInt64(2), + interpreter.NewUnmeteredIntValueFromInt64(3), + interpreter.NewUnmeteredIntValueFromInt64(100), + interpreter.NewUnmeteredIntValueFromInt64(200), + )) + + runValidCase(t, "filtersa", "originalsa", + interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + interpreter.VariableSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeInt, + }, + common.ZeroAddress, + interpreter.NewUnmeteredIntValueFromInt64(1), + interpreter.NewUnmeteredIntValueFromInt64(3), + ), interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + interpreter.VariableSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeInt, + }, + common.ZeroAddress, + interpreter.NewUnmeteredIntValueFromInt64(1), + interpreter.NewUnmeteredIntValueFromInt64(2), + interpreter.NewUnmeteredIntValueFromInt64(3), + )) +} + func TestInterpretOptionalReference(t *testing.T) { t.Parallel()