diff --git a/usvm-core/src/main/kotlin/org/usvm/Composition.kt b/usvm-core/src/main/kotlin/org/usvm/Composition.kt index 8f76aada61..49bf1e04a6 100644 --- a/usvm-core/src/main/kotlin/org/usvm/Composition.kt +++ b/usvm-core/src/main/kotlin/org/usvm/Composition.kt @@ -26,9 +26,6 @@ open class UComposer( ) : UExprTransformer(ctx) { open fun compose(expr: UExpr): UExpr = apply(expr) - override fun transform(expr: USymbol): UExpr = - error("You must override `transform` function in org.usvm.UComposer for ${expr::class}") - override fun transform(expr: UIteExpr): UExpr = transformExprAfterTransformed(expr, expr.condition) { condition -> when { @@ -42,12 +39,6 @@ open class UComposer( expr: URegisterReading, ): UExpr = with(expr) { memory.stack.readRegister(idx, sort) } - override fun transform(expr: UCollectionReading<*, *, *>): UExpr = - error("You must override `transform` function in org.usvm.UComposer for ${expr::class}") - - override fun transform(expr: UMockSymbol): UExpr = - error("You must override `transform` function in org.usvm.UComposer for ${expr::class}") - override fun transform( expr: UIndexedMethodReturnValue, ): UExpr = memory.mocker.eval(expr) diff --git a/usvm-core/src/main/kotlin/org/usvm/Context.kt b/usvm-core/src/main/kotlin/org/usvm/Context.kt index e19a6e77b0..34f15e9e83 100644 --- a/usvm-core/src/main/kotlin/org/usvm/Context.kt +++ b/usvm-core/src/main/kotlin/org/usvm/Context.kt @@ -43,9 +43,9 @@ import org.usvm.collection.set.ref.UInputRefSetWithAllocatedElementsReading import org.usvm.collection.set.ref.UInputRefSetWithInputElements import org.usvm.collection.set.ref.UInputRefSetWithInputElementsReading import org.usvm.memory.splitUHeapRef +import org.usvm.regions.Region import org.usvm.solver.USolverBase import org.usvm.types.UTypeSystem -import org.usvm.regions.Region @Suppress("LeakingThis") open class UContext( @@ -67,9 +67,7 @@ open class UContext( return currentStateId++ } - @Suppress("UNCHECKED_CAST") - fun solver(): USolverBase = - this.solver as USolverBase + fun solver(): USolverBase = this.solver.uncheckedCast() @Suppress("UNCHECKED_CAST") fun typeSystem(): UTypeSystem = diff --git a/usvm-core/src/main/kotlin/org/usvm/ExprTransformer.kt b/usvm-core/src/main/kotlin/org/usvm/ExprTransformer.kt index 0064003d20..ed5f93cc5a 100644 --- a/usvm-core/src/main/kotlin/org/usvm/ExprTransformer.kt +++ b/usvm-core/src/main/kotlin/org/usvm/ExprTransformer.kt @@ -20,12 +20,8 @@ import org.usvm.collection.set.ref.UInputRefSetWithInputElementsReading import org.usvm.regions.Region interface UTransformer : KTransformer { - fun transform(expr: USymbol): UExpr - fun transform(expr: URegisterReading): UExpr - fun transform(expr: UCollectionReading<*, *, *>): UExpr - fun transform(expr: UInputFieldReading): UExpr fun transform(expr: UAllocatedArrayReading): UExpr @@ -60,8 +56,6 @@ interface UTransformer : KTransformer { fun transform(expr: UInputRefSetWithInputElementsReading): UBoolExpr - fun transform(expr: UMockSymbol): UExpr - fun transform(expr: UIndexedMethodReturnValue): UExpr fun transform(expr: UIsSubtypeExpr): UBoolExpr diff --git a/usvm-core/src/main/kotlin/org/usvm/State.kt b/usvm-core/src/main/kotlin/org/usvm/State.kt index 93df656690..aed49332e2 100644 --- a/usvm-core/src/main/kotlin/org/usvm/State.kt +++ b/usvm-core/src/main/kotlin/org/usvm/State.kt @@ -9,20 +9,21 @@ import org.usvm.model.UModelBase import org.usvm.solver.USatResult import org.usvm.solver.UUnknownResult import org.usvm.solver.UUnsatResult +import org.usvm.targets.UTarget typealias StateId = UInt abstract class UState( // TODO: add interpreter-specific information - ctx: UContext, + val ctx: Context, open val callStack: UCallStack, - open val pathConstraints: UPathConstraints, + open val pathConstraints: UPathConstraints, open val memory: UMemory, open var models: List>, open var pathLocation: PathsTrieNode, targets: List = emptyList(), ) where Context : UContext, - Target : UTarget, + Target : UTarget, State : UState { /** * Deterministic state id. @@ -53,7 +54,7 @@ abstract class UState( * Creates new state structurally identical to this. * If [newConstraints] is null, clones [pathConstraints]. Otherwise, uses [newConstraints] in cloned state. */ - abstract fun clone(newConstraints: UPathConstraints? = null): State + abstract fun clone(newConstraints: UPathConstraints? = null): State override fun equals(other: Any?): Boolean { if (this === other) return true @@ -105,7 +106,7 @@ abstract class UState( val previousTargetCount = targetsImpl.size targetsImpl = targetsImpl.remove(target) - if (previousTargetCount == targetsImpl.size || !target.isRemoved) { + if (previousTargetCount == targetsImpl.size || target.isRemoved) { return false } @@ -155,7 +156,7 @@ private fun , Type, Context : UContext> fo } else { newConstraintToOriginalState } - val solver = newConstraintToForkedState.uctx.solver() + val solver = newConstraintToForkedState.uctx.solver() val satResult = solver.checkWithSoftConstraints(constraintsToCheck) return when (satResult) { diff --git a/usvm-core/src/main/kotlin/org/usvm/UComponents.kt b/usvm-core/src/main/kotlin/org/usvm/UComponents.kt index d4b83a6150..80e34e9d2d 100644 --- a/usvm-core/src/main/kotlin/org/usvm/UComponents.kt +++ b/usvm-core/src/main/kotlin/org/usvm/UComponents.kt @@ -1,13 +1,29 @@ package org.usvm +import org.usvm.model.ULazyModelDecoder +import org.usvm.model.UModelDecoder +import org.usvm.solver.UExprTranslator import org.usvm.solver.USolverBase import org.usvm.types.UTypeSystem /** * Provides core USVM components tuned for specific language. - * Instatiated once per [UContext]. + * Instantiated once per [UContext]. */ interface UComponents { - fun mkSolver(ctx: Context): USolverBase + fun mkSolver(ctx: UContext): USolverBase fun mkTypeSystem(ctx: UContext): UTypeSystem + + /** + * Initializes [UExprTranslator] and [UModelDecoder] and returns them. We can safely reuse them while [UContext] is + * alive. + */ + fun buildTranslatorAndLazyDecoder( + ctx: UContext, + ): Pair, ULazyModelDecoder> { + val translator = UExprTranslator(ctx) + val decoder = ULazyModelDecoder(translator) + + return translator to decoder + } } \ No newline at end of file diff --git a/usvm-core/src/main/kotlin/org/usvm/constraints/PathConstraints.kt b/usvm-core/src/main/kotlin/org/usvm/constraints/PathConstraints.kt index 63f3af01bb..d37f2b3a51 100644 --- a/usvm-core/src/main/kotlin/org/usvm/constraints/PathConstraints.kt +++ b/usvm-core/src/main/kotlin/org/usvm/constraints/PathConstraints.kt @@ -21,8 +21,8 @@ import org.usvm.uctx /** * Mutable representation of path constraints. */ -open class UPathConstraints private constructor( - val ctx: Context, +open class UPathConstraints private constructor( + private val ctx: UContext, logicalConstraints: PersistentSet = persistentSetOf(), /** * Specially represented equalities and disequalities between objects, used in various part of constraints management. @@ -51,7 +51,7 @@ open class UPathConstraints private constructor( var logicalConstraints: PersistentSet = logicalConstraints private set - constructor(ctx: Context) : this(ctx, persistentSetOf()) + constructor(ctx: UContext) : this(ctx, persistentSetOf()) open val isFalse: Boolean get() = equalityConstraints.isContradicting || @@ -128,7 +128,7 @@ open class UPathConstraints private constructor( } } - open fun clone(): UPathConstraints { + open fun clone(): UPathConstraints { val clonedEqualityConstraints = equalityConstraints.clone() val clonedTypeConstraints = typeConstraints.clone(clonedEqualityConstraints) val clonedNumericConstraints = numericConstraints.clone() diff --git a/usvm-core/src/main/kotlin/org/usvm/model/LazyModelDecoder.kt b/usvm-core/src/main/kotlin/org/usvm/model/LazyModelDecoder.kt index fbba3b26db..f91322a78f 100644 --- a/usvm-core/src/main/kotlin/org/usvm/model/LazyModelDecoder.kt +++ b/usvm-core/src/main/kotlin/org/usvm/model/LazyModelDecoder.kt @@ -19,18 +19,6 @@ interface UModelDecoder { fun decode(model: KModel): Model } -/** - * Initializes [UExprTranslator] and [UModelDecoder] and returns them. We can safely reuse them while [UContext] is - * alive. - */ -fun buildTranslatorAndLazyDecoder( - ctx: UContext, -): Pair, ULazyModelDecoder> { - val translator = UExprTranslator(ctx) - val decoder = ULazyModelDecoder(translator) - - return translator to decoder -} typealias AddressesMapping = Map, UConcreteHeapRef> @@ -53,18 +41,18 @@ open class ULazyModelDecoder( * equivalence classes of addresses and work with their number in the future. */ private fun buildMapping(model: KModel, nullRef: UConcreteHeapRef): AddressesMapping { - val interpreterdNullRef = model.eval(translatedNullRef, isComplete = true) + val interpretedNullRef = model.eval(translatedNullRef, isComplete = true) val result = mutableMapOf, UConcreteHeapRef>() // The null value has the NULL_ADDRESS - result[interpreterdNullRef] = nullRef + result[interpretedNullRef] = nullRef val universe = model.uninterpretedSortUniverse(ctx.addressSort) ?: return result // All the numbers are enumerated from the INITIAL_INPUT_ADDRESS to the Int.MIN_VALUE var counter = INITIAL_INPUT_ADDRESS for (interpretedAddress in universe) { - if (interpretedAddress == interpreterdNullRef) { + if (interpretedAddress == interpretedNullRef) { continue } diff --git a/usvm-core/src/main/kotlin/org/usvm/ps/PathSelectorFactory.kt b/usvm-core/src/main/kotlin/org/usvm/ps/PathSelectorFactory.kt index 321c04ff56..8667877d45 100644 --- a/usvm-core/src/main/kotlin/org/usvm/ps/PathSelectorFactory.kt +++ b/usvm-core/src/main/kotlin/org/usvm/ps/PathSelectorFactory.kt @@ -5,7 +5,6 @@ import org.usvm.PathSelectorCombinationStrategy import org.usvm.UMachineOptions import org.usvm.UPathSelector import org.usvm.UState -import org.usvm.UTarget import org.usvm.algorithms.DeterministicPriorityCollection import org.usvm.algorithms.RandomizedPriorityCollection import org.usvm.statistics.ApplicationGraph @@ -17,18 +16,23 @@ import org.usvm.statistics.distances.InterprocDistance import org.usvm.statistics.distances.InterprocDistanceCalculator import org.usvm.statistics.distances.MultiTargetDistanceCalculator import org.usvm.statistics.distances.ReachabilityKind +import org.usvm.targets.UTarget +import org.usvm.targets.UTargetController import org.usvm.util.log2 import kotlin.math.max import kotlin.random.Random -fun , State : UState<*, Method, Statement, *, Target, State>> createPathSelector( +fun createPathSelector( initialState: State, options: UMachineOptions, applicationGraph: ApplicationGraph, coverageStatistics: () -> CoverageStatistics? = { null }, cfgStatistics: () -> CfgStatistics? = { null }, - callGraphStatistics: () -> CallGraphStatistics? = { null } -): UPathSelector { + callGraphStatistics: () -> CallGraphStatistics? = { null }, +): UPathSelector + where Target : UTarget, + State : UState<*, Method, Statement, *, Target, State> { + val strategies = options.pathSelectionStrategies require(strategies.isNotEmpty()) { "At least one path selector strategy should be specified" } @@ -56,6 +60,7 @@ fun , State : USta requireNotNull(cfgStatistics()) { "CFG statistics is required for closest to uncovered path selector" }, applicationGraph ) + PathSelectionStrategy.CLOSEST_TO_UNCOVERED_RANDOM -> createClosestToUncoveredPathSelector( requireNotNull(coverageStatistics()) { "Coverage statistics is required for closest to uncovered path selector" }, requireNotNull(cfgStatistics()) { "CFG statistics is required for closest to uncovered path selector" }, @@ -68,6 +73,7 @@ fun , State : USta requireNotNull(callGraphStatistics()) { "Call graph statistics is required for targeted path selector" }, applicationGraph ) + PathSelectionStrategy.TARGETED_RANDOM -> createTargetedPathSelector( requireNotNull(cfgStatistics()) { "CFG statistics is required for targeted path selector" }, requireNotNull(callGraphStatistics()) { "Call graph statistics is required for targeted path selector" }, @@ -79,6 +85,7 @@ fun , State : USta requireNotNull(cfgStatistics()) { "CFG statistics is required for targeted call stack local path selector" }, applicationGraph ) + PathSelectionStrategy.TARGETED_CALL_STACK_LOCAL_RANDOM -> createTargetedPathSelector( requireNotNull(cfgStatistics()) { "CFG statistics is required for targeted call stack local path selector" }, applicationGraph, @@ -160,7 +167,12 @@ private fun distanceCalculator.removeTarget(method, statement) } + coverageStatistics.addOnCoveredObserver { _, method, statement -> + distanceCalculator.removeTarget( + method, + statement + ) + } if (random == null) { return WeightedPathSelector( @@ -171,7 +183,12 @@ private fun , State : UState<*, Method, Statement, *, Target, State>> createTargetedPathSelector( +internal fun createTargetedPathSelector( cfgStatistics: CfgStatistics, applicationGraph: ApplicationGraph, random: Random? = null, -): UPathSelector { +): UPathSelector + where Target : UTarget, + State : UState<*, Method, Statement, *, Target, State> { + val distanceCalculator = MultiTargetDistanceCalculator { loc -> CallStackDistanceCalculator( targets = listOf(loc), @@ -249,12 +269,15 @@ private fun InterprocDistance.logWeight(): UInt { return weight } -internal fun , State : UState<*, Method, Statement, *, Target, State>> createTargetedPathSelector( +internal fun createTargetedPathSelector( cfgStatistics: CfgStatistics, callGraphStatistics: CallGraphStatistics, applicationGraph: ApplicationGraph, random: Random? = null, -): UPathSelector { +): UPathSelector + where Target : UTarget, + State : UState<*, Method, Statement, *, Target, State> { + val distanceCalculator = MultiTargetDistanceCalculator { stmt -> InterprocDistanceCalculator( targetLocation = stmt, diff --git a/usvm-core/src/main/kotlin/org/usvm/solver/ExprTranslator.kt b/usvm-core/src/main/kotlin/org/usvm/solver/ExprTranslator.kt index 8f48358561..0de57c2690 100644 --- a/usvm-core/src/main/kotlin/org/usvm/solver/ExprTranslator.kt +++ b/usvm-core/src/main/kotlin/org/usvm/solver/ExprTranslator.kt @@ -8,7 +8,6 @@ import io.ksmt.utils.uncheckedCast import org.usvm.UAddressSort import org.usvm.UBoolExpr import org.usvm.UBoolSort -import org.usvm.UCollectionReading import org.usvm.UConcreteHeapRef import org.usvm.UContext import org.usvm.UExpr @@ -18,12 +17,10 @@ import org.usvm.UIndexedMethodReturnValue import org.usvm.UIsExpr import org.usvm.UIsSubtypeExpr import org.usvm.UIsSupertypeExpr -import org.usvm.UMockSymbol import org.usvm.UNullRef import org.usvm.URegisterReading import org.usvm.USizeSort import org.usvm.USort -import org.usvm.USymbol import org.usvm.USymbolicHeapRef import org.usvm.collection.array.UAllocatedArrayReading import org.usvm.collection.array.UArrayRegionDecoder @@ -78,20 +75,11 @@ open class UExprTranslator( ) : UExprTransformer(ctx) { open fun translate(expr: UExpr): KExpr = apply(expr) - override fun transform(expr: USymbol): KExpr = - error("You must override `transform` function in UExprTranslator for ${expr::class}") - override fun transform(expr: URegisterReading): KExpr { val registerConst = expr.sort.mkConst("r${expr.idx}_${expr.sort}") return registerConst } - override fun transform(expr: UCollectionReading<*, *, *>): KExpr = - error("You must override `transform` function in UExprTranslator for ${expr::class}") - - override fun transform(expr: UMockSymbol): KExpr = - error("You must override `transform` function in UExprTranslator for ${expr::class}") - override fun transform(expr: UIndexedMethodReturnValue): KExpr { val const = expr.sort.mkConst("m${expr.method}_${expr.callIndex}_${expr.sort}") return const diff --git a/usvm-core/src/main/kotlin/org/usvm/solver/Solver.kt b/usvm-core/src/main/kotlin/org/usvm/solver/Solver.kt index 9be38f9ac3..eea6baeab5 100644 --- a/usvm-core/src/main/kotlin/org/usvm/solver/Solver.kt +++ b/usvm-core/src/main/kotlin/org/usvm/solver/Solver.kt @@ -29,14 +29,14 @@ abstract class USolver { abstract fun check(query: Query): USolverResult } -open class USolverBase( - protected val ctx: Context, +open class USolverBase( + protected val ctx: UContext, protected val smtSolver: KSolver<*>, protected val typeSolver: UTypeSolver, protected val translator: UExprTranslator, protected val decoder: UModelDecoder>, protected val softConstraintsProvider: USoftConstraintsProvider, -) : USolver, UModelBase>(), AutoCloseable { +) : USolver, UModelBase>(), AutoCloseable { protected fun translateLogicalConstraints(constraints: Iterable) { for (constraint in constraints) { @@ -98,22 +98,22 @@ open class USolverBase( } } - protected fun translateToSmt(pc: UPathConstraints) { + protected fun translateToSmt(pc: UPathConstraints) { translateEqualityConstraints(pc.equalityConstraints) translateLogicalConstraints(pc.numericConstraints.constraints().asIterable()) translateLogicalConstraints(pc.logicalConstraints) } - override fun check(query: UPathConstraints): USolverResult> = + override fun check(query: UPathConstraints): USolverResult> = internalCheck(query, useSoftConstraints = false) fun checkWithSoftConstraints( - pc: UPathConstraints, + pc: UPathConstraints, ) = internalCheck(pc, useSoftConstraints = true) private fun internalCheck( - pc: UPathConstraints, + pc: UPathConstraints, useSoftConstraints: Boolean, ): USolverResult> { if (pc.isFalse) { diff --git a/usvm-core/src/main/kotlin/org/usvm/solver/USoftConstraintsProvider.kt b/usvm-core/src/main/kotlin/org/usvm/solver/USoftConstraintsProvider.kt index abe69c3e15..7cbe0c2c53 100644 --- a/usvm-core/src/main/kotlin/org/usvm/solver/USoftConstraintsProvider.kt +++ b/usvm-core/src/main/kotlin/org/usvm/solver/USoftConstraintsProvider.kt @@ -31,12 +31,10 @@ import org.usvm.UExpr import org.usvm.UIndexedMethodReturnValue import org.usvm.UIsSubtypeExpr import org.usvm.UIsSupertypeExpr -import org.usvm.UMockSymbol import org.usvm.UNullRef import org.usvm.URegisterReading import org.usvm.USizeExpr import org.usvm.USort -import org.usvm.USymbol import org.usvm.UTransformer import org.usvm.collection.array.UAllocatedArrayReading import org.usvm.collection.array.UInputArrayReading @@ -53,8 +51,8 @@ import org.usvm.collection.set.primitive.UInputSetReading import org.usvm.collection.set.ref.UAllocatedRefSetWithInputElementsReading import org.usvm.collection.set.ref.UInputRefSetWithAllocatedElementsReading import org.usvm.collection.set.ref.UInputRefSetWithInputElementsReading -import org.usvm.uctx import org.usvm.regions.Region +import org.usvm.uctx class USoftConstraintsProvider(override val ctx: UContext) : UTransformer { // We have a list here since sometimes we want to add several soft constraints @@ -89,19 +87,8 @@ class USoftConstraintsProvider(override val ctx: UContext) : UTransformer< // region USymbol specific methods - override fun transform(expr: USymbol): UExpr = - error("You must override `transform` function in UExprTranslator for ${expr::class}") - override fun transform(expr: URegisterReading): UExpr = transformExpr(expr) - override fun transform( - expr: UCollectionReading<*, *, *>, - ): UExpr = error("You must override `transform` function in UExprTranslator for ${expr::class}") - - override fun transform( - expr: UMockSymbol, - ): UExpr = error("You must override `transform` function in UExprTranslator for ${expr::class}") - override fun transform( expr: UIndexedMethodReturnValue, ): UExpr = transformAppIfPossible(expr) diff --git a/usvm-core/src/main/kotlin/org/usvm/statistics/UInterpreterObserver.kt b/usvm-core/src/main/kotlin/org/usvm/statistics/UInterpreterObserver.kt new file mode 100644 index 0000000000..0a8e99401b --- /dev/null +++ b/usvm-core/src/main/kotlin/org/usvm/statistics/UInterpreterObserver.kt @@ -0,0 +1,5 @@ +package org.usvm.statistics + +interface UInterpreterObserver { + // Empty +} \ No newline at end of file diff --git a/usvm-core/src/main/kotlin/org/usvm/statistics/distances/InterprocDistanceCalculator.kt b/usvm-core/src/main/kotlin/org/usvm/statistics/distances/InterprocDistanceCalculator.kt index fdb1ee437f..aa7ea88ae6 100644 --- a/usvm-core/src/main/kotlin/org/usvm/statistics/distances/InterprocDistanceCalculator.kt +++ b/usvm-core/src/main/kotlin/org/usvm/statistics/distances/InterprocDistanceCalculator.kt @@ -104,6 +104,10 @@ internal class InterprocDistanceCalculator( currentStatement: Statement, callStack: UCallStack ): InterprocDistance { + if (callStack.isEmpty()) { + return InterprocDistance(UInt.MAX_VALUE, ReachabilityKind.NONE) + } + val lastMethod = callStack.lastMethod() val lastFrameDistance = calculateFrameDistance(lastMethod, currentStatement) diff --git a/usvm-core/src/main/kotlin/org/usvm/stopstrategies/StopStrategyFactory.kt b/usvm-core/src/main/kotlin/org/usvm/stopstrategies/StopStrategyFactory.kt index 580cc1c9d4..9d752a2dc1 100644 --- a/usvm-core/src/main/kotlin/org/usvm/stopstrategies/StopStrategyFactory.kt +++ b/usvm-core/src/main/kotlin/org/usvm/stopstrategies/StopStrategyFactory.kt @@ -1,8 +1,8 @@ package org.usvm.stopstrategies import org.usvm.UMachineOptions -import org.usvm.UTarget import org.usvm.statistics.CoverageStatistics +import org.usvm.targets.UTarget fun createStopStrategy( options: UMachineOptions, diff --git a/usvm-core/src/main/kotlin/org/usvm/stopstrategies/TargetsReachedStopStrategy.kt b/usvm-core/src/main/kotlin/org/usvm/stopstrategies/TargetsReachedStopStrategy.kt index 18e15df004..118d2d0788 100644 --- a/usvm-core/src/main/kotlin/org/usvm/stopstrategies/TargetsReachedStopStrategy.kt +++ b/usvm-core/src/main/kotlin/org/usvm/stopstrategies/TargetsReachedStopStrategy.kt @@ -1,6 +1,6 @@ package org.usvm.stopstrategies -import org.usvm.UTarget +import org.usvm.targets.UTarget /** * A stop strategy which stops when all terminal targets in [targets] are reached. diff --git a/usvm-core/src/main/kotlin/org/usvm/UTarget.kt b/usvm-core/src/main/kotlin/org/usvm/targets/UTarget.kt similarity index 86% rename from usvm-core/src/main/kotlin/org/usvm/UTarget.kt rename to usvm-core/src/main/kotlin/org/usvm/targets/UTarget.kt index ab3a7ef70b..05e0a2d035 100644 --- a/usvm-core/src/main/kotlin/org/usvm/UTarget.kt +++ b/usvm-core/src/main/kotlin/org/usvm/targets/UTarget.kt @@ -1,4 +1,6 @@ -package org.usvm +package org.usvm.targets + +import org.usvm.UState /** * Base class for a symbolic execution target. A target can be understood as a 'task' for symbolic machine @@ -12,15 +14,15 @@ package org.usvm * a state which has reached the target which has no children, it is logically removed from the targets tree. * The other states ignore such removed targets. */ -abstract class UTarget( +abstract class UTarget( /** * Optional location of the target. */ val location: Statement? = null, -) where Target : UTarget, - State : UState<*, *, Statement, *, Target, State> { +) where Target : UTarget { private val childrenImpl = mutableListOf() - private var parent: Target? = null + var parent: Target? = null + private set /** * List of the child targets which should be reached after this target. @@ -58,9 +60,9 @@ abstract class UTarget( * should try to propagate the target. If the target without children has been * visited, it is logically removed from tree. */ - protected fun propagate(byState: State) { + fun > propagate(byState: State) { @Suppress("UNCHECKED_CAST") - if (byState.tryPropagateTarget(this as Target) && isTerminal) { + if (byState.tryPropagateTarget(this as T) && isTerminal) { remove() } } diff --git a/usvm-core/src/main/kotlin/org/usvm/targets/UTargetController.kt b/usvm-core/src/main/kotlin/org/usvm/targets/UTargetController.kt new file mode 100644 index 0000000000..def51e4eb1 --- /dev/null +++ b/usvm-core/src/main/kotlin/org/usvm/targets/UTargetController.kt @@ -0,0 +1,5 @@ +package org.usvm.targets + +interface UTargetController { + val targets: MutableCollection> +} diff --git a/usvm-core/src/test/kotlin/org/usvm/TestUtil.kt b/usvm-core/src/test/kotlin/org/usvm/TestUtil.kt index d8e4f8ea2d..a61412d505 100644 --- a/usvm-core/src/test/kotlin/org/usvm/TestUtil.kt +++ b/usvm-core/src/test/kotlin/org/usvm/TestUtil.kt @@ -8,6 +8,8 @@ import org.usvm.memory.UMemory import org.usvm.memory.USymbolicCollectionKeyInfo import org.usvm.model.UModelBase import org.usvm.regions.Region +import org.usvm.targets.UTarget +import org.usvm.targets.UTargetController typealias Field = java.lang.reflect.Field typealias Type = kotlin.reflect.KClass<*> @@ -29,7 +31,12 @@ internal fun pseudoRandom(i: Int): Int { return res } -internal class TestTarget(method: String, offset: Int) : UTarget( +internal class TestTargetController : UTargetController { + override val targets: MutableCollection> + get() = TODO("Not yet implemented") +} + +internal class TestTarget(method: String, offset: Int) : UTarget( TestInstruction(method, offset) ) { fun reach(state: TestState) { @@ -39,13 +46,13 @@ internal class TestTarget(method: String, offset: Int) : UTarget, pathConstraints: UPathConstraints, + callStack: UCallStack, pathConstraints: UPathConstraints, memory: UMemory, models: List>, pathLocation: PathsTrieNode, targetTrees: List = emptyList() ) : UState(ctx, callStack, pathConstraints, memory, models, pathLocation, targetTrees) { - override fun clone(newConstraints: UPathConstraints?): TestState = this + override fun clone(newConstraints: UPathConstraints?): TestState = this override val isExceptional = false } diff --git a/usvm-core/src/test/kotlin/org/usvm/api/collections/SymbolicCollectionTestBase.kt b/usvm-core/src/test/kotlin/org/usvm/api/collections/SymbolicCollectionTestBase.kt index b69e338cf9..6fd88ea634 100644 --- a/usvm-core/src/test/kotlin/org/usvm/api/collections/SymbolicCollectionTestBase.kt +++ b/usvm-core/src/test/kotlin/org/usvm/api/collections/SymbolicCollectionTestBase.kt @@ -13,35 +13,36 @@ import org.usvm.UComponents import org.usvm.UContext import org.usvm.UExpr import org.usvm.UState -import org.usvm.UTarget import org.usvm.constraints.UPathConstraints import org.usvm.memory.UMemory -import org.usvm.model.buildTranslatorAndLazyDecoder +import org.usvm.model.ULazyModelDecoder import org.usvm.solver.UExprTranslator import org.usvm.solver.USoftConstraintsProvider import org.usvm.solver.USolverBase import org.usvm.solver.UTypeSolver +import org.usvm.targets.UTarget +import org.usvm.targets.UTargetController import org.usvm.types.single.SingleTypeSystem import kotlin.test.assertEquals abstract class SymbolicCollectionTestBase { lateinit var ctx: UContext - lateinit var pathConstraints: UPathConstraints + lateinit var pathConstraints: UPathConstraints lateinit var memory: UMemory lateinit var scope: StepScope lateinit var translator: UExprTranslator - lateinit var uSolver: USolverBase + lateinit var uSolver: USolverBase @BeforeEach fun initializeContext() { val components: UComponents = mockk() every { components.mkTypeSystem(any()) } returns mockk() every { components.mkSolver(any()) } answers { uSolver } - ctx = UContext(components) val softConstraintProvider = USoftConstraintsProvider(ctx) - val (translator, decoder) = buildTranslatorAndLazyDecoder(ctx) + val translator = UExprTranslator(ctx) + val decoder = ULazyModelDecoder(translator) this.translator = translator val typeSolver = UTypeSolver(SingleTypeSystem) uSolver = USolverBase(ctx, KZ3Solver(ctx), typeSolver, translator, decoder, softConstraintProvider) @@ -52,17 +53,22 @@ abstract class SymbolicCollectionTestBase { scope = StepScope(StateStub(ctx, pathConstraints, memory)) } - class TargetStub : UTarget() + class TargetControllerStub : UTargetController { + override val targets: MutableCollection> + get() = TODO("Not yet implemented") + } + + class TargetStub : UTarget() class StateStub( ctx: UContext, - pathConstraints: UPathConstraints, + pathConstraints: UPathConstraints, memory: UMemory, ) : UState( ctx, UCallStack(), pathConstraints, memory, emptyList(), ctx.mkInitialLocation() ) { - override fun clone(newConstraints: UPathConstraints?): StateStub { + override fun clone(newConstraints: UPathConstraints?): StateStub { val clonedConstraints = newConstraints ?: pathConstraints.clone() return StateStub(memory.ctx, clonedConstraints, memory.clone(clonedConstraints.typeConstraints)) } diff --git a/usvm-core/src/test/kotlin/org/usvm/model/ModelDecodingTest.kt b/usvm-core/src/test/kotlin/org/usvm/model/ModelDecodingTest.kt index 4cece0627a..ec15339a1c 100644 --- a/usvm-core/src/test/kotlin/org/usvm/model/ModelDecodingTest.kt +++ b/usvm-core/src/test/kotlin/org/usvm/model/ModelDecodingTest.kt @@ -18,11 +18,12 @@ import org.usvm.api.readArrayIndex import org.usvm.api.readField import org.usvm.api.writeArrayIndex import org.usvm.api.writeField +import org.usvm.collection.array.UArrayIndexLValue import org.usvm.constraints.UPathConstraints import org.usvm.memory.UMemory import org.usvm.memory.URegisterStackLValue import org.usvm.memory.URegistersStack -import org.usvm.collection.array.UArrayIndexLValue +import org.usvm.solver.UExprTranslator import org.usvm.solver.USatResult import org.usvm.solver.USoftConstraintsProvider import org.usvm.solver.USolverBase @@ -35,9 +36,9 @@ private typealias Type = SingleTypeSystem.SingleType class ModelDecodingTest { private lateinit var ctx: UContext - private lateinit var solver: USolverBase + private lateinit var solver: USolverBase - private lateinit var pc: UPathConstraints + private lateinit var pc: UPathConstraints private lateinit var stack: URegistersStack private lateinit var heap: UMemory private lateinit var mocker: UIndexedMocker @@ -49,7 +50,8 @@ class ModelDecodingTest { ctx = UContext(components) val softConstraintsProvider = USoftConstraintsProvider(ctx) - val (translator, decoder) = buildTranslatorAndLazyDecoder(ctx) + val translator = UExprTranslator(ctx) + val decoder = ULazyModelDecoder(translator) val typeSolver = UTypeSolver(SingleTypeSystem) solver = USolverBase(ctx, KZ3Solver(ctx), typeSolver, translator, decoder, softConstraintsProvider) diff --git a/usvm-core/src/test/kotlin/org/usvm/solver/SoftConstraintsTest.kt b/usvm-core/src/test/kotlin/org/usvm/solver/SoftConstraintsTest.kt index 3c095f6ef3..16c58562d7 100644 --- a/usvm-core/src/test/kotlin/org/usvm/solver/SoftConstraintsTest.kt +++ b/usvm-core/src/test/kotlin/org/usvm/solver/SoftConstraintsTest.kt @@ -10,21 +10,20 @@ import org.junit.jupiter.api.Disabled import org.junit.jupiter.api.Test import org.usvm.UComponents import org.usvm.UContext -import org.usvm.constraints.UPathConstraints import org.usvm.collection.array.length.UInputArrayLengthId +import org.usvm.constraints.UPathConstraints import org.usvm.model.ULazyModelDecoder -import org.usvm.model.buildTranslatorAndLazyDecoder import org.usvm.types.single.SingleTypeSystem import kotlin.test.assertSame private typealias Type = SingleTypeSystem.SingleType -open class SoftConstraintsTest { +open class SoftConstraintsTest { private lateinit var ctx: UContext private lateinit var softConstraintsProvider: USoftConstraintsProvider private lateinit var translator: UExprTranslator private lateinit var decoder: ULazyModelDecoder - private lateinit var solver: USolverBase + private lateinit var solver: USolverBase @BeforeEach fun initialize() { @@ -34,10 +33,9 @@ open class SoftConstraintsTest { ctx = UContext(components) softConstraintsProvider = USoftConstraintsProvider(ctx) - val translatorWithDecoder = buildTranslatorAndLazyDecoder(ctx) + translator = UExprTranslator(ctx) + decoder = ULazyModelDecoder(translator) - translator = translatorWithDecoder.first - decoder = translatorWithDecoder.second val typeSolver = UTypeSolver(SingleTypeSystem) solver = USolverBase(ctx, KZ3Solver(ctx), typeSolver, translator, decoder, softConstraintsProvider) } @@ -48,7 +46,7 @@ open class SoftConstraintsTest { val sndRegister = mkRegisterReading(idx = 1, bv32Sort) val expr = mkBvSignedLessOrEqualExpr(fstRegister, sndRegister) - val pc = UPathConstraints(ctx) + val pc = UPathConstraints(ctx) pc += expr val result = solver.checkWithSoftConstraints(pc) as USatResult @@ -75,7 +73,7 @@ open class SoftConstraintsTest { every { softConstraintsProvider.provide(any()) } answers { callOriginal() } - val pc = UPathConstraints(ctx) + val pc = UPathConstraints(ctx) pc += fstExpr pc += sndExpr pc += sameAsFirstExpr @@ -122,7 +120,7 @@ open class SoftConstraintsTest { val reading = region.read(secondInputRef) - val pc = UPathConstraints(ctx) + val pc = UPathConstraints(ctx) pc += reading eq size.toBv() pc += inputRef eq secondInputRef pc += (inputRef eq nullRef).not() @@ -143,7 +141,7 @@ open class SoftConstraintsTest { .emptyRegion() .write(inputRef, mkRegisterReading(3, sizeSort), guard = trueExpr) - val pc = UPathConstraints(ctx) + val pc = UPathConstraints(ctx) pc += (inputRef eq nullRef).not() val result = (solver.checkWithSoftConstraints(pc)) as USatResult @@ -159,7 +157,7 @@ open class SoftConstraintsTest { val bvValue = 0.toBv() val expression = mkBvSignedLessOrEqualExpr(bvValue, inputRef).not() - val pc = UPathConstraints(ctx) + val pc = UPathConstraints(ctx) pc += expression val result = (solver.checkWithSoftConstraints(pc)) as USatResult diff --git a/usvm-core/src/test/kotlin/org/usvm/types/TypeSolverTest.kt b/usvm-core/src/test/kotlin/org/usvm/types/TypeSolverTest.kt index 7377a76c6a..5d2c72a9e6 100644 --- a/usvm-core/src/test/kotlin/org/usvm/types/TypeSolverTest.kt +++ b/usvm-core/src/test/kotlin/org/usvm/types/TypeSolverTest.kt @@ -14,14 +14,15 @@ import org.usvm.UContext import org.usvm.api.readField import org.usvm.api.typeStreamOf import org.usvm.api.writeField +import org.usvm.collection.array.UInputArrayId import org.usvm.constraints.UPathConstraints import org.usvm.isFalse import org.usvm.isTrue import org.usvm.memory.UMemory -import org.usvm.collection.array.UInputArrayId +import org.usvm.model.ULazyModelDecoder import org.usvm.model.UModelBase -import org.usvm.model.buildTranslatorAndLazyDecoder import org.usvm.solver.TypeSolverQuery +import org.usvm.solver.UExprTranslator import org.usvm.solver.USatResult import org.usvm.solver.USoftConstraintsProvider import org.usvm.solver.USolverBase @@ -55,11 +56,12 @@ class TypeSolverTest { private val typeSystem = testTypeSystem private val components = mockk>() private val ctx = UContext(components) - private val solver: USolverBase + private val solver: USolverBase private val typeSolver: UTypeSolver init { - val (translator, decoder) = buildTranslatorAndLazyDecoder(ctx) + val translator = UExprTranslator(ctx) + val decoder = ULazyModelDecoder(translator) val softConstraintsProvider = USoftConstraintsProvider(ctx) typeSolver = UTypeSolver(typeSystem) @@ -69,7 +71,7 @@ class TypeSolverTest { every { components.mkTypeSystem(ctx) } returns typeSystem } - private val pc = UPathConstraints(ctx) + private val pc = UPathConstraints(ctx) private val memory = UMemory(ctx, pc.typeConstraints) @Test diff --git a/usvm-jvm/src/main/kotlin/org/usvm/api/targets/JcTarget.kt b/usvm-jvm/src/main/kotlin/org/usvm/api/targets/JcTarget.kt index b6fea3e074..ab466861a8 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/api/targets/JcTarget.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/api/targets/JcTarget.kt @@ -1,12 +1,12 @@ package org.usvm.api.targets import org.jacodb.api.cfg.JcInst -import org.usvm.UTarget -import org.usvm.machine.state.JcState +import org.usvm.targets.UTarget +import org.usvm.targets.UTargetController /** * Base class for JcMachine targets. */ -abstract class JcTarget( +abstract class JcTarget( location: JcInst? = null -) : UTarget(location) +) : UTarget, TargetController>(location) diff --git a/usvm-jvm/src/main/kotlin/org/usvm/api/targets/Position.kt b/usvm-jvm/src/main/kotlin/org/usvm/api/targets/Position.kt new file mode 100644 index 0000000000..5de34e81d5 --- /dev/null +++ b/usvm-jvm/src/main/kotlin/org/usvm/api/targets/Position.kt @@ -0,0 +1,51 @@ +package org.usvm.api.targets + +import io.ksmt.utils.asExpr +import org.usvm.UBoolExpr +import org.usvm.UExpr +import org.usvm.UHeapRef +import org.usvm.uctx + + +interface PositionResolver { + fun resolve(position: Position): ResolvedPosition<*>? +} + +class CallPositionResolver( + val instance: UHeapRef?, + val args: List>, + val result: UExpr<*>?, +) : PositionResolver { + override fun resolve(position: Position): ResolvedPosition<*>? = when (position) { + ThisArgument -> instance?.let { mkResolvedPosition(position, it) } + + is Argument -> { + val index = position.number.toInt() + args.getOrNull(index)?.let { mkResolvedPosition(position, it) } + } + + Result -> result?.let { mkResolvedPosition(position, it) } + } + + private fun mkResolvedPosition(position: Position, resolved: UExpr<*>): ResolvedPosition<*>? = + with(resolved.uctx) { + when (resolved.sort) { + addressSort -> ResolvedRefPosition(position, resolved.asExpr(addressSort)) + boolSort -> ResolvedBoolPosition(position, resolved.asExpr(boolSort)) + else -> null + } + } +} + + +sealed interface Position + +object ThisArgument : Position + +class Argument(val number: UInt) : Position + +object Result : Position + +sealed class ResolvedPosition(val position: Position, val resolved: T) +class ResolvedRefPosition(position: Position, resolved: UHeapRef) : ResolvedPosition(position, resolved) +class ResolvedBoolPosition(position: Position, resolved: UBoolExpr) : ResolvedPosition(position, resolved) diff --git a/usvm-jvm/src/main/kotlin/org/usvm/api/targets/TaintAction.kt b/usvm-jvm/src/main/kotlin/org/usvm/api/targets/TaintAction.kt new file mode 100644 index 0000000000..7ab07ef190 --- /dev/null +++ b/usvm-jvm/src/main/kotlin/org/usvm/api/targets/TaintAction.kt @@ -0,0 +1,80 @@ +package org.usvm.api.targets + +import org.usvm.UBoolExpr +import org.usvm.UHeapRef +import org.usvm.machine.JcContext +import org.usvm.machine.interpreter.JcStepScope + + +sealed interface TaintActionVisitor { + fun visit(action: CopyAllMarks, stepScope: JcStepScope, condition: UBoolExpr?) + fun visit(action: AssignMark, stepScope: JcStepScope, condition: UBoolExpr?) + fun visit(action: RemoveAllMarks, stepScope: JcStepScope, condition: UBoolExpr?) + fun visit(action: RemoveMark, stepScope: JcStepScope, condition: UBoolExpr?) +} + +class TaintActionResolver( + private val ctx: JcContext, + private val positionResolver: PositionResolver, + private val readMark: (ref: UHeapRef, mark: JcTaintMark, JcStepScope) -> UBoolExpr, + private val writeMark: (ref: UHeapRef, mark: JcTaintMark, UBoolExpr, JcStepScope) -> Unit, + private val removeMark: (ref: UHeapRef, JcTaintMark, UBoolExpr, JcStepScope) -> Unit, + private val allMarks: Set, +) : TaintActionVisitor { + + override fun visit(action: CopyAllMarks, stepScope: JcStepScope, condition: UBoolExpr?) { + val resolvedFrom = positionResolver.resolve(action.from) as? ResolvedRefPosition ?: return + val resolvedTo = positionResolver.resolve(action.to) as? ResolvedRefPosition ?: return + + allMarks.forEach { + val fromValue = readMark(resolvedFrom.resolved, it, stepScope) + writeMark(resolvedTo.resolved, it, ctx.mkAnd(condition ?: ctx.trueExpr, fromValue), stepScope) + } + } + + override fun visit(action: AssignMark, stepScope: JcStepScope, condition: UBoolExpr?) { + val resolvedRef = positionResolver.resolve(action.position) as? ResolvedRefPosition ?: return + writeMark(resolvedRef.resolved, action.mark, condition ?: ctx.trueExpr, stepScope) + } + + override fun visit(action: RemoveAllMarks, stepScope: JcStepScope, condition: UBoolExpr?) { + val resolvedRef = positionResolver.resolve(action.position) as? ResolvedRefPosition ?: return + allMarks.forEach { + removeMark(resolvedRef.resolved, it, condition ?: ctx.trueExpr, stepScope) + } + } + + override fun visit(action: RemoveMark, stepScope: JcStepScope, condition: UBoolExpr?) { + val resolvedRef = positionResolver.resolve(action.position) as? ResolvedRefPosition ?: return + removeMark(resolvedRef.resolved, action.mark, condition ?: ctx.trueExpr, stepScope) + } +} + +sealed interface Action { + fun accept(visitor: TaintActionVisitor, stepScope: JcStepScope, condition: UBoolExpr?) +} + +// TODO add marks for aliases (if you pass an object and return it from the function) +class CopyAllMarks(val from: Position, val to: Position) : Action { + override fun accept(visitor: TaintActionVisitor, stepScope: JcStepScope, condition: UBoolExpr?) { + visitor.visit(this, stepScope, condition) + } +} + +class AssignMark(val position: Position, val mark: JcTaintMark) : Action { + override fun accept(visitor: TaintActionVisitor, stepScope: JcStepScope, condition: UBoolExpr?) { + visitor.visit(this, stepScope, condition) + } +} + +class RemoveAllMarks(val position: Position) : Action { + override fun accept(visitor: TaintActionVisitor, stepScope: JcStepScope, condition: UBoolExpr?) { + visitor.visit(this, stepScope, condition) + } +} + +class RemoveMark(val position: Position, val mark: JcTaintMark) : Action { + override fun accept(visitor: TaintActionVisitor, stepScope: JcStepScope, condition: UBoolExpr?) { + visitor.visit(this, stepScope, condition) + } +} diff --git a/usvm-jvm/src/main/kotlin/org/usvm/api/targets/TaintAnalysis.kt b/usvm-jvm/src/main/kotlin/org/usvm/api/targets/TaintAnalysis.kt new file mode 100644 index 0000000000..9f816e1900 --- /dev/null +++ b/usvm-jvm/src/main/kotlin/org/usvm/api/targets/TaintAnalysis.kt @@ -0,0 +1,310 @@ +package org.usvm.api.targets + +import io.ksmt.utils.asExpr +import io.ksmt.utils.uncheckedCast +import org.jacodb.api.cfg.JcAssignInst +import org.jacodb.api.cfg.JcCallExpr +import org.jacodb.api.cfg.JcCallInst +import org.jacodb.api.cfg.JcInst +import org.jacodb.api.cfg.JcInstanceCallExpr +import org.jacodb.api.ext.cfg.callExpr +import org.usvm.UBoolExpr +import org.usvm.UConcreteHeapRef +import org.usvm.UHeapRef +import org.usvm.api.allocateConcreteRef +import org.usvm.collection.set.ref.URefSetEntryLValue +import org.usvm.machine.JcContext +import org.usvm.machine.JcInterpreterObserver +import org.usvm.machine.JcMethodCallBaseInst +import org.usvm.machine.JcMethodEntrypointInst +import org.usvm.machine.interpreter.JcSimpleValueResolver +import org.usvm.machine.interpreter.JcStepScope +import org.usvm.machine.state.JcMethodResult +import org.usvm.machine.state.JcState +import org.usvm.statistics.UMachineObserver +import org.usvm.targets.UTargetController + +class TaintAnalysis( + private val configuration: TaintConfiguration, + override val targets: MutableCollection = mutableListOf(), +) : UTargetController, JcInterpreterObserver, UMachineObserver { + private val taintTargets: MutableMap> = mutableMapOf() + + init { + targets.forEach { + exposeTargets(it, taintTargets) + } + } + + // TODO save mapping between initial targets and the states that reach them + // Replace with the corresponding observer-collector? + val collectedStates: MutableList = mutableListOf() + + private fun exposeTargets(target: TaintTarget, result: MutableMap>) { + result.getOrPut(target.location!!) { hashSetOf() }.add(target) + + target.children.forEach { + exposeTargets(it as TaintTarget, result) + } + } + + private val marksAddresses: MutableMap = mutableMapOf() + + private fun getMarkAddress(mark: JcTaintMark, stepScope: JcStepScope): UConcreteHeapRef = + marksAddresses.getOrPut(mark) { + stepScope.calcOnState { memory.allocateConcreteRef() } + } + + private fun writeMark(ref: UHeapRef, mark: JcTaintMark, guard: UBoolExpr, stepScope: JcStepScope) { + stepScope.doWithState { + memory.write(createLValue(ref, mark, stepScope), ctx.trueExpr, guard) + } + } + + private fun removeMark(ref: UHeapRef, mark: JcTaintMark, guard: UBoolExpr, stepScope: JcStepScope) { + stepScope.doWithState { + memory.write(createLValue(ref, mark, stepScope), ctx.falseExpr, guard) + } + } + + private fun readMark(ref: UHeapRef, mark: JcTaintMark, stepScope: JcStepScope): UBoolExpr = + stepScope.calcOnState { + memory.read(createLValue(ref, mark, stepScope)) + } + + private fun createLValue( + ref: UHeapRef, + mark: Mark, + stepScope: JcStepScope, + ): URefSetEntryLValue = URefSetEntryLValue(ref, getMarkAddress(mark, stepScope), mark) + + + fun addTarget(target: JcTarget): TaintAnalysis { + require(target is TaintTarget) + + targets += target + exposeTargets(target, taintTargets) + + return this + } + + private fun findTaintTargets(stmt: JcInst, state: JcState): List = + taintTargets[stmt]?.let { targets -> + state.targets.filter { it.uncheckedCast() in targets } + }.orEmpty().toList().uncheckedCast() + + override fun onAssignStatement(exprResolver: JcSimpleValueResolver, stmt: JcAssignInst, stepScope: JcStepScope) { + // Sinks are already processed at this moment since we resolved it on a call statement + + stmt.callExpr?.let { processTaintConfiguration(it, stepScope, exprResolver) } + + // TODO add fields processing + } + + private fun processTaintConfiguration( + callExpr: JcCallExpr, + stepScope: JcStepScope, + simpleValueResolver: JcSimpleValueResolver, + ) { + val ctx = stepScope.ctx + val methodResult = stepScope.calcOnState { methodResult } + val method = callExpr.method.method + + require(methodResult is JcMethodResult.Success) { + "Other result statuses must be processed in `onMethodCallWithUnresolvedArguments`" + } + + val callPositionResolver = createCallPositionResolver(ctx, callExpr, simpleValueResolver, methodResult) + + val conditionResolver = ConditionResolver(ctx, callPositionResolver, ::readMark) + val actionResolver = TaintActionResolver( + ctx, + callPositionResolver, + ::readMark, + ::writeMark, + ::removeMark, + marksAddresses.keys + ) + + val sourceConfigurations = configuration.methodSources[method] + val currentStatement = stepScope.calcOnState { currentStatement } + + val sourceTargets = findTaintTargets(currentStatement, stepScope.state) + .filterIsInstance() + .associateBy { it.configurationRule } + + sourceConfigurations?.forEach { + val target = sourceTargets[it] + + val resolvedCondition = + conditionResolver.visit(it.condition, simpleValueResolver, stepScope) ?: ctx.trueExpr + + val targetCondition = target?.condition ?: ConstantTrue + val resolvedTargetCondition = + conditionResolver.visit(targetCondition, simpleValueResolver, stepScope) ?: ctx.trueExpr + + val combinedCondition = ctx.mkAnd(resolvedTargetCondition, resolvedCondition) + + it.action.accept(actionResolver, stepScope, combinedCondition) + + target?.propagate(stepScope.state) + } + + val cleanerConfigurations = configuration.cleaners[method] + cleanerConfigurations?.forEach { + val resolvedCondition = conditionResolver.visit(it.condition, simpleValueResolver, stepScope) + + it.action.accept(actionResolver, stepScope, resolvedCondition) + } + + val passThroughConfigurations = configuration.passThrough[method] + passThroughConfigurations?.forEach { + val resolvedCondition = conditionResolver.visit(it.condition, simpleValueResolver, stepScope) + + it.action.accept(actionResolver, stepScope, resolvedCondition) + } + } + + private val JcStepScope.ctx get() = calcOnState { ctx } + + private fun createCallPositionResolver( + ctx: JcContext, + callExpr: JcCallExpr, + simpleValueResolver: JcSimpleValueResolver, + methodResult: JcMethodResult.Success?, + ) = CallPositionResolver( + resolveCallInstance(callExpr)?.accept(simpleValueResolver)?.asExpr(ctx.addressSort), + callExpr.args.map { it.accept(simpleValueResolver) }, + methodResult?.value + ) + + override fun onEntryPoint( + simpleValueResolver: JcSimpleValueResolver, + stmt: JcMethodEntrypointInst, + stepScope: JcStepScope, + ) { + // TODO entry point configuration + } + + override fun onMethodCallWithUnresolvedArguments( + simpleValueResolver: JcSimpleValueResolver, + stmt: JcCallExpr, + stepScope: JcStepScope, + ) { + val method = stmt.method.method + + val methodResult = stepScope.calcOnState { methodResult } + require(methodResult is JcMethodResult.NoCall) { "This signal must be sent before the method call" } + + val ctx = stepScope.ctx + + val positionResolver = createCallPositionResolver(ctx, stmt, simpleValueResolver, methodResult = null) + val conditionResolver = ConditionResolver(ctx, positionResolver, ::readMark) + + if (inTargetedMode) { + val currentStatement = stepScope.calcOnState { currentStatement } + val targets = findTaintTargets(currentStatement, stepScope.state) + + targets + .filterIsInstance() + .forEach { + processSink(it.configRule, it.condition, conditionResolver, simpleValueResolver, stepScope, it) + } + } else { + val methodSinks = configuration.methodSinks[method] ?: return + methodSinks.forEach { + processSink(it, ConstantTrue, conditionResolver, simpleValueResolver, stepScope) + } + } + } + + private val JcStepScope.state get() = calcOnState { this } + + private val inTargetedMode: Boolean + get() = targets.isNotEmpty() + + private fun processSink( + methodSink: TaintMethodSink, + sinkCondition: Condition, + conditionResolver: ConditionResolver, + simpleValueResolver: JcSimpleValueResolver, + stepScope: JcStepScope, + target: TaintMethodSinkTarget? = null, + ) { + val resolvedConfigCondition = + methodSink.condition.visit(conditionResolver, simpleValueResolver, stepScope) ?: return + + val resolvedSinkCondition = sinkCondition.visit(conditionResolver, simpleValueResolver, stepScope) ?: return + + val resolvedCondition = stepScope.ctx.mkAnd(resolvedConfigCondition, resolvedSinkCondition) + + val (originalStateCopy, taintedStepScope) = stepScope.calcOnState { + val originalStateCopy = clone() + originalStateCopy to JcStepScope(originalStateCopy) + } + + taintedStepScope.assert(resolvedCondition)?.let { + // TODO remove corresponding target + collectedStates += originalStateCopy + target?.propagate(taintedStepScope.state) + } + } + + override fun onMethodCallWithResolvedArguments( + simpleValueResolver: JcSimpleValueResolver, + stmt: JcMethodCallBaseInst, + stepScope: JcStepScope, + ) { + // It is a redundant signal + } + + override fun onCallStatement(simpleValueResolver: JcSimpleValueResolver, stmt: JcCallInst, stepScope: JcStepScope) { + processTaintConfiguration(stmt.callExpr, stepScope, simpleValueResolver) + } + + override fun onState(parent: JcState, forks: Sequence) { + propagateIntermediateTarget(parent) + + forks.forEach { propagateIntermediateTarget(it) } + } + + private fun propagateIntermediateTarget(state: JcState) { + val parent = state.pathLocation.parent ?: error("This is impossible by construction") + val targets = findTaintTargets(parent.statement, state) + + targets.forEach { + when (it) { + is TaintIntermediateTarget -> it.propagate(state) + is TaintMethodSourceTarget, is TaintMethodSinkTarget -> return@forEach + } + } + } + + private fun resolveCallInstance( + callExpr: JcCallExpr, + ) = if (callExpr is JcInstanceCallExpr) callExpr.instance else null + + sealed class TaintTarget(location: JcInst) : JcTarget(location) + + class TaintMethodSourceTarget( + location: JcInst, + val condition: Condition, + val configurationRule: TaintMethodSource, + ) : TaintTarget(location) + // TODO add field sources and sinks targets + + class TaintIntermediateTarget(location: JcInst) : TaintTarget(location) + + class TaintMethodSinkTarget( + location: JcInst, + val condition: Condition, + val configRule: TaintMethodSink, + ) : TaintTarget(location) +} + + +sealed interface JcTaintMark + +object SqlInjection : JcTaintMark + +object SensitiveData : JcTaintMark diff --git a/usvm-jvm/src/main/kotlin/org/usvm/api/targets/TaintCondition.kt b/usvm-jvm/src/main/kotlin/org/usvm/api/targets/TaintCondition.kt new file mode 100644 index 0000000000..00b9b837d5 --- /dev/null +++ b/usvm-jvm/src/main/kotlin/org/usvm/api/targets/TaintCondition.kt @@ -0,0 +1,87 @@ +package org.usvm.api.targets + +import org.usvm.UBoolExpr +import org.usvm.UHeapRef +import org.usvm.machine.JcContext +import org.usvm.machine.interpreter.JcSimpleValueResolver +import org.usvm.machine.interpreter.JcStepScope + + +interface ConditionVisitor { + fun visit(condition: Condition, simpleValueResolver: JcSimpleValueResolver, stepScope: JcStepScope): R +} + +class ConditionResolver( + private val ctx: JcContext, + private val positionResolver: PositionResolver, + private val readMark: (ref: UHeapRef, mark: JcTaintMark, JcStepScope) -> UBoolExpr, +) : ConditionVisitor { + override fun visit( + condition: Condition, + simpleValueResolver: JcSimpleValueResolver, + stepScope: JcStepScope, + ): UBoolExpr? = + when (condition) { + ConstantTrue -> ctx.trueExpr + + is BooleanFromArgument -> (positionResolver.resolve(condition.argument) as? ResolvedBoolPosition)?.resolved + + is CallParameterContainsMark -> { + val resolvedPosition = positionResolver.resolve(condition.position) as? ResolvedRefPosition + + resolvedPosition?.let { + readMark(it.resolved, condition.mark, stepScope) + } + } + + is Negation -> condition.condition.visit(this, simpleValueResolver, stepScope)?.let { + ctx.mkNot(it) + } + // TODO this code is completely non extendable due to when usage + } +} + +sealed interface Condition { + fun visit( + conditionVisitor: ConditionVisitor, + simpleValueResolver: JcSimpleValueResolver, + stepScope: JcStepScope, + ): R +} + +sealed class UnaryCondition : Condition + +sealed class BinaryCondition : Condition + +class Negation(val condition: Condition) : UnaryCondition() { + override fun visit( + conditionVisitor: ConditionVisitor, + simpleValueResolver: JcSimpleValueResolver, + stepScope: JcStepScope, + ): R = conditionVisitor.visit(this, simpleValueResolver, stepScope) +} + +object ConstantTrue : Condition { + override fun visit( + conditionVisitor: ConditionVisitor, + simpleValueResolver: JcSimpleValueResolver, + stepScope: JcStepScope, + ): R = conditionVisitor.visit(this, simpleValueResolver, stepScope) +} + +class BooleanFromArgument(val argument: Argument) : Condition { + override fun visit( + conditionVisitor: ConditionVisitor, + simpleValueResolver: JcSimpleValueResolver, + stepScope: JcStepScope, + ): R = conditionVisitor.visit(this, simpleValueResolver, stepScope) +} + +class CallParameterContainsMark(val position: Position, val mark: JcTaintMark) : UnaryCondition() { + override fun visit( + conditionVisitor: ConditionVisitor, + simpleValueResolver: JcSimpleValueResolver, + stepScope: JcStepScope, + ): R = conditionVisitor.visit(this, simpleValueResolver, stepScope) +} + diff --git a/usvm-jvm/src/main/kotlin/org/usvm/api/targets/TaintConfigurationItem.kt b/usvm-jvm/src/main/kotlin/org/usvm/api/targets/TaintConfigurationItem.kt new file mode 100644 index 0000000000..536e46e255 --- /dev/null +++ b/usvm-jvm/src/main/kotlin/org/usvm/api/targets/TaintConfigurationItem.kt @@ -0,0 +1,58 @@ +package org.usvm.api.targets + +import org.jacodb.api.JcField +import org.jacodb.api.JcMethod + +// TODO separate cleaning actions from the ones who can taint data +data class TaintConfiguration( + val entryPoints: Map>, + val methodSources: Map>, + val fieldSources: Map>, + val passThrough: Map>, + val cleaners: Map>, + val methodSinks: Map>, + val fieldSinks: Map>, +) + +sealed interface TaintConfigurationItem + +class TaintEntryPointSource( + val method: JcMethod, + val condition: Condition, + val action: Action, +) : TaintConfigurationItem + +class TaintMethodSource( + val method: JcMethod, + val condition: Condition, + val action: Action, +) : TaintConfigurationItem + +class TaintFieldSource( + val field: JcField, + val condition: Condition, + val action: Action, +) : TaintConfigurationItem + +class TaintMethodSink( + val condition: Condition, + val method: JcMethod, +) : TaintConfigurationItem + +class TaintFieldSink( + val condition: Condition, + val field: JcField, +) : TaintConfigurationItem + +class TaintPassThrough( + val methodInfo: JcMethod, + val condition: Condition, + val action: Action, +) : TaintConfigurationItem + +class TaintCleaner( + val methodInfo: JcMethod, + val condition: Condition, + val action: Action, +) : TaintConfigurationItem + diff --git a/usvm-jvm/src/main/kotlin/org/usvm/api/util/JcTestResolver.kt b/usvm-jvm/src/main/kotlin/org/usvm/api/util/JcTestResolver.kt index 3a6898c364..82dfd05201 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/api/util/JcTestResolver.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/api/util/JcTestResolver.kt @@ -72,7 +72,7 @@ class JcTestResolver( val model = state.models.first() val memory = state.memory - val ctx = state.pathConstraints.ctx + val ctx = state.ctx val initialScope = MemoryScope(ctx, model, model, method, classLoader) val afterScope = MemoryScope(ctx, model, memory, method, classLoader) diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/JcApproximations.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcApproximations.kt index 7ab46177b9..6694dbc467 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/JcApproximations.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcApproximations.kt @@ -124,12 +124,14 @@ class JcMethodApproximationResolver( val classNameRef = arguments.single() val predefinedTypeNames = ctx.primitiveTypes.associateBy { - exprResolver.resolveStringConstant(it.typeName) + exprResolver.simpleValueResolver.resolveStringConstant(it.typeName) } val primitive = predefinedTypeNames[classNameRef] ?: return false - val classRef = exprResolver.resolveClassRef(primitive) + val classRef = with(exprResolver.simpleValueResolver) { + resolveClassRef(primitive) + } scope.doWithState { skipMethodInvocationWithValue(methodCall, classRef) @@ -172,7 +174,10 @@ class JcMethodApproximationResolver( * */ val type = possibleTypes.singleOrNull() ?: return false - val result = exprResolver.resolveClassRef(type) + val result = with(exprResolver.simpleValueResolver) { + resolveClassRef(type) + } + scope.doWithState { skipMethodInvocationWithValue(methodCall, result) } @@ -204,7 +209,9 @@ class JcMethodApproximationResolver( ) val primitiveArrayRefScale = primitiveArrayScale.mapKeys { (type, _) -> - exprResolver.resolveClassRef(ctx.cp.arrayTypeOf(type)) + with(exprResolver.simpleValueResolver) { + resolveClassRef(ctx.cp.arrayTypeOf(type)) + } } val arrayTypeRef = arguments.last().asExpr(ctx.addressSort) diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/JcComponents.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcComponents.kt index b1642be142..cd5bbddfe6 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/JcComponents.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcComponents.kt @@ -7,28 +7,28 @@ import org.jacodb.api.JcType import org.usvm.SolverType import org.usvm.UComponents import org.usvm.UContext -import org.usvm.model.buildTranslatorAndLazyDecoder import org.usvm.solver.USoftConstraintsProvider import org.usvm.solver.USolverBase import org.usvm.solver.UTypeSolver class JcComponents( private val typeSystem: JcTypeSystem, - private val solverType: SolverType + private val solverType: SolverType, ) : UComponents { private val closeableResources = mutableListOf() - override fun mkSolver(ctx: Context): USolverBase { - val (translator, decoder) = buildTranslatorAndLazyDecoder(ctx) + + override fun mkSolver(ctx: UContext): USolverBase { + val (translator, decoder) = buildTranslatorAndLazyDecoder(ctx) val softConstraintsProvider = USoftConstraintsProvider(ctx) - val smtSolver = - when (solverType) { - // Yices with Fp support via SymFpu - SolverType.YICES -> KSymFpuSolver(KYicesSolver(ctx), ctx) - SolverType.Z3 -> KZ3Solver(ctx) - } + val smtSolver = when (solverType) { + // Yices with Fp support via SymFpu + SolverType.YICES -> KSymFpuSolver(KYicesSolver(ctx), ctx) + SolverType.Z3 -> KZ3Solver(ctx) + } val typeSolver = UTypeSolver(typeSystem) closeableResources += smtSolver + return USolverBase(ctx, smtSolver, typeSolver, translator, decoder, softConstraintsProvider) } diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/JcInterpreterObserver.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcInterpreterObserver.kt new file mode 100644 index 0000000000..7d014e169a --- /dev/null +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcInterpreterObserver.kt @@ -0,0 +1,31 @@ +package org.usvm.machine + +import org.jacodb.api.cfg.JcAssignInst +import org.jacodb.api.cfg.JcCallExpr +import org.jacodb.api.cfg.JcCallInst +import org.jacodb.api.cfg.JcEnterMonitorInst +import org.jacodb.api.cfg.JcExitMonitorInst +import org.jacodb.api.cfg.JcGotoInst +import org.jacodb.api.cfg.JcIfInst +import org.jacodb.api.cfg.JcReturnInst +import org.jacodb.api.cfg.JcSwitchInst +import org.jacodb.api.cfg.JcThrowInst +import org.usvm.machine.interpreter.JcSimpleValueResolver +import org.usvm.machine.interpreter.JcStepScope +import org.usvm.statistics.UInterpreterObserver + +interface JcInterpreterObserver : UInterpreterObserver { + fun onAssignStatement(exprResolver: JcSimpleValueResolver, stmt: JcAssignInst, stepScope: JcStepScope) {} + fun onEntryPoint(simpleValueResolver: JcSimpleValueResolver, stmt: JcMethodEntrypointInst, stepScope: JcStepScope) + fun onMethodCallWithUnresolvedArguments(simpleValueResolver: JcSimpleValueResolver, stmt: JcCallExpr, stepScope: JcStepScope) {} + fun onMethodCallWithResolvedArguments(simpleValueResolver: JcSimpleValueResolver, stmt: JcMethodCallBaseInst, stepScope: JcStepScope) {} + fun onIfStatement(simpleValueResolver: JcSimpleValueResolver, stmt: JcIfInst, stepScope: JcStepScope) {} + fun onReturnStatement(simpleValueResolver: JcSimpleValueResolver, stmt: JcReturnInst, stepScope: JcStepScope) {} + fun onGotoStatement(simpleValueResolver: JcSimpleValueResolver, stmt: JcGotoInst, stepScope: JcStepScope) {} + fun onCatchStatement(simpleValueResolver: JcSimpleValueResolver, stepScope: JcStepScope) {} + fun onSwitchStatement(simpleValueResolver: JcSimpleValueResolver, stmt: JcSwitchInst, stepScope: JcStepScope) {} + fun onThrowStatement(simpleValueResolver: JcSimpleValueResolver, stmt: JcThrowInst, stepScope: JcStepScope) {} + fun onCallStatement(simpleValueResolver: JcSimpleValueResolver, stmt: JcCallInst, stepScope: JcStepScope) {} + fun onEnterMonitorStatement(simpleValueResolver: JcSimpleValueResolver, stmt: JcEnterMonitorInst, stepScope: JcStepScope) {} + fun onExitMonitorStatement(simpleValueResolver: JcSimpleValueResolver, stmt: JcExitMonitorInst, stepScope: JcStepScope) {} +} diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/JcMachine.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcMachine.kt index e0dff78f9b..ee14726902 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/JcMachine.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcMachine.kt @@ -25,12 +25,14 @@ import org.usvm.statistics.collectors.TargetsReachedStatesCollector import org.usvm.statistics.distances.CfgStatisticsImpl import org.usvm.statistics.distances.PlainCallGraphStatistics import org.usvm.stopstrategies.createStopStrategy +import org.usvm.targets.UTargetController val logger = object : KLogging() {}.logger class JcMachine( cp: JcClasspath, - private val options: UMachineOptions + private val options: UMachineOptions, + private val interpreterObserver: JcInterpreterObserver? = null ) : UMachine() { private val applicationGraph = JcApplicationGraph(cp) @@ -38,11 +40,11 @@ class JcMachine( private val components = JcComponents(typeSystem, options.solverType) private val ctx = JcContext(cp, components) - private val interpreter = JcInterpreter(ctx, applicationGraph) + private val interpreter = JcInterpreter(ctx, applicationGraph, interpreterObserver) private val cfgStatistics = CfgStatisticsImpl(applicationGraph) - fun analyze(method: JcMethod, targets: List = emptyList()): List { + fun analyze(method: JcMethod, targets: List> = emptyList()): List { logger.debug("{}.analyze({}, {})", this, method, targets) val initialState = interpreter.getInitialState(method, targets) @@ -99,6 +101,11 @@ class JcMachine( val observers = mutableListOf>(coverageStatistics) observers.add(TerminatedStateRemover()) + if (interpreterObserver is UMachineObserver<*>) { + @Suppress("UNCHECKED_CAST") + observers.add(interpreterObserver as UMachineObserver) + } + if (options.coverageZone == CoverageZone.TRANSITIVE) { observers.add( TransitiveCoverageZoneObserver( diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcExprResolver.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcExprResolver.kt index 7fd0043522..ad6bd1a3e3 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcExprResolver.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcExprResolver.kt @@ -60,6 +60,7 @@ import org.jacodb.api.cfg.JcRemExpr import org.jacodb.api.cfg.JcShlExpr import org.jacodb.api.cfg.JcShort import org.jacodb.api.cfg.JcShrExpr +import org.jacodb.api.cfg.JcSimpleValue import org.jacodb.api.cfg.JcSpecialCallExpr import org.jacodb.api.cfg.JcStaticCallExpr import org.jacodb.api.cfg.JcStringConstant @@ -122,12 +123,20 @@ import org.usvm.util.write class JcExprResolver( private val ctx: JcContext, private val scope: JcStepScope, - private val localToIdx: (JcMethod, JcLocal) -> Int, - private val mkTypeRef: (JcType, JcState) -> UConcreteHeapRef, - private val mkStringConstRef: (String, JcState) -> UConcreteHeapRef, + localToIdx: (JcMethod, JcLocal) -> Int, + mkTypeRef: (JcType, JcState) -> UConcreteHeapRef, + mkStringConstRef: (String, JcState) -> UConcreteHeapRef, private val classInitializerAnalysisAlwaysRequiredForType: (JcRefType) -> Boolean, private val hardMaxArrayLength: Int = 1_500, // TODO: move to options ) : JcExprVisitor?> { + val simpleValueResolver: JcSimpleValueResolver = JcSimpleValueResolver( + ctx, + scope, + localToIdx, + mkTypeRef, + mkStringConstRef + ) + /** * Resolves the [expr] and casts it to match the desired [type]. * @@ -164,7 +173,7 @@ class JcExprResolver( when (value) { is JcFieldRef -> resolveFieldRef(value.instance, value.field) is JcArrayAccess -> resolveArrayAccess(value.array, value.index) - is JcLocal -> resolveLocal(value) + is JcLocal -> simpleValueResolver.resolveLocal(value) else -> error("Unexpected value: $value") } @@ -262,105 +271,41 @@ class JcExprResolver( // region constants - override fun visitJcBool(value: JcBool): UExpr = with(ctx) { - mkBool(value.value) - } - - override fun visitJcChar(value: JcChar): UExpr = with(ctx) { - mkBv(value.value.code, charSort) - } - - override fun visitJcByte(value: JcByte): UExpr = with(ctx) { - mkBv(value.value, byteSort) - } - - override fun visitJcShort(value: JcShort): UExpr = with(ctx) { - mkBv(value.value, shortSort) - } - - override fun visitJcInt(value: JcInt): UExpr = with(ctx) { - mkBv(value.value, integerSort) - } - - override fun visitJcLong(value: JcLong): UExpr = with(ctx) { - mkBv(value.value, longSort) - } - - override fun visitJcFloat(value: JcFloat): UExpr = with(ctx) { - mkFp(value.value, floatSort) - } - - override fun visitJcDouble(value: JcDouble): UExpr = with(ctx) { - mkFp(value.value, doubleSort) - } - - override fun visitJcNullConstant(value: JcNullConstant): UExpr = with(ctx) { - nullRef - } - - override fun visitJcStringConstant(value: JcStringConstant): UExpr = with(ctx) { - scope.calcOnState { - // Equal string constants always have equal references - val ref = resolveStringConstant(value.value) - val stringValueLValue = UFieldLValue(addressSort, ref, stringValueField.field) - val stringCoderLValue = UFieldLValue(byteSort, ref, stringCoderField.field) - - // String.value type depends on the JVM version - val charValues = when (stringValueField.fieldType.ifArrayGetElementType) { - cp.char -> value.value.asSequence().map { mkBv(it.code, charSort) } - cp.byte -> value.value.encodeToByteArray().asSequence().map { mkBv(it, byteSort) } - else -> error("Unexpected string values type: ${stringValueField.fieldType}") - } - - val valuesArrayDescriptor = arrayDescriptorOf(stringValueField.fieldType as JcArrayType) - val elementType = requireNotNull(stringValueField.fieldType.ifArrayGetElementType) - val charArrayRef = memory.allocateArrayInitialized( - valuesArrayDescriptor, - typeToSort(elementType), - charValues.uncheckedCast() - ) + override fun visitJcBool(value: JcBool): UExpr = simpleValueResolver.visitJcBool(value) - // overwrite array type because descriptor is element type - memory.types.allocate(charArrayRef.address, stringValueField.fieldType) + override fun visitJcChar(value: JcChar): UExpr = simpleValueResolver.visitJcChar(value) - // String constants are immutable. Therefore, it is correct to overwrite value, coder and type. - memory.write(stringValueLValue, charArrayRef) - memory.write(stringCoderLValue, mkBv(0, byteSort)) - memory.types.allocate(ref.address, stringType) + override fun visitJcByte(value: JcByte): UExpr = simpleValueResolver.visitJcByte(value) - ref - } - } + override fun visitJcShort(value: JcShort): UExpr = simpleValueResolver.visitJcShort(value) - fun resolveStringConstant(value: String): UConcreteHeapRef = scope.calcOnState { - mkStringConstRef(value, this) - } + override fun visitJcInt(value: JcInt): UExpr = simpleValueResolver.visitJcInt(value) - override fun visitJcMethodConstant(value: JcMethodConstant): UExpr { - TODO("Method constant") - } + override fun visitJcLong(value: JcLong): UExpr = simpleValueResolver.visitJcLong(value) - override fun visitJcMethodType(value: JcMethodType): UExpr { - TODO("Method type") - } + override fun visitJcFloat(value: JcFloat): UExpr = simpleValueResolver.visitJcFloat(value) - fun resolveClassRef(type: JcType): UConcreteHeapRef = scope.calcOnState { - val ref = mkTypeRef(type, this) - val classRefTypeLValue = UFieldLValue(ctx.addressSort, ref, ctx.classTypeSyntheticField) + override fun visitJcDouble(value: JcDouble): UExpr = simpleValueResolver.visitJcDouble(value) - // Ref type is java.lang.Class - memory.types.allocate(ref.address, ctx.classType) + override fun visitJcNullConstant( + value: JcNullConstant, + ): UExpr = simpleValueResolver.visitJcNullConstant(value) - // Save ref original class type with the negative address - val classRefType = memory.allocStatic(type) - memory.write(classRefTypeLValue, classRefType) + override fun visitJcStringConstant( + value: JcStringConstant, + ): UExpr = simpleValueResolver.visitJcStringConstant(value) - ref - } + override fun visitJcMethodConstant( + value: JcMethodConstant, + ): UExpr = simpleValueResolver.visitJcMethodConstant(value) - override fun visitJcClassConstant(value: JcClassConstant): UExpr = - resolveClassRef(value.klass) + override fun visitJcMethodType( + value: JcMethodType, + ): UExpr = simpleValueResolver.visitJcMethodType(value) + override fun visitJcClassConstant( + value: JcClassConstant, + ): UExpr = simpleValueResolver.visitJcClassConstant(value) // endregion override fun visitJcCastExpr(expr: JcCastExpr): UExpr? = resolveCast(expr.operand, expr.type) @@ -513,20 +458,11 @@ class JcExprResolver( // region jc locals - override fun visitJcLocalVar(value: JcLocalVar): UExpr = with(ctx) { - val ref = resolveLocal(value) - scope.calcOnState { memory.read(ref) } - } + override fun visitJcLocalVar(value: JcLocalVar): UExpr = simpleValueResolver.visitJcLocalVar(value) - override fun visitJcThis(value: JcThis): UExpr = with(ctx) { - val ref = resolveLocal(value) - scope.calcOnState { memory.read(ref) } - } + override fun visitJcThis(value: JcThis): UExpr = simpleValueResolver.visitJcThis(value) - override fun visitJcArgument(value: JcArgument): UExpr = with(ctx) { - val ref = resolveLocal(value) - scope.calcOnState { memory.read(ref) } - } + override fun visitJcArgument(value: JcArgument): UExpr = simpleValueResolver.visitJcArgument(value) // endregion @@ -730,7 +666,9 @@ class JcExprResolver( return body() } - val classRef = resolveClassRef(type) + val classRef = with(simpleValueResolver) { + resolveClassRef(type) + } val initializedFlag = staticFieldsInitializedFlag(type, classRef) @@ -791,13 +729,6 @@ class JcExprResolver( return UArrayIndexLValue(cellSort, arrayRef, idx, arrayDescriptor) } - private fun resolveLocal(local: JcLocal): URegisterStackLValue<*> { - val method = requireNotNull(scope.calcOnState { lastEnteredMethod }) - val localIdx = localToIdx(method, local) - val sort = ctx.typeToSort(local.type) - return URegisterStackLValue(sort, localIdx) - } - // endregion // region implicit exceptions @@ -1029,3 +960,240 @@ class JcExprResolver( } } } + +class JcSimpleValueResolver( + private val ctx: JcContext, + private val scope: JcStepScope, + private val localToIdx: (JcMethod, JcLocal) -> Int, + private val mkTypeRef: (JcType, JcState) -> UConcreteHeapRef, + private val mkStringConstRef: (String, JcState) -> UConcreteHeapRef, +) : JcExprVisitor> { + override fun visitJcArgument(value: JcArgument): UExpr = with(ctx) { + val ref = resolveLocal(value) + scope.calcOnState { memory.read(ref) } + } + + override fun visitJcBool(value: JcBool): UExpr = with(ctx) { + mkBool(value.value) + } + + override fun visitJcChar(value: JcChar): UExpr = with(ctx) { + mkBv(value.value.code, charSort) + } + + override fun visitJcByte(value: JcByte): UExpr = with(ctx) { + mkBv(value.value, byteSort) + } + + override fun visitJcShort(value: JcShort): UExpr = with(ctx) { + mkBv(value.value, shortSort) + } + + override fun visitJcInt(value: JcInt): UExpr = with(ctx) { + mkBv(value.value, integerSort) + } + + override fun visitJcLong(value: JcLong): UExpr = with(ctx) { + mkBv(value.value, longSort) + } + + override fun visitJcFloat(value: JcFloat): UExpr = with(ctx) { + mkFp(value.value, floatSort) + } + + override fun visitJcDouble(value: JcDouble): UExpr = with(ctx) { + mkFp(value.value, doubleSort) + } + + override fun visitJcNullConstant(value: JcNullConstant): UExpr = with(ctx) { + nullRef + } + + override fun visitJcStringConstant(value: JcStringConstant): UExpr = with(ctx) { + scope.calcOnState { + // Equal string constants always have equal references + val ref = resolveStringConstant(value.value) + val stringValueLValue = UFieldLValue(addressSort, ref, stringValueField.field) + val stringCoderLValue = UFieldLValue(byteSort, ref, stringCoderField.field) + + // String.value type depends on the JVM version + val charValues = when (stringValueField.fieldType.ifArrayGetElementType) { + cp.char -> value.value.asSequence().map { mkBv(it.code, charSort) } + cp.byte -> value.value.encodeToByteArray().asSequence().map { mkBv(it, byteSort) } + else -> error("Unexpected string values type: ${stringValueField.fieldType}") + } + + val valuesArrayDescriptor = arrayDescriptorOf(stringValueField.fieldType as JcArrayType) + val elementType = requireNotNull(stringValueField.fieldType.ifArrayGetElementType) + val charArrayRef = memory.allocateArrayInitialized( + valuesArrayDescriptor, + typeToSort(elementType), + charValues.uncheckedCast() + ) + + // overwrite array type because descriptor is element type + memory.types.allocate(charArrayRef.address, stringValueField.fieldType) + + // String constants are immutable. Therefore, it is correct to overwrite value, coder and type. + memory.write(stringValueLValue, charArrayRef) + memory.write(stringCoderLValue, mkBv(0, byteSort)) + memory.types.allocate(ref.address, stringType) + + ref + } + } + + + override fun visitJcClassConstant(value: JcClassConstant): UExpr = + resolveClassRef(value.klass) + + override fun visitJcMethodConstant(value: JcMethodConstant): UExpr { + TODO("Method constant") + } + + override fun visitJcMethodType(value: JcMethodType): UExpr { + TODO("Method type") + } + + override fun visitJcLocalVar(value: JcLocalVar): UExpr = with(ctx) { + val ref = resolveLocal(value) + scope.calcOnState { memory.read(ref) } + } + + override fun visitJcThis(value: JcThis): UExpr = with(ctx) { + val ref = resolveLocal(value) + scope.calcOnState { memory.read(ref) } + } + + override fun visitExternalJcExpr(expr: JcExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcAddExpr(expr: JcAddExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcAndExpr(expr: JcAndExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcArrayAccess(value: JcArrayAccess): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcCastExpr(expr: JcCastExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcCmpExpr(expr: JcCmpExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcCmpgExpr(expr: JcCmpgExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcCmplExpr(expr: JcCmplExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcDivExpr(expr: JcDivExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcDynamicCallExpr(expr: JcDynamicCallExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcEqExpr(expr: JcEqExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcFieldRef(value: JcFieldRef): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcGeExpr(expr: JcGeExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcGtExpr(expr: JcGtExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcInstanceOfExpr(expr: JcInstanceOfExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcLambdaExpr(expr: JcLambdaExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcLeExpr(expr: JcLeExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcLengthExpr(expr: JcLengthExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcLtExpr(expr: JcLtExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcMulExpr(expr: JcMulExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcNegExpr(expr: JcNegExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcNeqExpr(expr: JcNeqExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcNewArrayExpr(expr: JcNewArrayExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcNewExpr(expr: JcNewExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcOrExpr(expr: JcOrExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcPhiExpr(expr: JcPhiExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcRemExpr(expr: JcRemExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcShlExpr(expr: JcShlExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcShrExpr(expr: JcShrExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcSpecialCallExpr(expr: JcSpecialCallExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcStaticCallExpr(expr: JcStaticCallExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcSubExpr(expr: JcSubExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcUshrExpr(expr: JcUshrExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcVirtualCallExpr(expr: JcVirtualCallExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + override fun visitJcXorExpr(expr: JcXorExpr): UExpr = + error("Simple expr resolver must resolve only inheritors of ${JcSimpleValue::class}.") + + fun resolveLocal(local: JcLocal): URegisterStackLValue<*> { + val method = requireNotNull(scope.calcOnState { lastEnteredMethod }) + val localIdx = localToIdx(method, local) + val sort = ctx.typeToSort(local.type) + return URegisterStackLValue(sort, localIdx) + } + + fun resolveClassRef(type: JcType): UConcreteHeapRef = scope.calcOnState { + val ref = mkTypeRef(type, this) + val classRefTypeLValue = UFieldLValue(ctx.addressSort, ref, ctx.classTypeSyntheticField) + + // Ref type is java.lang.Class + memory.types.allocate(ref.address, ctx.classType) + + // Save ref original class type with the negative address + val classRefType = memory.allocStatic(type) + memory.write(classRefTypeLValue, classRefType) + + ref + } + + + fun resolveStringConstant(value: String): UConcreteHeapRef = + scope.calcOnState { + mkStringConstRef(value, this) + } +} \ No newline at end of file diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcInterpreter.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcInterpreter.kt index 5e6f15c3b9..2575f57b02 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcInterpreter.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcInterpreter.kt @@ -28,6 +28,7 @@ import org.jacodb.api.cfg.JcSwitchInst import org.jacodb.api.cfg.JcThis import org.jacodb.api.cfg.JcThrowInst import org.jacodb.api.ext.boolean +import org.jacodb.api.ext.cfg.callExpr import org.jacodb.api.ext.isEnum import org.jacodb.api.ext.void import org.usvm.StepResult @@ -44,6 +45,7 @@ import org.usvm.isStaticHeapRef import org.usvm.machine.JcApplicationGraph import org.usvm.machine.JcConcreteMethodCallInst import org.usvm.machine.JcContext +import org.usvm.machine.JcInterpreterObserver import org.usvm.machine.JcMethodApproximationResolver import org.usvm.machine.JcMethodCall import org.usvm.machine.JcMethodCallBaseInst @@ -63,6 +65,7 @@ import org.usvm.machine.state.throwExceptionAndDropStackFrame import org.usvm.machine.state.throwExceptionWithoutStackFrameDrop import org.usvm.memory.URegisterStackLValue import org.usvm.solver.USatResult +import org.usvm.targets.UTargetController import org.usvm.types.first import org.usvm.util.findMethod import org.usvm.util.write @@ -75,13 +78,14 @@ typealias JcStepScope = StepScope class JcInterpreter( private val ctx: JcContext, private val applicationGraph: JcApplicationGraph, + private val observer: JcInterpreterObserver? = null, ) : UInterpreter() { companion object { val logger = object : KLogging() {}.logger } - fun getInitialState(method: JcMethod, targets: List = emptyList()): JcState { + fun getInitialState(method: JcMethod, targets: List> = emptyList()): JcState { val state = JcState(ctx, targets = targets) val typedMethod = with(applicationGraph) { method.typed } @@ -111,7 +115,7 @@ class JcInterpreter( } } - val solver = ctx.solver() + val solver = ctx.solver() val model = (solver.checkWithSoftConstraints(state.pathConstraints) as USatResult).model state.models = listOf(model) @@ -149,6 +153,7 @@ class JcInterpreter( is JcExitMonitorInst -> visitMonitorExitStmt(scope, stmt) else -> error("Unknown stmt: $stmt") } + return scope.stepResult() } @@ -196,14 +201,20 @@ class JcInterpreter( val catchSectionMiss = typeConditionToMiss to functionBlockOnMiss + // TODO observer?.onCatchStatement + scope.forkMulti(catchForks + catchSectionMiss) } private val typeSelector = JcFixedInheritorsNumberTypeSelector() private fun visitMethodCall(scope: JcStepScope, stmt: JcMethodCallBaseInst) { + val exprResolver = exprResolverWithScope(scope) + val simpleValueResolver = exprResolver.simpleValueResolver + when (stmt) { is JcMethodEntrypointInst -> { + observer?.onEntryPoint(simpleValueResolver, stmt, scope) scope.doWithState { if (callStack.isEmpty()) { val method = stmt.method @@ -212,7 +223,6 @@ class JcInterpreter( } } - val exprResolver = exprResolverWithScope(scope) // Run static initializer for all enum arguments of the entrypoint for ((type, ref) in stmt.entrypointArguments) { exprResolver.ensureExprCorrectness(ref, type) ?: return @@ -226,6 +236,7 @@ class JcInterpreter( } is JcConcreteMethodCallInst -> { + observer?.onMethodCallWithResolvedArguments(simpleValueResolver, stmt, scope) if (approximateMethod(scope, stmt)) { return } @@ -241,6 +252,8 @@ class JcInterpreter( } is JcVirtualMethodCallInst -> { + observer?.onMethodCallWithResolvedArguments(simpleValueResolver, stmt, scope) + if (approximateMethod(scope, stmt)) { return } @@ -252,6 +265,18 @@ class JcInterpreter( private fun visitAssignInst(scope: JcStepScope, stmt: JcAssignInst) { val exprResolver = exprResolverWithScope(scope) + + + stmt.callExpr?.let { + val methodResult = scope.calcOnState { methodResult } + + when (methodResult) { + is JcMethodResult.NoCall -> observer?.onMethodCallWithUnresolvedArguments(exprResolver.simpleValueResolver, it, scope) + is JcMethodResult.Success -> observer?.onAssignStatement(exprResolver.simpleValueResolver, stmt, scope) + is JcMethodResult.JcException -> error("Exceptions must be processed earlier") + } + } + val lvalue = exprResolver.resolveLValue(stmt.lhv) ?: return val expr = exprResolver.resolveJcExpr(stmt.rhv, stmt.lhv.type) ?: return @@ -265,6 +290,8 @@ class JcInterpreter( private fun visitIfStmt(scope: JcStepScope, stmt: JcIfInst) { val exprResolver = exprResolverWithScope(scope) + observer?.onIfStatement(exprResolver.simpleValueResolver, stmt, scope) + val boolExpr = exprResolver .resolveJcExpr(stmt.condition) ?.asExpr(ctx.boolSort) @@ -282,6 +309,9 @@ class JcInterpreter( private fun visitReturnStmt(scope: JcStepScope, stmt: JcReturnInst) { val exprResolver = exprResolverWithScope(scope) + + observer?.onReturnStatement(exprResolver.simpleValueResolver, stmt, scope) + val method = requireNotNull(scope.calcOnState { callStack.lastMethod() }) val returnType = with(applicationGraph) { method.typed }.returnType @@ -295,6 +325,10 @@ class JcInterpreter( } private fun visitGotoStmt(scope: JcStepScope, stmt: JcGotoInst) { + val exprResolver = exprResolverWithScope(scope) + + observer?.onGotoStatement(exprResolver.simpleValueResolver, stmt, scope) + val nextStmt = stmt.location.method.instList[stmt.target.index] scope.doWithState { newStmt(nextStmt) } } @@ -307,6 +341,8 @@ class JcInterpreter( private fun visitSwitchStmt(scope: JcStepScope, stmt: JcSwitchInst) { val exprResolver = exprResolverWithScope(scope) + observer?.onSwitchStatement(exprResolver.simpleValueResolver, stmt, scope) + val switchKey = stmt.key // Note that the switch key can be an rvalue, for example, a simple int constant. val instList = stmt.location.method.instList @@ -330,8 +366,11 @@ class JcInterpreter( } private fun visitThrowStmt(scope: JcStepScope, stmt: JcThrowInst) { - val resolver = exprResolverWithScope(scope) - val address = resolver.resolveJcExpr(stmt.throwable)?.asExpr(ctx.addressSort) ?: return + val exprResolver = exprResolverWithScope(scope) + + observer?.onThrowStatement(exprResolver.simpleValueResolver, stmt, scope) + + val address = exprResolver.resolveJcExpr(stmt.throwable)?.asExpr(ctx.addressSort) ?: return scope.calcOnState { throwExceptionWithoutStackFrameDrop(address, stmt.throwable.type) @@ -340,7 +379,16 @@ class JcInterpreter( private fun visitCallStmt(scope: JcStepScope, stmt: JcCallInst) { val exprResolver = exprResolverWithScope(scope) - exprResolver.resolveJcExpr(stmt.callExpr) ?: return + val callExpr = stmt.callExpr + val methodResult = scope.calcOnState { methodResult } + + when (methodResult) { + is JcMethodResult.NoCall -> observer?.onMethodCallWithUnresolvedArguments(exprResolver.simpleValueResolver, callExpr, scope) + is JcMethodResult.Success -> observer?.onCallStatement(exprResolver.simpleValueResolver, stmt, scope) + is JcMethodResult.JcException -> error("Exceptions must be processed earlier") + } + + exprResolver.resolveJcExpr(callExpr) ?: return scope.doWithState { val nextStmt = stmt.nextStmt @@ -352,6 +400,8 @@ class JcInterpreter( val exprResolver = exprResolverWithScope(scope) exprResolver.resolveJcNotNullRefExpr(stmt.monitor, stmt.monitor.type) ?: return + observer?.onEnterMonitorStatement(exprResolver.simpleValueResolver, stmt, scope) + // Monitor enter makes sense only in multithreaded environment scope.doWithState { @@ -363,6 +413,8 @@ class JcInterpreter( val exprResolver = exprResolverWithScope(scope) exprResolver.resolveJcNotNullRefExpr(stmt.monitor, stmt.monitor.type) ?: return + observer?.onExitMonitorStatement(exprResolver.simpleValueResolver, stmt, scope) + // Monitor exit makes sense only in multithreaded environment scope.doWithState { @@ -511,7 +563,7 @@ class JcInterpreter( private fun mockNativeMethod( scope: JcStepScope, - methodCall: JcConcreteMethodCallInst + methodCall: JcConcreteMethodCallInst, ) = with(methodCall) { logger.warn { "Mocked: ${method.enclosingClass.name}::${method.name}" } diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcState.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcState.kt index 066c9ded34..5ea9303483 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcState.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcState.kt @@ -11,17 +11,18 @@ import org.usvm.constraints.UPathConstraints import org.usvm.machine.JcContext import org.usvm.memory.UMemory import org.usvm.model.UModelBase +import org.usvm.targets.UTargetController class JcState( ctx: JcContext, callStack: UCallStack = UCallStack(), - pathConstraints: UPathConstraints = UPathConstraints(ctx), + pathConstraints: UPathConstraints = UPathConstraints(ctx), memory: UMemory = UMemory(ctx, pathConstraints.typeConstraints), models: List> = listOf(), override var pathLocation: PathsTrieNode = ctx.mkInitialLocation(), var methodResult: JcMethodResult = JcMethodResult.NoCall, - targets: List = emptyList(), -) : UState( + targets: List> = emptyList(), +) : UState, JcState>( ctx, callStack, pathConstraints, @@ -30,10 +31,10 @@ class JcState( pathLocation, targets ) { - override fun clone(newConstraints: UPathConstraints?): JcState { + override fun clone(newConstraints: UPathConstraints?): JcState { val clonedConstraints = newConstraints ?: pathConstraints.clone() return JcState( - pathConstraints.ctx, + ctx, callStack.clone(), clonedConstraints, memory.clone(clonedConstraints.typeConstraints), diff --git a/usvm-jvm/src/samples/java/org/usvm/samples/taint/Taint.java b/usvm-jvm/src/samples/java/org/usvm/samples/taint/Taint.java new file mode 100644 index 0000000000..5519d8b776 --- /dev/null +++ b/usvm-jvm/src/samples/java/org/usvm/samples/taint/Taint.java @@ -0,0 +1,87 @@ +package org.usvm.samples.taint; + +public class Taint { + public void taintedEntrySource(String taintedVariable, String cleanVariable, boolean returnTainted) { + String value; + if (returnTainted) { + value = taintedVariable; + } else { + value = cleanVariable; + } + + consumerOfInjections(value); + } + + public int simpleTaint(boolean x) { + String value = stringProducer(x); + + consumerOfInjections(value); + + return value.length(); + } + + public int simpleFalsePositive(boolean x) { + String value = stringProducer(x); + String[] array = new String[2]; + + array[0] = value; + array[1] = "safe_data"; + + consumerOfInjections(array[1]); + + return value.length(); + } + + public int simpleTruePositive(boolean x, int i) { + String value = stringProducer(x); + String[] array = new String[2]; + + array[0] = value; + array[1] = "safe_data"; + + consumerOfInjections(array[i]); + + return value.length(); + } + + public int taintWithReturningValue(boolean x) { + String value = stringProducer(x); + + return consumerWithReturningValue(value); + } + + public void goThroughCleaner() { + String value = stringProducer(true); + + String cleanData = cleaner(value); + consumerOfInjections(cleanData); + } + + // TODO add tests for PassThrough + + + public void consumerOfInjections(String data) { + // empty + } + + public void consumerOfSensitiveData(String data) { + // empty + } + + public int consumerWithReturningValue(String data) { + return 1; + } + + + public String cleaner(String data) { + return data; + } + + public String stringProducer(boolean produceTaint) { + if (produceTaint) { + return "taintedData"; + } + + return ""; + } +} diff --git a/usvm-jvm/src/test/kotlin/org/usvm/samples/JavaMethodTestRunner.kt b/usvm-jvm/src/test/kotlin/org/usvm/samples/JavaMethodTestRunner.kt index bdd691376e..61b3b860e9 100644 --- a/usvm-jvm/src/test/kotlin/org/usvm/samples/JavaMethodTestRunner.kt +++ b/usvm-jvm/src/test/kotlin/org/usvm/samples/JavaMethodTestRunner.kt @@ -11,7 +11,9 @@ import org.usvm.api.JcParametersState import org.usvm.api.JcTest import org.usvm.api.targets.JcTarget import org.usvm.api.util.JcTestResolver +import org.usvm.machine.JcInterpreterObserver import org.usvm.machine.JcMachine +import org.usvm.targets.UTargetController import org.usvm.test.util.TestRunner import org.usvm.test.util.checkers.AnalysisResultsNumberMatcher import org.usvm.test.util.checkers.ignoreNumberOfAnalysisResults @@ -29,18 +31,26 @@ import kotlin.reflect.jvm.javaMethod @TestInstance(TestInstance.Lifecycle.PER_CLASS) open class JavaMethodTestRunner : TestRunner, KClass<*>?, JcClassCoverage>() { - private var targets: List = emptyList() + private var targets: List> = emptyList() + private var interpreterObserver: JcInterpreterObserver? = null /** * Sets JcTargets to run JcMachine with in the scope of [action]. */ - protected fun withTargets(targets: List, action: () -> T): T { + protected fun withTargets( + targets: List>, + interpreterObserver: JcInterpreterObserver, + action: () -> T, + ): T { val prevTargets = this.targets + val prevInterpreterObserver = this.interpreterObserver try { this.targets = targets + this.interpreterObserver = interpreterObserver return action() } finally { this.targets = prevTargets + this.interpreterObserver = prevInterpreterObserver } } @@ -748,7 +758,7 @@ open class JavaMethodTestRunner : TestRunner, KClass<*>?, J val jcClass = cp.findClass(declaringClassName).toType() val jcMethod = jcClass.declaredMethods.first { it.name == method.name } - JcMachine(cp, options).use { machine -> + JcMachine(cp, options, interpreterObserver).use { machine -> val states = machine.analyze(jcMethod.method, targets) states.map { testResolver.resolve(jcMethod, it) } } diff --git a/usvm-jvm/src/test/kotlin/org/usvm/samples/taint/TaintTest.kt b/usvm-jvm/src/test/kotlin/org/usvm/samples/taint/TaintTest.kt new file mode 100644 index 0000000000..c28f57a81d --- /dev/null +++ b/usvm-jvm/src/test/kotlin/org/usvm/samples/taint/TaintTest.kt @@ -0,0 +1,399 @@ +package org.usvm.samples.taint + +import io.ksmt.utils.cast +import org.jacodb.api.JcClasspath +import org.jacodb.api.JcField +import org.usvm.PathSelectionStrategy +import org.usvm.UMachineOptions +import org.usvm.api.targets.Argument +import org.usvm.api.targets.AssignMark +import org.usvm.api.targets.BooleanFromArgument +import org.usvm.api.targets.CallParameterContainsMark +import org.usvm.api.targets.ConstantTrue +import org.usvm.api.targets.CopyAllMarks +import org.usvm.api.targets.JcTarget +import org.usvm.api.targets.RemoveAllMarks +import org.usvm.api.targets.Result +import org.usvm.api.targets.SensitiveData +import org.usvm.api.targets.SqlInjection +import org.usvm.api.targets.TaintAnalysis +import org.usvm.api.targets.TaintCleaner +import org.usvm.api.targets.TaintConfiguration +import org.usvm.api.targets.TaintEntryPointSource +import org.usvm.api.targets.TaintFieldSource +import org.usvm.api.targets.TaintMethodSink +import org.usvm.api.targets.TaintMethodSource +import org.usvm.api.targets.TaintPassThrough +import org.usvm.samples.JavaMethodTestRunner +import org.usvm.test.util.checkers.ignoreNumberOfAnalysisResults +import org.usvm.util.Options +import org.usvm.util.UsvmTest +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class TaintTest : JavaMethodTestRunner() { + @UsvmTest([Options([PathSelectionStrategy.TARGETED])]) + fun testSimpleTaint(options: UMachineOptions) { + withOptions(options) { + val sampleAnalysis = constructSampleTaintAnalysis(cp) + + withTargets(sampleAnalysis.targets.toList().cast(), sampleAnalysis) { + checkDiscoveredProperties( + Taint::simpleTaint, + ignoreNumberOfAnalysisResults, + ) + } + + val collectedStates = sampleAnalysis.collectedStates + assertEquals(expected = 1, actual = collectedStates.size) + + val reachedTargets = collectedStates.single().reachedTerminalTargets.singleOrNull() as? JcTarget<*> + + assertNotNull(reachedTargets) + assertTrue { reachedTargets.isTerminal } + assertTrue { reachedTargets.isRemoved } + assertTrue { reachedTargets is TaintAnalysis.TaintMethodSinkTarget } + assertTrue { reachedTargets.parent is TaintAnalysis.TaintMethodSourceTarget } + } + } + + @UsvmTest([Options([PathSelectionStrategy.TARGETED])]) + fun testSimpleFalsePositive(options: UMachineOptions) { + withOptions(options) { + val sampleAnalysis = constructSampleTaintAnalysis(cp) + + withTargets(sampleAnalysis.targets.toList().cast(), sampleAnalysis) { + checkDiscoveredProperties( + Taint::simpleFalsePositive, + ignoreNumberOfAnalysisResults, + ) + } + + val collectedStates = sampleAnalysis.collectedStates + assertEquals(expected = 0, actual = collectedStates.size) + } + } + + @UsvmTest([Options([PathSelectionStrategy.TARGETED])]) + fun testSimpleTruePositive(options: UMachineOptions) { + withOptions(options) { + val sampleAnalysis = constructSampleTaintAnalysis(cp) + + withTargets(sampleAnalysis.targets.toList().cast(), sampleAnalysis) { + checkDiscoveredProperties( + Taint::simpleTruePositive, + ignoreNumberOfAnalysisResults, + ) + } + + val collectedStates = sampleAnalysis.collectedStates + assertEquals(expected = 1, actual = collectedStates.size) + + val reachedTargets = collectedStates.single().reachedTerminalTargets.singleOrNull() as? JcTarget<*> + + assertNotNull(reachedTargets) + assertTrue { reachedTargets.isTerminal } + assertTrue { reachedTargets.isRemoved } + assertTrue { reachedTargets is TaintAnalysis.TaintMethodSinkTarget } + } + } + + @UsvmTest([Options([PathSelectionStrategy.TARGETED])]) + fun testTaintWithReturningValue(options: UMachineOptions) { + withOptions(options) { + val sampleAnalysis = constructSampleTaintAnalysis(cp) + + withTargets(sampleAnalysis.targets.toList().cast(), sampleAnalysis) { + checkDiscoveredProperties( + Taint::taintWithReturningValue, + ignoreNumberOfAnalysisResults + ) + } + + val collectedStates = sampleAnalysis.collectedStates + assertEquals(expected = 1, actual = collectedStates.size) + + val reachedTargets = collectedStates.single().reachedTerminalTargets.singleOrNull() as? JcTarget<*> + + assertNotNull(reachedTargets) + assertTrue { reachedTargets.isTerminal } + assertTrue { reachedTargets.isRemoved } + assertTrue { reachedTargets is TaintAnalysis.TaintMethodSinkTarget } + assertTrue { reachedTargets.parent is TaintAnalysis.TaintMethodSourceTarget } + } + } + + @UsvmTest([Options([PathSelectionStrategy.TARGETED])]) + fun testGoThroughCleaner(options: UMachineOptions) { + withOptions(options) { + val sampleAnalysis = constructSampleTaintAnalysis(cp) + + withTargets(sampleAnalysis.targets.toList().cast(), sampleAnalysis) { + checkDiscoveredProperties( + Taint::goThroughCleaner, + ignoreNumberOfAnalysisResults + ) + } + + val collectedStates = sampleAnalysis.collectedStates + assertEquals(expected = 0, actual = collectedStates.size) + } + } + + private fun sampleConfiguration(cp: JcClasspath): TaintConfiguration { + fun findMethod(className: String, methodName: String) = cp + .findClassOrNull(className)!! + .declaredMethods + .first { it.name == methodName } + + val sampleClassName = "org.usvm.samples.taint.Taint" + + val taintEntryPointSourceMethod = findMethod(sampleClassName, "taintedEntrySource") + val taintEntryPointSourceCondition = ConstantTrue + + val sampleEntryPointsSources = mapOf( + taintEntryPointSourceMethod to listOf( + TaintEntryPointSource( + taintEntryPointSourceMethod, + taintEntryPointSourceCondition, AssignMark(Argument(0u), SqlInjection) + ) + ) + ) + + val sampleSourceMethod = findMethod(sampleClassName, "stringProducer") + val sampleCondition = BooleanFromArgument(Argument(0u)) + val sampleMethodSources = mapOf( + sampleSourceMethod to listOf( + TaintMethodSource( + sampleSourceMethod, + sampleCondition, AssignMark(Result, SqlInjection) + ), + TaintMethodSource( + sampleSourceMethod, + sampleCondition, AssignMark(Result, SensitiveData) + ), + ) + ) + + // TODO + val sampleFieldSources = emptyMap>() + + + val samplePassThoughMethod = findMethod("java.lang.String", "concat") + val samplePassThroughCondition = ConstantTrue + val samplePassThrough = mapOf( + samplePassThoughMethod to listOf( + TaintPassThrough( + samplePassThoughMethod, + samplePassThroughCondition, CopyAllMarks(Argument(0u), Result) + ), + TaintPassThrough( + samplePassThoughMethod, + samplePassThroughCondition, CopyAllMarks(Argument(1u), Result) + ), + ) + ) + + val sampleCleanerMethod = findMethod(sampleClassName, "cleaner") + val cleanerCondition = ConstantTrue + val sampleCleaners = mapOf( + sampleCleanerMethod to listOf( + TaintCleaner( + sampleCleanerMethod, + cleanerCondition, RemoveAllMarks(Result) + ) + ) + ) + + val consumerOfInjections = findMethod(sampleClassName, "consumerOfInjections") + val consumerOfSensitiveData = findMethod(sampleClassName, "consumerOfSensitiveData") + val consumerWithReturningValue = findMethod(sampleClassName, "consumerWithReturningValue") + + val sampleSinks = mapOf( + consumerOfInjections to listOf( + TaintMethodSink( + CallParameterContainsMark(Argument(0u), SqlInjection), + consumerOfInjections + ) + ), + consumerOfSensitiveData to listOf( + TaintMethodSink( + CallParameterContainsMark(Argument(0u), SqlInjection), + consumerOfSensitiveData + ) + ), + consumerWithReturningValue to listOf( + TaintMethodSink(CallParameterContainsMark(Argument(0u), SqlInjection), consumerWithReturningValue), + TaintMethodSink(CallParameterContainsMark(Argument(0u), SensitiveData), consumerWithReturningValue) + ), + ) + + return TaintConfiguration( + sampleEntryPointsSources, + sampleMethodSources, + sampleFieldSources, + samplePassThrough, + sampleCleaners, + sampleSinks, + emptyMap() // TODO field sinks + ) + } + + private fun constructSampleTaintAnalysis(cp: JcClasspath): TaintAnalysis { + fun findMethod(className: String, methodName: String) = cp + .findClassOrNull(className)!! + .declaredMethods + .first { it.name == methodName } + + val sampleClassName = "org.usvm.samples.taint.Taint" + + val configuration = sampleConfiguration(cp) + + val consumerOfInjections = findMethod(sampleClassName, "consumerOfInjections") + val consumerSinkRule = configuration.methodSinks[consumerOfInjections]!!.single() + + val targetForTaintedEntrySink = TaintAnalysis.TaintMethodSinkTarget( + findMethod(sampleClassName, "taintedEntrySource") + .instList + .first { "consumerOfInjections" in it.toString() }, + consumerSinkRule.condition, + consumerSinkRule + ) + + val sampleSourceMethod = findMethod(sampleClassName, "stringProducer") + val stringProducerRule = configuration.methodSources[sampleSourceMethod]!!.first() + + val sourceTargetForSimpleTaint = TaintAnalysis.TaintMethodSourceTarget( + findMethod(sampleClassName, "simpleTaint") + .instList + .last { "stringProducer" in it.toString() }, + stringProducerRule.condition, + stringProducerRule + ) + + val sinkTargetForSimpleTaint = TaintAnalysis.TaintMethodSinkTarget( + findMethod(sampleClassName, "simpleTaint") + .instList + .first { "consumerOfInjections" in it.toString() }, + consumerSinkRule.condition, + consumerSinkRule + ) + sourceTargetForSimpleTaint.addChild(sinkTargetForSimpleTaint) + + + val sourceTargetForFalsePositive = TaintAnalysis.TaintMethodSourceTarget( + findMethod(sampleClassName, "simpleFalsePositive") + .instList + .first { "stringProducer" in it.toString() }, + stringProducerRule.condition, + stringProducerRule + ) + + val intermediateTarget = TaintAnalysis.TaintIntermediateTarget( + findMethod(sampleClassName, "simpleFalsePositive") + .instList + .first { "[0]" in it.toString() }, + ) + + val secondIntermediateTarget = TaintAnalysis.TaintIntermediateTarget( + findMethod(sampleClassName, "simpleFalsePositive") + .instList + .first { "[1]" in it.toString() }, + ) + + val sinkTargetForFalsePositive = TaintAnalysis.TaintMethodSinkTarget( + findMethod(sampleClassName, "simpleFalsePositive") + .instList + .first { "consumerOfInjections" in it.toString() }, + consumerSinkRule.condition, + consumerSinkRule + ) + + secondIntermediateTarget.addChild(sinkTargetForFalsePositive) + intermediateTarget.addChild(secondIntermediateTarget) + sourceTargetForFalsePositive.addChild(intermediateTarget) + + val sourceTargetForTruePositive = TaintAnalysis.TaintMethodSourceTarget( + findMethod(sampleClassName, "simpleTruePositive") + .instList + .first { "stringProducer" in it.toString() }, + stringProducerRule.condition, + stringProducerRule + ) + + val intermediateTargetTruePositive = TaintAnalysis.TaintIntermediateTarget( + findMethod(sampleClassName, "simpleTruePositive") + .instList + .first { "[0]" in it.toString() }, + ) + + val secondIntermediateTargetTruePositive = TaintAnalysis.TaintIntermediateTarget( + findMethod(sampleClassName, "simpleTruePositive") + .instList + .first { "[1]" in it.toString() }, + ) + + val sinkTargetForTruePositive = TaintAnalysis.TaintMethodSinkTarget( + findMethod(sampleClassName, "simpleTruePositive") + .instList + .first { "consumerOfInjections" in it.toString() }, + consumerSinkRule.condition, + consumerSinkRule + ) + + secondIntermediateTargetTruePositive.addChild(sinkTargetForTruePositive) + intermediateTargetTruePositive.addChild(secondIntermediateTargetTruePositive) + sourceTargetForTruePositive.addChild(intermediateTargetTruePositive) + + + val sourceTaintWithReturningValue = TaintAnalysis.TaintMethodSourceTarget( + findMethod(sampleClassName, "taintWithReturningValue") + .instList + .first { "stringProducer" in it.toString() }, + stringProducerRule.condition, + stringProducerRule + ) + + val consumerWithReturningValue = findMethod(sampleClassName, "consumerWithReturningValue") + val consumerWithReturningValueSinkRule = configuration.methodSinks[consumerWithReturningValue]!!.first() + + val sinkTaintWithRetuningValue = TaintAnalysis.TaintMethodSinkTarget( + findMethod(sampleClassName, "taintWithReturningValue") + .instList + .first { "consumerWithReturningValue" in it.toString() }, + consumerWithReturningValueSinkRule.condition, + consumerWithReturningValueSinkRule + ) + + sourceTaintWithReturningValue.addChild(sinkTaintWithRetuningValue) + + + val sourceTaintGoThroughCleaner = TaintAnalysis.TaintMethodSourceTarget( + findMethod(sampleClassName, "goThroughCleaner") + .instList + .first { "stringProducer" in it.toString() }, + stringProducerRule.condition, + stringProducerRule + ) + + val sinkTaintGoThroughCleaner = TaintAnalysis.TaintMethodSinkTarget( + findMethod(sampleClassName, "goThroughCleaner") + .instList + .first { "consumerOfInjections" in it.toString() }, + consumerSinkRule.condition, + consumerSinkRule + ) + + sourceTaintGoThroughCleaner.addChild(sinkTaintGoThroughCleaner) + + return TaintAnalysis(configuration) + .addTarget(targetForTaintedEntrySink) + .addTarget(sourceTargetForSimpleTaint) + .addTarget(sourceTargetForFalsePositive) + .addTarget(sourceTargetForTruePositive) + .addTarget(sourceTaintWithReturningValue) + .addTarget(sourceTaintGoThroughCleaner) + } +} + diff --git a/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleLanguageComponents.kt b/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleLanguageComponents.kt index 9d60c50041..fce2952c10 100644 --- a/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleLanguageComponents.kt +++ b/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleLanguageComponents.kt @@ -6,7 +6,6 @@ import org.usvm.SolverType import org.usvm.UComponents import org.usvm.UContext import org.usvm.language.SampleType -import org.usvm.model.buildTranslatorAndLazyDecoder import org.usvm.solver.USoftConstraintsProvider import org.usvm.solver.USolverBase import org.usvm.solver.UTypeSolver @@ -14,17 +13,16 @@ import org.usvm.types.UTypeSystem class SampleLanguageComponents( private val typeSystem: SampleTypeSystem, - private val solverType: SolverType + private val solverType: SolverType, ) : UComponents { - override fun mkSolver(ctx: Context): USolverBase { - val (translator, decoder) = buildTranslatorAndLazyDecoder(ctx) + override fun mkSolver(ctx: UContext): USolverBase { + val (translator, decoder) = buildTranslatorAndLazyDecoder(ctx) val softConstraintsProvider = USoftConstraintsProvider(ctx) - val solver = - when (solverType) { - SolverType.YICES -> KYicesSolver(ctx) - SolverType.Z3 -> KZ3Solver(ctx) - } + val solver = when (solverType) { + SolverType.YICES -> KYicesSolver(ctx) + SolverType.Z3 -> KZ3Solver(ctx) + } val typeSolver = UTypeSolver(typeSystem) return USolverBase(ctx, solver, typeSolver, translator, decoder, softConstraintsProvider) diff --git a/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleMachine.kt b/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleMachine.kt index 86a69913a9..3428f7548f 100644 --- a/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleMachine.kt +++ b/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleMachine.kt @@ -20,6 +20,7 @@ import org.usvm.statistics.distances.CallGraphStatisticsImpl import org.usvm.statistics.distances.CfgStatisticsImpl import org.usvm.statistics.distances.PlainCallGraphStatistics import org.usvm.stopstrategies.createStopStrategy +import org.usvm.targets.UTargetController /** * Entry point for a sample language analyzer. @@ -32,14 +33,17 @@ class SampleMachine( private val typeSystem = SampleTypeSystem() private val components = SampleLanguageComponents(typeSystem, options.solverType) private val ctx = UContext(components) - private val solver = ctx.solver() + private val solver = ctx.solver() private val interpreter = SampleInterpreter(ctx, applicationGraph) private val resultModelConverter = ResultModelConverter(ctx) private val cfgStatistics = CfgStatisticsImpl(applicationGraph) - fun analyze(method: Method<*>, targets: List = emptyList()): Collection { + fun analyze( + method: Method<*>, + targets: List> = emptyList() + ): Collection { val initialState = getInitialState(method, targets) val coverageStatistics: CoverageStatistics, Stmt, SampleState> = CoverageStatistics(setOf(method), applicationGraph) @@ -92,7 +96,10 @@ class SampleMachine( return statesCollector.collectedStates.map { resultModelConverter.convert(it, method) } } - private fun getInitialState(method: Method<*>, targets: List): SampleState = + private fun getInitialState( + method: Method<*>, + targets: List> + ): SampleState = SampleState(ctx, targets = targets).apply { addEntryMethodCall(applicationGraph, method) val model = solver.emptyModel() diff --git a/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleState.kt b/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleState.kt index 73c712b242..bf9f565e9d 100644 --- a/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleState.kt +++ b/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleState.kt @@ -15,18 +15,19 @@ import org.usvm.language.argumentCount import org.usvm.language.localsCount import org.usvm.memory.UMemory import org.usvm.model.UModelBase +import org.usvm.targets.UTargetController class SampleState( ctx: UContext, callStack: UCallStack, Stmt> = UCallStack(), - pathConstraints: UPathConstraints = UPathConstraints(ctx), + pathConstraints: UPathConstraints = UPathConstraints(ctx), memory: UMemory> = UMemory(ctx, pathConstraints.typeConstraints), models: List> = listOf(), pathLocation: PathsTrieNode = ctx.mkInitialLocation(), var returnRegister: UExpr? = null, var exceptionRegister: ProgramException? = null, - targets: List = emptyList() -) : UState, Stmt, UContext, SampleTarget, SampleState>( + targets: List> = emptyList() +) : UState, Stmt, UContext, SampleTarget, SampleState>( ctx, callStack, pathConstraints, @@ -35,10 +36,10 @@ class SampleState( pathLocation, targets ) { - override fun clone(newConstraints: UPathConstraints?): SampleState { + override fun clone(newConstraints: UPathConstraints?): SampleState { val clonedConstraints = newConstraints ?: pathConstraints.clone() return SampleState( - pathConstraints.ctx, + ctx, callStack.clone(), clonedConstraints, memory.clone(clonedConstraints.typeConstraints), diff --git a/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleTarget.kt b/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleTarget.kt index 9f76e10b48..87bf96c94a 100644 --- a/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleTarget.kt +++ b/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleTarget.kt @@ -1,9 +1,12 @@ package org.usvm.machine -import org.usvm.UTarget import org.usvm.language.Stmt +import org.usvm.targets.UTarget +import org.usvm.targets.UTargetController /** * Base class for SampleMachine targets. */ -abstract class SampleTarget(location: Stmt) : UTarget(location) +abstract class SampleTarget( + location: Stmt, +) : UTarget, TargetController>(location)