Skip to content

Commit

Permalink
Support lambda expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
Saloed committed Oct 30, 2023
1 parent 3733d76 commit 827fcc0
Show file tree
Hide file tree
Showing 11 changed files with 338 additions and 37 deletions.
2 changes: 1 addition & 1 deletion buildSrc/src/main/kotlin/Versions.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ object Versions {
const val ksmt = "0.5.7"
const val collections = "0.3.5"
const val coroutines = "1.6.4"
const val jcdb = "1.3.0"
const val jcdb = "1.4.0"
const val mockk = "1.13.4"
const val junitParams = "5.9.3"
const val logback = "1.4.8"
Expand Down
3 changes: 3 additions & 0 deletions usvm-jvm/src/main/kotlin/org/usvm/machine/JcContext.kt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.jacodb.impl.bytecode.JcFieldImpl
import org.jacodb.impl.types.FieldInfo
import org.usvm.UBv32Sort
import org.usvm.UContext
import org.usvm.machine.interpreter.JcLambdaCallSiteRegionId
import org.usvm.util.extractJcRefType

internal typealias USizeSort = UBv32Sort
Expand Down Expand Up @@ -60,6 +61,8 @@ class JcContext(
?: error("No enum type in classpath")
}

val lambdaCallSiteRegionId by lazy { JcLambdaCallSiteRegionId(this) }

// TODO store it in JcComponents? Make it mutable?
internal val useNegativeAddressesInStaticInitializer: Boolean = false

Expand Down
16 changes: 16 additions & 0 deletions usvm-jvm/src/main/kotlin/org/usvm/machine/JcMethodCallInst.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package org.usvm.machine

import org.jacodb.api.JcMethod
import org.jacodb.api.JcRefType
import org.jacodb.api.cfg.JcDynamicCallExpr
import org.jacodb.api.cfg.JcExpr
import org.jacodb.api.cfg.JcInst
import org.jacodb.api.cfg.JcInstLocation
Expand Down Expand Up @@ -86,3 +87,18 @@ data class JcVirtualMethodCallInst(

override val originalInst: JcInst = returnSite
}

/**
* Invoke dynamic instruction.
* The [dynamicCall] can't be processed and the machine
* must resolve it to some [JcConcreteMethodCallInst] or approximate.
* */
data class JcDynamicMethodCallInst(
val dynamicCall: JcDynamicCallExpr,
override val arguments: List<UExpr<out USort>>,
override val returnSite: JcInst,
) : JcMethodCallBaseInst, JcMethodCall {
override val location: JcInstLocation = returnSite.location
override val method: JcMethod = dynamicCall.method.method
override val originalInst: JcInst = returnSite
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package org.usvm.machine

import io.ksmt.expr.KExpr
import io.ksmt.utils.asExpr
import org.jacodb.api.JcType
import org.jacodb.api.ext.toType
import org.usvm.UAddressSort
import org.usvm.UBoolExpr
import org.usvm.UConcreteHeapRef
import org.usvm.UHeapRef
Expand All @@ -13,14 +11,16 @@ import org.usvm.api.evalTypeEquals
import org.usvm.api.typeStreamOf
import org.usvm.isAllocatedConcreteHeapRef
import org.usvm.isStaticHeapRef
import org.usvm.machine.interpreter.JcLambdaCallSite
import org.usvm.machine.interpreter.JcLambdaCallSiteMemoryRegion
import org.usvm.machine.interpreter.JcStepScope
import org.usvm.machine.interpreter.JcTypeSelector
import org.usvm.machine.state.JcState
import org.usvm.machine.state.newStmt
import org.usvm.memory.foldHeapRefWithStaticAsSymbolic
import org.usvm.memory.foldHeapRef
import org.usvm.model.UModelBase
import org.usvm.types.UTypeStream
import org.usvm.types.first
import org.usvm.types.single
import org.usvm.util.findMethod

/**
Expand Down Expand Up @@ -68,14 +68,21 @@ private fun resolveVirtualInvokeWithModel(
val concreteRef = model.eval(instance) as UConcreteHeapRef

if (isAllocatedConcreteHeapRef(concreteRef) || isStaticHeapRef(concreteRef)) {
val concreteCall = makeConcreteMethodCall(scope, concreteRef, methodCall)
val callSite = findLambdaCallSite(methodCall, scope, concreteRef)
val concreteCall = if (callSite != null) {
makeLambdaCallSiteCall(callSite)
} else {
makeConcreteMethodCall(scope, concreteRef, methodCall)
}

scope.doWithState {
newStmt(concreteCall)
}

return@with
}

// Resolved lambda call site can't be an input ref
val typeStream = model.typeStreamOf(concreteRef)
val typeConstraintsWithBlockOnStates = makeConcreteCallsForPossibleTypes(
scope,
Expand All @@ -100,17 +107,36 @@ private fun resolveVirtualInvokeWithoutModel(
val instance = arguments.first().asExpr(ctx.addressSort)

val refsWithConditions = mutableListOf<Pair<UHeapRef, UBoolExpr>>()
foldHeapRefWithStaticAsSymbolic(
val lambdaCallSitesWithConditions = mutableListOf<Pair<JcLambdaCallSite, UBoolExpr>>()
foldHeapRef(
instance,
refsWithConditions,
Unit,
initialGuard = ctx.trueExpr,
ignoreNullRefs = true,
collapseHeapRefs = false,
blockOnConcrete = { curRefsWithConditions, (ref, condition) -> curRefsWithConditions.also { it += ref to condition } },
blockOnSymbolic = { curRefsWithConditions, (ref, condition) -> curRefsWithConditions.also { it += ref to condition } },
blockOnConcrete = { _, (ref, condition) ->
val lambdaCallSite = findLambdaCallSite(methodCall, scope, ref)
if (lambdaCallSite != null) {
lambdaCallSitesWithConditions += lambdaCallSite to condition
} else {
refsWithConditions += ref to condition
}
},
blockOnStatic = { _, (ref, condition) ->
val lambdaCallSite = findLambdaCallSite(methodCall, scope, ref)
if (lambdaCallSite != null) {
lambdaCallSitesWithConditions += lambdaCallSite to condition
} else {
refsWithConditions += ref to condition
}
},
blockOnSymbolic = { _, (ref, condition) ->
// Resolved lambda call site can't be a symbolic ref
refsWithConditions.also { it += ref to condition }
},
)

val conditionsWithBlocks = refsWithConditions.flatMap { (ref, condition) ->
val conditionsWithBlocks = refsWithConditions.flatMapTo(mutableListOf()) { (ref, condition) ->
when {
isAllocatedConcreteHeapRef(ref) || isStaticHeapRef(ref) -> {
val concreteCall = makeConcreteMethodCall(scope, ref, methodCall)
Expand Down Expand Up @@ -140,6 +166,11 @@ private fun resolveVirtualInvokeWithoutModel(
}
}

lambdaCallSitesWithConditions.mapTo(conditionsWithBlocks) { (callSite, condition) ->
val concreteCall = makeLambdaCallSiteCall(callSite)
condition to { state: JcState -> state.newStmt(concreteCall) }
}

scope.forkMulti(conditionsWithBlocks)
}

Expand All @@ -149,7 +180,7 @@ private fun JcVirtualMethodCallInst.makeConcreteMethodCall(
methodCall: JcVirtualMethodCallInst,
): JcConcreteMethodCallInst {
// We have only one type for allocated and static heap refs
val type = scope.calcOnState { memory.typeStreamOf(concreteRef) }.first()
val type = scope.calcOnState { memory.typeStreamOf(concreteRef) }.single()

val concreteMethod = type.findMethod(method)
?: error("Can't find method $method in type ${type.typeName}")
Expand All @@ -162,7 +193,7 @@ private fun JcVirtualMethodCallInst.makeConcreteCallsForPossibleTypes(
methodCall: JcVirtualMethodCallInst,
typeStream: UTypeStream<JcType>,
typeSelector: JcTypeSelector,
instance: KExpr<UAddressSort>,
instance: UHeapRef,
ctx: JcContext,
condition: UBoolExpr,
forkOnRemainingTypes: Boolean,
Expand Down Expand Up @@ -196,3 +227,35 @@ private fun JcVirtualMethodCallInst.makeConcreteCallsForPossibleTypes(

return typeConstraintsWithBlockOnStates
}

private fun findLambdaCallSite(
methodCall: JcVirtualMethodCallInst,
scope: JcStepScope,
ref: UConcreteHeapRef,
): JcLambdaCallSite? = with(methodCall) {
val callSites = scope.calcOnState { memory.getRegion(ctx.lambdaCallSiteRegionId) as JcLambdaCallSiteMemoryRegion }
val callSite = callSites.findCallSite(ref) ?: return null

val lambdaMethodType = callSite.lambda.dynamicMethodType

// Match function signature
when {
method.name != callSite.lambda.callSiteMethodName -> return null
method.returnType != lambdaMethodType.returnType -> return null
lambdaMethodType.argumentTypes != method.parameters.map { it.type } -> return null
}

return callSite
}

private fun JcVirtualMethodCallInst.makeLambdaCallSiteCall(
callSite: JcLambdaCallSite,
): JcConcreteMethodCallInst {
val lambdaMethod = callSite.lambda.actualMethod.method

// Instance was already resolved to the call site
val callArgsWithoutInstance = this.arguments.drop(1)
val lambdaMethodArgs = callSite.callSiteArgs + callArgsWithoutInstance

return JcConcreteMethodCallInst(location, lambdaMethod.method, lambdaMethodArgs, returnSite)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package org.usvm.machine.interpreter

import kotlinx.collections.immutable.PersistentMap
import kotlinx.collections.immutable.persistentHashMapOf
import org.jacodb.api.cfg.JcLambdaExpr
import org.usvm.UAddressSort
import org.usvm.UBoolExpr
import org.usvm.UConcreteHeapAddress
import org.usvm.UConcreteHeapRef
import org.usvm.UExpr
import org.usvm.machine.JcContext
import org.usvm.memory.UMemoryRegion
import org.usvm.memory.UMemoryRegionId

class JcLambdaCallSiteRegionId(private val ctx: JcContext) : UMemoryRegionId<Nothing, UAddressSort> {
override val sort: UAddressSort
get() = ctx.addressSort

override fun emptyRegion(): UMemoryRegion<Nothing, UAddressSort> =
JcLambdaCallSiteMemoryRegion(ctx)
}

internal class JcLambdaCallSiteMemoryRegion(
private val ctx: JcContext,
private val callSites: PersistentMap<UConcreteHeapAddress, JcLambdaCallSite> = persistentHashMapOf()
) : UMemoryRegion<Nothing, UAddressSort> {
fun writeCallSite(callSite: JcLambdaCallSite) =
JcLambdaCallSiteMemoryRegion(ctx, callSites.put(callSite.ref.address, callSite))

fun findCallSite(ref: UConcreteHeapRef): JcLambdaCallSite? = callSites[ref.address]

override fun read(key: Nothing): UExpr<UAddressSort> {
error("Unsupported operation for call site region")
}

override fun write(
key: Nothing,
value: UExpr<UAddressSort>,
guard: UBoolExpr
): UMemoryRegion<Nothing, UAddressSort> {
error("Unsupported operation for call site region")
}
}

data class JcLambdaCallSite(
val ref: UConcreteHeapRef,
val lambda: JcLambdaExpr,
val callSiteArgs: List<UExpr<*>>
)
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,12 @@ import org.usvm.machine.operator.wideTo32BitsIfNeeded
import org.usvm.machine.state.JcMethodResult
import org.usvm.machine.state.JcState
import org.usvm.machine.state.addConcreteMethodCallStmt
import org.usvm.machine.state.addDynamicCall
import org.usvm.machine.state.addVirtualMethodCallStmt
import org.usvm.machine.state.throwExceptionWithoutStackFrameDrop
import org.usvm.memory.ULValue
import org.usvm.memory.URegisterStackLValue
import org.usvm.memory.UWritableMemory
import org.usvm.mkSizeExpr
import org.usvm.sizeSort
import org.usvm.util.allocHeapRef
Expand Down Expand Up @@ -390,22 +392,30 @@ class JcExprResolver(
resolveInvoke(
expr.method,
instanceExpr = null,
argumentExprs = expr::args,
argumentTypes = expr::callSiteArgTypes
) { arguments ->
TODO("Dynamic invoke: ${expr.method.method} $arguments")
argumentExprs = { expr.callSiteArgs },
argumentTypes = { expr.callSiteArgTypes }
) { callSiteArguments ->
scope.doWithState { addDynamicCall(expr, callSiteArguments) }
}

override fun visitJcLambdaExpr(expr: JcLambdaExpr): UExpr<out USort>? =
resolveInvoke(
expr.method,
instanceExpr = null,
argumentExprs = expr::args,
argumentTypes = { expr.method.parameters.map { it.type } }
) { arguments ->
scope.doWithState { addConcreteMethodCallStmt(expr.method.method, arguments) }
override fun visitJcLambdaExpr(expr: JcLambdaExpr): UExpr<out USort>? {
val callSiteArgs = expr.callSiteArgs.zip(expr.callSiteArgTypes) { arg, type ->
resolveJcExpr(arg, type) ?: return null
}

val callSiteRef = scope.calcOnState { memory.allocConcrete(expr.callSiteReturnType) }
val callSite = JcLambdaCallSite(callSiteRef, expr, callSiteArgs)
scope.doWithState { memory.writeCallSite(callSite) }

return callSiteRef
}

private fun UWritableMemory<JcType>.writeCallSite(callSite: JcLambdaCallSite) {
val callSiteRegion = getRegion(ctx.lambdaCallSiteRegionId) as JcLambdaCallSiteMemoryRegion
val updatedRegion = callSiteRegion.writeCallSite(callSite)
setRegion(ctx.lambdaCallSiteRegionId, updatedRegion)
}

private inline fun resolveInvoke(
method: JcTypedMethod,
instanceExpr: JcValue?,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import org.usvm.forkblacklists.UForkBlackList
import org.usvm.machine.JcApplicationGraph
import org.usvm.machine.JcConcreteMethodCallInst
import org.usvm.machine.JcContext
import org.usvm.machine.JcDynamicMethodCallInst
import org.usvm.machine.JcInterpreterObserver
import org.usvm.machine.JcMethodApproximationResolver
import org.usvm.machine.JcMethodCall
Expand Down Expand Up @@ -240,7 +241,7 @@ class JcInterpreter(
}

if (stmt.method.isNative) {
mockNativeMethod(scope, stmt)
mockMethod(scope, stmt)
return
}

Expand All @@ -256,7 +257,17 @@ class JcInterpreter(
return
}

resolveVirtualInvoke(stmt, scope, typeSelector, forkOnRemainingTypes = false)
resolveVirtualInvoke(stmt, scope, forkOnRemainingTypes = false)
}

is JcDynamicMethodCallInst -> {
observer?.onMethodCallWithResolvedArguments(simpleValueResolver, stmt, scope)

if (approximateMethod(scope, stmt)) {
return
}

mockMethod(scope, stmt, stmt.dynamicCall.callSiteReturnType)
}
}
}
Expand Down Expand Up @@ -504,7 +515,6 @@ class JcInterpreter(
private fun resolveVirtualInvoke(
methodCall: JcVirtualMethodCallInst,
scope: JcStepScope,
typeSelector: JcTypeSelector,
forkOnRemainingTypes: Boolean,
): Unit = resolveVirtualInvoke(ctx, methodCall, scope, typeSelector, forkOnRemainingTypes)

Expand All @@ -515,13 +525,13 @@ class JcInterpreter(
return approximationResolver.approximate(scope, exprResolver, methodCall)
}

private fun mockNativeMethod(
scope: JcStepScope,
methodCall: JcConcreteMethodCallInst,
) = with(methodCall) {
logger.warn { "Mocked: ${method.enclosingClass.name}::${method.name}" }
private fun mockMethod(scope: JcStepScope, methodCall: JcMethodCall) {
val returnType = with(applicationGraph) { methodCall.method.typed }.returnType
mockMethod(scope, methodCall, returnType)
}

val returnType = with(applicationGraph) { method.typed }.returnType
private fun mockMethod(scope: JcStepScope, methodCall: JcMethodCall, returnType: JcType) = with(methodCall) {
logger.warn { "Mocked: ${method.enclosingClass.name}::${method.name}" }

if (returnType == ctx.cp.void) {
scope.doWithState { skipMethodInvocationWithValue(methodCall, ctx.voidValue) }
Expand Down
Loading

0 comments on commit 827fcc0

Please sign in to comment.