Skip to content

Commit

Permalink
Eagerly check function arguments when called from inside iterable
Browse files Browse the repository at this point in the history
This mitigates an issue where lazy mappings and listings widen an existing bug.

This is a follow-up to apple#752.
  • Loading branch information
bioball committed Nov 4, 2024
1 parent 4b4d81b commit f9e2984
Show file tree
Hide file tree
Showing 22 changed files with 226 additions and 49 deletions.
11 changes: 9 additions & 2 deletions pkl-core/src/main/java/org/pkl/core/ast/builder/AstBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -1986,6 +1986,7 @@ private ExpressionNode doVisitMethodAccessExpr(QualifiedAccessExprContext ctx) {
visitArgumentList(argCtx),
MemberLookupMode.EXPLICIT_RECEIVER,
needsConst,
symbolTable.getCurrentScope().isVisitingIterable(),
PropagateNullReceiverNodeGen.create(unavailableSourceSection(), receiver),
GetClassNodeGen.create(null)));
}
Expand All @@ -1998,6 +1999,7 @@ private ExpressionNode doVisitMethodAccessExpr(QualifiedAccessExprContext ctx) {
visitArgumentList(argCtx),
MemberLookupMode.EXPLICIT_RECEIVER,
needsConst,
symbolTable.getCurrentScope().isVisitingIterable(),
receiver,
GetClassNodeGen.create(null));
}
Expand Down Expand Up @@ -2072,7 +2074,11 @@ public ExpressionNode visitSuperAccessExpr(SuperAccessExprContext ctx) {
}

return InvokeSuperMethodNodeGen.create(
sourceSection, memberName, visitArgumentList(argCtx), needsConst);
sourceSection,
memberName,
symbolTable.getCurrentScope().isVisitingIterable(),
visitArgumentList(argCtx),
needsConst);
}

// superproperty call
Expand Down Expand Up @@ -2130,7 +2136,8 @@ public ExpressionNode visitUnqualifiedAccessExpr(UnqualifiedAccessExprContext ct
isBaseModule,
scope.isCustomThisScope(),
scope.getConstLevel(),
scope.getConstDepth());
scope.getConstDepth(),
scope.isVisitingIterable());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,6 @@ public Object executeGeneric(VirtualFrame frame) {

var value = valueNode.executeGeneric(frame);

return callNode.call(function.getThisValue(), function, value);
return callNode.call(function.getThisValue(), function, false, value);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ public Object executeGeneric(VirtualFrame frame) {
var arguments = new Object[frameArguments.length];
arguments[0] = functionToAmend.getThisValue();
arguments[1] = functionToAmend;
System.arraycopy(frameArguments, 2, arguments, 2, frameArguments.length - 2);
arguments[2] = false;
System.arraycopy(frameArguments, 3, arguments, 3, frameArguments.length - 3);

var valueToAmend = callNode.call(functionToAmend.getCallTarget(), arguments);
if (!(valueToAmend instanceof VmFunction newFunctionToAmend)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,31 +28,35 @@ public final class InvokeMethodDirectNode extends ExpressionNode {
private final VmObjectLike owner;
@Child private ExpressionNode receiverNode;
@Children private final ExpressionNode[] argumentNodes;
private final boolean isInIterable;

@Child private DirectCallNode callNode;

public InvokeMethodDirectNode(
SourceSection sourceSection,
ClassMethod method,
ExpressionNode receiverNode,
ExpressionNode[] argumentNodes) {
ExpressionNode[] argumentNodes,
boolean isInIterable) {

super(sourceSection);
this.owner = method.getOwner();
this.receiverNode = receiverNode;
this.argumentNodes = argumentNodes;
this.isInIterable = isInIterable;

callNode = DirectCallNode.create(method.getCallTarget(sourceSection));
}

@Override
@ExplodeLoop
public Object executeGeneric(VirtualFrame frame) {
var args = new Object[2 + argumentNodes.length];
var args = new Object[3 + argumentNodes.length];
args[0] = receiverNode.executeGeneric(frame);
args[1] = owner;
args[2] = isInIterable;
for (var i = 0; i < argumentNodes.length; i++) {
args[2 + i] = argumentNodes[i].executeGeneric(frame);
args[3 + i] = argumentNodes[i].executeGeneric(frame);
}

return callNode.call(args);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,29 +33,33 @@ public final class InvokeMethodLexicalNode extends ExpressionNode {
private final int levelsUp;

@Child private DirectCallNode callNode;
private final boolean isInIterable;

InvokeMethodLexicalNode(
SourceSection sourceSection,
CallTarget callTarget,
int levelsUp,
ExpressionNode[] argumentNodes) {
ExpressionNode[] argumentNodes,
boolean isInIterable) {

super(sourceSection);
this.levelsUp = levelsUp;
this.argumentNodes = argumentNodes;

callNode = DirectCallNode.create(callTarget);
this.isInIterable = isInIterable;
}

@Override
@ExplodeLoop
public Object executeGeneric(VirtualFrame frame) {
var args = new Object[2 + argumentNodes.length];
var args = new Object[3 + argumentNodes.length];
var enclosingFrame = getEnclosingFrame(frame);
args[0] = VmUtils.getReceiver(enclosingFrame);
args[1] = VmUtils.getOwner(enclosingFrame);
args[2] = isInIterable;
for (var i = 0; i < argumentNodes.length; i++) {
args[2 + i] = argumentNodes[i].executeGeneric(frame);
args[3 + i] = argumentNodes[i].executeGeneric(frame);
}

return callNode.call(args);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.pkl.core.runtime.VmFunction;

/** A virtual method call. */
@SuppressWarnings("DuplicatedCode")
@ImportStatic(Identifier.class)
@NodeChild(value = "receiverNode", type = ExpressionNode.class)
@NodeChild(value = "receiverClassNode", type = GetClassNode.class, executeWith = "receiverNode")
Expand All @@ -44,27 +45,31 @@ public abstract class InvokeMethodVirtualNode extends ExpressionNode {
@Children private final ExpressionNode[] argumentNodes;
private final MemberLookupMode lookupMode;
private final boolean needsConst;
private final boolean isInIterable;

protected InvokeMethodVirtualNode(
SourceSection sourceSection,
Identifier methodName,
ExpressionNode[] argumentNodes,
MemberLookupMode lookupMode,
boolean needsConst) {
boolean needsConst,
boolean isInIterable) {

super(sourceSection);
this.methodName = methodName;
this.argumentNodes = argumentNodes;
this.lookupMode = lookupMode;
this.needsConst = needsConst;
this.isInIterable = isInIterable;
}

protected InvokeMethodVirtualNode(
SourceSection sourceSection,
Identifier methodName,
ExpressionNode[] argumentNodes,
MemberLookupMode lookupMode) {
this(sourceSection, methodName, argumentNodes, lookupMode, false);
MemberLookupMode lookupMode,
boolean isInIterable) {
this(sourceSection, methodName, argumentNodes, lookupMode, false, isInIterable);
}

/**
Expand All @@ -84,11 +89,12 @@ protected Object evalFunctionCached(
RootCallTarget cachedCallTarget,
@Cached("create(cachedCallTarget)") DirectCallNode callNode) {

var args = new Object[2 + argumentNodes.length];
var args = new Object[3 + argumentNodes.length];
args[0] = receiver.getThisValue();
args[1] = receiver;
args[2] = isInIterable;
for (var i = 0; i < argumentNodes.length; i++) {
args[2 + i] = argumentNodes[i].executeGeneric(frame);
args[3 + i] = argumentNodes[i].executeGeneric(frame);
}

return callNode.call(args);
Expand All @@ -103,11 +109,12 @@ protected Object evalFunction(
@SuppressWarnings("unused") VmClass receiverClass,
@Exclusive @Cached("create()") IndirectCallNode callNode) {

var args = new Object[2 + argumentNodes.length];
var args = new Object[3 + argumentNodes.length];
args[0] = receiver.getThisValue();
args[1] = receiver;
args[2] = isInIterable;
for (var i = 0; i < argumentNodes.length; i++) {
args[2 + i] = argumentNodes[i].executeGeneric(frame);
args[3 + i] = argumentNodes[i].executeGeneric(frame);
}

return callNode.call(receiver.getCallTarget(), args);
Expand All @@ -123,11 +130,12 @@ protected Object evalCached(
@Cached("resolveMethod(receiverClass)") ClassMethod method,
@Cached("create(method.getCallTarget(sourceSection))") DirectCallNode callNode) {

var args = new Object[2 + argumentNodes.length];
var args = new Object[3 + argumentNodes.length];
args[0] = receiver;
args[1] = method.getOwner();
args[2] = isInIterable;
for (var i = 0; i < argumentNodes.length; i++) {
args[2 + i] = argumentNodes[i].executeGeneric(frame);
args[3 + i] = argumentNodes[i].executeGeneric(frame);
}

return callNode.call(args);
Expand All @@ -142,11 +150,12 @@ protected Object eval(
@Exclusive @Cached("create()") IndirectCallNode callNode) {

var method = resolveMethod(receiverClass);
var args = new Object[2 + argumentNodes.length];
var args = new Object[3 + argumentNodes.length];
args[0] = receiver;
args[1] = method.getOwner();
args[2] = isInIterable;
for (var i = 0; i < argumentNodes.length; i++) {
args[2 + i] = argumentNodes[i].executeGeneric(frame);
args[3 + i] = argumentNodes[i].executeGeneric(frame);
}

// Deprecation should not report here (getCallTarget(sourceSection)), as this happens for each
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,18 @@
public abstract class InvokeSuperMethodNode extends ExpressionNode {
private final Identifier methodName;
@Children private final ExpressionNode[] argumentNodes;
private final boolean isInIterable;
private final boolean needsConst;

protected InvokeSuperMethodNode(
SourceSection sourceSection,
Identifier methodName,
boolean isInIterable,
ExpressionNode[] argumentNodes,
boolean needsConst) {

super(sourceSection);
this.isInIterable = isInIterable;
this.needsConst = needsConst;

assert !methodName.isLocalMethod();
Expand All @@ -54,11 +57,12 @@ protected Object eval(
@Cached(value = "findSupermethod(frame)", neverDefault = true) ClassMethod supermethod,
@Cached("create(supermethod.getCallTarget(sourceSection))") DirectCallNode callNode) {

var args = new Object[2 + argumentNodes.length];
var args = new Object[3 + argumentNodes.length];
args[0] = VmUtils.getReceiverOrNull(frame);
args[1] = supermethod.getOwner();
args[2] = isInIterable;
for (int i = 0; i < argumentNodes.length; i++) {
args[2 + i] = argumentNodes[i].executeGeneric(frame);
args[3 + i] = argumentNodes[i].executeGeneric(frame);
}

return callNode.call(args);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public final class ResolveMethodNode extends ExpressionNode {
private final boolean isCustomThisScope;
private final ConstLevel constLevel;
private final int constDepth;
private final boolean isInIterable;

public ResolveMethodNode(
SourceSection sourceSection,
Expand All @@ -58,7 +59,8 @@ public ResolveMethodNode(
boolean isBaseModule,
boolean isCustomThisScope,
ConstLevel constLevel,
int constDepth) {
int constDepth,
boolean isInIterable) {

super(sourceSection);

Expand All @@ -68,6 +70,7 @@ public ResolveMethodNode(
this.isCustomThisScope = isCustomThisScope;
this.constLevel = constLevel;
this.constDepth = constDepth;
this.isInIterable = isInIterable;
}

@Override
Expand All @@ -91,15 +94,23 @@ private ExpressionNode doResolve(VmObjectLike initialOwner) {
assert localMethod.isLocal();
checkConst(currOwner, localMethod, levelsUp);
return new InvokeMethodLexicalNode(
sourceSection, localMethod.getCallTarget(sourceSection), levelsUp, argumentNodes);
sourceSection,
localMethod.getCallTarget(sourceSection),
levelsUp,
argumentNodes,
isInIterable);
}
var method = currOwner.getVmClass().getDeclaredMethod(methodName);
if (method != null) {
assert !method.isLocal();
checkConst(currOwner, method, levelsUp);
if (method.getDeclaringClass().isClosed()) {
return new InvokeMethodLexicalNode(
sourceSection, method.getCallTarget(sourceSection), levelsUp, argumentNodes);
sourceSection,
method.getCallTarget(sourceSection),
levelsUp,
argumentNodes,
isInIterable);
}

//noinspection ConstantConditions
Expand All @@ -108,6 +119,7 @@ private ExpressionNode doResolve(VmObjectLike initialOwner) {
methodName,
argumentNodes,
MemberLookupMode.IMPLICIT_LEXICAL,
isInIterable,
levelsUp == 0 ? new GetReceiverNode() : new GetEnclosingReceiverNode(levelsUp),
GetClassNodeGen.create(null));
}
Expand All @@ -122,7 +134,7 @@ private ExpressionNode doResolve(VmObjectLike initialOwner) {
(CallTarget) localMethod.getCallTarget().call(currOwner, currOwner);

return new InvokeMethodLexicalNode(
sourceSection, methodCallTarget, levelsUp, argumentNodes);
sourceSection, methodCallTarget, levelsUp, argumentNodes, isInIterable);
}
}

Expand All @@ -138,7 +150,7 @@ private ExpressionNode doResolve(VmObjectLike initialOwner) {
if (method != null) {
assert !method.isLocal();
return new InvokeMethodDirectNode(
sourceSection, method, new ConstantValueNode(baseModule), argumentNodes);
sourceSection, method, new ConstantValueNode(baseModule), argumentNodes, isInIterable);
}
}

Expand All @@ -158,6 +170,7 @@ private ExpressionNode doResolve(VmObjectLike initialOwner) {
argumentNodes,
MemberLookupMode.IMPLICIT_THIS,
needsConst,
isInIterable,
VmUtils.createThisNode(VmUtils.unavailableSourceSection(), isCustomThisScope),
GetClassNodeGen.create(null));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ protected InvokeMethodVirtualNode createInvokeNode() {
Identifier.TO_STRING,
new ExpressionNode[] {},
MemberLookupMode.EXPLICIT_RECEIVER,
false,
null,
null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ protected Object evalDirect(
RootCallTarget cachedCallTarget,
@Cached("create(cachedCallTarget)") DirectCallNode callNode) {

return callNode.call(function.getThisValue(), function);
return callNode.call(function.getThisValue(), function, false);
}

@Specialization(replaces = "evalDirect")
protected Object eval(VmFunction function, @Cached("create()") IndirectCallNode callNode) {

return callNode.call(function.getCallTarget(), function.getThisValue(), function);
return callNode.call(function.getCallTarget(), function.getThisValue(), function, false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ protected Object evalDirect(
RootCallTarget cachedCallTarget,
@Cached("create(cachedCallTarget)") DirectCallNode callNode) {

return callNode.call(function.getThisValue(), function, arg1);
return callNode.call(function.getThisValue(), function, false, arg1);
}

@Specialization(replaces = "evalDirect")
protected Object eval(
VmFunction function, Object arg1, @Cached("create()") IndirectCallNode callNode) {

return callNode.call(function.getCallTarget(), function.getThisValue(), function, arg1);
return callNode.call(function.getCallTarget(), function.getThisValue(), function, false, arg1);
}
}
Loading

0 comments on commit f9e2984

Please sign in to comment.