diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index f65c1d3b97..ef81e143a1 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -1607,7 +1607,7 @@ func NewArrayValueWithIterator( interpreter *Interpreter, arrayType ArrayStaticType, address common.Address, - count uint64, + countOverestimate uint64, values func() Value, ) *ArrayValue { interpreter.ReportComputation(common.ComputationKindCreateArrayValue, 1) @@ -1652,7 +1652,7 @@ func NewArrayValueWithIterator( return array } // must assign to v here for tracing to work properly - v = newArrayValueFromConstructor(interpreter, arrayType, count, constructor) + v = newArrayValueFromConstructor(interpreter, arrayType, countOverestimate, constructor) return v } @@ -1669,14 +1669,14 @@ func newArrayValueFromAtreeValue( func newArrayValueFromConstructor( gauge common.MemoryGauge, staticType ArrayStaticType, - count uint64, + countOverestimate uint64, constructor func() *atree.Array, ) (array *ArrayValue) { var elementSize uint if staticType != nil { elementSize = staticType.ElementType().elementSize() } - baseUsage, elementUsage, dataSlabs, metaDataSlabs := common.NewArrayMemoryUsages(count, elementSize) + baseUsage, elementUsage, dataSlabs, metaDataSlabs := common.NewArrayMemoryUsages(countOverestimate, elementSize) common.UseMemory(gauge, baseUsage) common.UseMemory(gauge, elementUsage) common.UseMemory(gauge, dataSlabs) @@ -2448,6 +2448,7 @@ func (v *ArrayValue) GetMember(interpreter *Interpreter, locationRange LocationR return NewHostFunctionValue( interpreter, sema.ArrayFilterFunctionType( + interpreter, v.SemaType(interpreter).ElementType(false), ), func(invocation Invocation) Value { @@ -2973,42 +2974,20 @@ func (v *ArrayValue) Filter( locationRange LocationRange, procedure FunctionValue, ) Value { - filteredValuesCount := 0 - invocation := NewInvocation( - interpreter, - nil, - nil, - []Value{}, // Set later during invocation. - []sema.Type{v.semaType.ElementType(false)}, - nil, - locationRange, - ) iterationInvocation := func(arrayElement Value) Invocation { - invocation.Arguments = []Value{arrayElement} + invocation := NewInvocation( + interpreter, + nil, + nil, + []Value{arrayElement}, + []sema.Type{v.semaType.ElementType(false)}, + nil, + locationRange, + ) return invocation } - err := v.array.Iterate( - func(item atree.Value) (bool, error) { - arrayElement := MustConvertStoredValue(interpreter, item) - - shouldInclude, ok := procedure.invoke(iterationInvocation(arrayElement)).(BoolValue) - if !ok { - panic(errors.NewUnreachableError()) - } - - if shouldInclude { - filteredValuesCount++ - } - - return true, nil - }, - ) - if err != nil { - panic(errors.NewExternalError(err)) - } - iterator, err := v.array.Iterator() if err != nil { panic(errors.NewExternalError(err)) @@ -3018,7 +2997,7 @@ func (v *ArrayValue) Filter( interpreter, NewVariableSizedStaticType(interpreter, v.Type.ElementType()), common.ZeroAddress, - uint64(filteredValuesCount), + uint64(v.Count()), // worst case estimation. func() Value { var value Value diff --git a/runtime/sema/type.go b/runtime/sema/type.go index 6cc205609d..a15eaed396 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -1940,7 +1940,7 @@ func getArrayMembers(arrayType ArrayType) map[string]MemberResolver { memoryGauge, arrayType, identifier, - ArrayFilterFunctionType(elementType), + ArrayFilterFunctionType(memoryGauge, elementType), arrayTypeFilterFunctionDocString, ) }, @@ -2264,7 +2264,7 @@ func ArrayReverseFunctionType(arrayType ArrayType) *FunctionType { } } -func ArrayFilterFunctionType(elementType Type) *FunctionType { +func ArrayFilterFunctionType(memoryGauge common.MemoryGauge, elementType Type) *FunctionType { // fun filter(_ function: ((T): Bool)): [T] // funcType: elementType -> Bool funcType := &FunctionType{ @@ -2285,9 +2285,7 @@ func ArrayFilterFunctionType(elementType Type) *FunctionType { TypeAnnotation: NewTypeAnnotation(funcType), }, }, - ReturnTypeAnnotation: NewTypeAnnotation(&VariableSizedType{ - Type: elementType, - }), + ReturnTypeAnnotation: NewTypeAnnotation(NewVariableSizedType(memoryGauge, elementType)), } }