Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track last peeked state, tweak coverage #13

Merged
merged 3 commits into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions usvm-core/src/main/kotlin/org/usvm/ps/AIPathSelector.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class AIPathSelector<Statement, State, Block>(
State : UState<*, *, Statement, *, *, State>,
Block : BasicBlock {
private val statesMap = mutableMapOf<State, StateWrapper<Statement, State, Block>>()
private var lastPeekedState: State? = null
private val lastPeekedState: State?
get() = stepsStatistics.lastPeekedState
private val totalSteps
get() = stepsStatistics.totalSteps.toInt()

Expand All @@ -29,6 +30,8 @@ Block : BasicBlock {
private var touchedBlocks = mutableSetOf<Block>()
private var newBlocks = listOf<Block>()

private var firstSend = true

private fun predict(): State {
val wrappers = statesMap.values
val vertices = blockGraph.blocks
Expand All @@ -46,9 +49,10 @@ Block : BasicBlock {
val game = when (predictor) {
is OnnxModel<Game<Block>> -> Game(vertices, wrappers, blockGraph)
is Oracle<Game<Block>> -> {
// if we played with default searcher before
// if we played with default searcher before,
// client has no information about the game
if (lastPeekedState == null) {
if (firstSend) {
firstSend = false
Game(vertices, wrappers, blockGraph)
} else {
if (blockGraph.newBlocks != newBlocks) {
Expand All @@ -71,12 +75,11 @@ Block : BasicBlock {

override fun peek(): State {
if (totalSteps == 0) {
return statesMap.keys.single().also { lastPeekedState = it }
return statesMap.keys.single()
}

val predictedState = predict()
lastPeekedState = predictedState
val wrapper = checkNotNull(statesMap[lastPeekedState])
val wrapper = checkNotNull(statesMap[predictedState])
touchedBlocks.add(wrapper.currentBlock)
touchedStates.add(wrapper)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,19 @@ class StepsStatistics<Method, State : UState<*, Method, *, *, *, State>> : UMach
var totalSteps = 0UL
private set

var lastPeekedState: State? = null

private val methodSteps = mutableMapOf<Method, ULong>()

/**
* Returns number of steps machine made during [method] exploration.
*/
fun getMethodSteps(method: Method) = methodSteps.getOrDefault(method, 0UL)

override fun onStatePeeked(state: State) {
lastPeekedState = state
}

override fun onState(parent: State, forks: Sequence<State>) {
totalSteps++
parent.callStack.forEach { (m, _) ->
Expand Down
6 changes: 3 additions & 3 deletions usvm-jvm/src/main/kotlin/org/usvm/machine/JcMachine.kt
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class JcMachine(

private val cfgStatistics = CfgStatisticsImpl(applicationGraph)

fun analyze(methods: List<JcMethod>, targets: List<JcTarget> = emptyList()): Pair<List<JcState>, Int> {
fun analyze(methods: List<JcMethod>, targets: List<JcTarget> = emptyList()): Pair<List<JcState>, Float> {
logger.debug("{}.analyze({})", this, methods)
val initialStates = mutableMapOf<JcMethod, JcState>()
methods.forEach {
Expand Down Expand Up @@ -204,13 +204,13 @@ class JcMachine(
stopStrategy = stopStrategy,
)

val percentageCoverage = coverageStatistics.getMethodCoverage(methods.single()).roundToInt()
val percentageCoverage = coverageStatistics.getMethodCoverage(methods.single())

return statesCollector.collectedStates to percentageCoverage

}

fun analyze(method: JcMethod, targets: List<JcTarget> = emptyList()): Pair<List<JcState>, Int> =
fun analyze(method: JcMethod, targets: List<JcTarget> = emptyList()): Pair<List<JcState>, Float> =
analyze(listOf(method), targets)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class JavaMethodRunner(gameMap: GameMap, oracle: Predictor<*>? = null) {
val options: UMachineOptions = UMachineOptions(
pathSelectionStrategies = listOf(defaultSearcher, PathSelectionStrategy.AI),
pathSelectorCombinationStrategy = PathSelectorCombinationStrategy.SEQUENTIAL,
coverageZone = CoverageZone.TRANSITIVE,
coverageZone = CoverageZone.METHOD,
exceptionsPropagation = true,
solverTimeout = Duration.INFINITE,
timeout = Duration.INFINITE,
Expand All @@ -45,7 +45,7 @@ class JavaMethodRunner(gameMap: GameMap, oracle: Predictor<*>? = null) {
oracle = oracle
)

fun cover(className: String, methodName: String): Pair<List<JcState>, Int> {
fun cover(className: String, methodName: String): Pair<List<JcState>, Float> {
val jcMethod = cp.getJcMethodByName(className, methodName)

val (states, percentageCoverage) = JcMachine(cp, options, interpreterObserver = null).use { machine ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package org.usvm.gameserver
import org.usvm.JavaMethodRunner
import org.usvm.OracleImpl
import org.usvm.statistics.BasicBlock
import kotlin.math.floor

fun randomExplorer(
inputBody: Start,
Expand All @@ -29,7 +30,7 @@ fun randomExplorer(
val tests = results.size - errors

return GameOver(
percentageCoverage.toUByte(),
floor(percentageCoverage).toUInt(),
test = tests.toUInt(),
error = errors.toUInt()
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ sealed class OutputMessageBody

@Serializable
data class GameOver(
@SerialName("ActualCoverage") val percent: UByte,
@SerialName("ActualCoverage") val percent: UInt,
@SerialName("TestsCount") val test: UInt,
@SerialName("ErrorsCount") val error: UInt
) : OutputMessageBody()
Expand Down
Loading