diff --git a/usvm-core/src/main/kotlin/org/usvm/statistics/distances/InterprocDistanceCalculator.kt b/usvm-core/src/main/kotlin/org/usvm/statistics/distances/InterprocDistanceCalculator.kt index fdb1ee437f..aa7ea88ae6 100644 --- a/usvm-core/src/main/kotlin/org/usvm/statistics/distances/InterprocDistanceCalculator.kt +++ b/usvm-core/src/main/kotlin/org/usvm/statistics/distances/InterprocDistanceCalculator.kt @@ -104,6 +104,10 @@ internal class InterprocDistanceCalculator( currentStatement: Statement, callStack: UCallStack ): InterprocDistance { + if (callStack.isEmpty()) { + return InterprocDistance(UInt.MAX_VALUE, ReachabilityKind.NONE) + } + val lastMethod = callStack.lastMethod() val lastFrameDistance = calculateFrameDistance(lastMethod, currentStatement) diff --git a/usvm-core/src/main/kotlin/org/usvm/targets/UTarget.kt b/usvm-core/src/main/kotlin/org/usvm/targets/UTarget.kt index 97110ebd81..05e0a2d035 100644 --- a/usvm-core/src/main/kotlin/org/usvm/targets/UTarget.kt +++ b/usvm-core/src/main/kotlin/org/usvm/targets/UTarget.kt @@ -21,28 +21,8 @@ abstract class UTarget( val location: Statement? = null, ) where Target : UTarget { private val childrenImpl = mutableListOf() - private var parent: Target? = null - - private var status: Status = Status.UNPROCESSED - - // TODO move it - enum class Status { - UNPROCESSED, - PROCESSING, - PROCESSED - } - - fun processing() { - require(status == Status.UNPROCESSED) - - status = Status.PROCESSING - } - - fun processed() { - require(status == Status.PROCESSING) - - status = Status.PROCESSED - } + var parent: Target? = null + private set /** * List of the child targets which should be reached after this target. @@ -60,10 +40,6 @@ abstract class UTarget( var isRemoved = false private set - // TODO add docs - fun > isReachedBy(state: State): Boolean? = - location?.let { it == state.currentStatement } - /** * Adds a child target to this target. * TODO: avoid possible recursion @@ -84,9 +60,9 @@ abstract class UTarget( * should try to propagate the target. If the target without children has been * visited, it is logically removed from tree. */ - protected fun > propagate(byState: State) { + fun > propagate(byState: State) { @Suppress("UNCHECKED_CAST") - if (byState.tryPropagateTarget(this as Target) && isTerminal) { + if (byState.tryPropagateTarget(this as T) && isTerminal) { remove() } } diff --git a/usvm-core/src/main/kotlin/org/usvm/targets/UTargetController.kt b/usvm-core/src/main/kotlin/org/usvm/targets/UTargetController.kt index 25aabb209f..e5de4ca164 100644 --- a/usvm-core/src/main/kotlin/org/usvm/targets/UTargetController.kt +++ b/usvm-core/src/main/kotlin/org/usvm/targets/UTargetController.kt @@ -2,5 +2,5 @@ package org.usvm.targets // TODO add self generic for the controller interface UTargetController { - val targets: MutableCollection> + val targets: MutableCollection> } diff --git a/usvm-jvm/src/main/kotlin/org/usvm/api/targets/TaintAnalysis.kt b/usvm-jvm/src/main/kotlin/org/usvm/api/targets/TaintAnalysis.kt index f23a672259..d62196c62a 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/api/targets/TaintAnalysis.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/api/targets/TaintAnalysis.kt @@ -1,14 +1,13 @@ package org.usvm.api.targets import io.ksmt.utils.asExpr -import io.ksmt.utils.cast +import io.ksmt.utils.uncheckedCast import org.jacodb.api.cfg.JcAssignInst import org.jacodb.api.cfg.JcCallExpr import org.jacodb.api.cfg.JcCallInst import org.jacodb.api.cfg.JcInst import org.jacodb.api.cfg.JcInstanceCallExpr -import org.jacodb.api.cfg.JcLocalVar -import org.jacodb.api.ext.findClassOrNull +import org.jacodb.api.ext.cfg.callExpr import org.usvm.UBoolExpr import org.usvm.UConcreteHeapRef import org.usvm.UHeapRef @@ -22,24 +21,36 @@ import org.usvm.machine.interpreter.JcSimpleValueResolver import org.usvm.machine.interpreter.JcStepScope import org.usvm.machine.state.JcMethodResult import org.usvm.machine.state.JcState -import org.usvm.targets.UTarget +import org.usvm.statistics.UMachineObserver import org.usvm.targets.UTargetController // TODO do we need a context here? class TaintAnalysis( - private val ctx: JcContext, private val configuration: TaintConfiguration,// TODO how to pass initial targets???? - override val targets: MutableCollection> = mutableListOf(), -) : UTargetController, JcInterpreterObserver { - private val taintTargets: MutableMap> = targets - .filterIsInstance() - .groupByTo(mutableMapOf()) { requireNotNull(it.location) } + override val targets: MutableCollection = mutableListOf(), +) : UTargetController, JcInterpreterObserver, UMachineObserver { + private val taintTargets: MutableMap> = mutableMapOf() + + init { + targets.forEach { + exposeTargets(it, taintTargets) + } + } + // TODO recursively add children here??????????????????????????????????????? // TODO add recursively children at the beginning???????????????????????????????????????????????????????????????? // TODO save mapping between initial targets and the states that reach them // Replace with the corresponding observer-collector? - private val collectedStates: MutableList = mutableListOf() + val collectedStates: MutableList = mutableListOf() + + private fun exposeTargets(target: TaintTarget, result: MutableMap>) { + result.getOrPut(target.location!!) { hashSetOf() }.add(target) + + target.children.forEach { + exposeTargets(it as TaintTarget, result) + } + } private val marksAddresses: MutableMap = mutableMapOf() @@ -48,19 +59,19 @@ class TaintAnalysis( stepScope.calcOnState { memory.allocate() } } - fun writeMark(ref: UHeapRef, mark: JcTaintMark, guard: UBoolExpr, stepScope: JcStepScope) { + private fun writeMark(ref: UHeapRef, mark: JcTaintMark, guard: UBoolExpr, stepScope: JcStepScope) { stepScope.doWithState { memory.write(createLValue(ref, mark, stepScope), ctx.trueExpr, guard) } } - fun removeMark(ref: UHeapRef, mark: JcTaintMark, guard: UBoolExpr, stepScope: JcStepScope) { + private fun removeMark(ref: UHeapRef, mark: JcTaintMark, guard: UBoolExpr, stepScope: JcStepScope) { stepScope.doWithState { memory.write(createLValue(ref, mark, stepScope), ctx.falseExpr, guard) } } - fun readMark(ref: UHeapRef, mark: JcTaintMark, stepScope: JcStepScope): UBoolExpr = + private fun readMark(ref: UHeapRef, mark: JcTaintMark, stepScope: JcStepScope): UBoolExpr = stepScope.calcOnState { memory.read(createLValue(ref, mark, stepScope)) } @@ -76,83 +87,106 @@ class TaintAnalysis( require(target is TaintTarget) targets += target - taintTargets.getOrPut(target.location!!) { mutableListOf() }.add(target) + exposeTargets(target, taintTargets) return this } - private fun findTaintTargets(stmt: JcInst): MutableList? = taintTargets[stmt] + private fun findTaintTargets(stmt: JcInst, state: JcState): List = + taintTargets[stmt]?.let { targets -> + state.targets.filter { it.uncheckedCast() in targets } + }.orEmpty().toList().uncheckedCast() override fun onAssignStatement(exprResolver: JcSimpleValueResolver, stmt: JcAssignInst, stepScope: JcStepScope) { - val targets = findTaintTargets(stmt) ?: return + // Sinks are already processed at this moment since we resolved it on call statement + + stmt.callExpr?.let { processTaintConfiguration(it, stepScope, exprResolver) } + + // TODO add fields processing + } - // TODO sinks are already processed at this moment since we resolved it on call statement + private fun processTaintConfiguration( + callExpr: JcCallExpr, + stepScope: JcStepScope, + simpleValueResolver: JcSimpleValueResolver, + ) { + val ctx = stepScope.ctx + val methodResult = stepScope.calcOnState { methodResult } + val method = callExpr.method.method + require(methodResult is JcMethodResult.Success) { "TODO message" } - // These are the targets that are matched by the location - targets.forEach { target -> - // TODO check for the target satisfiability using additional information. - // At the moment it's not really clear what this information might look like + val callPositionResolver = createCallPositionResolver(ctx, callExpr, simpleValueResolver, methodResult) -// val valueFromConfiguration = TODO() -// val additionalCondition = -// stepScope.doWithState { tryPropagateTarget(target.uncheckedCast()) } // TODO } + val conditionResolver = ConditionResolver(ctx, callPositionResolver, ::readMark) + val actionResolver = TaintActionResolver( + ctx, + callPositionResolver, + ::readMark, + ::writeMark, + ::removeMark, + marksAddresses.keys + ) - if (stmt.rhv is JcCallExpr) { - val methodResult = stepScope.calcOnState { methodResult } - val callExpr = stmt.rhv as JcCallExpr - val method = callExpr.method.method + val sourceConfigurations = configuration.methodSources[method] + val currentStatement = stepScope.calcOnState { currentStatement } - require(methodResult is JcMethodResult.Success) { "TODO message" } + val sourceTargets = findTaintTargets(currentStatement, stepScope.state) + .filterIsInstance() + .associateBy { it.configurationRule } + sourceConfigurations?.forEach { + val target = sourceTargets[it] - val callPositionResolver = CallPositionResolver( - resolveCallInstance(callExpr)?.accept(exprResolver)?.asExpr(ctx.addressSort), - callExpr.args.map { it.accept(exprResolver) }, - methodResult.value - ) + val (condition, action) = it.conditionWithAction + val resolvedCondition = conditionResolver.visit(condition, simpleValueResolver, stepScope) ?: ctx.trueExpr - val conditionResolver = ConditionResolver(ctx, callPositionResolver, ::readMark) - val actionResolver = TaintActionResolver( - ctx, - callPositionResolver, - ::readMark, - ::writeMark, - ::removeMark, - marksAddresses.keys - ) + val targetCondition = target?.condition ?: ConstantTrue + val resolvedTargetCondition = + conditionResolver.visit(targetCondition, simpleValueResolver, stepScope) ?: ctx.trueExpr - val sourceConfigurations = configuration.methodSources[method] - sourceConfigurations?.forEach { - val (condition, action) = it.conditionWithAction - val resolvedCondition = conditionResolver.visit(condition, exprResolver, stepScope) + val combinedCondition = ctx.mkAnd(resolvedTargetCondition, resolvedCondition) - action.accept(actionResolver, stepScope, resolvedCondition) - } + action.accept(actionResolver, stepScope, combinedCondition) - val cleanerConfigurations = configuration.cleaners[method] - cleanerConfigurations?.forEach { - TODO() - } + target?.propagate(stepScope.state) + } - val passThroughConfigurations = configuration.passThrough[method] - passThroughConfigurations?.forEach { - TODO() - } + val cleanerConfigurations = configuration.cleaners[method] + cleanerConfigurations?.forEach { + val (condition, action) = it.conditionWithAction + val resolvedCondition = conditionResolver.visit(condition, simpleValueResolver, stepScope) - return@forEach - } + action.accept(actionResolver, stepScope, resolvedCondition) } - // TODO add fields processing + val passThroughConfigurations = configuration.passThrough[method] + passThroughConfigurations?.forEach { + val (condition, action) = it.conditionWithAction + val resolvedCondition = conditionResolver.visit(condition, simpleValueResolver, stepScope) + + action.accept(actionResolver, stepScope, resolvedCondition) + } } + private val JcStepScope.ctx get() = calcOnState { ctx } + + private fun createCallPositionResolver( + ctx: JcContext, + callExpr: JcCallExpr, + simpleValueResolver: JcSimpleValueResolver, + methodResult: JcMethodResult.Success?, + ) = CallPositionResolver( + resolveCallInstance(callExpr)?.accept(simpleValueResolver)?.asExpr(ctx.addressSort), + callExpr.args.map { it.accept(simpleValueResolver) }, + methodResult?.value + ) + override fun onEntryPoint( simpleValueResolver: JcSimpleValueResolver, stmt: JcMethodEntrypointInst, stepScope: JcStepScope, ) { - println() // TODO entry point configuration } @@ -164,11 +198,63 @@ class TaintAnalysis( stmt: JcCallExpr, stepScope: JcStepScope, ) { + // TODO add comment about absence of configuration val method = stmt.method.method - val sinks = configuration.sinks[method] ?: return - sinks.filter { it.method == method }.forEach { - println("TODO onMethodCallWithUnresolvedArguments") + val methodResult = stepScope.calcOnState { methodResult } + require(methodResult is JcMethodResult.NoCall) { "TODO" } + + val ctx = stepScope.ctx + + val positionResolver = createCallPositionResolver(ctx, stmt, simpleValueResolver, methodResult = null) + val conditionResolver = ConditionResolver(ctx, positionResolver, ::readMark) + + if (inTargetedMode) { + val currentStatement = stepScope.calcOnState { currentStatement } + val targets = findTaintTargets(currentStatement, stepScope.state) + + val sinks = targets.filterIsInstance() + sinks.forEach { + processSink(it.configRule, it.condition, conditionResolver, simpleValueResolver, stepScope, it) + } + } else { + val methodSinks = configuration.methodSinks[method] ?: return + methodSinks.forEach { + processSink(it, ConstantTrue, conditionResolver, simpleValueResolver, stepScope) + } + } + } + + private val JcStepScope.state get() = calcOnState { this } + + private val inTargetedMode: Boolean + get() = targets.isNotEmpty() + + private fun processSink( + methodSink: TaintMethodSink, + sinkCondition: Condition, + conditionResolver: ConditionResolver, + simpleValueResolver: JcSimpleValueResolver, + stepScope: JcStepScope, + target: TaintMethodSinkTarget? = null, + ) { + val resolvedConfigCondition = + methodSink.condition.visit(conditionResolver, simpleValueResolver, stepScope) ?: return + + val resolvedSinkCondition = sinkCondition.visit(conditionResolver, simpleValueResolver, stepScope) ?: return + + val resolvedCondition = stepScope.ctx.mkAnd(resolvedConfigCondition, resolvedSinkCondition) + + val (originalStateCopy, taintedStepScope) = stepScope.calcOnState { + val originalStateCopy = clone() + originalStateCopy to JcStepScope(originalStateCopy) + } + + taintedStepScope.assert(resolvedCondition)?.let { + // TODO remove corresponding target + // TODO probably, we should terminate this state? Yes, we should + collectedStates += originalStateCopy + target?.propagate(taintedStepScope.state) } } @@ -177,38 +263,11 @@ class TaintAnalysis( stmt: JcMethodCallBaseInst, stepScope: JcStepScope, ) { - // TODO() + // TODO message, it is a redundant signal } override fun onCallStatement(simpleValueResolver: JcSimpleValueResolver, stmt: JcCallInst, stepScope: JcStepScope) { - val methodResult = stepScope.calcOnState { methodResult } - - if (methodResult is JcMethodResult.NoCall) { - val method = stmt.callExpr.method.method - val sinks = configuration.sinks[method] ?: return - - sinks - .filter { it.method == method } - .forEach { - val address = stepScope.calcOnState { - memory.read(simpleValueResolver.resolveLocal(stmt.callExpr.args.first() as JcLocalVar)) - } - val markReading = readMark(address.cast(), SqlInjection, stepScope) - - val (originalStateCopy, taintedStepScope) = stepScope.calcOnState { - val originalStateCopy = clone() - originalStateCopy to JcStepScope(originalStateCopy) - } - - taintedStepScope.assert(markReading)?.let { - // TODO remove corresponding target - // TODO probably, we should terminate this state? Yes, we should - collectedStates += originalStateCopy - } - } - } else { - // TODO process taint source???????????????????????????????????????????????? - } + processTaintConfiguration(stmt.callExpr, stepScope, simpleValueResolver) } override fun onStateProcessed( @@ -216,66 +275,47 @@ class TaintAnalysis( stmt: JcInst, stepScope: JcStepScope, ) { - println("TODO()") + // TODO for now, we might process each target several times if we have a transition with the same instruction + // and different path constraints + + } + + override fun onState(parent: JcState, forks: Sequence) { + propagateIntermediateTarget(parent) + + forks.forEach { propagateIntermediateTarget(it) } } + private fun propagateIntermediateTarget(state: JcState) { + // TODO add comment why is it safe + val targets = findTaintTargets(state.pathLocation.parent!!.statement, state) + targets.forEach { + when (it) { + is TaintIntermediateTarget -> it.propagate(state) + is TaintMethodSourceTarget, is TaintMethodSinkTarget -> return@forEach + } + } + } sealed class TaintTarget( location: JcInst, ) : JcTarget(location) - class TaintSourceTarget(location: JcInst) : TaintTarget(location) - class TaintPassThroughTarget(location: JcInst) : TaintTarget(location) - class TaintCleanerTarget(location: JcInst) : TaintTarget(location) - - // TODO is it important? Or we track every possible mark? - class TaintSinkTarget( + class TaintMethodSourceTarget( location: JcInst, -// val accessPath????? TODO - val markType: JcTaintMark, + val condition: Condition, + val configurationRule: TaintMethodSource, ) : TaintTarget(location) + // TODO add field source targets + class TaintIntermediateTarget(location: JcInst) : TaintTarget(location) - private val environmentMethod = ctx.cp - .findClassOrNull()!! - .declaredMethods - .single { it.name == "getProperty" && it.parameters.size == 1 } - - @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") - private val concatMethod = ctx.cp - .findClassOrNull()!! - .declaredMethods - .single { it.name == "concat" } - - @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") - private val cleanerMethod = ctx.cp - .findClassOrNull()!! - .declaredMethods - .single { it.name == "isEmpty" } - - // TODO remove these, they are needed for demonstration purposes only - private val taintMethodSourceWithMethod = TaintMethodSource( - ctx, - conditionWithAction = ConstantTrue to AssignMark(Result, SqlInjection), - method = ctx.cp.findClassOrNull("org.usvm.samples.taint.Taint")!!.declaredMethods.single { "Producer" in it.name } - ) - - private val taintPassThroughMethod = TaintPassThrough( - ctx, - // TODO it is not complete - conditionWithAction = CallParameterContainsMark(Argument(number = 0u), SqlInjection) to CopyAllMarks( - Argument(number = 0u), - ThisArgument - ), - methodInfo = concatMethod - ) - -// private val taintSinkMethod = TaintSink( -// ctx, -// method = ctx.cp.findClassOrNull("org.usvm.samples.taint.Taint")!!.declaredMethods.single { "Consumer" in it.name }, -// condition = ContainsMark(Argument(0u), SqlInjection) -// ) - + // TODO is it important? Or we track every possible mark? + class TaintMethodSinkTarget( + location: JcInst, + val condition: Condition, + val configRule: TaintMethodSink, + ) : TaintTarget(location) private fun resolveCallInstance( callExpr: JcCallExpr, @@ -288,40 +328,3 @@ sealed interface JcTaintMark object SqlInjection : JcTaintMark object SensitiveData : JcTaintMark - - -// TODO REMOVE IT, FOR DEMONSTRATION PURPOSES ONLY -fun constructSampleTaintAnalysis(ctx: JcContext): TaintAnalysis { - fun findMethod(className: String, methodName: String) = ctx - .cp - .findClassOrNull(className)!! - .declaredMethods - .first { it.name == methodName } - - val sampleClassName = "org.usvm.samples.taint.Taint" - - val targetForTaintedEntrySource = TaintAnalysis.TaintSinkTarget( - findMethod(sampleClassName, "taintedEntrySource") - .instList - .first { "consumerOfInjections" in it.toString() }, - SqlInjection - ) - - val sourceTargetForSimpleTaint = TaintAnalysis.TaintSourceTarget( - findMethod(sampleClassName, "simpleTaint") - .instList - .last { "stringProducer" in it.toString() } - ) - - val sinkTargetForSimpleTaint = TaintAnalysis.TaintSinkTarget( - findMethod(sampleClassName, "simpleTaint") - .instList - .first { "consumerOfInjections" in it.toString() }, - SqlInjection - ) - sourceTargetForSimpleTaint.addChild(sinkTargetForSimpleTaint) - - return TaintAnalysis(ctx, sampleConfiguration(ctx)) - .addTarget(targetForTaintedEntrySource) - .addTarget(sourceTargetForSimpleTaint) -} \ No newline at end of file diff --git a/usvm-jvm/src/main/kotlin/org/usvm/api/targets/TaintConfigurationItem.kt b/usvm-jvm/src/main/kotlin/org/usvm/api/targets/TaintConfigurationItem.kt index 6acd0d84b8..5a32950c56 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/api/targets/TaintConfigurationItem.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/api/targets/TaintConfigurationItem.kt @@ -2,8 +2,6 @@ package org.usvm.api.targets import org.jacodb.api.JcField import org.jacodb.api.JcMethod -import org.usvm.UContext -import org.usvm.machine.JcContext data class TaintConfiguration( val entryPoints: Map>, @@ -11,152 +9,46 @@ data class TaintConfiguration( val fieldSources: Map>, val passThrough: Map>, val cleaners: Map>, - val sinks: Map>, + val methodSinks: Map>, + val fieldSinks: Map>, ) -sealed class TaintConfigurationItem( - val ctx: UContext, -) +sealed interface TaintConfigurationItem class TaintEntryPointSource( - ctx: UContext, val method: JcMethod, val conditionWithAction: Pair, -) : TaintConfigurationItem(ctx) +) : TaintConfigurationItem // TODO probably conditions should be stored in some other way // source is it either a method call or a field reading class TaintMethodSource( - ctx: UContext, - val method: JcMethod? = null, + val method: JcMethod, val conditionWithAction: Pair, // TODO replace with a specific class -) : TaintConfigurationItem(ctx) +) : TaintConfigurationItem class TaintFieldSource( - ctx: UContext, - val field: JcField? = null, + val field: JcField, val conditionWithAction: Pair, -) : TaintConfigurationItem(ctx) +) : TaintConfigurationItem + +class TaintMethodSink( + val condition: Condition, + val method: JcMethod, +) : TaintConfigurationItem -class TaintSink( - ctx: UContext, +class TaintFieldSink( val condition: Condition, - val method: JcMethod? = null, - val field: JcField? = null, -) : TaintConfigurationItem(ctx) + val field: JcField, +) : TaintConfigurationItem class TaintPassThrough( - ctx: UContext, val methodInfo: JcMethod, val conditionWithAction: Pair, -) : TaintConfigurationItem(ctx) +) : TaintConfigurationItem class TaintCleaner( - ctx: UContext, val methodInfo: JcMethod, val conditionWithAction: Pair, -) : TaintConfigurationItem(ctx) - - -// TODO for demonstration purposes only, must be either moved to another place or removed completely -fun sampleConfiguration(ctx: JcContext): TaintConfiguration { - fun findMethod(className: String, methodName: String) = ctx - .cp - .findClassOrNull(className)!! - .declaredMethods - .first { it.name == methodName } - - val sampleClassName = "org.usvm.samples.taint.Taint" - - val taintEntryPointSourceMethod = findMethod(sampleClassName, "taintedEntrySource") - val taintEntryPointSourceCondition = ConstantTrue // TODO could be replaced with a check for emptiness - - val sampleEntryPointsSources = mapOf( - taintEntryPointSourceMethod to listOf( - TaintEntryPointSource( - ctx, - taintEntryPointSourceMethod, - taintEntryPointSourceCondition to AssignMark(Argument(0u), SqlInjection) - ) - ) - ) - - val sampleSourceMethod = findMethod(sampleClassName, "stringProducer") - val sampleCondition = BooleanFromArgument(Argument(0u)) - val sampleMethodSources = mapOf( - sampleSourceMethod to listOf( - TaintMethodSource( - ctx, - sampleSourceMethod, - sampleCondition to AssignMark(Result, SqlInjection) - ), - TaintMethodSource( - ctx, - sampleSourceMethod, - sampleCondition to AssignMark(Result, SensitiveData) // TODO replace with a bulk operation - ), - ) - ) - - // TODO - val sampleFieldSources = emptyMap>() - - - val samplePassThoughMethod = findMethod("java.lang.String", "concat") - val samplePassThroughCondition = ConstantTrue - val samplePassThrough = mapOf( - samplePassThoughMethod to listOf( - TaintPassThrough( - ctx, - samplePassThoughMethod, - samplePassThroughCondition to CopyAllMarks(Argument(0u), Result) - ), // TODO replace with a bulk operation - TaintPassThrough( - ctx, - samplePassThoughMethod, - samplePassThroughCondition to CopyAllMarks(Argument(1u), Result) - ), // TODO replace with a bulk operation - ) - ) - - val sampleCleanerMethod = findMethod(sampleClassName, "cleaner") - val cleanerCondition = ConstantTrue - val sampleCleaners = mapOf( - sampleCleanerMethod to listOf( - TaintCleaner( - ctx, - sampleCleanerMethod, - cleanerCondition to RemoveAllMarks(Result) - ) - ) - ) - - val consumerOfInjections = findMethod(sampleClassName, "consumerOfInjections") - val consumerOfSensitiveData = findMethod(sampleClassName, "consumerOfSensitiveData") - val consumerWithReturningValue = findMethod(sampleClassName, "consumerWithReturningValue") - - val sampleSinks = mapOf( - consumerOfInjections to listOf(TaintSink(ctx, CallParameterContainsMark(Argument(0u), SqlInjection), consumerOfInjections)), - consumerOfSensitiveData to listOf( - TaintSink( - ctx, - CallParameterContainsMark(Argument(0u), SqlInjection), - consumerOfSensitiveData - ) - ), - consumerWithReturningValue to listOf( - TaintSink(ctx, CallParameterContainsMark(Argument(0u), SqlInjection), consumerWithReturningValue), - TaintSink(ctx, CallParameterContainsMark(Argument(0u), SensitiveData), consumerWithReturningValue) - ), - ) - - return TaintConfiguration( - sampleEntryPointsSources, - sampleMethodSources, - sampleFieldSources, - samplePassThrough, - sampleCleaners, - sampleSinks - ) -} +) : TaintConfigurationItem diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/JcMachine.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcMachine.kt index e2e23d6082..6db36fadee 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/JcMachine.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcMachine.kt @@ -10,7 +10,6 @@ import org.usvm.StateCollectionStrategy import org.usvm.UMachine import org.usvm.UMachineOptions import org.usvm.api.targets.JcTarget -import org.usvm.api.targets.constructSampleTaintAnalysis import org.usvm.machine.interpreter.JcInterpreter import org.usvm.machine.state.JcMethodResult import org.usvm.machine.state.JcState @@ -32,7 +31,8 @@ val logger = object : KLogging() {}.logger class JcMachine( cp: JcClasspath, - private val options: UMachineOptions + private val options: UMachineOptions, + private val interpreterObserver: JcInterpreterObserver? = null ) : UMachine() { private val applicationGraph = JcApplicationGraph(cp) @@ -41,7 +41,7 @@ class JcMachine( private val ctx = JcContext(cp, components) // TODO change the way of observers creation - private val interpreter = JcInterpreter(ctx, applicationGraph, constructSampleTaintAnalysis(ctx)) + private val interpreter = JcInterpreter(ctx, applicationGraph, interpreterObserver) private val cfgStatistics = CfgStatisticsImpl(applicationGraph) @@ -102,6 +102,11 @@ class JcMachine( val observers = mutableListOf>(coverageStatistics) observers.add(TerminatedStateRemover()) + if (interpreterObserver is UMachineObserver<*>) { + @Suppress("UNCHECKED_CAST") + observers.add(interpreterObserver as UMachineObserver) + } + if (options.coverageZone == CoverageZone.TRANSITIVE) { observers.add( TransitiveCoverageZoneObserver( diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcInterpreter.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcInterpreter.kt index 936b0b198b..a753a060f2 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcInterpreter.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcInterpreter.kt @@ -11,7 +11,6 @@ import org.jacodb.api.JcRefType import org.jacodb.api.JcType import org.jacodb.api.cfg.JcArgument import org.jacodb.api.cfg.JcAssignInst -import org.jacodb.api.cfg.JcCallExpr import org.jacodb.api.cfg.JcCallInst import org.jacodb.api.cfg.JcCatchInst import org.jacodb.api.cfg.JcEnterMonitorInst @@ -29,6 +28,7 @@ import org.jacodb.api.cfg.JcSwitchInst import org.jacodb.api.cfg.JcThis import org.jacodb.api.cfg.JcThrowInst import org.jacodb.api.ext.boolean +import org.jacodb.api.ext.cfg.callExpr import org.jacodb.api.ext.void import org.usvm.INITIAL_INPUT_ADDRESS import org.usvm.StepResult @@ -75,7 +75,7 @@ typealias JcStepScope = StepScope class JcInterpreter( private val ctx: JcContext, private val applicationGraph: JcApplicationGraph, - private val observer: JcInterpreterObserver, + private val observer: JcInterpreterObserver? = null, ) : UInterpreter() { companion object { @@ -145,7 +145,7 @@ class JcInterpreter( else -> error("Unknown stmt: $ stmt") } - observer.onStateProcessed(lazySimpleValueResolverWithScope(scope), stmt, scope) + observer?.onStateProcessed(lazySimpleValueResolverWithScope(scope), stmt, scope) return scope.stepResult() } @@ -194,7 +194,7 @@ class JcInterpreter( val catchSectionMiss = typeConditionToMiss to functionBlockOnMiss - // TODO observer.onCatchStatement + // TODO observer?.onCatchStatement scope.forkMulti(catchForks + catchSectionMiss) } @@ -206,14 +206,14 @@ class JcInterpreter( when (stmt) { is JcMethodEntrypointInst -> { - observer.onEntryPoint(resolver, stmt, scope) + observer?.onEntryPoint(resolver, stmt, scope) scope.doWithState { addEntryMethodCall(applicationGraph, stmt) } } is JcConcreteMethodCallInst -> { - observer.onMethodCallWithResolvedArguments(resolver, stmt, scope) + observer?.onMethodCallWithResolvedArguments(resolver, stmt, scope) if (approximateMethod(scope, stmt)) { return } @@ -229,7 +229,7 @@ class JcInterpreter( } is JcVirtualMethodCallInst -> { - observer.onMethodCallWithResolvedArguments(resolver, stmt, scope) + observer?.onMethodCallWithResolvedArguments(resolver, stmt, scope) if (approximateMethod(scope, stmt)) { return @@ -243,17 +243,16 @@ class JcInterpreter( private fun visitAssignInst(scope: JcStepScope, stmt: JcAssignInst) { val exprResolver = exprResolverWithScope(scope) - val methodResult = scope.calcOnState { methodResult } - // TODo comment - if (stmt.rhv is JcCallExpr && methodResult is JcMethodResult.NoCall) { - observer.onMethodCallWithUnresolvedArguments( - exprResolver.simpleValueResolver, - stmt.rhv as JcCallExpr, - scope - ) - } else { - observer.onAssignStatement(exprResolver.simpleValueResolver, stmt, scope) + stmt.callExpr?.let { + val methodResult = scope.calcOnState { methodResult } + + // TODo comment + when (methodResult) { + is JcMethodResult.NoCall -> observer?.onMethodCallWithUnresolvedArguments(exprResolver.simpleValueResolver, it, scope) + is JcMethodResult.Success -> observer?.onAssignStatement(exprResolver.simpleValueResolver, stmt, scope) + is JcMethodResult.JcException -> error("TODO") + } } val lvalue = exprResolver.resolveLValue(stmt.lhv) ?: return @@ -269,7 +268,7 @@ class JcInterpreter( private fun visitIfStmt(scope: JcStepScope, stmt: JcIfInst) { val exprResolver = exprResolverWithScope(scope) - observer.onIfStatement(exprResolver.simpleValueResolver, stmt, scope) + observer?.onIfStatement(exprResolver.simpleValueResolver, stmt, scope) val boolExpr = exprResolver .resolveJcExpr(stmt.condition) @@ -289,7 +288,7 @@ class JcInterpreter( private fun visitReturnStmt(scope: JcStepScope, stmt: JcReturnInst) { val exprResolver = exprResolverWithScope(scope) - observer.onReturnStatement(exprResolver.simpleValueResolver, stmt, scope) + observer?.onReturnStatement(exprResolver.simpleValueResolver, stmt, scope) val method = requireNotNull(scope.calcOnState { callStack.lastMethod() }) val returnType = with(applicationGraph) { method.typed }.returnType @@ -304,7 +303,7 @@ class JcInterpreter( } private fun visitGotoStmt(scope: JcStepScope, stmt: JcGotoInst) { - observer.onGotoStatement(lazySimpleValueResolverWithScope(scope), stmt, scope) + observer?.onGotoStatement(lazySimpleValueResolverWithScope(scope), stmt, scope) val nextStmt = stmt.location.method.instList[stmt.target.index] scope.doWithState { newStmt(nextStmt) } @@ -318,7 +317,7 @@ class JcInterpreter( private fun visitSwitchStmt(scope: JcStepScope, stmt: JcSwitchInst) { val exprResolver = exprResolverWithScope(scope) - observer.onSwitchStatement(exprResolver.simpleValueResolver, stmt, scope) + observer?.onSwitchStatement(exprResolver.simpleValueResolver, stmt, scope) val switchKey = stmt.key // Note that the switch key can be an rvalue, for example, a simple int constant. @@ -345,7 +344,7 @@ class JcInterpreter( private fun visitThrowStmt(scope: JcStepScope, stmt: JcThrowInst) { val exprResolver = exprResolverWithScope(scope) - observer.onThrowStatement(exprResolver.simpleValueResolver, stmt, scope) + observer?.onThrowStatement(exprResolver.simpleValueResolver, stmt, scope) val address = exprResolver.resolveJcExpr(stmt.throwable)?.asExpr(ctx.addressSort) ?: return @@ -356,10 +355,17 @@ class JcInterpreter( private fun visitCallStmt(scope: JcStepScope, stmt: JcCallInst) { val exprResolver = exprResolverWithScope(scope) + val callExpr = stmt.callExpr + val methodResult = scope.calcOnState { methodResult } - observer.onCallStatement(exprResolver.simpleValueResolver, stmt, scope) + // TODO comment + when (methodResult) { + is JcMethodResult.NoCall -> observer?.onMethodCallWithUnresolvedArguments(exprResolver.simpleValueResolver, callExpr, scope) + is JcMethodResult.Success -> observer?.onCallStatement(exprResolver.simpleValueResolver, stmt, scope) + is JcMethodResult.JcException -> error("TODO") + } - exprResolver.resolveJcExpr(stmt.callExpr) ?: return + exprResolver.resolveJcExpr(callExpr) ?: return scope.doWithState { val nextStmt = stmt.nextStmt @@ -371,7 +377,7 @@ class JcInterpreter( val exprResolver = exprResolverWithScope(scope) exprResolver.resolveJcNotNullRefExpr(stmt.monitor, stmt.monitor.type) ?: return - observer.onEnterMonitorStatement(exprResolver.simpleValueResolver, stmt, scope) + observer?.onEnterMonitorStatement(exprResolver.simpleValueResolver, stmt, scope) // Monitor enter makes sense only in multithreaded environment @@ -384,7 +390,7 @@ class JcInterpreter( val exprResolver = exprResolverWithScope(scope) exprResolver.resolveJcNotNullRefExpr(stmt.monitor, stmt.monitor.type) ?: return - observer.onExitMonitorStatement(exprResolver.simpleValueResolver, stmt, scope) + observer?.onExitMonitorStatement(exprResolver.simpleValueResolver, stmt, scope) // Monitor exit makes sense only in multithreaded environment diff --git a/usvm-jvm/src/test/kotlin/org/usvm/samples/JavaMethodTestRunner.kt b/usvm-jvm/src/test/kotlin/org/usvm/samples/JavaMethodTestRunner.kt index bb089e7663..61b3b860e9 100644 --- a/usvm-jvm/src/test/kotlin/org/usvm/samples/JavaMethodTestRunner.kt +++ b/usvm-jvm/src/test/kotlin/org/usvm/samples/JavaMethodTestRunner.kt @@ -11,6 +11,7 @@ import org.usvm.api.JcParametersState import org.usvm.api.JcTest import org.usvm.api.targets.JcTarget import org.usvm.api.util.JcTestResolver +import org.usvm.machine.JcInterpreterObserver import org.usvm.machine.JcMachine import org.usvm.targets.UTargetController import org.usvm.test.util.TestRunner @@ -31,17 +32,25 @@ import kotlin.reflect.jvm.javaMethod open class JavaMethodTestRunner : TestRunner, KClass<*>?, JcClassCoverage>() { private var targets: List> = emptyList() + private var interpreterObserver: JcInterpreterObserver? = null /** * Sets JcTargets to run JcMachine with in the scope of [action]. */ - protected fun withTargets(targets: List>, action: () -> T): T { + protected fun withTargets( + targets: List>, + interpreterObserver: JcInterpreterObserver, + action: () -> T, + ): T { val prevTargets = this.targets + val prevInterpreterObserver = this.interpreterObserver try { this.targets = targets + this.interpreterObserver = interpreterObserver return action() } finally { this.targets = prevTargets + this.interpreterObserver = prevInterpreterObserver } } @@ -749,7 +758,7 @@ open class JavaMethodTestRunner : TestRunner, KClass<*>?, J val jcClass = cp.findClass(declaringClassName).toType() val jcMethod = jcClass.declaredMethods.first { it.name == method.name } - JcMachine(cp, options).use { machine -> + JcMachine(cp, options, interpreterObserver).use { machine -> val states = machine.analyze(jcMethod.method, targets) states.map { testResolver.resolve(jcMethod, it) } } diff --git a/usvm-jvm/src/test/kotlin/org/usvm/samples/taint/TaintTest.kt b/usvm-jvm/src/test/kotlin/org/usvm/samples/taint/TaintTest.kt index 2d10298615..5a20e82bd2 100644 --- a/usvm-jvm/src/test/kotlin/org/usvm/samples/taint/TaintTest.kt +++ b/usvm-jvm/src/test/kotlin/org/usvm/samples/taint/TaintTest.kt @@ -1,11 +1,37 @@ package org.usvm.samples.taint +import io.ksmt.utils.cast +import org.jacodb.api.JcClasspath +import org.jacodb.api.JcField import org.usvm.PathSelectionStrategy import org.usvm.UMachineOptions +import org.usvm.api.targets.Argument +import org.usvm.api.targets.AssignMark +import org.usvm.api.targets.BooleanFromArgument +import org.usvm.api.targets.CallParameterContainsMark +import org.usvm.api.targets.ConstantTrue +import org.usvm.api.targets.CopyAllMarks +import org.usvm.api.targets.JcTarget +import org.usvm.api.targets.RemoveAllMarks +import org.usvm.api.targets.Result +import org.usvm.api.targets.SensitiveData +import org.usvm.api.targets.SqlInjection +import org.usvm.api.targets.TaintAnalysis +import org.usvm.api.targets.TaintCleaner +import org.usvm.api.targets.TaintConfiguration +import org.usvm.api.targets.TaintEntryPointSource +import org.usvm.api.targets.TaintFieldSource +import org.usvm.api.targets.TaintMethodSink +import org.usvm.api.targets.TaintMethodSource +import org.usvm.api.targets.TaintPassThrough import org.usvm.samples.JavaMethodTestRunner import org.usvm.test.util.checkers.eq +import org.usvm.test.util.checkers.ignoreNumberOfAnalysisResults import org.usvm.util.Options import org.usvm.util.UsvmTest +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue class TaintTest : JavaMethodTestRunner() { @UsvmTest([Options([PathSelectionStrategy.TARGETED])]) @@ -21,10 +47,25 @@ class TaintTest : JavaMethodTestRunner() { @UsvmTest([Options([PathSelectionStrategy.TARGETED])]) fun testSimpleTaint(options: UMachineOptions) { withOptions(options) { - checkDiscoveredProperties( - Taint::simpleTaint, - eq(2) - ) + val sampleAnalysis = constructSampleTaintAnalysis(cp) + + withTargets(sampleAnalysis.targets.toList().cast(), sampleAnalysis) { + checkDiscoveredProperties( + Taint::simpleTaint, + ignoreNumberOfAnalysisResults, + ) + } + + val collectedStates = sampleAnalysis.collectedStates + assertEquals(expected = 1, actual = collectedStates.size) + + val reachedTargets = collectedStates.single().reachedTerminalTargets.singleOrNull() as? JcTarget<*> + + assertNotNull(reachedTargets) + assertTrue { reachedTargets.isTerminal } + assertTrue { reachedTargets.isRemoved } + assertTrue { reachedTargets is TaintAnalysis.TaintMethodSinkTarget } + assertTrue { reachedTargets.parent is TaintAnalysis.TaintMethodSourceTarget } } } @@ -47,4 +88,156 @@ class TaintTest : JavaMethodTestRunner() { ) } } -} \ No newline at end of file + + + // TODO separate cleaning actions +// TODO for demonstration purposes only, must be either moved to another place or removed completely + fun sampleConfiguration(cp: JcClasspath): TaintConfiguration { + fun findMethod(className: String, methodName: String) = cp + .findClassOrNull(className)!! + .declaredMethods + .first { it.name == methodName } + + val sampleClassName = "org.usvm.samples.taint.Taint" + + val taintEntryPointSourceMethod = findMethod(sampleClassName, "taintedEntrySource") + val taintEntryPointSourceCondition = ConstantTrue // TODO could be replaced with a check for emptiness + + val sampleEntryPointsSources = mapOf( + taintEntryPointSourceMethod to listOf( + TaintEntryPointSource( + taintEntryPointSourceMethod, + taintEntryPointSourceCondition to AssignMark(Argument(0u), SqlInjection) + ) + ) + ) + + val sampleSourceMethod = findMethod(sampleClassName, "stringProducer") + val sampleCondition = BooleanFromArgument(Argument(0u)) + val sampleMethodSources = mapOf( + sampleSourceMethod to listOf( + TaintMethodSource( + sampleSourceMethod, + sampleCondition to AssignMark(Result, SqlInjection) + ), + TaintMethodSource( + sampleSourceMethod, + sampleCondition to AssignMark(Result, SensitiveData) // TODO replace with a bulk operation + ), + ) + ) + + // TODO + val sampleFieldSources = emptyMap>() + + + val samplePassThoughMethod = findMethod("java.lang.String", "concat") + val samplePassThroughCondition = ConstantTrue + val samplePassThrough = mapOf( + samplePassThoughMethod to listOf( + TaintPassThrough( + samplePassThoughMethod, + samplePassThroughCondition to CopyAllMarks(Argument(0u), Result) + ), // TODO replace with a bulk operation + TaintPassThrough( + samplePassThoughMethod, + samplePassThroughCondition to CopyAllMarks(Argument(1u), Result) + ), // TODO replace with a bulk operation + ) + ) + + val sampleCleanerMethod = findMethod(sampleClassName, "cleaner") + val cleanerCondition = ConstantTrue + val sampleCleaners = mapOf( + sampleCleanerMethod to listOf( + TaintCleaner( + sampleCleanerMethod, + cleanerCondition to RemoveAllMarks(Result) + ) + ) + ) + + val consumerOfInjections = findMethod(sampleClassName, "consumerOfInjections") + val consumerOfSensitiveData = findMethod(sampleClassName, "consumerOfSensitiveData") + val consumerWithReturningValue = findMethod(sampleClassName, "consumerWithReturningValue") + + val sampleSinks = mapOf( + consumerOfInjections to listOf( + TaintMethodSink( + CallParameterContainsMark(Argument(0u), SqlInjection), + consumerOfInjections + ) + ), + consumerOfSensitiveData to listOf( + TaintMethodSink( + CallParameterContainsMark(Argument(0u), SqlInjection), + consumerOfSensitiveData + ) + ), + consumerWithReturningValue to listOf( + TaintMethodSink(CallParameterContainsMark(Argument(0u), SqlInjection), consumerWithReturningValue), + TaintMethodSink(CallParameterContainsMark(Argument(0u), SensitiveData), consumerWithReturningValue) + ), + ) + + return TaintConfiguration( + sampleEntryPointsSources, + sampleMethodSources, + sampleFieldSources, + samplePassThrough, + sampleCleaners, + sampleSinks, + emptyMap() // TODO field sinks + ) + } + + + // TODO REMOVE IT, FOR DEMONSTRATION PURPOSES ONLY + fun constructSampleTaintAnalysis(cp: JcClasspath): TaintAnalysis { + fun findMethod(className: String, methodName: String) = cp + .findClassOrNull(className)!! + .declaredMethods + .first { it.name == methodName } + + val sampleClassName = "org.usvm.samples.taint.Taint" + + val configuration = sampleConfiguration(cp) + + val consumerOfInjections = findMethod(sampleClassName, "consumerOfInjections") + val consumerSinkRule = configuration.methodSinks[consumerOfInjections]!!.single() + + val targetForTaintedEntrySink = TaintAnalysis.TaintMethodSinkTarget( + findMethod(sampleClassName, "taintedEntrySource") + .instList + .first { "consumerOfInjections" in it.toString() }, + consumerSinkRule.condition, + consumerSinkRule + ) + + val sampleSourceMethod = findMethod(sampleClassName, "stringProducer") + val stringProducerRule = configuration.methodSources[sampleSourceMethod]!!.first() + + val sourceTargetForSimpleTaint = TaintAnalysis.TaintMethodSourceTarget( + findMethod(sampleClassName, "simpleTaint") + .instList + .last { "stringProducer" in it.toString() }, + stringProducerRule.conditionWithAction.first, + stringProducerRule + ) + + val sinkTargetForSimpleTaint = TaintAnalysis.TaintMethodSinkTarget( + findMethod(sampleClassName, "simpleTaint") + .instList + .first { "consumerOfInjections" in it.toString() }, + consumerSinkRule.condition, + consumerSinkRule + ) + sourceTargetForSimpleTaint.addChild(sinkTargetForSimpleTaint) + + + return TaintAnalysis(configuration) + .addTarget(targetForTaintedEntrySink) + .addTarget(sourceTargetForSimpleTaint) + } +} +